def _config_by_algo(algo): """ :param algo: a string represent index or algo name :return : a config dictionary """ if not algo: raise ValueError("please input a specific algo") elif algo.isdigit(): config = load_config(algo) else: config = load_config() return config
def train_all(processes=1, device="cpu"): """ train all the agents in the train_package folders :param processes: the number of the processes. If equal to 1, the logging level is debug at file and info at console. If greater than 1, the logging level is info at file and warming at console. """ if processes == 1: console_level = logging.INFO logfile_level = logging.DEBUG else: console_level = logging.WARNING logfile_level = logging.INFO train_dir = "train_package" if not os.path.exists("./" + train_dir): #if the directory does not exist, creates one os.makedirs("./" + train_dir) all_subdir = os.listdir("./" + train_dir) all_subdir.sort() pool = [] for dir in all_subdir: # train only if the log dir does not exist if not str.isdigit(dir): return # NOTE: logfile is for compatibility reason if not (os.path.isdir("./"+train_dir+"/"+dir+"/tensorboard") or os.path.isdir("./"+train_dir+"/"+dir+"/logfile")): p = Process( target=train_one, args=( "./" + train_dir + "/" + dir, load_config(dir), "./" + train_dir + "/" + dir + "/tensorboard", dir, logfile_level, console_level, device ) ) p.start() pool.append(p) else: print("Already trained nnagents in "+train_dir+": generate new or reinitialize agents to start again") continue # suspend if the processes are too many wait = True while wait: time.sleep(5) for p in pool: alive = p.is_alive() if not alive: pool.remove(p) if len(pool)<processes: wait = False print("All the Tasks are Over")
def main(): parser = build_parser() options = parser.parse_args() if not os.path.exists("./" + "train_package"): os.makedirs("./" + "train_package") if not os.path.exists("./" + "database"): os.makedirs("./" + "database") if options.mode == "train": import axiom.autotrain.training if not options.algo: axiom.autotrain.training.train_all(int(options.processes), options.device) else: for folder in options.folder: raise NotImplementedError() if options.mode == "act": import axiom.autotrain.acting if not options.algo: axiom.autotrain.acting.act_all(int(options.processes), options.device) else: for folder in options.folder: raise NotImplementedError() if options.mode == "serve": import axiom.autotrain.serving if not options.algo: axiom.autotrain.serving.act_all(int(options.processes), options.device, options.serving_model) else: for folder in options.folder: raise NotImplementedError() elif options.mode == "generate": import axiom.autotrain.generate as generate print("Generating configuration...") logging.basicConfig(level=logging.INFO) generate.add_packages(load_config(), int(options.repeat)) elif options.mode == "download_data": from axiom.marketdata.datamatrices import DataMatrices with open("./axiom/net_config.json") as file: config = json.load(file) config = preprocess_config(config) start = time.mktime( datetime.strptime(config["input"]["start_date"], "%Y/%m/%d").timetuple()) end = time.mktime( datetime.strptime(config["input"]["end_date"], "%Y/%m/%d").timetuple()) DataMatrices( start=start, end=end, feature_number=config["input"]["feature_number"], window_size=config["input"]["window_size"], online=True, period=config["input"]["global_period"], volume_average_days=config["input"]["volume_average_days"], coin_filter=config["input"]["coin_number"], is_permed=config["input"]["is_permed"], test_portion=config["input"]["test_portion"], portion_reversed=config["input"]["portion_reversed"]) elif options.mode == "backtest": config = _config_by_algo(options.algo) _set_logging_by_algo(logging.DEBUG, logging.DEBUG, options.algo, "backtestlog") execute_backtest(options.algo, config) elif options.mode == "save_test_data": # This is used to export the test data save_test_data(load_config(options.folder)) elif options.mode == "plot": logging.basicConfig(level=logging.INFO) algos = options.algos.split(",") if options.labels: labels = options.labels.replace("_", " ") labels = labels.split(",") else: labels = algos plot.plot_backtest(load_config(), algos, labels) elif options.mode == "table": algos = options.algos.split(",") if options.labels: labels = options.labels.replace("_", " ") labels = labels.split(",") else: labels = algos plot.table_backtest(load_config(), algos, labels, format=options.format)