diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 4cef6062c81..af45d1e67f5 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -576,3 +576,18 @@ def __repr__(self) -> str: f"trade_range: {self.trade_range}; " f"order_list[{len(self.order_list)}]" ) + + +class TradeDecisionWithDetails(TradeDecisionWO): + """Decision with detail information. Detail information is used to generate execution reports. + """ + def __init__( + self, + order_list: List[Order], + strategy: BaseStrategy, + trade_range: Optional[Tuple[int, int]] = None, + details: Optional[Any] = None, + ) -> None: + super().__init__(order_list, strategy, trade_range) + + self.details = details diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 709c050dfb7..4d3d3cf4b7b 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -2,11 +2,12 @@ # Licensed under the MIT License. from __future__ import annotations +import argparse import copy import pickle -import sys +from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -14,12 +15,13 @@ from joblib import Parallel, delayed from qlib.backtest import collect_data_loop, get_strategy_executor -from qlib.backtest.decision import TradeRangeByTime +from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib +from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper @@ -41,7 +43,7 @@ def _get_multi_level_executor_config( } freqs = list(strategy_config.keys()) - freqs.sort(key=lambda x: pd.Timedelta(x)) + freqs.sort(key=pd.Timedelta) for freq in freqs: executor_config = { "class": "NestedExecutor", @@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: # HACK: for qlib v0.8 value_dict = value_dict.to_series() try: - value_dict = {k: v for k, v in value_dict.items()} + value_dict = copy.deepcopy(value_dict) if value_dict["ffr"].empty: continue except Exception: @@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: return records -def _generate_report(decisions: list, report_dict: dict) -> dict: +# TODO: there should be richer annotation for the input (e.g. report) and the returned report +# TODO: For example, @ dataclass with typed fields and detailed docstrings. +def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict: + """Generate backtest reports + + Parameters + ---------- + decisions: + List of trade decisions. + report_indicators + List of indicator reports. + Returns + ------- + + """ + indicator_dict = defaultdict(list) + indicator_his = defaultdict(list) + for report_indicator in report_indicators: + for key, value in report_indicator.items(): + if key.endswith("_obj"): + indicator_his[key].append(value.order_indicator_his) + else: + indicator_dict[key].append(value) + report = {} - decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) - for key in ["1minute", "5minute", "30minute", "1day"]: - if key not in report_dict["indicator"]: + decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")]) + for key in ["1min", "5min", "30min", "1day"]: + if key not in indicator_dict: continue - report[key] = report_dict["indicator"][key] - report[key + "_obj"] = _convert_indicator_to_dataframe( - report_dict["indicator"][key + "_obj"].order_indicator_his - ) - cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"]) + + report[key] = pd.concat(indicator_dict[key]) + report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]]) + + cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"]) if len(cur_details) > 0: cur_details.pop("freq") report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") - if "1minute" in report_dict["report"]: - report["simulator"] = report_dict["report"]["1minute"][0] + return report -def single( +def single_with_simulator( backtest_config: dict, orders: pd.DataFrame, - split: str = "stock", + split: Literal["stock", "day"] = "stock", cash_limit: float = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + """Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day. + A new simulator will be created and used for every single-day order. + + Parameters + ---------- + backtest_config: + Backtest config + orders: + Orders to be executed. Example format: + datetime instrument amount direction + 0 2020-06-01 INST 600.0 0 + 1 2020-06-02 INST 700.0 1 + ... + split + Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. + cash_limit + Limitation of cash. + generate_report + Whether to generate reports. + + Returns + ------- + If generate_report is True, return execution records and the generated report. Otherwise, return only records. + """ + if split == "stock": + stock_id = orders.iloc[0].instrument + init_qlib(backtest_config["qlib"], part=stock_id) + else: + day = orders.iloc[0].datetime + init_qlib(backtest_config["qlib"], part=day) + + stocks = orders.instrument.unique().tolist() + + reports = [] + decisions = [] + for _, row in orders.iterrows(): + date = pd.Timestamp(row["datetime"]) + start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day) + end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day) + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(row["direction"]), + start_time=start_time, + end_time=end_time, + ) + + executor_config = _get_multi_level_executor_config( + strategy_config=backtest_config["strategies"], + cash_limit=cash_limit, + generate_report=generate_report, + ) + + exchange_config = copy.deepcopy(backtest_config["exchange"]) + exchange_config.update( + { + "codes": stocks, + "freq": "1min", + } + ) + + simulator = SingleAssetOrderExecution( + order=order, + executor_config=executor_config, + exchange_config=exchange_config, + qlib_config=None, + cash_limit=None, + backtest_mode=True, + ) + + reports.append(simulator.report_dict) + decisions += simulator.decisions + + indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()} + records = _convert_indicator_to_dataframe(indicator) + assert records is None or not np.isnan(records["ffr"]).any() + + if generate_report: + report = _generate_report(decisions, [report["indicator"] for report in reports]) + + if split == "stock": + stock_id = orders.iloc[0].instrument + report = {stock_id: report} + else: + day = orders.iloc[0].datetime + report = {day: report} + + return records, report + else: + return records + + +def single_with_collect_data_loop( + backtest_config: dict, + orders: pd.DataFrame, + split: Literal["stock", "day"] = "stock", + cash_limit: float = None, + generate_report: bool = False, +) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + """Run backtest in a single thread with collect_data_loop. + + Parameters + ---------- + backtest_config: + Backtest config + orders: + Orders to be executed. Example format: + datetime instrument amount direction + 0 2020-06-01 INST 600.0 0 + 1 2020-06-02 INST 700.0 1 + ... + split + Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. + cash_limit + Limitation of cash. + generate_report + Whether to generate reports. + + Returns + ------- + If generate_report is True, return execution records and the generated report. Otherwise, return only records. + """ + if split == "stock": stock_id = orders.iloc[0].instrument init_qlib(backtest_config["qlib"], part=stock_id) @@ -127,7 +274,7 @@ def single( trade_end_time = orders["datetime"].max() stocks = orders.instrument.unique().tolist() - top_strategy_config = { + strategy_config = { "class": "FileOrderStrategy", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { @@ -139,14 +286,14 @@ def single( }, } - top_executor_config = _get_multi_level_executor_config( + executor_config = _get_multi_level_executor_config( strategy_config=backtest_config["strategies"], cash_limit=cash_limit, generate_report=generate_report, ) - tmp_backtest_config = copy.deepcopy(backtest_config["exchange"]) - tmp_backtest_config.update( + exchange_config = copy.deepcopy(backtest_config["exchange"]) + exchange_config.update( { "codes": stocks, "freq": "1min", @@ -156,11 +303,11 @@ def single( strategy, executor = get_strategy_executor( start_time=pd.Timestamp(trade_start_time), end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), - strategy=top_strategy_config, - executor=top_executor_config, + strategy=strategy_config, + executor=executor_config, benchmark=None, account=cash_limit if cash_limit is not None else int(1e12), - exchange_kwargs=tmp_backtest_config, + exchange_kwargs=exchange_config, pos_type="Position" if cash_limit is not None else "InfPosition", ) _set_env_for_all_strategy(executor=executor) @@ -172,7 +319,7 @@ def single( assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, report_dict) + report = _generate_report(decisions, [report_dict["indicator"]]) if split == "stock": stock_id = orders.iloc[0].instrument report = {stock_id: report} @@ -184,7 +331,7 @@ def single( return records -def backtest(backtest_config: dict) -> pd.DataFrame: +def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame: order_df = read_order_file(backtest_config["order_file"]) cash_limit = backtest_config["exchange"].pop("cash_limit") @@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame: stock_pool = order_df["instrument"].unique().tolist() stock_pool.sort() + single = single_with_simulator if with_simulator else single_with_collect_data_loop mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 res = Parallel(**mp_config)( @@ -227,5 +375,12 @@ def backtest(backtest_config: dict) -> pd.DataFrame: warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) - path = sys.argv[1] - backtest(get_backtest_config_fromfile(path)) + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") + args = parser.parse_args() + + backtest( + backtest_config=get_backtest_config_fromfile(args.config_path), + with_simulator=args.use_simulator, + ) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index eaf62636cc0..3f3d2eeadc8 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -53,7 +53,8 @@ def parse_backtest_config(path: str) -> dict: del sys.modules[tmp_module_name] else: - config = yaml.safe_load(open(tmp_config_file.name)) + with open(tmp_config_file.name) as input_stream: + config = yaml.safe_load(input_stream) if "_base_" in config: base_file_name = config.pop("_base_") diff --git a/qlib/rl/data/integration.py b/qlib/rl/data/integration.py index d32ce49c822..af5025c843d 100644 --- a/qlib/rl/data/integration.py +++ b/qlib/rl/data/integration.py @@ -81,10 +81,12 @@ def init_qlib(qlib_config: dict, part: str = None) -> None: def _convert_to_path(path: str | Path) -> Path: return path if isinstance(path, Path) else Path(path) - provider_uri_map = { - "day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(), - "1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(), - } + provider_uri_map = {} + if "provider_uri_day" in qlib_config: + provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix() + if "provider_uri_1min" in qlib_config: + provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix() + qlib.init( region=REG_CN, auto_mount=False, diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index eb612cf64eb..9417534f867 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -9,12 +9,11 @@ from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime -from qlib.constant import EPS_T, ONE_DAY from qlib.rl.order_execution.utils import get_ticks_slice -from qlib.utils.index_data import IndexData from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from .integration import fetch_features +from ...data import D class IntradayBacktestData(BaseIntradayBacktestData): @@ -82,18 +81,20 @@ def load_backtest_data( trade_exchange: Exchange, trade_range: TradeRange, ) -> IntradayBacktestData: - data = cast( - IndexData, - trade_exchange.get_deal_price( - stock_id=order.stock_id, - start_time=order.date, - end_time=order.date + ONE_DAY - EPS_T, - direction=order.direction, - method=None, - ), + # TODO: making exchange return data without missing will make it more elegant. Fix this in the future. + tmp_data = D.features( + trade_exchange.codes, + trade_exchange.all_fields, + trade_exchange.start_time, + trade_exchange.end_time, + freq=trade_exchange.freq, + disk_cache=True, ) - ticks_index = pd.DatetimeIndex(data.index) + ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"]) + ticks_index = ticks_index[order.start_time <= ticks_index] + ticks_index = ticks_index[ticks_index <= order.end_time] + if isinstance(trade_range, TradeRangeByTime): ticks_for_order = get_ticks_slice( ticks_index, @@ -122,7 +123,10 @@ def __init__( date: pd.Timestamp, ) -> None: def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: - return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) + df = df.reset_index() + if "instrument" in df.columns: + df = df.drop(columns=["instrument"]) + return df.set_index(["datetime"]) self.today = _drop_stock_id(fetch_features(stock_id, date)) self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index ed62a4180df..3af1e248396 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -91,7 +91,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", @@ -99,7 +99,7 @@ def __init__( ) -> None: super(SimpleIntradayBacktestData, self).__init__() - backtest = _read_pickle(data_dir / stock_id) + backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] # No longer need for pandas >= 1.4 @@ -154,13 +154,13 @@ class IntradayProcessedData(BaseIntradayProcessedData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> None: - proc = _read_pickle(data_dir / stock_id) + proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) # We have to infer the names here because, # unfortunately they are not included in the original data. cnames = _infer_processed_data_column_names(feature_dim) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index cfd3181ca25..7f7a98e9a71 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -163,6 +163,12 @@ def auto_device(module: nn.Module) -> torch.device: def load_weight(policy: nn.Module, path: Path) -> None: assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." loaded_weight = torch.load(path, map_location="cpu") + + # TODO: this should be handled by whoever calls load_weight. + # TODO: For example, when the outer class receives a weight, it should first unpack it, + # TODO: and send the corresponding part to individual component. + if "vessel" in loaded_weight: + loaded_weight = loaded_weight["vessel"]["policy"] try: policy.load_state_dict(loaded_weight) except RuntimeError: diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 718c2ba5729..c9702b1e48d 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -3,17 +3,18 @@ from __future__ import annotations -from typing import Generator, Optional +from typing import Generator, List, Optional import pandas as pd -from qlib.backtest import collect_data_loop, get_strategy_executor -from qlib.backtest.decision import Order -from qlib.backtest.executor import NestedExecutor -from qlib.rl.simulator import Simulator +from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime +from qlib.backtest.executor import BaseExecutor, NestedExecutor from qlib.rl.data.integration import init_qlib +from qlib.rl.simulator import Simulator from .state import SAOEState, SAOEStateAdapter from .strategy import SAOEStrategy +from ..utils.env_wrapper import CollectDataEnvWrapper class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): @@ -23,30 +24,42 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): ---------- order The seed to start an SAOE simulator is an order. - strategy_config - Strategy configuration executor_config Executor configuration exchange_config Exchange configuration qlib_config Configuration used to initialize Qlib. If it is None, Qlib will not be initialized. + cash_limit: + Cash limit. + backtest_mode + Whether the simulator is under backtest mode. """ def __init__( self, order: Order, - strategy_config: dict, executor_config: dict, exchange_config: dict, qlib_config: dict = None, + cash_limit: Optional[float] = None, + backtest_mode: bool = False, ) -> None: super().__init__(initial=order) assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same." + strategy_config = { + "class": "SingleOrderStrategy", + "module_path": "qlib.rl.strategy.single_order", + "kwargs": { + "order": order, + "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), + }, + } + self._collect_data_loop: Optional[Generator] = None - self.reset(order, strategy_config, executor_config, exchange_config, qlib_config) + self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode) def reset( self, @@ -55,6 +68,8 @@ def reset( executor_config: dict, exchange_config: dict, qlib_config: dict = None, + cash_limit: Optional[float] = None, + backtest_mode: bool = False, ) -> None: if qlib_config is not None: init_qlib(qlib_config, part="skip") @@ -65,22 +80,35 @@ def reset( strategy=strategy_config, executor=executor_config, benchmark=order.stock_id, - account=1e12, + account=cash_limit if cash_limit is not None else int(1e12), exchange_kwargs=exchange_config, - pos_type="InfPosition", + pos_type="Position" if cash_limit is not None else "InfPosition", ) assert isinstance(self._executor, NestedExecutor) + self.report_dict: dict = {} + self.decisions: List[BaseTradeDecision] = [] self._collect_data_loop = collect_data_loop( start_time=order.date, end_time=order.date, trade_strategy=strategy, trade_executor=self._executor, + return_value=self.report_dict, ) assert isinstance(self._collect_data_loop, Generator) - self._last_yielded_saoe_strategy = self._iter_strategy(action=None) + # TODO: backtest_mode is not a necessary parameter if we carefully design it. + # TODO: It should disappear with CollectDataEnvWrapper in the future. + if backtest_mode: + executor: BaseExecutor = self._executor + while isinstance(executor, NestedExecutor): + if hasattr(executor.inner_strategy, "set_env"): + executor.inner_strategy.set_env(CollectDataEnvWrapper()) + executor = executor.inner_executor + + # Call `step()` with None action to initialize the internal generator. + self.step(action=None) self._order = order @@ -91,17 +119,19 @@ def _get_adapter(self) -> SAOEStateAdapter: def twap_price(self) -> float: return self._get_adapter().twap_price - def _iter_strategy(self, action: float = None) -> SAOEStrategy: + def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy: """Iterate the _collect_data_loop until we get the next yield SAOEStrategy.""" assert self._collect_data_loop is not None - strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - while not isinstance(strategy, SAOEStrategy): - strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - assert isinstance(strategy, SAOEStrategy) - return strategy + obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + while not isinstance(obj, SAOEStrategy): + if isinstance(obj, BaseTradeDecision): + self.decisions.append(obj) + obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + assert isinstance(obj, SAOEStrategy) + return obj - def step(self, action: float) -> None: + def step(self, action: Optional[float]) -> None: """Execute one step or SAOE. Parameters diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index a46928ee89c..f417173e524 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -4,7 +4,7 @@ from __future__ import annotations import typing -from typing import cast, NamedTuple, Optional, Tuple +from typing import cast, Callable, List, NamedTuple, Optional, Tuple import numpy as np import pandas as pd @@ -13,6 +13,7 @@ from qlib.constant import EPS, ONE_MIN, REG_CN from qlib.rl.order_execution.utils import dataframe_append, price_advantage from qlib.typehint import TypedDict +from qlib.utils.index_data import IndexData from qlib.utils.time import get_day_min_idx_range if typing.TYPE_CHECKING: @@ -38,6 +39,37 @@ def _get_all_timestamps( return pd.DatetimeIndex(ret) +def fill_missing_data( + original_data: np.ndarray, + total_time_list: List[pd.Timestamp], + found_time_list: List[pd.Timestamp], + fill_method: Callable = np.median, +) -> np.ndarray: + """Fill missing data. We need this function to deal with data that have missing values in some minutes. + + TODO: making exchange return data without missing will make it more elegant. Fix this in the future. + + Parameters + ---------- + original_data + Original data without missing values. + total_time_list + All timestamps that required. + found_time_list + Timestamps found in the original data. + fill_method + Method used to fill the missing data. + + Returns + ------- + The filled data. + """ + assert len(original_data) == len(found_time_list) + tmp = dict(zip(found_time_list, original_data)) + fill_val = fill_method(original_data) + return np.array([tmp.get(t, fill_val) for t in total_time_list]) + + class SAOEStateAdapter: """ Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state @@ -106,16 +138,17 @@ def update( assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" exec_vol *= self.position / (exec_vol.sum()) - market_volume = np.array( + market_volume = cast( + IndexData, self.exchange.get_volume( self.order.stock_id, pd.Timestamp(start_time), pd.Timestamp(end_time), method=None, ), - ).reshape(-1) - - market_price = np.array( + ) + market_price = cast( + IndexData, self.exchange.get_deal_price( self.order.stock_id, pd.Timestamp(start_time), @@ -123,7 +156,11 @@ def update( method=None, direction=self.order.direction, ), - ).reshape(-1) + ) + found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)] + total_time_list = _get_all_timestamps(start_time, end_time) + market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list) + market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list) assert market_price.shape == market_volume.shape == exec_vol.shape diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index ecc879bf512..663b8e8ff4a 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -5,15 +5,16 @@ import collections from types import GeneratorType -from typing import Any, cast, Dict, Generator, Optional, Union +from typing import Any, cast, Dict, Generator, List, Optional, Union +import numpy as np import pandas as pd import torch from tianshou.data import Batch from tianshou.policy import BasePolicy from qlib.backtest import CommonInfrastructure, Order -from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange +from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN from qlib.rl.data.native import load_backtest_data @@ -235,6 +236,23 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) - if self._backtest: self._env.reset() + def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: + assert hasattr(self.outer_trade_decision, "order_list") + + trade_details = [] + for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, "order_list")): + trade_details.append( + { + "instrument": o.stock_id, + "datetime": self.trade_calendar.get_step_time()[0], + "freq": self.trade_calendar.get_freq(), + "rl_exec_vol": v, + } + ) + if a is not None: + trade_details[-1]["rl_action"] = a + return pd.DataFrame.from_records(trade_details) + def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: states = [] obs_batch = [] @@ -261,4 +279,8 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci order = cast(Order, decision) order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) - return TradeDecisionWO(order_list=order_list, strategy=self) + return TradeDecisionWithDetails( + order_list=order_list, + strategy=self, + details=self._generate_trade_details(act, exec_vols), + ) diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index 14bf8b5a112..92ad9c0583e 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -32,16 +32,7 @@ def get_order() -> Order: ) -def get_configs(order: Order) -> Tuple[dict, dict, dict]: - strategy_config = { - "class": "SingleOrderStrategy", - "module_path": "qlib.rl.strategy.single_order", - "kwargs": { - "order": order, - "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), - }, - } - +def get_configs(order: Order) -> Tuple[dict, dict]: executor_config = { "class": "NestedExecutor", "module_path": "qlib.backtest.executor", @@ -93,7 +84,7 @@ def get_configs(order: Order) -> Tuple[dict, dict, dict]: "trade_unit": None, } - return strategy_config, executor_config, exchange_config + return executor_config, exchange_config def get_simulator(order: Order) -> SingleAssetOrderExecution: @@ -115,12 +106,11 @@ def get_simulator(order: Order) -> SingleAssetOrderExecution: } # fmt: on - strategy_config, executor_config, exchange_config = get_configs(order) + executor_config, exchange_config = get_configs(order) return SingleAssetOrderExecution( order=order, qlib_config=qlib_config, - strategy_config=strategy_config, executor_config=executor_config, exchange_config=exchange_config, )