Ejemplo n.º 1
0
    def __init__(
        self,
        name_id: str,
        task_template: Union[dict, List[dict]],
        rolling_gen: RollingGen,
    ):
        """
        Init RollingStrategy.

        Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.

        Args:
            name_id (str): a unique name or id. Will be also the name of the Experiment.
            task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
            rolling_gen (RollingGen): an instance of RollingGen
        """
        super().__init__(name_id=name_id)
        self.exp_name = self.name_id
        if not isinstance(task_template, list):
            task_template = [task_template]
        self.task_template = task_template
        self.rg = rolling_gen
        assert issubclass(
            self.rg.__class__, RollingGen
        ), "The rolling strategy relies on the feature if RollingGen"
        self.tool = OnlineToolR(self.exp_name)
        self.ta = TimeAdjuster()
Ejemplo n.º 2
0
 def __init__(self,
              provider_uri="~/.qlib/qlib_data/cn_data",
              region=REG_CN,
              experiment_name="online_srv",
              task_config=task):
     qlib.init(provider_uri=provider_uri, region=region)
     self.experiment_name = experiment_name
     self.online_tool = OnlineToolR(self.experiment_name)
     self.task_config = task_config
Ejemplo n.º 3
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))

        good_pred = rec.load_object("pred.pkl")

        mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
        mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
        mod_pred = good_pred.copy()

        mod_pred.loc[mod_range] = -1
        mod_pred.loc[mod_range2] = -2

        rec.save_objects(**{"pred.pkl": mod_pred})
        online_tool.update_online_pred(
            to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
        )

        updated_pred = rec.load_object("pred.pkl")

        # this range is not fixed
        self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
        # this range is fixed now
        self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
Ejemplo n.º 4
0
    def test_update_label(self):

        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        shift = 10
        latest_date = cal[-1 - shift]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model
        online_tool.update_online_pred()

        new_pred = rec.load_object("pred.pkl")
        label = rec.load_object("label.pkl")
        label_date = label.dropna().index.get_level_values("datetime").max()
        pred_date = new_pred.dropna().index.get_level_values("datetime").max()

        # The prediction is updated, but the label is not updated.
        self.assertTrue(label_date < pred_date)

        # Update label now
        lu = LabelUpdater(rec)
        lu.update()
        new_label = rec.load_object("label.pkl")
        new_label_date = new_label.index.get_level_values("datetime").max()
        self.assertTrue(
            new_label_date == pred_date)  # make sure the label is updated now
Ejemplo n.º 5
0
class UpdatePredExample:
    def __init__(self,
                 provider_uri="~/.qlib/qlib_data/cn_data",
                 region=REG_CN,
                 experiment_name="online_srv",
                 task_config=task):
        qlib.init(provider_uri=provider_uri, region=region)
        self.experiment_name = experiment_name
        self.online_tool = OnlineToolR(self.experiment_name)
        self.task_config = task_config

    def first_train(self):
        rec = task_train(self.task_config,
                         experiment_name=self.experiment_name)
        self.online_tool.reset_online_tag(rec)  # set to online model

    def update_online_pred(self):
        self.online_tool.update_online_pred()

    def main(self):
        self.first_train()
        self.update_online_pred()
Ejemplo n.º 6
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date +
                                       pd.Timedelta(days=10))
Ejemplo n.º 7
0
class RollingStrategy(OnlineStrategy):
    """
    This example strategy always uses the latest rolling model sas online models.
    """
    def __init__(
        self,
        name_id: str,
        task_template: Union[dict, List[dict]],
        rolling_gen: RollingGen,
    ):
        """
        Init RollingStrategy.

        Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.

        Args:
            name_id (str): a unique name or id. Will be also the name of the Experiment.
            task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
            rolling_gen (RollingGen): an instance of RollingGen
        """
        super().__init__(name_id=name_id)
        self.exp_name = self.name_id
        if not isinstance(task_template, list):
            task_template = [task_template]
        self.task_template = task_template
        self.rg = rolling_gen
        assert issubclass(
            self.rg.__class__, RollingGen
        ), "The rolling strategy relies on the feature if RollingGen"
        self.tool = OnlineToolR(self.exp_name)
        self.ta = TimeAdjuster()

    def get_collector(self,
                      process_list=[RollingGroup()],
                      rec_key_func=None,
                      rec_filter_func=None,
                      artifacts_key=None):
        """
        Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.

        Assumption: the models can be distinguished based on the model name and rolling test segments.
        If you do not want this assumption, please implement your method or use another rec_key_func.

        Args:
            rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
            rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
            artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
        """
        def rec_key(recorder):
            task_config = recorder.load_object("task")
            model_key = task_config["model"]["class"]
            rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
            return model_key, rolling_key

        if rec_key_func is None:
            rec_key_func = rec_key

        artifacts_collector = RecorderCollector(
            experiment=self.exp_name,
            process_list=process_list,
            rec_key_func=rec_key_func,
            rec_filter_func=rec_filter_func,
            artifacts_key=artifacts_key,
        )

        return artifacts_collector

    def first_tasks(self) -> List[dict]:
        """
        Use rolling_gen to generate different tasks based on task_template.

        Returns:
            List[dict]: a list of tasks
        """
        return task_generator(
            tasks=self.task_template,
            generators=self.rg,  # generate different date segment
        )

    def prepare_tasks(self, cur_time) -> List[dict]:
        """
        Prepare new tasks based on cur_time (None for the latest).

        You can find the last online models by OnlineToolR.online_models.

        Returns:
            List[dict]: a list of new tasks.
        """
        # TODO: filter recorders by latest test segments is not a necessary
        latest_records, max_test = self._list_latest(self.tool.online_models())
        if max_test is None:
            self.logger.warn(f"No latest online recorders, no new tasks.")
            return []
        calendar_latest = transform_end_date(cur_time)
        self.logger.info(
            f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
        )
        res = []
        for rec in latest_records:
            task = rec.load_object("task")
            res.extend(self.rg.gen_following_tasks(task, calendar_latest))
        return res

    def _list_latest(self, rec_list: List[Recorder]):
        """
        List latest recorder form rec_list

        Args:
            rec_list (List[Recorder]): a list of Recorder

        Returns:
            List[Recorder], pd.Timestamp: the latest recorders and their test end time
        """
        if len(rec_list) == 0:
            return rec_list, None
        max_test = max(
            rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"]
            for rec in rec_list)
        latest_rec = []
        for rec in rec_list:
            if rec.load_object("task")["dataset"]["kwargs"]["segments"][
                    "test"] == max_test:
                latest_rec.append(rec)
        return latest_rec, max_test