def main(xargs, exp_yaml): assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) pprint('Run {:}'.format(xargs.alg)) with open(exp_yaml) as fp: config = yaml.safe_load(fp) config = update_market(config, xargs.market) config = update_gpu(config, xargs.gpu) qlib.init(**config.get("qlib_init")) dataset_config = config.get("task").get("dataset") dataset = init_instance_by_config(dataset_config) pprint("args: {:}".format(xargs)) pprint(dataset_config) pprint(dataset) for irun in range(xargs.times): run_exp(config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market))
def main(alg_name, market, config, times, save_dir, gpu): pprint("Run {:}".format(alg_name)) config = update_market(config, market) config = update_gpu(config, gpu) qlib.init(**config.get("qlib_init")) dataset_config = config.get("task").get("dataset") dataset = init_instance_by_config(dataset_config) pprint(dataset_config) pprint(dataset) for irun in range(times): run_exp( config.get("task"), dataset, alg_name, "recorder-{:02d}-{:02d}".format(irun, times), "{:}-{:}".format(save_dir, market), )
def main(xargs, config): pprint("Run {:}".format(xargs.alg)) config = update_market(config, xargs.market) config = update_gpu(config, xargs.gpu) qlib.init(**config.get("qlib_init")) dataset_config = config.get("task").get("dataset") dataset = init_instance_by_config(dataset_config) pprint("args: {:}".format(xargs)) pprint(dataset_config) pprint(dataset) for irun in range(xargs.times): run_exp( config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), "{:}-{:}".format(xargs.save_dir, xargs.market), )
args = parser.parse_args() if len(args.alg) == 1: main( args.alg[0], args.market, alg2configs[args.alg[0]], args.times, args.save_dir, args.gpu, ) elif len(args.alg) > 1: assert args.shared_dataset, "Must allow share dataset" pprint(args) configs = [ update_gpu(update_market(alg2configs[name], args.market), args.gpu) for name in args.alg ] qlib.init(**configs[0].get("qlib_init")) dataset_config = configs[0].get("task").get("dataset") dataset = init_instance_by_config(dataset_config) pprint(dataset_config) pprint(dataset) for alg_name, config in zip(args.alg, configs): print("Run {:} over {:}".format(alg_name, args.alg)) for irun in range(args.times): run_exp( config.get("task"), dataset, alg_name, "recorder-{:02d}-{:02d}".format(irun, args.times),
def main(xargs): dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", "kwargs": { "handler": { "class": "Alpha360", "module_path": "qlib.contrib.data.handler", "kwargs": { "start_time": "2008-01-01", "end_time": "2020-08-01", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": xargs.market, "infer_processors": [ { "class": "RobustZScoreNorm", "kwargs": { "fields_group": "feature", "clip_outlier": True }, }, { "class": "Fillna", "kwargs": { "fields_group": "feature" } }, ], "learn_processors": [ { "class": "DropnaLabel" }, { "class": "CSRankNorm", "kwargs": { "fields_group": "label" } }, ], "label": ["Ref($close, -2) / Ref($close, -1) - 1"], }, }, "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), "test": ("2017-01-01", "2020-08-01"), }, }, } model_config = { "class": "QuantTransformer", "module_path": "trade_models", "kwargs": { "net_config": None, "opt_config": None, "GPU": "0", "metric": "loss", }, } port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.strategy", "kwargs": { "topk": 50, "n_drop": 5, }, }, "backtest": { "verbose": False, "limit_threshold": 0.095, "account": 100000000, "benchmark": "SH000300", "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, } record_config = [ { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(), }, { "class": "SigAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(ana_long_short=False, ann_scaler=252), }, { "class": "PortAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(config=port_analysis_config), }, ] provider_uri = "~/.qlib/qlib_data/cn_data" qlib.init(provider_uri=provider_uri, region=REG_CN) save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) dataset = init_instance_by_config(dataset_config) for irun in range(xargs.times): xmodel_config = model_config.copy() xmodel_config = update_gpu(xmodel_config, xargs.gpu) task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) run_exp( task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir, )