def __init__( self, name_id: str, task_template: Union[dict, List[dict]], rolling_gen: RollingGen, ): """ Init RollingStrategy. Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same. Args: name_id (str): a unique name or id. Will be also the name of the Experiment. task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. rolling_gen (RollingGen): an instance of RollingGen """ super().__init__(name_id=name_id) self.exp_name = self.name_id if not isinstance(task_template, list): task_template = [task_template] self.task_template = task_template self.rg = rolling_gen assert issubclass( self.rg.__class__, RollingGen ), "The rolling strategy relies on the feature if RollingGen" self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster()
def __init__(self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task): qlib.init(provider_uri=provider_uri, region=region) self.experiment_name = experiment_name self.online_tool = OnlineToolR(self.experiment_name) self.task_config = task_config
def test_update_pred(self): """ This test is for testing if it will raise error if the `to_date` is out of the boundary. """ task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = ["qlib.workflow.record_temp.SignalRecord"] exp_name = "online_srv_test" cal = D.calendar() latest_date = cal[-1] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10)) good_pred = rec.load_object("pred.pkl") mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10)) mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2)) mod_pred = good_pred.copy() mod_pred.loc[mod_range] = -1 mod_pred.loc[mod_range2] = -2 rec.save_objects(**{"pred.pkl": mod_pred}) online_tool.update_online_pred( to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20) ) updated_pred = rec.load_object("pred.pkl") # this range is not fixed self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item()) # this range is fixed now self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
def test_update_label(self): task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": { "dataset": "<DATASET>", "model": "<MODEL>" }, } exp_name = "online_srv_test" cal = D.calendar() shift = 10 latest_date = cal[-1 - shift] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred() new_pred = rec.load_object("pred.pkl") label = rec.load_object("label.pkl") label_date = label.dropna().index.get_level_values("datetime").max() pred_date = new_pred.dropna().index.get_level_values("datetime").max() # The prediction is updated, but the label is not updated. self.assertTrue(label_date < pred_date) # Update label now lu = LabelUpdater(rec) lu.update() new_label = rec.load_object("label.pkl") new_label_date = new_label.index.get_level_values("datetime").max() self.assertTrue( new_label_date == pred_date) # make sure the label is updated now
class UpdatePredExample: def __init__(self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task): qlib.init(provider_uri=provider_uri, region=region) self.experiment_name = experiment_name self.online_tool = OnlineToolR(self.experiment_name) self.task_config = task_config def first_train(self): rec = task_train(self.task_config, experiment_name=self.experiment_name) self.online_tool.reset_online_tag(rec) # set to online model def update_online_pred(self): self.online_tool.update_online_pred() def main(self): self.first_train() self.update_online_pred()
def test_update_pred(self): """ This test is for testing if it will raise error if the `to_date` is out of the boundary. """ task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": { "dataset": "<DATASET>", "model": "<MODEL>" }, } exp_name = "online_srv_test" cal = D.calendar() latest_date = cal[-1] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
class RollingStrategy(OnlineStrategy): """ This example strategy always uses the latest rolling model sas online models. """ def __init__( self, name_id: str, task_template: Union[dict, List[dict]], rolling_gen: RollingGen, ): """ Init RollingStrategy. Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same. Args: name_id (str): a unique name or id. Will be also the name of the Experiment. task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. rolling_gen (RollingGen): an instance of RollingGen """ super().__init__(name_id=name_id) self.exp_name = self.name_id if not isinstance(task_template, list): task_template = [task_template] self.task_template = task_template self.rg = rolling_gen assert issubclass( self.rg.__class__, RollingGen ), "The rolling strategy relies on the feature if RollingGen" self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models. Assumption: the models can be distinguished based on the model name and rolling test segments. If you do not want this assumption, please implement your method or use another rec_key_func. Args: rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts. """ def rec_key(recorder): task_config = recorder.load_object("task") model_key = task_config["model"]["class"] rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] return model_key, rolling_key if rec_key_func is None: rec_key_func = rec_key artifacts_collector = RecorderCollector( experiment=self.exp_name, process_list=process_list, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func, artifacts_key=artifacts_key, ) return artifacts_collector def first_tasks(self) -> List[dict]: """ Use rolling_gen to generate different tasks based on task_template. Returns: List[dict]: a list of tasks """ return task_generator( tasks=self.task_template, generators=self.rg, # generate different date segment ) def prepare_tasks(self, cur_time) -> List[dict]: """ Prepare new tasks based on cur_time (None for the latest). You can find the last online models by OnlineToolR.online_models. Returns: List[dict]: a list of new tasks. """ # TODO: filter recorders by latest test segments is not a necessary latest_records, max_test = self._list_latest(self.tool.online_models()) if max_test is None: self.logger.warn(f"No latest online recorders, no new tasks.") return [] calendar_latest = transform_end_date(cur_time) self.logger.info( f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" ) res = [] for rec in latest_records: task = rec.load_object("task") res.extend(self.rg.gen_following_tasks(task, calendar_latest)) return res def _list_latest(self, rec_list: List[Recorder]): """ List latest recorder form rec_list Args: rec_list (List[Recorder]): a list of Recorder Returns: List[Recorder], pd.Timestamp: the latest recorders and their test end time """ if len(rec_list) == 0: return rec_list, None max_test = max( rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) latest_rec = [] for rec in rec_list: if rec.load_object("task")["dataset"]["kwargs"]["segments"][ "test"] == max_test: latest_rec.append(rec) return latest_rec, max_test