コード例 #1
0
ファイル: task_manager_rolling.py プロジェクト: you-n-g/qlib
 def __init__(
     self,
     provider_uri="~/.qlib/qlib_data/cn_data",
     region=REG_CN,
     task_url="mongodb://10.0.0.4:27017/",
     task_db_name="rolling_db",
     experiment_name="rolling_exp",
     task_pool="rolling_task",
     task_config=None,
     rolling_step=550,
     rolling_type=RollingGen.ROLL_SD,
 ):
     # TaskManager config
     if task_config is None:
         task_config = [
             CSI100_RECORD_XGBOOST_TASK_CONFIG,
             CSI100_RECORD_LGB_TASK_CONFIG
         ]
     mongo_conf = {
         "task_url": task_url,
         "task_db_name": task_db_name,
     }
     qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
     self.experiment_name = experiment_name
     self.task_pool = task_pool
     self.task_config = task_config
     self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
コード例 #2
0
ファイル: workflow.py プロジェクト: yutiansut/qlib
 def _init_qlib(self):
     """initialize qlib"""
     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
     GetData().qlib_data(target_dir=provider_uri,
                         region=REG_CN,
                         exists_skip=True)
     qlib.init(provider_uri=provider_uri, region=REG_CN)
コード例 #3
0
ファイル: example.py プロジェクト: yutiansut/qlib
 def setUp(self):
     """
     Configure for arctic
     """
     provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min"
     qlib.init(
         provider_uri=provider_uri,
         mem_cache_size_limit=1024**3 * 2,
         mem_cache_type="sizeof",
         kernels=1,
         expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}},
         feature_provider={
             "class": "ArcticFeatureProvider",
             "module_path": "qlib.contrib.data.data",
             "kwargs": {"uri": "127.0.0.1"},
         },
         dataset_provider={
             "class": "LocalDatasetProvider",
             "kwargs": {
                 "align_time": False,  # Order book is not fixed, so it can't be align to a shared fixed frequency calendar
             },
         },
     )
     # self.stocks_list = ["SH600519"]
     self.stocks_list = ["SZ000725"]
コード例 #4
0
ファイル: test_get_data.py プロジェクト: yelianjin/qlib
 def setUpClass(cls) -> None:
     provider_uri = str(QLIB_DIR.resolve())
     qlib.init(
         provider_uri=provider_uri,
         expression_cache=None,
         dataset_cache=None,
     )
コード例 #5
0
 def __init__(
     self,
     provider_uri="~/.qlib/qlib_data/cn_data",
     region="cn",
     trainer=DelayTrainerRM(),  # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
     task_url="mongodb://10.0.0.4:27017/",  # not necessary when using TrainerR or DelayTrainerR
     task_db_name="rolling_db",  # not necessary when using TrainerR or DelayTrainerR
     rolling_step=550,
     tasks=None,
     add_tasks=None,
 ):
     if add_tasks is None:
         add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
     if tasks is None:
         tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
     mongo_conf = {
         "task_url": task_url,  # your MongoDB url
         "task_db_name": task_db_name,  # database name
     }
     qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
     self.tasks = tasks
     self.add_tasks = add_tasks
     self.rolling_step = rolling_step
     strategies = []
     for task in tasks:
         name_id = task["model"]["class"]  # NOTE: Assumption: The model class can specify only one strategy
         strategies.append(
             RollingStrategy(
                 name_id,
                 task,
                 RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
             )
         )
     self.trainer = trainer
     self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
コード例 #6
0
    def __init__(
        self,
        provider_uri="~/.qlib/qlib_data/cn_data",
        region="cn",
        task_url="mongodb://10.0.0.4:27017/",
        task_db_name="rolling_db",
        rolling_step=550,
        tasks=[task_xgboost_config],
        add_tasks=[task_lgb_config],
    ):
        mongo_conf = {
            "task_url": task_url,  # your MongoDB url
            "task_db_name": task_db_name,  # database name
        }
        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
        self.tasks = tasks
        self.add_tasks = add_tasks
        self.rolling_step = rolling_step
        strategies = []
        for task in tasks:
            name_id = task["model"][
                "class"]  # NOTE: Assumption: The model class can specify only one strategy
            strategies.append(
                RollingStrategy(
                    name_id,
                    task,
                    RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
                ))

        self.rolling_online_manager = OnlineManager(strategies)
コード例 #7
0
def main(seed, config_file="configs/config_alstm.yaml"):

    # set random seed
    with open(config_file) as f:
        config = yaml.safe_load(f)

    # seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
    seed_suffix = ""
    config["task"]["model"]["kwargs"].update({
        "seed":
        seed,
        "logdir":
        config["task"]["model"]["kwargs"]["logdir"] + seed_suffix
    })

    # initialize workflow
    qlib.init(
        provider_uri=config["qlib_init"]["provider_uri"],
        region=config["qlib_init"]["region"],
    )
    dataset = init_instance_by_config(config["task"]["dataset"])
    model = init_instance_by_config(config["task"]["model"])

    # train model
    model.fit(dataset)
コード例 #8
0
ファイル: workflow.py プロジェクト: majiajue/qlib
 def _init_qlib(self):
     """initialize qlib"""
     # use cn_data_1min data
     QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
     provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
     GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
     qlib.init(**QLIB_INIT_CONFIG)
コード例 #9
0
ファイル: workflow.py プロジェクト: yi6ei2ifd/qlib
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
     qlib.init(provider_uri=provider_uri, region=REG_CN)
コード例 #10
0
 def __init__(self,
              provider_uri="~/.qlib/qlib_data/cn_data",
              region=REG_CN,
              experiment_name="online_srv",
              task_config=task):
     qlib.init(provider_uri=provider_uri, region=region)
     self.experiment_name = experiment_name
     self.online_tool = OnlineToolR(self.experiment_name)
     self.task_config = task_config
コード例 #11
0
ファイル: collector.py プロジェクト: ycl010203/qlib
    def _get_old_data(self, qlib_data_dir: [str, Path]):
        import qlib
        from qlib.data import D

        qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
        qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
        df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
        df.columns = [self._ori_close_field, self._first_close_field]
        return df
コード例 #12
0
ファイル: collector.py プロジェクト: ycl010203/qlib
    def _get_all_1d_data(self):
        import qlib
        from qlib.data import D

        qlib.init(provider_uri=self.qlib_data_1d_dir)
        df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
        df.reset_index(inplace=True)
        df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
        df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
        return df
コード例 #13
0
    def setUpClass(cls) -> None:
        # use default data
        provider_uri = "~/.qlib/qlib_data/cn_data_simple"  # target_dir
        if not exists_qlib_data(provider_uri):
            print(f"Qlib data is not found in {provider_uri}")
            sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
            from get_data import GetData

            GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
        qlib.init(provider_uri=provider_uri, region=REG_CN)
コード例 #14
0
ファイル: test_dump_data.py プロジェクト: yelianjin/qlib
 def setUpClass(cls) -> None:
     GetData().csv_data_cn(SOURCE_DIR)
     TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR)
     TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.iterdir()))
     provider_uri = str(QLIB_DIR.resolve())
     qlib.init(
         provider_uri=provider_uri,
         expression_cache=None,
         dataset_cache=None,
     )
コード例 #15
0
ファイル: highfreq_provider.py プロジェクト: yutiansut/qlib
    def _init_qlib(self, qlib_conf):
        """initialize qlib"""

        qlib.init(
            region=REG_CN,
            auto_mount=False,
            custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut],
            expression_cache=None,
            **qlib_conf,
        )
コード例 #16
0
ファイル: check_dump_bin.py プロジェクト: ailabx/ailabx
    def __init__(
        self,
        qlib_dir: str,
        csv_path: str,
        check_fields: str = None,
        freq: str = "day",
        symbol_field_name: str = "symbol",
        date_field_name: str = "date",
        file_suffix: str = ".csv",
        max_workers: int = 16,
    ):
        """

        Parameters
        ----------
        qlib_dir : str
            qlib dir
        csv_path : str
            origin csv path
        check_fields : str, optional
            check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
        freq : str, optional
            freq, value from ["day", "1m"]
        symbol_field_name: str, optional
            symbol field name, by default "symbol"
        date_field_name: str, optional
            date field name, by default "date"
        file_suffix: str, optional
            csv file suffix, by default ".csv"
        max_workers: int, optional
            max workers, by default 16
        """
        self.qlib_dir = Path(qlib_dir).expanduser()
        bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
        self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
        qlib.init(
            provider_uri=str(self.qlib_dir.resolve()),
            mount_path=str(self.qlib_dir.resolve()),
            auto_mount=False,
            redis_port=-1,
        )
        csv_path = Path(csv_path).expanduser()
        self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])

        if check_fields is None:
            check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin")))
        else:
            check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
        self.check_fields = list(map(lambda x: x.strip(), check_fields))
        self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
        self.max_workers = max_workers
        self.symbol_field_name = symbol_field_name
        self.date_field_name = date_field_name
        self.freq = freq
        self.file_suffix = file_suffix
コード例 #17
0
def fill_1min_using_1d(
    data_1min_dir: [str, Path],
    qlib_data_1d_dir: [str, Path],
    max_workers: int = 16,
    date_field_name: str = "date",
    symbol_field_name: str = "symbol",
):
    """Use 1d data to fill in the missing symbols relative to 1min

    Parameters
    ----------
    data_1min_dir: str
        1min data dir
    qlib_data_1d_dir: str
        1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
    max_workers: int
        ThreadPoolExecutor(max_workers), by default 16
    date_field_name: str
        date field name, by default date
    symbol_field_name: str
        symbol field name, by default symbol

    """
    data_1min_dir = Path(data_1min_dir).expanduser().resolve()
    qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()

    min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
    symbols_1min = get_symbols(data_1min_dir)

    qlib.init(provider_uri=str(qlib_data_1d_dir))
    data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")

    miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
    if not miss_symbols:
        logger.warning("More symbols in 1min than 1d, no padding required")
        return

    logger.info(f"miss_symbols  {len(miss_symbols)}: {miss_symbols}")
    tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
    columns = tmp_df.columns
    _si = tmp_df[symbol_field_name].first_valid_index()
    is_lower = tmp_df.loc[_si][symbol_field_name].islower()
    for symbol in tqdm(miss_symbols):
        if is_lower:
            symbol = symbol.lower()
        index_1d = data_1d.loc(axis=0)[symbol.upper()].index
        index_1min = generate_minutes_calendar_from_daily(index_1d)
        index_1min.name = date_field_name
        _df = pd.DataFrame(columns=columns, index=index_1min)
        if date_field_name in _df.columns:
            del _df[date_field_name]
        _df.reset_index(inplace=True)
        _df[symbol_field_name] = symbol
        _df["paused_num"] = 0
        _df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
コード例 #18
0
 def _init_qlib(self):
     """initialize qlib"""
     # use yahoo_cn_1min data
     QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
     provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
     if not exists_qlib_data(provider_uri):
         print(f"Qlib data is not found in {provider_uri}")
         GetData().qlib_data(target_dir=provider_uri,
                             interval="1min",
                             region=REG_CN)
     qlib.init(**QLIB_INIT_CONFIG)
コード例 #19
0
    def __init__(self, recorder_id, experiment_id, provider_uri=r"E:\TDX\cjzq_tdx\vipdoc", region=REG_CN):
        self.record_id = recorder_id
        self.experiment_id = experiment_id

        qlib.init(provider_uri=provider_uri, region=region)

        R.set_uri('file:D:\\Code\\my_qlib\\examples\\mlruns')
        self.recorder = R.get_recorder(recorder_id=recorder_id, experiment_id=experiment_id)
        self.expr_dir = Path(self.recorder.uri[5:]).joinpath(experiment_id).joinpath(recorder_id)
        self.artifacts_dir = self.expr_dir.joinpath('artifacts')
        self.portfolio_dir = self.artifacts_dir.joinpath('portfolio_analysis')
        self.sig_dir = self.artifacts_dir.joinpath('sig_analysis')
コード例 #20
0
ファイル: cli.py プロジェクト: NTUT-SELab/qlib
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
    with open(config_path) as fp:
        config = yaml.load(fp, Loader=yaml.SafeLoader)

    # config the `sys` section
    sys_config(config, config_path)

    exp_manager = C["exp_manager"]
    exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
    qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)

    task_train(config.get("task"), experiment_name=experiment_name)
コード例 #21
0
    def __init__(
        self,
        provider_uri="~/.qlib/qlib_data/cn_data",
        region="cn",
        exp_name="rolling_exp",
        task_url="mongodb://10.0.0.4:27017/",
        task_db_name="rolling_db",
        task_pool="rolling_task",
        rolling_step=80,
        start_time="2018-09-10",
        end_time="2018-10-31",
        tasks=[task_xgboost_config, task_lgb_config],
    ):
        """
        Init OnlineManagerExample.

        Args:
            provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
            region (str, optional): the stock region. Defaults to "cn".
            exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
            task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
            task_db_name (str, optional): database name. Defaults to "rolling_db".
            task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
            rolling_step (int, optional): the step for rolling. Defaults to 80.
            start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
            end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
            tasks (dict or list[dict]): a set of the task config waiting for rolling and training
        """
        self.exp_name = exp_name
        self.task_pool = task_pool
        self.start_time = start_time
        self.end_time = end_time
        mongo_conf = {
            "task_url": task_url,
            "task_db_name": task_db_name,
        }
        qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
        self.rolling_gen = RollingGen(
            step=rolling_step,
            rtype=RollingGen.ROLL_SD,
            ds_extra_mod_func=None
        )  # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
        self.trainer = DelayTrainerRM(
            self.exp_name,
            self.task_pool)  # Also can be TrainerR, TrainerRM, DelayTrainerR
        self.rolling_online_manager = OnlineManager(
            RollingStrategy(exp_name,
                            task_template=tasks,
                            rolling_gen=self.rolling_gen),
            trainer=self.trainer,
            begin_time=self.start_time,
        )
        self.tasks = tasks
コード例 #22
0
 def _init_qlib(self, exp_folder_name):
     # init qlib
     GetData().qlib_data(exists_skip=True)
     qlib.init(
         exp_manager={
             "class": "MLflowExpManager",
             "module_path": "qlib.workflow.expm",
             "kwargs": {
                 "uri":
                 "file:" +
                 str(Path(os.getcwd()).resolve() / exp_folder_name),
                 "default_exp_name":
                 "Experiment",
             },
         })
コード例 #23
0
 def _init_qlib(self):
     """initialize qlib"""
     provider_uri_day = "~/.qlib/qlib_data/cn_data"  # target_dir
     GetData().qlib_data(target_dir=provider_uri_day,
                         region=REG_CN,
                         version="v2",
                         exists_skip=True)
     provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
     GetData().qlib_data(target_dir=provider_uri_1min,
                         interval="1min",
                         region=REG_CN,
                         version="v2",
                         exists_skip=True)
     provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
     qlib.init(provider_uri=provider_uri_map,
               dataset_cache=None,
               expression_cache=None)
コード例 #24
0
def main(xargs, exp_yaml):
    assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml)

    with open(exp_yaml) as fp:
        config = yaml.safe_load(fp)
    config = update_gpu(config, xargs.gpu)
    # config = update_market(config, 'csi300')

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint("args: {:}".format(xargs))
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(xargs.times):
        run_exp(config.get("task"), dataset, xargs.alg,
                "recorder-{:02d}-{:02d}".format(irun,
                                                xargs.times), xargs.save_dir)
コード例 #25
0
def main(alg_name, market, config, times, save_dir, gpu):

    pprint("Run {:}".format(alg_name))
    config = update_market(config, market)
    config = update_gpu(config, gpu)

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(times):
        run_exp(
            config.get("task"),
            dataset,
            alg_name,
            "recorder-{:02d}-{:02d}".format(irun, times),
            "{:}-{:}".format(save_dir, market),
        )
コード例 #26
0
def init(conf, logging_config=None):
    """set_config

    :param conf: A  dict-like object
    :param logging_config: logging config
    """
    # config the files
    for key, val in conf.items():
        C[key] = val
    qlib.init(
        "server",
        provider_uri=C["provider_uri"],
        logging_level=C["logging_level"],
        logging_config=logging_config,
        dataset_cache_dir_name=C["dataset_cache_dir_name"],
        features_cache_dir_name=C["features_cache_dir_name"],
        redis_task_db=C["redis_task_db"],
        redis_port=C["redis_port"],
        redis_host=C["redis_host"],
    )
コード例 #27
0
def main(xargs, config):

    pprint("Run {:}".format(xargs.alg))
    config = update_market(config, xargs.market)
    config = update_gpu(config, xargs.gpu)

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint("args: {:}".format(xargs))
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(xargs.times):
        run_exp(
            config.get("task"),
            dataset,
            xargs.alg,
            "recorder-{:02d}-{:02d}".format(irun, xargs.times),
            "{:}-{:}".format(xargs.save_dir, xargs.market),
        )
コード例 #28
0
ファイル: evaluate.py プロジェクト: newlyedward/qlib
def t_run():
    # pred_FN = "./check_pred.csv"
    # pred: pd.DataFrame = pd.read_csv(pred_FN)
    from ..workflow import R
    from ..constant import REG_CN
    from ..tests.config import CSI300_MARKET
    import qlib

    provider_uri = r"E:\TDX\cjzq_tdx\vipdoc"
    qlib.init(provider_uri=provider_uri, region=REG_CN)
    recorder = R.get_recorder(recorder_id='b1acc57c97bc471d942961636f3e1b0d',
                              experiment_name='workflow')
    pred = recorder.load_object("pred.pkl")

    # pred["datetime"] = pd.to_datetime(pred["datetime"])
    # pred = pred.set_index([pred.columns[0], pred.columns[1]])
    pred = pred.iloc[:9000]
    strategy_config = {
        "topk": 10,
        "n_drop": 3,
        "signal": pred,
    }

    strategy = {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": strategy_config,
    }

    report_df, positions = \
        backtest_daily(
            start_time="2017-01-01", end_time="2020-08-01",
            strategy=strategy,
            exchange_kwargs={"codes": CSI300_MARKET}
        )
    print(report_df.head())
    print(positions.keys())
    print(positions[list(positions.keys())[0]])
    return 0
コード例 #29
0
 def __init__(
     self,
     provider_uri="~/.qlib/qlib_data/cn_data",
     region=REG_CN,
     task_url="mongodb://10.0.0.4:27017/",
     task_db_name="rolling_db",
     experiment_name="rolling_exp",
     task_pool="rolling_task",
     task_config=[task_xgboost_config, task_lgb_config],
     rolling_step=550,
     rolling_type=RollingGen.ROLL_SD,
 ):
     # TaskManager config
     mongo_conf = {
         "task_url": task_url,
         "task_db_name": task_db_name,
     }
     qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
     self.experiment_name = experiment_name
     self.task_pool = task_pool
     self.task_config = task_config
     self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
コード例 #30
0
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
    """
    This is a Qlib CLI entrance.
    User can run the whole Quant research workflow defined by a configure file
    - the code is located here ``qlib/workflow/cli.py`
    """
    with open(config_path) as fp:
        config = yaml.safe_load(fp)

    # config the `sys` section
    sys_config(config, config_path)

    if "exp_manager" in config.get("qlib_init"):
        qlib.init(**config.get("qlib_init"))
    else:
        exp_manager = C["exp_manager"]
        exp_manager["kwargs"]["uri"] = "file:" + str(
            Path(os.getcwd()).resolve() / uri_folder)
        qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)

    if "experiment_name" in config:
        experiment_name = config["experiment_name"]
    recorder = task_train(config.get("task"), experiment_name=experiment_name)
    recorder.save_objects(config=config)