def eval_performance(self, model):

        print 'start evaluating...'
        Agent = dqn_agent_nature.dqn_agent(
            gpu_id=self.gpu_id,
            state_dimention=1,
            enable_controller=self.enable_controller)
        Agent.agent_init()
        Agent.DQN.model = model
        Agent.DQN.model_to_gpu()
        Agent.policyFrozen = True

        profit_list = []
        test_profit_list = []
        files = os.listdir(self.target_folder)

        #train term evaluation
        print 'start evaluation train term...'
        for f in files:

            stock_agent = env_stockmarket.Stock_agent(Agent,
                                                      self.action_split_number)
            try:
                traindata, trainprice = self.market.get_trainData(
                    f, self.input_num)
            except:
                continue

            profit_ratio = stock_agent.trading(self.input_num, trainprice,
                                               traindata)
            profit_list.append(profit_ratio)

        train_ave = np.mean(np.array(profit_list))
        train_ave_Q = Agent.get_average_Q()
        train_ave_reward = Agent.get_average_reward()

        #test term evaluation
        print 'start evaluation test term...'
        for f in files:

            stock_agent = env_stockmarket.Stock_agent(Agent,
                                                      self.action_split_number)
            try:
                traindata, trainprice = self.market.get_testData(
                    f, self.input_num)
            except:
                continue

            profit_ratio = stock_agent.trading(self.input_num, trainprice,
                                               traindata)
            test_profit_list.append(profit_ratio)

        test_ave = np.mean(np.array(test_profit_list))

        self.train_ave_profit_list.append(train_ave)
        self.test_ave_profit_list.append(test_ave)
        self.ave_Q_list.append(train_ave_Q)
        self.ave_reward_list.append(train_ave_reward)
        print 'finish evaluation'
        return train_ave, test_ave, train_ave_Q, train_ave_reward
Esempio n. 2
0
#コントローラの設定
enable_controller = range(-args.action_split_number,
                          args.action_split_number + 1)
print 'enable_controller:', enable_controller

END_TRAIN_DAY = 20081230
START_TEST_DAY = 20090105
#START_TEST_DAY = 20100104

org_model = 0
#モデルの読み込み
#not use online update
if args.online_update == 0:
    Agent = dqn_agent_nature.dqn_agent(gpu_id=args.gpu,
                                       state_dimention=1,
                                       enable_controller=enable_controller)
    Agent.agent_init()
    Agent.DQN.load_model(args.model)
    Agent.policyFrozen = True

#use online update
elif args.online_update == 1:
    Agent = dqn_agent_without_ER.dqn_agent(state_dimention=1,
                                           enable_controller=enable_controller)
    Agent.agent_init()
    #オリジナルを改変しないようにコピー
    with open(args.model, 'rb') as m:
        print "open " + args.model
        org_model = pickle.load(m)
    Agent.DQN.model = copy.deepcopy(org_model)
Esempio n. 3
0
else:
    print 'make experiment folder'
    os.makedirs(folder)

evaluation_freq = 100000
END_TRAIN_DAY = 20081230
START_TEST_DAY = 20090105
n_epoch = 2000
agent_state = 4
#コントローラの設定
enable_controller = range( - args.action_split_number,args.action_split_number + 1)
print 'enable_controller:',enable_controller

start_time = time.clock()

Agent = dqn_agent_nature.dqn_agent(gpu_id=args.gpu,enable_controller = enable_controller,state_dimention=args.input_num * args.channel + agent_state,batchsize=args.batchsize,historysize=args.historysize,epsilon_discount_size=args.epsilon_discount_size,arch=args.arch)

Agent.agent_init()

market = env_stockmarket.StockMarket(END_TRAIN_DAY,START_TEST_DAY,u_vol=u_vol,u_ema=u_ema,u_rsi=u_rsi,u_macd=u_macd,u_stoch=u_stoch,u_wil=u_wil)

evaluater = evaluation_performance.Evaluation(args.gpu,market,eval_folder,folder,args.input_num,args.action_split_number,args.arch)



print 'epoch:', n_epoch

with open(folder + 'settings.txt', 'wb') as o:
    o.write('epoch:' + str(n_epoch) + '\n')
    o.write('data_folder:' + str(args.data_folder) + '\n')
    o.write('input:' + str(args.input_num) + '\n')
Esempio n. 4
0
if os.path.isdir(folder) == True:
    print 'this experiment name is existed'
    print 'please change experiment name'
    raw_input()
else:
    print 'make experiment folder'
    os.makedirs(folder)


END_TRAIN_DAY = 20081230
START_TEST_DAY = 20090105
n_epoch = 1000

start_time = time.clock()

Agent = dqn_agent_nature.dqn_agent(gpu_id=args.gpu,state_dimention=args.input_num * args.channel + 2,batchsize=args.batchsize,historysize=args.historysize,epsilon_discount_size=args.epsilon_discount_size,targetFlag = targetFlag)
Agent.agent_init()

market = env_stockmarket.StockMarket(END_TRAIN_DAY,START_TEST_DAY,u_vol=u_vol,u_ema=u_ema,u_rsi=u_rsi,u_macd=u_macd,u_stoch=u_stoch,u_wil=u_wil)

evaluater = evaluation_performance.Evaluation(args.gpu,market,args.data_folder,folder,args.input_num)



print 'epoch:', n_epoch

with open(folder + 'settings.txt', 'wb') as o:
    o.write('epoch:' + str(n_epoch) + '\n')
    o.write('data_folder:' + str(args.data_folder) + '\n')
    o.write('input:' + str(args.input_num) + '\n')
    o.write('channel:' + str(args.channel) + '\n')
Esempio n. 5
0
folder = './test_result/' + args.experiment_name + '/'
if os.path.isdir(folder) == True:
    print 'this experiment name is existed'
    print 'please change experiment name'
    raw_input()
else:
    print 'make experiment folder'
    os.makedirs(folder)
    
    
END_TRAIN_DAY = 20081230
#START_TEST_DAY = 20090105
START_TEST_DAY = 20100104

#モデルの読み込み
Agent = dqn_agent_nature.dqn_agent()
Agent.agent_init()
Agent.DQN.load_model(args.model)
Agent.policyFrozen = True
    
market = env_stockmarket.StockMarket(END_TRAIN_DAY,START_TEST_DAY,u_vol=u_vol,u_ema=u_ema,u_rsi=u_rsi,u_macd=u_macd,u_stoch=u_stoch,u_wil=u_wil)

files = os.listdir("./nikkei100")

Agent.init_max_Q_list()
Agent.init_reward_list()
profit_list = []


for f in files:
    print f