def session(config,mode): if mode == 'train': files = glob.glob('saved_network/PG/*') for f in files: os.remove(f) from data.environment import Environment codes, start_date, end_date, features, agent_config, market,predictor, framework, window_length,noise_flag, record_flag, plot_flag,reload_flag,trainable,method=parse_config(config,mode) env = Environment(start_date, end_date, codes, features, int(window_length),market,mode) global M M=len(codes)+1 if framework == 'DDPG': print("*-----------------Loading DDPG Agent---------------------*") from agents.ddpg import DDPG agent = DDPG(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable) elif framework == 'PPO': print("*-----------------Loading PPO Agent---------------------*") from agents.ppo import PPO agent = PPO(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable) elif framework == 'PG': print("*-----------------Loading PG Agent---------------------*") from agents.pg import PG agent = PG(len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable) stocktrader=StockTrader() if mode=='train': print("Training with {:d}".format(epochs)) for epoch in range(epochs): print("Now we are at epoch", epoch) traversal(stocktrader,agent,env,epoch,noise_flag,framework,method,trainable) if record_flag=='True': stocktrader.write(epoch) if plot_flag=='True': stocktrader.plot_result() agent.reset_buffer() stocktrader.print_result(epoch,agent) stocktrader.reset() elif mode=='test': backtest(agent, env)
def session(config, args): global PATH_prefix codes, start_date, end_date, features, agent_config, \ market,predictor, framework, window_length,noise_flag, record_flag,\ plot_flag,reload_flag,trainable,method=parse_config(config,args) env = Environment() global M M = codes + 1 stocktrader = StockTrader() PATH_prefix = "result/DDPG/" + str(args['num']) + '/' if args['mode'] == 'train': if not os.path.exists(PATH_prefix): os.makedirs(PATH_prefix) train_start_date, train_end_date, test_start_date, test_end_date, codes = env.get_repo( start_date, end_date, codes, market) env.get_data(train_start_date, train_end_date, features, window_length, market, codes) print("Codes:", codes) print('Training Time Period:', train_start_date, ' ', train_end_date) print('Testing Time Period:', test_start_date, ' ', test_end_date) with open(PATH_prefix + 'config.json', 'w') as f: json.dump( { "train_start_date": train_start_date.strftime('%Y-%m-%d'), "train_end_date": train_end_date.strftime('%Y-%m-%d'), "test_start_date": test_start_date.strftime('%Y-%m-%d'), "test_end_date": test_end_date.strftime('%Y-%m-%d'), "codes": codes }, f) print("finish writing config") else: with open("result/DDPG/" + str(args['num']) + '/config.json', 'r') as f: dict_data = json.load(f) print("successfully load config") train_start_date, train_end_date, codes = datetime.datetime.strptime( dict_data['train_start_date'], '%Y-%m-%d'), datetime.datetime.strptime( dict_data['train_end_date'], '%Y-%m-%d'), dict_data['codes'] env.get_data(train_start_date, train_end_date, features, window_length, market, codes) for noise_flag in [ 'True' ]: #['False','True'] to train agents with noise and without noise in assets prices print("*-----------------Loading DDPG Agent---------------------*") agent = DDPG(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag, trainable) print("Training with {:d}".format(epochs)) for epoch in range(epochs): print("Now we are at epoch", epoch) traversal(stocktrader, agent, env, epoch, noise_flag, framework, method, trainable) if record_flag == 'True': stocktrader.write(epoch, framework) if plot_flag == 'True': stocktrader.plot_result() agent.reset_buffer() stocktrader.print_result(epoch, agent, noise_flag) stocktrader.reset() agent.close() del agent elif args['mode'] == 'test': with open(PATH_prefix + 'config.json', 'r') as f: dict_data = json.load(f) test_start_date, test_end_date, codes = datetime.datetime.strptime( dict_data['test_start_date'], '%Y-%m-%d'), datetime.datetime.strptime( dict_data['test_end_date'], '%Y-%m-%d'), dict_data['codes'] env.get_data(test_start_date, test_end_date, features, window_length, market, codes) backtest([ DDPG(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), "True", "False") ], env)