Ejemplo n.º 1
0
    def test_update_label(self):

        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        shift = 10
        latest_date = cal[-1 - shift]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model
        online_tool.update_online_pred()

        new_pred = rec.load_object("pred.pkl")
        label = rec.load_object("label.pkl")
        label_date = label.dropna().index.get_level_values("datetime").max()
        pred_date = new_pred.dropna().index.get_level_values("datetime").max()

        # The prediction is updated, but the label is not updated.
        self.assertTrue(label_date < pred_date)

        # Update label now
        lu = LabelUpdater(rec)
        lu.update()
        new_label = rec.load_object("label.pkl")
        new_label_date = new_label.index.get_level_values("datetime").max()
        self.assertTrue(
            new_label_date == pred_date)  # make sure the label is updated now
Ejemplo n.º 2
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))

        good_pred = rec.load_object("pred.pkl")

        mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
        mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
        mod_pred = good_pred.copy()

        mod_pred.loc[mod_range] = -1
        mod_pred.loc[mod_range2] = -2

        rec.save_objects(**{"pred.pkl": mod_pred})
        online_tool.update_online_pred(
            to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
        )

        updated_pred = rec.load_object("pred.pkl")

        # this range is not fixed
        self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
        # this range is fixed now
        self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
Ejemplo n.º 3
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date +
                                       pd.Timedelta(days=10))
Ejemplo n.º 4
0
class UpdatePredExample:
    def __init__(self,
                 provider_uri="~/.qlib/qlib_data/cn_data",
                 region=REG_CN,
                 experiment_name="online_srv",
                 task_config=task):
        qlib.init(provider_uri=provider_uri, region=region)
        self.experiment_name = experiment_name
        self.online_tool = OnlineToolR(self.experiment_name)
        self.task_config = task_config

    def first_train(self):
        rec = task_train(self.task_config,
                         experiment_name=self.experiment_name)
        self.online_tool.reset_online_tag(rec)  # set to online model

    def update_online_pred(self):
        self.online_tool.update_online_pred()

    def main(self):
        self.first_train()
        self.update_online_pred()