Example #1
0
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
    with open(config_path) as fp:
        config = yaml.load(fp, Loader=yaml.SafeLoader)

    # config the `sys` section
    sys_config(config, config_path)

    exp_manager = C["exp_manager"]
    exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
    qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)

    task_train(config.get("task"), experiment_name=experiment_name)
Example #2
0
    def test_update_label(self):

        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        shift = 10
        latest_date = cal[-1 - shift]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model
        online_tool.update_online_pred()

        new_pred = rec.load_object("pred.pkl")
        label = rec.load_object("label.pkl")
        label_date = label.dropna().index.get_level_values("datetime").max()
        pred_date = new_pred.dropna().index.get_level_values("datetime").max()

        # The prediction is updated, but the label is not updated.
        self.assertTrue(label_date < pred_date)

        # Update label now
        lu = LabelUpdater(rec)
        lu.update()
        new_label = rec.load_object("label.pkl")
        new_label_date = new_label.index.get_level_values("datetime").max()
        self.assertTrue(
            new_label_date == pred_date)  # make sure the label is updated now
Example #3
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))

        good_pred = rec.load_object("pred.pkl")

        mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
        mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
        mod_pred = good_pred.copy()

        mod_pred.loc[mod_range] = -1
        mod_pred.loc[mod_range2] = -2

        rec.save_objects(**{"pred.pkl": mod_pred})
        online_tool.update_online_pred(
            to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
        )

        updated_pred = rec.load_object("pred.pkl")

        # this range is not fixed
        self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
        # this range is fixed now
        self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
Example #4
0
    def test_update_pred(self):
        """
        This test is for testing if it will raise error if the `to_date` is out of the boundary.
        """
        task = copy.deepcopy(CSI300_GBDT_TASK)

        task["record"] = {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": {
                "dataset": "<DATASET>",
                "model": "<MODEL>"
            },
        }

        exp_name = "online_srv_test"

        cal = D.calendar()
        latest_date = cal[-1]

        train_start = latest_date - pd.Timedelta(days=61)
        train_end = latest_date - pd.Timedelta(days=41)
        task["dataset"]["kwargs"]["segments"] = {
            "train": (train_start, train_end),
            "valid": (latest_date - pd.Timedelta(days=40),
                      latest_date - pd.Timedelta(days=21)),
            "test": (latest_date - pd.Timedelta(days=20), latest_date),
        }

        task["dataset"]["kwargs"]["handler"]["kwargs"] = {
            "start_time": train_start,
            "end_time": latest_date,
            "fit_start_time": train_start,
            "fit_end_time": train_end,
            "instruments": "csi300",
        }

        rec = task_train(task, exp_name)

        pred = rec.load_object("pred.pkl")

        online_tool = OnlineToolR(exp_name)
        online_tool.reset_online_tag(rec)  # set to online model

        online_tool.update_online_pred(to_date=latest_date +
                                       pd.Timedelta(days=10))
Example #5
0
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
    """
    This is a Qlib CLI entrance.
    User can run the whole Quant research workflow defined by a configure file
    - the code is located here ``qlib/workflow/cli.py`
    """
    with open(config_path) as fp:
        config = yaml.safe_load(fp)

    # config the `sys` section
    sys_config(config, config_path)

    if "exp_manager" in config.get("qlib_init"):
        qlib.init(**config.get("qlib_init"))
    else:
        exp_manager = C["exp_manager"]
        exp_manager["kwargs"]["uri"] = "file:" + str(
            Path(os.getcwd()).resolve() / uri_folder)
        qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)

    if "experiment_name" in config:
        experiment_name = config["experiment_name"]
    recorder = task_train(config.get("task"), experiment_name=experiment_name)
    recorder.save_objects(config=config)
Example #6
0
 def first_train(self):
     rec = task_train(self.task_config,
                      experiment_name=self.experiment_name)
     self.online_tool.reset_online_tag(rec)  # set to online model