def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): with open(config_path) as fp: config = yaml.load(fp, Loader=yaml.SafeLoader) # config the `sys` section sys_config(config, config_path) exp_manager = C["exp_manager"] exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder) qlib.init(**config.get("qlib_init"), exp_manager=exp_manager) task_train(config.get("task"), experiment_name=experiment_name)
def test_update_label(self): task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": { "dataset": "<DATASET>", "model": "<MODEL>" }, } exp_name = "online_srv_test" cal = D.calendar() shift = 10 latest_date = cal[-1 - shift] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred() new_pred = rec.load_object("pred.pkl") label = rec.load_object("label.pkl") label_date = label.dropna().index.get_level_values("datetime").max() pred_date = new_pred.dropna().index.get_level_values("datetime").max() # The prediction is updated, but the label is not updated. self.assertTrue(label_date < pred_date) # Update label now lu = LabelUpdater(rec) lu.update() new_label = rec.load_object("label.pkl") new_label_date = new_label.index.get_level_values("datetime").max() self.assertTrue( new_label_date == pred_date) # make sure the label is updated now
def test_update_pred(self): """ This test is for testing if it will raise error if the `to_date` is out of the boundary. """ task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = ["qlib.workflow.record_temp.SignalRecord"] exp_name = "online_srv_test" cal = D.calendar() latest_date = cal[-1] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10)) good_pred = rec.load_object("pred.pkl") mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10)) mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2)) mod_pred = good_pred.copy() mod_pred.loc[mod_range] = -1 mod_pred.loc[mod_range2] = -2 rec.save_objects(**{"pred.pkl": mod_pred}) online_tool.update_online_pred( to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20) ) updated_pred = rec.load_object("pred.pkl") # this range is not fixed self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item()) # this range is fixed now self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
def test_update_pred(self): """ This test is for testing if it will raise error if the `to_date` is out of the boundary. """ task = copy.deepcopy(CSI300_GBDT_TASK) task["record"] = { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": { "dataset": "<DATASET>", "model": "<MODEL>" }, } exp_name = "online_srv_test" cal = D.calendar() latest_date = cal[-1] train_start = latest_date - pd.Timedelta(days=61) train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), "test": (latest_date - pd.Timedelta(days=20), latest_date), } task["dataset"]["kwargs"]["handler"]["kwargs"] = { "start_time": train_start, "end_time": latest_date, "fit_start_time": train_start, "fit_end_time": train_end, "instruments": "csi300", } rec = task_train(task, exp_name) pred = rec.load_object("pred.pkl") online_tool = OnlineToolR(exp_name) online_tool.reset_online_tag(rec) # set to online model online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): """ This is a Qlib CLI entrance. User can run the whole Quant research workflow defined by a configure file - the code is located here ``qlib/workflow/cli.py` """ with open(config_path) as fp: config = yaml.safe_load(fp) # config the `sys` section sys_config(config, config_path) if "exp_manager" in config.get("qlib_init"): qlib.init(**config.get("qlib_init")) else: exp_manager = C["exp_manager"] exp_manager["kwargs"]["uri"] = "file:" + str( Path(os.getcwd()).resolve() / uri_folder) qlib.init(**config.get("qlib_init"), exp_manager=exp_manager) if "experiment_name" in config: experiment_name = config["experiment_name"] recorder = task_train(config.get("task"), experiment_name=experiment_name) recorder.save_objects(config=config)
def first_train(self): rec = task_train(self.task_config, experiment_name=self.experiment_name) self.online_tool.reset_online_tag(rec) # set to online model