示例#1
0
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
     qlib.init(provider_uri=provider_uri, region=REG_CN)
示例#2
0
文件: data.py 项目: you-n-g/qlib
    def qlib_data(
        self,
        name="qlib_data",
        target_dir="~/.qlib/qlib_data/cn_data",
        version=None,
        interval="1d",
        region="cn",
        delete_old=True,
        exists_skip=False,
    ):
        """download cn qlib data from remote

        Parameters
        ----------
        target_dir: str
            data save directory
        name: str
            dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
        version: str
            data version, value from [v1, ...], by default None(use script to specify version)
        interval: str
            data freq, value from [1d], by default 1d
        region: str
            data region, value from [cn, us], by default cn
        delete_old: bool
            delete an existing directory, by default True
        exists_skip: bool
            exists skip, by default False

        Examples
        ---------
        # get 1d data
        python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn

        # get 1min data
        python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn
        -------

        """
        if exists_skip and exists_qlib_data(target_dir):
            logger.warning(
                f"Data already exists: {target_dir}, the data download will be skipped\n"
                f"\tIf downloading is required: `exists_skip=False` or `change target_dir`"
            )
            return

        qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))

        def _get_file_name(v):
            return self.QLIB_DATA_NAME.format(
                dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
            )

        file_name = _get_file_name(qlib_version)
        if not self.check_dataset(file_name, version):
            file_name = _get_file_name("latest")
        self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
示例#3
0
    def setUpClass(cls) -> None:
        # use default data
        provider_uri = "~/.qlib/qlib_data/cn_data_simple"  # target_dir
        if not exists_qlib_data(provider_uri):
            print(f"Qlib data is not found in {provider_uri}")
            sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
            from get_data import GetData

            GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
        qlib.init(provider_uri=provider_uri, region=REG_CN)
示例#4
0
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
     provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri,
                             interval="1min",
                             region=REG_CN)
     qlib.init(**QLIB_INIT_CONFIG)
示例#5
0
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
    backtest as normal_backtest,
    risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData

if __name__ == "__main__":

    # use default data
    provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
    if not exists_qlib_data(provider_uri):
        print(f"Qlib data is not found in {provider_uri}")
        GetData().qlib_data(target_dir=provider_uri, region=REG_CN)

    qlib.init(provider_uri=provider_uri, region=REG_CN)

    market = "csi300"
    benchmark = "SH000300"

    ###################################
    # train model
    ###################################
    data_handler_config = {
        "start_time": "2008-01-01",
        "end_time": "2020-08-01",
        "fit_start_time": "2008-01-01",
示例#6
0
    def update_data_to_bin(
        self,
        qlib_data_1d_dir: str,
        trading_date: str = None,
        end_date: str = None,
        check_data_length: int = None,
        delay: float = 1,
    ):
        """update yahoo data to bin

        Parameters
        ----------
        qlib_data_1d_dir: str
            the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data

        trading_date: str
            trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
        end_date: str
            end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
        check_data_length: int
            check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
        delay: float
            time.sleep(delay), default 1
        Notes
        -----
            If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day

        Examples
        -------
            $ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
            # get 1m data
        """

        if self.interval.lower() != "1d":
            logger.warning(
                f"currently supports 1d data updates: --interval 1d")

        # start/end date
        if trading_date is None:
            trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
            logger.warning(
                f"trading_date is None, use the current date: {trading_date}")

        if end_date is None:
            end_date = (pd.Timestamp(trading_date) +
                        pd.Timedelta(days=1)).strftime("%Y-%m-%d")

        # download qlib 1d data
        qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
        if not exists_qlib_data(qlib_data_1d_dir):
            GetData().qlib_data(target_dir=qlib_data_1d_dir,
                                interval=self.interval,
                                region=self.region)

        # download data from yahoo
        # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
        self.download_data(delay=delay,
                           start=trading_date,
                           end=end_date,
                           check_data_length=check_data_length)
        # NOTE: a larger max_workers setting here would be faster
        self.max_workers = (max(multiprocessing.cpu_count() -
                                2, 1) if self.max_workers is None
                            or self.max_workers <= 1 else self.max_workers)
        # normalize data
        self.normalize_data_1d_extend(qlib_data_1d_dir)

        # dump bin
        _dump = DumpDataUpdate(
            csv_path=self.normalize_dir,
            qlib_dir=qlib_data_1d_dir,
            exclude_fields="symbol,date",
            max_workers=self.max_workers,
        )
        _dump.dump()

        # parse index
        _region = self.region.lower()
        if _region not in ["cn", "us"]:
            logger.warning(
                f"Unsupported region: region={_region}, component downloads will be ignored"
            )
            return
        index_list = ["CSI100", "CSI300"] if _region == "cn" else [
            "SP500", "NASDAQ100", "DJIA", "SP400"
        ]
        get_instruments = getattr(
            importlib.import_module(
                f"data_collector.{_region}_index.collector"),
            "get_instruments")
        for _index in index_list:
            get_instruments(str(qlib_data_1d_dir), _index)