def main(xargs, exp_yaml):
    assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml)

    pprint('Run {:}'.format(xargs.alg))
    with open(exp_yaml) as fp:
        config = yaml.safe_load(fp)
    config = update_market(config, xargs.market)
    config = update_gpu(config, xargs.gpu)

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint("args: {:}".format(xargs))
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(xargs.times):
        run_exp(config.get("task"), dataset, xargs.alg,
                "recorder-{:02d}-{:02d}".format(irun, xargs.times),
                '{:}-{:}'.format(xargs.save_dir, xargs.market))
Пример #2
0
def main(alg_name, market, config, times, save_dir, gpu):

    pprint("Run {:}".format(alg_name))
    config = update_market(config, market)
    config = update_gpu(config, gpu)

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(times):
        run_exp(
            config.get("task"),
            dataset,
            alg_name,
            "recorder-{:02d}-{:02d}".format(irun, times),
            "{:}-{:}".format(save_dir, market),
        )
Пример #3
0
def main(xargs, config):

    pprint("Run {:}".format(xargs.alg))
    config = update_market(config, xargs.market)
    config = update_gpu(config, xargs.gpu)

    qlib.init(**config.get("qlib_init"))
    dataset_config = config.get("task").get("dataset")
    dataset = init_instance_by_config(dataset_config)
    pprint("args: {:}".format(xargs))
    pprint(dataset_config)
    pprint(dataset)

    for irun in range(xargs.times):
        run_exp(
            config.get("task"),
            dataset,
            xargs.alg,
            "recorder-{:02d}-{:02d}".format(irun, xargs.times),
            "{:}-{:}".format(xargs.save_dir, xargs.market),
        )
Пример #4
0
    args = parser.parse_args()

    if len(args.alg) == 1:
        main(
            args.alg[0],
            args.market,
            alg2configs[args.alg[0]],
            args.times,
            args.save_dir,
            args.gpu,
        )
    elif len(args.alg) > 1:
        assert args.shared_dataset, "Must allow share dataset"
        pprint(args)
        configs = [
            update_gpu(update_market(alg2configs[name], args.market), args.gpu)
            for name in args.alg
        ]
        qlib.init(**configs[0].get("qlib_init"))
        dataset_config = configs[0].get("task").get("dataset")
        dataset = init_instance_by_config(dataset_config)
        pprint(dataset_config)
        pprint(dataset)
        for alg_name, config in zip(args.alg, configs):
            print("Run {:} over {:}".format(alg_name, args.alg))
            for irun in range(args.times):
                run_exp(
                    config.get("task"),
                    dataset,
                    alg_name,
                    "recorder-{:02d}-{:02d}".format(irun, args.times),
Пример #5
0
def main(xargs):
    dataset_config = {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha360",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": {
                    "start_time":
                    "2008-01-01",
                    "end_time":
                    "2020-08-01",
                    "fit_start_time":
                    "2008-01-01",
                    "fit_end_time":
                    "2014-12-31",
                    "instruments":
                    xargs.market,
                    "infer_processors": [
                        {
                            "class": "RobustZScoreNorm",
                            "kwargs": {
                                "fields_group": "feature",
                                "clip_outlier": True
                            },
                        },
                        {
                            "class": "Fillna",
                            "kwargs": {
                                "fields_group": "feature"
                            }
                        },
                    ],
                    "learn_processors": [
                        {
                            "class": "DropnaLabel"
                        },
                        {
                            "class": "CSRankNorm",
                            "kwargs": {
                                "fields_group": "label"
                            }
                        },
                    ],
                    "label": ["Ref($close, -2) / Ref($close, -1) - 1"],
                },
            },
            "segments": {
                "train": ("2008-01-01", "2014-12-31"),
                "valid": ("2015-01-01", "2016-12-31"),
                "test": ("2017-01-01", "2020-08-01"),
            },
        },
    }

    model_config = {
        "class": "QuantTransformer",
        "module_path": "trade_models",
        "kwargs": {
            "net_config": None,
            "opt_config": None,
            "GPU": "0",
            "metric": "loss",
        },
    }

    port_analysis_config = {
        "strategy": {
            "class": "TopkDropoutStrategy",
            "module_path": "qlib.contrib.strategy.strategy",
            "kwargs": {
                "topk": 50,
                "n_drop": 5,
            },
        },
        "backtest": {
            "verbose": False,
            "limit_threshold": 0.095,
            "account": 100000000,
            "benchmark": "SH000300",
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    }

    record_config = [
        {
            "class": "SignalRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": dict(),
        },
        {
            "class": "SigAnaRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": dict(ana_long_short=False, ann_scaler=252),
        },
        {
            "class": "PortAnaRecord",
            "module_path": "qlib.workflow.record_temp",
            "kwargs": dict(config=port_analysis_config),
        },
    ]

    provider_uri = "~/.qlib/qlib_data/cn_data"
    qlib.init(provider_uri=provider_uri, region=REG_CN)

    save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
    dataset = init_instance_by_config(dataset_config)
    for irun in range(xargs.times):
        xmodel_config = model_config.copy()
        xmodel_config = update_gpu(xmodel_config, xargs.gpu)
        task_config = dict(model=xmodel_config,
                           dataset=dataset_config,
                           record=record_config)

        run_exp(
            task_config,
            dataset,
            xargs.name,
            "recorder-{:02d}-{:02d}".format(irun, xargs.times),
            save_dir,
        )