예제 #1
0
파일: strategy.py 프로젝트: yutiansut/qlib
    def prepare_tasks(self, cur_time) -> List[dict]:
        """
        Prepare new tasks based on cur_time (None for the latest).

        You can find the last online models by OnlineToolR.online_models.

        Returns:
            List[dict]: a list of new tasks.
        """
        # TODO: filter recorders by latest test segments is not a necessary
        latest_records, max_test = self._list_latest(self.tool.online_models())
        if max_test is None:
            self.logger.warn(f"No latest online recorders, no new tasks.")
            return []
        calendar_latest = transform_end_date(cur_time)
        self.logger.info(
            f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
        )
        res = []
        for rec in latest_records:
            task = rec.load_object("task")
            res.extend(self.rg.gen_following_tasks(task, calendar_latest))
        return res
예제 #2
0
    def generate(self, task: dict) -> List[dict]:
        """
        Converting the task into a rolling task.

        Parameters
        ----------
        task: dict
            A dict describing a task. For example.

            .. code-block:: python

                DEFAULT_TASK = {
                    "model": {
                        "class": "LGBModel",
                        "module_path": "qlib.contrib.model.gbdt",
                    },
                    "dataset": {
                        "class": "DatasetH",
                        "module_path": "qlib.data.dataset",
                        "kwargs": {
                            "handler": {
                                "class": "Alpha158",
                                "module_path": "qlib.contrib.data.handler",
                                "kwargs": {
                                    "start_time": "2008-01-01",
                                    "end_time": "2020-08-01",
                                    "fit_start_time": "2008-01-01",
                                    "fit_end_time": "2014-12-31",
                                    "instruments": "csi100",
                                },
                            },
                            "segments": {
                                "train": ("2008-01-01", "2014-12-31"),
                                "valid": ("2015-01-01", "2016-12-20"),  # Please avoid leaking the future test data into validation
                                "test": ("2017-01-01", "2020-08-01"),
                            },
                        },
                    },
                    "record": [
                        {
                            "class": "SignalRecord",
                            "module_path": "qlib.workflow.record_temp",
                        },
                    ]
                }

        Returns
        ----------
        List[dict]: a list of tasks
        """
        res = []

        t = copy.deepcopy(task)

        # calculate segments

        # First rolling
        # 1) prepare the end point
        segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
        test_end = transform_end_date(segments[self.test_key][1])
        # 2) and init test segments
        test_start_idx = self.ta.align_idx(segments[self.test_key][0])
        segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))

        # update segments of this task
        self._update_task_segs(t, segments)

        res.append(t)

        # Update the following rolling
        res.extend(self.gen_following_tasks(t, test_end))
        return res