def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", experiment_name="rolling_exp", task_pool=None, # if user want to "rolling_task" task_config=None, rolling_step=550, rolling_type=RollingGen.ROLL_SD, ): # TaskManager config if task_config is None: task_config = [ CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG ] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.experiment_name = experiment_name if task_pool is None: self.trainer = TrainerR(experiment_name=self.experiment_name) else: self.task_pool = task_pool self.trainer = TrainerRM(self.experiment_name, self.task_pool) self.task_config = task_config self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region="cn", exp_name="rolling_exp", task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR task_pool="rolling_task", rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", tasks=None, ): """ Init OnlineManagerExample. Args: provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data". region (str, optional): the stock region. Defaults to "cn". exp_name (str, optional): the experiment name. Defaults to "rolling_exp". task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/". task_db_name (str, optional): database name. Defaults to "rolling_db". task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task". rolling_step (int, optional): the step for rolling. Defaults to 80. start_time (str, optional): the start time of simulating. Defaults to "2018-09-10". end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time self.end_time = end_time mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.rolling_gen = RollingGen( step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR self.rolling_online_manager = OnlineManager( RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, begin_time=self.start_time, ) self.tasks = tasks
def task_training(self, tasks): print("========== task_training ==========") trainer = TrainerRM(self.experiment_name, self.task_pool) trainer.train(tasks)
class OnlineSimulationExample: def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region="cn", exp_name="rolling_exp", task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR task_pool="rolling_task", rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", tasks=None, ): """ Init OnlineManagerExample. Args: provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data". region (str, optional): the stock region. Defaults to "cn". exp_name (str, optional): the experiment name. Defaults to "rolling_exp". task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/". task_db_name (str, optional): database name. Defaults to "rolling_db". task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task". rolling_step (int, optional): the step for rolling. Defaults to 80. start_time (str, optional): the start time of simulating. Defaults to "2018-09-10". end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time self.end_time = end_time mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.rolling_gen = RollingGen( step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR self.rolling_online_manager = OnlineManager( RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, begin_time=self.start_time, ) self.tasks = tasks # Reset all things to the first status, be careful to save important data def reset(self): TaskManager(self.task_pool).remove() exp = R.get_exp(experiment_name=self.exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) # Run this to run all workflow automatically def main(self): print("========== reset ==========") self.reset() print("========== simulate ==========") self.rolling_online_manager.simulate(end_time=self.end_time) print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) print("========== signals ==========") print(self.rolling_online_manager.get_signals()) def worker(self): # train tasks by other progress or machines for multiprocessing # FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception. print("========== worker ==========") if isinstance(self.trainer, TrainerRM): self.trainer.worker() else: print(f"{type(self.trainer)} is not supported for worker.")
class RollingTaskExample: def __init__( self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", experiment_name="rolling_exp", task_pool=None, # if user want to "rolling_task" task_config=None, rolling_step=550, rolling_type=RollingGen.ROLL_SD, ): # TaskManager config if task_config is None: task_config = [ CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG ] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) self.experiment_name = experiment_name if task_pool is None: self.trainer = TrainerR(experiment_name=self.experiment_name) else: self.task_pool = task_pool self.trainer = TrainerRM(self.experiment_name, self.task_pool) self.task_config = task_config self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type) # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") if isinstance(self.trainer, TrainerRM): TaskManager(task_pool=self.task_pool).remove() exp = R.get_exp(experiment_name=self.experiment_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) def task_generating(self): print("========== task_generating ==========") tasks = task_generator( tasks=self.task_config, generators=self.rolling_gen, # generate different date segments ) pprint(tasks) return tasks def task_training(self, tasks): print("========== task_training ==========") self.trainer.train(tasks) def worker(self): # NOTE: this is only used for TrainerRM # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker. print("========== worker ==========") run_task(task_train, self.task_pool, experiment_name=self.experiment_name) def task_collecting(self): print("========== task_collecting ==========") def rec_key(recorder): task_config = recorder.load_object("task") model_key = task_config["model"]["class"] rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] return model_key, rolling_key def my_filter(recorder): # only choose the results of "LGBModel" model_key, rolling_key = rec_key(recorder) if model_key == "LGBModel": return True return False collector = RecorderCollector( experiment=self.experiment_name, process_list=RollingGroup(), rec_key_func=rec_key, rec_filter_func=my_filter, ) print(collector()) def main(self): self.reset() tasks = self.task_generating() self.task_training(tasks) self.task_collecting()