Exemplo n.º 1
0
    def __init__(
        self,
        strategies: Union[OnlineStrategy, List[OnlineStrategy]],
        trainer: Trainer = None,
        begin_time: Union[str, pd.Timestamp] = None,
        freq="day",
    ):
        """
        Init OnlineManager.
        One OnlineManager must have at least one OnlineStrategy.

        Args:
            strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
            begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
            trainer (Trainer): the trainer to train task. None for using TrainerR.
            freq (str, optional): data frequency. Defaults to "day".
        """
        self.logger = get_module_logger(self.__class__.__name__)
        if not isinstance(strategies, list):
            strategies = [strategies]
        self.strategies = strategies
        self.freq = freq
        if begin_time is None:
            begin_time = D.calendar(freq=self.freq).max()
        self.begin_time = pd.Timestamp(begin_time)
        self.cur_time = self.begin_time
        # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
        self.history = {}
        if trainer is None:
            trainer = TrainerR()
        self.trainer = trainer
        self.signals = None
        self.status = self.STATUS_NORMAL
Exemplo n.º 2
0
 def __init__(
     self,
     provider_uri="~/.qlib/qlib_data/cn_data",
     region=REG_CN,
     task_url="mongodb://10.0.0.4:27017/",
     task_db_name="rolling_db",
     experiment_name="rolling_exp",
     task_pool=None,  # if user want to  "rolling_task"
     task_config=None,
     rolling_step=550,
     rolling_type=RollingGen.ROLL_SD,
 ):
     # TaskManager config
     if task_config is None:
         task_config = [
             CSI100_RECORD_XGBOOST_TASK_CONFIG,
             CSI100_RECORD_LGB_TASK_CONFIG
         ]
     mongo_conf = {
         "task_url": task_url,
         "task_db_name": task_db_name,
     }
     qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
     self.experiment_name = experiment_name
     if task_pool is None:
         self.trainer = TrainerR(experiment_name=self.experiment_name)
     else:
         self.task_pool = task_pool
         self.trainer = TrainerRM(self.experiment_name, self.task_pool)
     self.task_config = task_config
     self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
Exemplo n.º 3
0
    def __init__(
        self,
        provider_uri="~/.qlib/qlib_data/cn_data",
        region="cn",
        exp_name="rolling_exp",
        task_url="mongodb://10.0.0.4:27017/",  # not necessary when using TrainerR or DelayTrainerR
        task_db_name="rolling_db",  # not necessary when using TrainerR or DelayTrainerR
        task_pool="rolling_task",
        rolling_step=80,
        start_time="2018-09-10",
        end_time="2018-10-31",
        tasks=None,
        trainer="TrainerR",
    ):
        """
        Init OnlineManagerExample.

        Args:
            provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
            region (str, optional): the stock region. Defaults to "cn".
            exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
            task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
            task_db_name (str, optional): database name. Defaults to "rolling_db".
            task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
            rolling_step (int, optional): the step for rolling. Defaults to 80.
            start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
            end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
            tasks (dict or list[dict]): a set of the task config waiting for rolling and training
        """
        if tasks is None:
            tasks = [
                CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE,
                CSI100_RECORD_LGB_TASK_CONFIG_ONLINE
            ]
        self.exp_name = exp_name
        self.task_pool = task_pool
        self.start_time = start_time
        self.end_time = end_time
        mongo_conf = {
            "task_url": task_url,
            "task_db_name": task_db_name,
        }
        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
        self.rolling_gen = RollingGen(
            step=rolling_step,
            rtype=RollingGen.ROLL_SD,
            ds_extra_mod_func=None
        )  # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
        if trainer == "TrainerRM":
            self.trainer = TrainerRM(self.exp_name, self.task_pool)
        elif trainer == "TrainerR":
            self.trainer = TrainerR(self.exp_name)
        else:
            # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
            raise NotImplementedError(f"This type of input is not supported")
        self.rolling_online_manager = OnlineManager(
            RollingStrategy(exp_name,
                            task_template=tasks,
                            rolling_gen=self.rolling_gen),
            trainer=self.trainer,
            begin_time=self.start_time,
        )
        self.tasks = tasks
Exemplo n.º 4
0
class OnlineSimulationExample:
    def __init__(
        self,
        provider_uri="~/.qlib/qlib_data/cn_data",
        region="cn",
        exp_name="rolling_exp",
        task_url="mongodb://10.0.0.4:27017/",  # not necessary when using TrainerR or DelayTrainerR
        task_db_name="rolling_db",  # not necessary when using TrainerR or DelayTrainerR
        task_pool="rolling_task",
        rolling_step=80,
        start_time="2018-09-10",
        end_time="2018-10-31",
        tasks=None,
        trainer="TrainerR",
    ):
        """
        Init OnlineManagerExample.

        Args:
            provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
            region (str, optional): the stock region. Defaults to "cn".
            exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
            task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
            task_db_name (str, optional): database name. Defaults to "rolling_db".
            task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
            rolling_step (int, optional): the step for rolling. Defaults to 80.
            start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
            end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
            tasks (dict or list[dict]): a set of the task config waiting for rolling and training
        """
        if tasks is None:
            tasks = [
                CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE,
                CSI100_RECORD_LGB_TASK_CONFIG_ONLINE
            ]
        self.exp_name = exp_name
        self.task_pool = task_pool
        self.start_time = start_time
        self.end_time = end_time
        mongo_conf = {
            "task_url": task_url,
            "task_db_name": task_db_name,
        }
        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
        self.rolling_gen = RollingGen(
            step=rolling_step,
            rtype=RollingGen.ROLL_SD,
            ds_extra_mod_func=None
        )  # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
        if trainer == "TrainerRM":
            self.trainer = TrainerRM(self.exp_name, self.task_pool)
        elif trainer == "TrainerR":
            self.trainer = TrainerR(self.exp_name)
        else:
            # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
            raise NotImplementedError(f"This type of input is not supported")
        self.rolling_online_manager = OnlineManager(
            RollingStrategy(exp_name,
                            task_template=tasks,
                            rolling_gen=self.rolling_gen),
            trainer=self.trainer,
            begin_time=self.start_time,
        )
        self.tasks = tasks

    # Reset all things to the first status, be careful to save important data
    def reset(self):
        if isinstance(self.trainer, TrainerRM):
            TaskManager(self.task_pool).remove()
        exp = R.get_exp(experiment_name=self.exp_name)
        for rid in exp.list_recorders():
            exp.delete_recorder(rid)

    # Run this to run all workflow automatically
    def main(self):
        print("========== reset ==========")
        self.reset()
        print("========== simulate ==========")
        self.rolling_online_manager.simulate(end_time=self.end_time)
        print("========== collect results ==========")
        print(self.rolling_online_manager.get_collector()())
        print("========== signals ==========")
        signals = self.rolling_online_manager.get_signals()
        print(signals)
        # Backtesting
        # - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
        CSI300_BENCH = "SH000903"
        STRATEGY_CONFIG = {
            "topk": 30,
            "n_drop": 3,
            "signal": signals.to_frame("score"),
        }
        strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
        report_normal, positions_normal = backtest_daily(
            start_time=signals.index.get_level_values("datetime").min(),
            end_time=signals.index.get_level_values("datetime").max(),
            strategy=strategy_obj,
        )
        analysis = dict()
        analysis["excess_return_without_cost"] = risk_analysis(
            report_normal["return"] - report_normal["bench"])
        analysis["excess_return_with_cost"] = risk_analysis(
            report_normal["return"] - report_normal["bench"] -
            report_normal["cost"])

        analysis_df = pd.concat(analysis)  # type: pd.DataFrame
        pprint(analysis_df)

    def worker(self):
        # train tasks by other progress or machines for multiprocessing
        # FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
        print("========== worker ==========")
        if isinstance(self.trainer, TrainerRM):
            self.trainer.worker()
        else:
            print(f"{type(self.trainer)} is not supported for worker.")
Exemplo n.º 5
0
 def train_rolling_tasks(self, task_l=None):
     if task_l is None:
         task_l = self.create_rolling_tasks()
     trainer = TrainerR(experiment_name=self.rolling_exp)
     trainer(task_l)