Ejemplo n.º 1
0
    def setUpClass(cls) -> None:
        cn_data_dir = str(QLIB_DIR.joinpath("cn_data").resolve())
        pit_dir = str(SOURCE_DIR.joinpath("pit").resolve())
        pit_normalized_dir = str(
            SOURCE_DIR.joinpath("pit_normalized").resolve())
        GetData().qlib_data(name="qlib_data_simple",
                            target_dir=cn_data_dir,
                            region="cn",
                            delete_old=False,
                            exists_skip=True)
        GetData().qlib_data(name="qlib_data",
                            target_dir=pit_dir,
                            region="pit",
                            delete_old=False,
                            exists_skip=True)

        # NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data.
        # bs.login()
        # Run(
        #     source_dir=pit_dir,
        #     interval="quarterly",
        # ).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*")
        # bs.logout()

        Run(
            source_dir=pit_dir,
            normalize_dir=pit_normalized_dir,
            interval="quarterly",
        ).normalize_data()
        DumpPitData(
            csv_path=pit_normalized_dir,
            qlib_dir=cn_data_dir,
        ).dump(interval="quarterly")
Ejemplo n.º 2
0
 def _init_qlib(self):
     """initialize qlib"""
     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
     GetData().qlib_data(target_dir=provider_uri,
                         region=REG_CN,
                         exists_skip=True)
     qlib.init(provider_uri=provider_uri, region=REG_CN)
Ejemplo n.º 3
0
 def _init_qlib(self):
     """initialize qlib"""
     # use cn_data_1min data
     QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
     provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
     GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
     qlib.init(**QLIB_INIT_CONFIG)
Ejemplo n.º 4
0
 def _init_qlib(self):
     """initialize qlib"""
     provider_uri_day = "~/.qlib/qlib_data/cn_data"  # target_dir
     GetData().qlib_data(target_dir=provider_uri_day,
                         region=REG_CN,
                         version="v2",
                         exists_skip=True)
     provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
     GetData().qlib_data(target_dir=provider_uri_1min,
                         interval="1min",
                         region=REG_CN,
                         version="v2",
                         exists_skip=True)
     provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
     qlib.init(provider_uri=provider_uri_map,
               dataset_cache=None,
               expression_cache=None)
Ejemplo n.º 5
0
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
     qlib.init(provider_uri=provider_uri, region=REG_CN)
Ejemplo n.º 6
0
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
     provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri,
                             interval="1min",
                             region=REG_CN)
     qlib.init(**QLIB_INIT_CONFIG)
Ejemplo n.º 7
0
    def test_0_qlib_data(self):

        GetData().qlib_data(name="qlib_data_simple",
                            target_dir=QLIB_DIR,
                            region="cn",
                            interval="1d",
                            delete_old=False,
                            exists_skip=True)
        df = D.features(D.instruments("csi300"), self.FIELDS)
        self.assertListEqual(list(df.columns), self.FIELDS,
                             "get qlib data failed")
        self.assertFalse(df.dropna().empty, "get qlib data failed")
Ejemplo n.º 8
0
 def _init_qlib(self, exp_folder_name):
     # init qlib
     GetData().qlib_data(exists_skip=True)
     qlib.init(
         exp_manager={
             "class": "MLflowExpManager",
             "module_path": "qlib.workflow.expm",
             "kwargs": {
                 "uri":
                 "file:" +
                 str(Path(os.getcwd()).resolve() / exp_folder_name),
                 "default_exp_name":
                 "Experiment",
             },
         })
Ejemplo n.º 9
0
                "bagging_fraction":
                trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
                "bagging_freq":
                trial.suggest_int("bagging_freq", 1, 7),
                "min_data_in_leaf":
                trial.suggest_int("min_data_in_leaf", 1, 50),
                "min_child_samples":
                trial.suggest_int("min_child_samples", 5, 100),
            },
        },
    }

    evals_result = dict()
    model = init_instance_by_config(task["model"])
    model.fit(dataset, evals_result=evals_result)
    return min(evals_result["valid"])


if __name__ == "__main__":

    provider_uri = "~/.qlib/qlib_data/cn_data"
    GetData().qlib_data(target_dir=provider_uri,
                        region=REG_CN,
                        exists_skip=True)
    qlib.init(provider_uri=provider_uri, region=REG_CN)

    dataset = init_instance_by_config(DATASET_CONFIG)

    study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
    study.optimize(objective, n_jobs=6)
Ejemplo n.º 10
0
from qlib.contrib.evaluate import (
    backtest as normal_backtest,
    risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData

if __name__ == "__main__":

    # use default data
    provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
    if not exists_qlib_data(provider_uri):
        print(f"Qlib data is not found in {provider_uri}")
        GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

    qlib.init(provider_uri=provider_uri, region=REG_CN)

    market = "csi300"
    benchmark = "SH000300"

    ###################################
    # train model
    ###################################
    data_handler_config = {
        "start_time": "2008-01-01",
        "end_time": "2020-08-01",
        "fit_start_time": "2008-01-01",
        "fit_end_time": "2014-12-31",
        "instruments": market,
Ejemplo n.º 11
0
    def update_data_to_bin(
        self,
        qlib_data_1d_dir: str,
        trading_date: str = None,
        end_date: str = None,
        check_data_length: int = None,
        delay: float = 1,
    ):
        """update yahoo data to bin

        Parameters
        ----------
        qlib_data_1d_dir: str
            the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data

        trading_date: str
            trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
        end_date: str
            end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
        check_data_length: int
            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
        delay: float
            time.sleep(delay), default 1
        Notes
        -----
            If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day

        Examples
        -------
            $ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
            # get 1m data
        """

        if self.interval.lower() != "1d":
            logger.warning(
                f"currently supports 1d data updates: --interval 1d")

        # start/end date
        if trading_date is None:
            trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
            logger.warning(
                f"trading_date is None, use the current date: {trading_date}")

        if end_date is None:
            end_date = (pd.Timestamp(trading_date) +
                        pd.Timedelta(days=1)).strftime("%Y-%m-%d")

        # download qlib 1d data
        qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
        if not exists_qlib_data(qlib_data_1d_dir):
            GetData().qlib_data(target_dir=qlib_data_1d_dir,
                                interval=self.interval,
                                region=self.region)

        # download data from yahoo
        # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
        self.download_data(delay=delay,
                           start=trading_date,
                           end=end_date,
                           check_data_length=check_data_length)
        # NOTE: a larger max_workers setting here would be faster
        self.max_workers = (max(multiprocessing.cpu_count() -
                                2, 1) if self.max_workers is None
                            or self.max_workers <= 1 else self.max_workers)
        # normalize data
        self.normalize_data_1d_extend(qlib_data_1d_dir)

        # dump bin
        _dump = DumpDataUpdate(
            csv_path=self.normalize_dir,
            qlib_dir=qlib_data_1d_dir,
            exclude_fields="symbol,date",
            max_workers=self.max_workers,
        )
        _dump.dump()

        # parse index
        _region = self.region.lower()
        if _region not in ["cn", "us"]:
            logger.warning(
                f"Unsupported region: region={_region}, component downloads will be ignored"
            )
            return
        index_list = ["CSI100", "CSI300"] if _region == "cn" else [
            "SP500", "NASDAQ100", "DJIA", "SP400"
        ]
        get_instruments = getattr(
            importlib.import_module(
                f"data_collector.{_region}_index.collector"),
            "get_instruments")
        for _index in index_list:
            get_instruments(str(qlib_data_1d_dir), _index)
Ejemplo n.º 12
0
                        test_start_time,
                        end_time,
                    ),
                },
            },
        },
    }
    """initialize qlib"""
    # use yahoo_cn_1min data
    QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **SPEC_CONF}
    print(QLIB_INIT_CONFIG)
    provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
    if not exists_qlib_data(provider_uri):
        print(f"Qlib data is not found in {provider_uri}")
        GetData().qlib_data(target_dir=provider_uri,
                            interval="1min",
                            region=REG_CN)
    qlib.init(**QLIB_INIT_CONFIG)

    Cal.calendar(freq="1min")
    get_calendar_day(freq="1min")

    # get data
    dataset = init_instance_by_config(task["dataset"])
    xtrain, xtest = dataset.prepare(["train", "test"])
    print(xtrain, xtest)
    xtrain.to_csv("xtrain.csv")

    dataset_backtest = init_instance_by_config(task["dataset_backtest"])
    backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
    print(backtest_train, backtest_test)
Ejemplo n.º 13
0
        Training the tasks generated by meta model
        Then evaluate it
        """
        with self._task_path.open("rb") as f:
            tasks = pickle.load(f)
        rb = RollingBenchmark(rolling_exp="rolling_ds",
                              model_type=self.forecast_model)
        rb.train_rolling_tasks(tasks)
        rb.ens_rolling()
        rb.update_rolling_rec()

    def run_all(self):
        # 1) file: handler_proxy.pkl
        self.dump_data_for_proxy_model()
        # 2)
        # file: internal_data_s20.pkl
        # mlflow: data_sim_s20, models for calculating meta_ipt
        self.dump_meta_ipt()
        # 3) meta model will be stored in `DDG-DA`
        self.train_meta_model()
        # 4) new_tasks are saved in "tasks_s20.pkl" (reweighter is added)
        self.meta_inference()
        # 5) load the saved tasks and train model
        self.train_and_eval_tasks()


if __name__ == "__main__":
    GetData().qlib_data(exists_skip=True)
    auto_init()
    fire.Fire(DDGDA)
Ejemplo n.º 14
0
 def test_1_csv_data(self):
     GetData().csv_data_cn(SOURCE_DIR)
     stock_name = set(
         map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
     self.assertEqual(len(stock_name), 85, "get csv data failed")