def _init_qlib(self): """initialize qlib""" # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN)
def qlib_data( self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version=None, interval="1d", region="cn", delete_old=True, exists_skip=False, ): """download cn qlib data from remote Parameters ---------- target_dir: str data save directory name: str dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data version: str data version, value from [v1, ...], by default None(use script to specify version) interval: str data freq, value from [1d], by default 1d region: str data region, value from [cn, us], by default cn delete_old: bool delete an existing directory, by default True exists_skip: bool exists skip, by default False Examples --------- # get 1d data python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn # get 1min data python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn ------- """ if exists_skip and exists_qlib_data(target_dir): logger.warning( f"Data already exists: {target_dir}, the data download will be skipped\n" f"\tIf downloading is required: `exists_skip=False` or `change target_dir`" ) return qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) def _get_file_name(v): return self.QLIB_DATA_NAME.format( dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v ) file_name = _get_file_name(qlib_version) if not self.check_dataset(file_name, version): file_name = _get_file_name("latest") self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
def setUpClass(cls) -> None: # use default data provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) from get_data import GetData GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN)
def _init_qlib(self): """initialize qlib""" # use yahoo_cn_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) qlib.init(**QLIB_INIT_CONFIG)
from qlib.contrib.data.handler import Alpha158 from qlib.contrib.strategy.strategy import TopkDropoutStrategy from qlib.contrib.evaluate import ( backtest as normal_backtest, risk_analysis, ) from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) market = "csi300" benchmark = "SH000300" ################################### # train model ################################### data_handler_config = { "start_time": "2008-01-01", "end_time": "2020-08-01", "fit_start_time": "2008-01-01",
def update_data_to_bin( self, qlib_data_1d_dir: str, trading_date: str = None, end_date: str = None, check_data_length: int = None, delay: float = 1, ): """update yahoo data to bin Parameters ---------- qlib_data_1d_dir: str the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data trading_date: str trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")`` end_date: str end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end) check_data_length: int check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. delay: float time.sleep(delay), default 1 Notes ----- If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day Examples ------- $ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date> # get 1m data """ if self.interval.lower() != "1d": logger.warning( f"currently supports 1d data updates: --interval 1d") # start/end date if trading_date is None: trading_date = datetime.datetime.now().strftime("%Y-%m-%d") logger.warning( f"trading_date is None, use the current date: {trading_date}") if end_date is None: end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = (max(multiprocessing.cpu_count() - 2, 1) if self.max_workers is None or self.max_workers <= 1 else self.max_workers) # normalize data self.normalize_data_1d_extend(qlib_data_1d_dir) # dump bin _dump = DumpDataUpdate( csv_path=self.normalize_dir, qlib_dir=qlib_data_1d_dir, exclude_fields="symbol,date", max_workers=self.max_workers, ) _dump.dump() # parse index _region = self.region.lower() if _region not in ["cn", "us"]: logger.warning( f"Unsupported region: region={_region}, component downloads will be ignored" ) return index_list = ["CSI100", "CSI300"] if _region == "cn" else [ "SP500", "NASDAQ100", "DJIA", "SP400" ] get_instruments = getattr( importlib.import_module( f"data_collector.{_region}_index.collector"), "get_instruments") for _index in index_list: get_instruments(str(qlib_data_1d_dir), _index)