def session(config,mode):
    if mode == 'train':
        files = glob.glob('saved_network/PG/*')
        for f in files:
            os.remove(f)

    from data.environment import Environment
    codes, start_date, end_date, features, agent_config, market,predictor, framework, window_length,noise_flag, record_flag, plot_flag,reload_flag,trainable,method=parse_config(config,mode)
    env = Environment(start_date, end_date, codes, features, int(window_length),market,mode)

    global M
    M=len(codes)+1

    if framework == 'DDPG':
        print("*-----------------Loading DDPG Agent---------------------*")
        from agents.ddpg import DDPG
        agent = DDPG(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable)

    elif framework == 'PPO':
        print("*-----------------Loading PPO Agent---------------------*")
        from agents.ppo import PPO
        agent = PPO(predictor, len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable)

    elif framework == 'PG':
        print("*-----------------Loading PG Agent---------------------*")
        from agents.pg import PG
        agent = PG(len(codes) + 1, int(window_length), len(features), '-'.join(agent_config), reload_flag,trainable)


    stocktrader=StockTrader()

    if mode=='train':
        print("Training with {:d}".format(epochs))
        for epoch in range(epochs):
            print("Now we are at epoch", epoch)
            traversal(stocktrader,agent,env,epoch,noise_flag,framework,method,trainable)

            if record_flag=='True':
                stocktrader.write(epoch)

            if plot_flag=='True':
                stocktrader.plot_result()

            agent.reset_buffer()
            stocktrader.print_result(epoch,agent)
            stocktrader.reset()

    elif mode=='test':
        backtest(agent, env)
예제 #2
0
def session(config, args):
    global PATH_prefix

    codes, start_date, end_date, features, agent_config, \
    market,predictor, framework, window_length,noise_flag, record_flag,\
    plot_flag,reload_flag,trainable,method=parse_config(config,args)
    env = Environment()

    global M
    M = codes + 1

    stocktrader = StockTrader()
    PATH_prefix = "result/DDPG/" + str(args['num']) + '/'

    if args['mode'] == 'train':
        if not os.path.exists(PATH_prefix):
            os.makedirs(PATH_prefix)
            train_start_date, train_end_date, test_start_date, test_end_date, codes = env.get_repo(
                start_date, end_date, codes, market)
            env.get_data(train_start_date, train_end_date, features,
                         window_length, market, codes)
            print("Codes:", codes)
            print('Training Time Period:', train_start_date, '   ',
                  train_end_date)
            print('Testing Time Period:', test_start_date, '   ',
                  test_end_date)
            with open(PATH_prefix + 'config.json', 'w') as f:
                json.dump(
                    {
                        "train_start_date":
                        train_start_date.strftime('%Y-%m-%d'),
                        "train_end_date": train_end_date.strftime('%Y-%m-%d'),
                        "test_start_date":
                        test_start_date.strftime('%Y-%m-%d'),
                        "test_end_date": test_end_date.strftime('%Y-%m-%d'),
                        "codes": codes
                    }, f)
                print("finish writing config")
        else:
            with open("result/DDPG/" + str(args['num']) + '/config.json',
                      'r') as f:
                dict_data = json.load(f)
                print("successfully load config")
            train_start_date, train_end_date, codes = datetime.datetime.strptime(
                dict_data['train_start_date'],
                '%Y-%m-%d'), datetime.datetime.strptime(
                    dict_data['train_end_date'],
                    '%Y-%m-%d'), dict_data['codes']
            env.get_data(train_start_date, train_end_date, features,
                         window_length, market, codes)

        for noise_flag in [
                'True'
        ]:  #['False','True'] to train agents with noise and without noise in assets prices

            print("*-----------------Loading DDPG Agent---------------------*")
            agent = DDPG(predictor,
                         len(codes) + 1, int(window_length), len(features),
                         '-'.join(agent_config), reload_flag, trainable)

            print("Training with {:d}".format(epochs))
            for epoch in range(epochs):
                print("Now we are at epoch", epoch)
                traversal(stocktrader, agent, env, epoch, noise_flag,
                          framework, method, trainable)

                if record_flag == 'True':
                    stocktrader.write(epoch, framework)

                if plot_flag == 'True':
                    stocktrader.plot_result()

                agent.reset_buffer()
                stocktrader.print_result(epoch, agent, noise_flag)
                stocktrader.reset()
            agent.close()
            del agent

    elif args['mode'] == 'test':
        with open(PATH_prefix + 'config.json', 'r') as f:
            dict_data = json.load(f)
        test_start_date, test_end_date, codes = datetime.datetime.strptime(
            dict_data['test_start_date'],
            '%Y-%m-%d'), datetime.datetime.strptime(
                dict_data['test_end_date'], '%Y-%m-%d'), dict_data['codes']
        env.get_data(test_start_date, test_end_date, features, window_length,
                     market, codes)

        backtest([
            DDPG(predictor,
                 len(codes) + 1, int(window_length), len(features),
                 '-'.join(agent_config), "True", "False")
        ], env)