class RollingOnlineExample: def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, tasks=[task_xgboost_config], add_tasks=[task_lgb_config], ): mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.tasks = tasks self.add_tasks = add_tasks self.rolling_step = rolling_step strategies = [] for task in tasks: name_id = task["model"][ "class"] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, task, RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), )) self.rolling_online_manager = OnlineManager(strategies) _ROLLING_MANAGER_PATH = ( ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. ) # Reset all things to the first status, be careful to save important data def reset(self): for task in self.tasks + self.add_tasks: name_id = task["model"]["class"] exp = R.get_exp(experiment_name=name_id) for rid in exp.list_recorders(): exp.delete_recorder(rid) if os.path.exists(self._ROLLING_MANAGER_PATH): os.remove(self._ROLLING_MANAGER_PATH) def first_run(self): print("========== reset ==========") self.reset() print("========== first_run ==========") self.rolling_online_manager.first_train() print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def routine(self): print("========== load ==========") self.rolling_online_manager = OnlineManager.load( self._ROLLING_MANAGER_PATH) print("========== routine ==========") self.rolling_online_manager.routine() print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== signals ==========") print(self.rolling_online_manager.get_signals()) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def add_strategy(self): print("========== load ==========") self.rolling_online_manager = OnlineManager.load( self._ROLLING_MANAGER_PATH) print("========== add strategy ==========") strategies = [] for task in self.add_tasks: name_id = task["model"][ "class"] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, task, RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD), )) self.rolling_online_manager.add_strategy(strategies=strategies) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def main(self): self.first_run() self.routine() self.add_strategy() self.routine()
class RollingOnlineExample: def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region="cn", trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR rolling_step=550, tasks=None, add_tasks=None, ): if add_tasks is None: add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING] if tasks is None: tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.tasks = tasks self.add_tasks = add_tasks self.rolling_step = rolling_step strategies = [] for task in tasks: name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, task, RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), ) ) self.trainer = trainer self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer) _ROLLING_MANAGER_PATH = ( ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. ) def worker(self): # train tasks by other progress or machines for multiprocessing print("========== worker ==========") if isinstance(self.trainer, TrainerRM): for task in self.tasks + self.add_tasks: name_id = task["model"]["class"] self.trainer.worker(experiment_name=name_id) else: print(f"{type(self.trainer)} is not supported for worker.") # Reset all things to the first status, be careful to save important data def reset(self): for task in self.tasks + self.add_tasks: name_id = task["model"]["class"] TaskManager(task_pool=name_id).remove() exp = R.get_exp(experiment_name=name_id) for rid in exp.list_recorders(): exp.delete_recorder(rid) if os.path.exists(self._ROLLING_MANAGER_PATH): os.remove(self._ROLLING_MANAGER_PATH) def first_run(self): print("========== reset ==========") self.reset() print("========== first_run ==========") self.rolling_online_manager.first_train() print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def routine(self): print("========== load ==========") self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) print("========== routine ==========") self.rolling_online_manager.routine() print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== signals ==========") print(self.rolling_online_manager.get_signals()) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def add_strategy(self): print("========== load ==========") self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) print("========== add strategy ==========") strategies = [] for task in self.add_tasks: name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, task, RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD), ) ) self.rolling_online_manager.add_strategy(strategies=strategies) print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) def main(self): self.first_run() self.routine() self.add_strategy() self.routine()