Пример #1
0
class rl_stock_trader():
    def __init__(self,
                 path_to_symbol_csv,
                 request_symbols=8,
                 tb_outdir=tb_outdir):

        self.writer = SummaryWriter(tb_outdir)

        self.request_symbols = request_symbols
        self.monitor_freq = 100

        self.start_budget = 10000.

        index_df = pd.read_csv(path_to_symbol_csv)

        # symbol_vec = list(index_df.values[:self.request_symbols,0])

        symbol_vec = list(
            index_df.values[np.random.randint(0, index_df.values.
                                              shape[0], self.request_symbols),
                            0])

        self.dataframe, self.num_symbols = self.get_data(symbol_vec)

        # env = DummyVecEnv([lambda: StockTradingEnv(dataframe)])
        self.env = StockTradingEnv(self.dataframe, self.num_symbols)

        self.tb_action_type = np.zeros(3)
        self.tb_action_symbol = np.zeros(self.num_symbols)
        self.tb_action_vec = []
        self.tb_action_amount = []

        self.tb_balance = np.zeros(4)
        self.tb_net_worth = np.zeros(4)

        self.balance_dummy = []
        self.net_worth_dummy = []
        self.tb_reward = 0.

        self.tb_cache_reward_vec = []
        self.tb_cache_rollout_vec = []

        self.tb_cache_final_net = []
        self.tb_cache_final_balance = []

        self.tb_chache_balance = np.zeros(4)
        self.tb_chache_net_worth = np.zeros(4)

    def get_data(self,
                 symbols,
                 start=None,
                 end=None,
                 period='5y',
                 interval='1d'):
        '''	valid periods: 1d,5d,1mo,3mo,6mo,1y,2y,5y,10y,ytd,max
			fetch data by interval (including intraday if period < 60 days)
			valid intervals: 1m,2m,5m,15m,30m,60m,90m,1h,1d,5d,1wk,1mo,3mo
			group by ticker (to access via data['SPY']) (optional, default is 'column')
			adjust all OHLC automatically
			download pre/post regular market hours data
			use threads for mass downloading? (True/False/Integer)
			proxy URL scheme use use when downloading? '''

        df_keys = ['Adj Close', 'Close', 'High', 'Low', 'Open', 'Volume']

        if start == None or end == None:

            print('\nload S&P 500 data for period: ', period,
                  ' and interval: ', interval, '\n')

            data_array = yf.download(tickers=symbols,
                                     period=period,
                                     interval=interval,
                                     group_by='column',
                                     auto_adjust=True,
                                     prepost=False,
                                     threads=True,
                                     proxy=None)

        else:

            print('\nload S&P 500 data since: ', start, '/ end: ', end,
                  ' and interval: ', interval, '\n')

            data_array = yf.download(tickers=symbols,
                                     start=start,
                                     end=end,
                                     interval=interval,
                                     group_by='column',
                                     auto_adjust=True,
                                     prepost=False,
                                     threads=True,
                                     proxy=None)

        called_symbols = list(data_array['Volume'].keys())
        try:
            failed_symbols = list(data_array['Adj Close'].keys())
        except KeyError:
            failed_symbols = []
            pass

        loaded_symbols = []

        for i in range(len(called_symbols)):
            if called_symbols[i] not in failed_symbols:
                loaded_symbols.append(called_symbols[i])

        for i in range(len(failed_symbols)):
            for j in range(len(df_keys)):
                data_array = data_array.drop(
                    columns=[(str(df_keys[j]), str(failed_symbols[i]))])

        data_array.insert(0, 'i', np.arange(data_array.shape[0]))

        data_index_axis = data_array.index.values
        data_array = data_array.drop(
            index=[data_index_axis[0], data_index_axis[-1]])

        dfkeys = ['Open', 'Close', 'High', 'Low', 'Volume']

        for dfkey in range(len(dfkeys)):

            data_array[dfkeys[dfkey]].fillna(method='pad')
            data_array[dfkeys[dfkey]].fillna(0.)
            data_array[dfkeys[dfkey]].replace(to_replace=np.nan, value=0.)
            data_array[dfkeys[dfkey]].replace(to_replace='NaN', value=0.)



        print(
         '\n------------------------------------\
			\nsuccesfully loaded stock data\nnumber of loaded data points: '                                                                         , data_array.shape[0], \
         '\nnumber of loaded symbols: ', len(loaded_symbols), '/', len(called_symbols), \
         '\n------------------------------------\n\n', \
         '\ndataframe:\n', data_array, \
         '\n------------------------------------\n\n')

        return data_array, len(loaded_symbols)

    def monitor_training(self, tb_writer, t, i, done, action, monitor_data):
        '''
		after each episode save:
			action_type [3 x 1]  v
			action_amount [1 x 1] (avg /t)  v
			action_symbol [num_symbols x 1]  v
			balance [4x1] (low, avg, high, final)  v
			net_worth [4x1] (low, avg, high, final)  v

		'''

        if t == 0:

            self.balance_dummy = []
            self.net_worth_dummy = []
            self.tb_reward = 0.

            if i == 0:

                self.tb_balance = np.zeros(4)
                self.tb_net_worth = np.zeros(4)

                self.tb_action_amount = []
                self.tb_action_symbol_vec = []

                self.tb_action_vec = []

                self.tb_cache_reward_vec = []
                self.tb_cache_rollout_vec = []

                self.tb_cache_final_net = np.zeros(4)
                self.tb_cache_final_balance = np.zeros(4)

        self.tb_action_symbol_vec.append(monitor_data['action_sym'])

        self.tb_action_amount.append(monitor_data['action_amount'])

        self.tb_action_vec.append(monitor_data['action_type'])

        self.tb_reward += monitor_data['reward']

        self.balance_dummy.append(monitor_data['balance'])
        self.net_worth_dummy.append(monitor_data['net_worth'])

        if done:

            self.tb_cache_reward_vec.append(self.tb_reward)

            self.tb_balance[0] = np.amin(self.balance_dummy)
            self.tb_balance[1] = np.mean(self.balance_dummy)
            self.tb_balance[2] = np.amax(self.balance_dummy)
            self.tb_balance[3] = self.balance_dummy[-1]

            self.tb_net_worth[0] = np.amin(self.net_worth_dummy)
            self.tb_net_worth[1] = np.mean(self.net_worth_dummy)
            self.tb_net_worth[2] = np.amax(self.net_worth_dummy)
            self.tb_net_worth[3] = self.net_worth_dummy[-1]

            self.tb_cache_rollout_vec.append(t)

            if np.ndim(self.tb_cache_final_balance) == 1:
                self.tb_cache_final_balance = np.reshape(
                    self.tb_balance, [1, -1])
                self.tb_cache_final_net = np.reshape(self.tb_net_worth,
                                                     [1, -1])
            else:
                self.tb_cache_final_balance = np.concatenate(
                    (self.tb_cache_final_balance,
                     np.reshape(self.tb_balance, [1, -1])),
                    axis=0)
                self.tb_cache_final_net = np.concatenate(
                    (self.tb_cache_final_net,
                     np.reshape(self.tb_net_worth, [1, -1])),
                    axis=0)

            if i % self.monitor_freq == 0 and i != 0:

                tb_writer.add_scalar('training/reward',
                                     np.mean(self.tb_cache_reward_vec), i)
                tb_writer.add_scalar('training/rollout',
                                     np.mean(self.tb_cache_rollout_vec), i)

                tb_writer.add_scalar(
                    'balance/low', np.mean(self.tb_cache_final_balance[:, 0]),
                    i)
                tb_writer.add_scalar(
                    'balance/avg', np.mean(self.tb_cache_final_balance[:, 1]),
                    i)
                tb_writer.add_scalar(
                    'balance/high', np.mean(self.tb_cache_final_balance[:, 2]),
                    i)
                tb_writer.add_scalar(
                    'balance/final',
                    np.mean(self.tb_cache_final_balance[:, 3]), i)

                tb_writer.add_scalar('net_worth/low',
                                     np.mean(self.tb_cache_final_net[:, 0]), i)
                tb_writer.add_scalar('net_worth/avg',
                                     np.mean(self.tb_cache_final_net[:, 1]), i)
                tb_writer.add_scalar('net_worth/high',
                                     np.mean(self.tb_cache_final_net[:, 2]), i)
                tb_writer.add_scalar('net_worth/final',
                                     np.mean(self.tb_cache_final_net[:, 3]), i)
                tb_writer.add_scalar(
                    'net_worth/profit',
                    np.mean(self.tb_cache_final_net[:, 3] - self.start_budget),
                    i)

                tb_writer.add_histogram('training_stats/reward',
                                        np.asarray(self.tb_cache_reward_vec),
                                        i)
                tb_writer.add_histogram('training_stats/rollout',
                                        np.asarray(self.tb_cache_rollout_vec),
                                        i)

                tb_writer.add_histogram(
                    'performance_stats/final_balance',
                    np.asarray(self.tb_cache_final_balance[:, -1]), i)
                tb_writer.add_histogram(
                    'performance_stats/final_net_worth',
                    np.asarray(self.tb_cache_final_net[:, -1]), i)
                tb_writer.add_histogram(
                    'performance_stats/profit',
                    np.asarray(self.tb_cache_final_net[:, -1] -
                               self.start_budget), i)

                tb_writer.add_histogram('action/type',
                                        np.asarray(self.tb_action_vec), i)
                tb_writer.add_histogram('action/symbol',
                                        np.asarray(self.tb_action_symbol_vec),
                                        i)
                tb_writer.add_histogram('action/action_amount',
                                        np.asarray(self.tb_action_amount), i)

                self.tb_cache_reward_vec = []
                self.tb_cache_rollout_vec = []

                self.tb_cache_final_net = np.zeros(4)
                self.tb_cache_final_balance = np.zeros(4)

                self.tb_action_vec = []

                self.tb_action_symbol_vec = []

                self.tb_action_amount = []

                self.tb_balance = np.zeros(4)
                self.tb_net_worth = np.zeros(4)

    def rl_agent(self, env):

        self.policy = chainer.Sequential(
            L.Linear(None, 256),
            F.tanh,
            L.Linear(None, 128),
            F.tanh,
            # L.Linear(None, env.action_space.low.size, initialW=winit_last),
            L.Linear(None, env.action_space.low.size),
            # F.sigmoid,
            chainerrl.policies.GaussianHeadWithStateIndependentCovariance(
                action_size=env.action_space.low.size,
                var_type='diagonal',
                var_func=lambda x: F.exp(2 * x),  # Parameterize log std
                # var_param_init=0,  # log std = 0 => std = 1
            ))

        self.vf = chainer.Sequential(
            L.Linear(None, 256),
            F.tanh,
            L.Linear(None, 128),
            F.tanh,
            L.Linear(None, 1),
        )

        # Combine a policy and a value function into a single model
        self.model = chainerrl.links.Branched(self.policy, self.vf)

        self.opt = chainer.optimizers.Adam(alpha=3e-4, eps=1e-5)
        self.opt.setup(self.model)

        self.agent = PPO(
            self.model,
            self.opt,
            # obs_normalizer=obs_normalizer,
            gpu=-1,
            update_interval=512,
            minibatch_size=8,
            clip_eps_vf=None,
            entropy_coef=0.001,
            # standardize_advantages=args.standardize_advantages,
        )

        return self.agent

    def train(self):

        print('\nstart training loop\n')

        def check_types(input, inputname):
            if np.isnan(input).any():
                print('----> ', inputname, ' array contains NaN\n',
                      np.isnan(input).shape, '\n')
            if np.isinf(input).any():
                print('----> ', inputname, ' array contains inf\n',
                      np.isinf(input).shape, '\n')

        self.agent = self.rl_agent(self.env)

        n_episodes = 1000000
        max_episode_len = 1000

        for i in range(0, n_episodes + 1):

            obs = self.env.reset()

            reward = 0
            done = False
            R = 0  # return (sum of rewards)
            t = 0  # time step

            while not done and t < max_episode_len:

                # Uncomment to watch the behaviour
                # self.env.render()
                action = self.agent.act_and_train(obs, reward)
                check_types(action, 'action')

                obs, reward, done, _, monitor_data = self.env.step(action)
                check_types(obs, 'obs')
                check_types(reward, 'reward')

                self.monitor_training(self.writer, t, i, done, action,
                                      monitor_data)

                R += reward
                t += 1

                if done: print(' training at episode ' + str(i), end='\r')

            if i % 100 == 0 and i > 0:

                self.agent.save(model_outdir)

                serializers.save_npz(model_outdir + 'model.npz', self.model)

            # if i % 1000 == 0:
            #     print('\nepisode:', i, ' | episode length: ', t, '\nreward:', R,
            #           '\nstatistics:', self.agent.get_statistics(), '\n')

        self.agent.stop_episode_and_train(obs, reward, done)
        print('Finished.')
Пример #2
0
class rl_stock_trader():
    def __init__(self):

        run_name = 'run_test'
        self.outdir = './results/' + run_name + '/'
        self.outdir_train = self.outdir + 'train/'
        self.outdir_test = self.outdir + 'test/'

        self.training_counter = 0

        try:
            sys.makedirs(self.outdir_train)
            sys.makedirs(self.outdir_test)
        except Exception:
            pass

        self.writer_train = SummaryWriter(self.outdir_train)
        self.writer_test = SummaryWriter(self.outdir_test)

        self.monitor_freq = 100
        self.testing_samples = 100

        self.validation_scores = []
        self.training_scores = []

        self.settings = {
            'past_horzion': 100,
            'max_steps': 365,
            'inital_account_balance': 1e4,
            'stop_below_balance': 1e3,
            'transation_fee': .1,
            'years_training': 5,
            'years_testing': 1,
        }

        testing_end = date.today()
        testing_beginning = testing_end - relativedelta(
            years=self.settings['years_testing']) - relativedelta(
                days=self.settings['past_horzion'])
        training_end = testing_beginning - relativedelta(days=1)
        training_beginning = training_end - relativedelta(
            years=self.settings['years_training']) - relativedelta(
                days=self.settings['past_horzion'])

        self.data = {
            'train_gold':
            self.get_prices(gold_shanghai, 1, training_beginning,
                            training_end),
            'train_copper':
            self.get_prices(copper_shanghai, 1, training_beginning,
                            training_end),
            'train_aluminum':
            self.get_prices(aluminum_shanghai, 1, training_beginning,
                            training_end),
            'test_gold':
            self.get_prices(gold_shanghai, 1, testing_beginning, testing_end),
            'test_copper':
            self.get_prices(copper_shanghai, 1, testing_beginning,
                            testing_end),
            'test_aluminum':
            self.get_prices(aluminum_shanghai, 1, testing_beginning,
                            testing_end),
            'test_soybean_oil':
            self.get_prices(soybean_oil, 1, testing_beginning, testing_end),
            'test_dax_futures':
            self.get_prices(dax_futures, 1, testing_beginning, testing_end),
            'test_corn':
            self.get_prices(corn, 1, testing_beginning, testing_end),
            'test_canadian_dollar':
            self.get_prices(canadian_dollar, 1, testing_beginning,
                            testing_end),
        }

        # print('\n\n*************\n', self.data['test_corn'], '\n\n')

        self.env_test_gold = StockTradingEnv(self.get_prices(
            gold_shanghai, 1, testing_beginning, testing_end),
                                             self.settings,
                                             test=True)
        self.env_test_copper = StockTradingEnv(self.get_prices(
            copper_shanghai, 1, testing_beginning, testing_end),
                                               self.settings,
                                               test=True)
        self.env_test_aluminum = StockTradingEnv(self.get_prices(
            aluminum_shanghai, 1, testing_beginning, testing_end),
                                                 self.settings,
                                                 test=True)
        self.env_test_soy_bean = StockTradingEnv(self.get_prices(
            soybean_oil, 1, testing_beginning, testing_end),
                                                 self.settings,
                                                 test=True)
        self.env_test_dax = StockTradingEnv(self.get_prices(
            dax_futures, 1, testing_beginning, testing_end),
                                            self.settings,
                                            test=True)
        self.env_test_corn = StockTradingEnv(self.get_prices(
            corn, 1, testing_beginning, testing_end),
                                             self.settings,
                                             test=True)
        self.env_test_canadian_dollar = StockTradingEnv(self.get_prices(
            canadian_dollar, 1, testing_beginning, testing_end),
                                                        self.settings,
                                                        test=True)

        self.env_train = StockTradingEnv(self.data['train_gold'],
                                         self.settings,
                                         test=False)
        # self.env_test = StockTradingEnv(self.data['test_gold'], self.settings, test=True)

        self.test_envs = {
            'gold':
            StockTradingEnv(self.data['test_gold'], self.settings, test=True),
            'copper':
            StockTradingEnv(self.data['test_copper'], self.settings,
                            test=True),
            'aluminum':
            StockTradingEnv(self.data['test_aluminum'],
                            self.settings,
                            test=True),
        }

        self.agent = self.rl_agent(self.env_train)

    def get_prices(self, index, depth, start, end):

        data_prices = quandl.get(index + str(depth),
                                 start_date=start,
                                 end_date=end)

        data_prices.index = pd.to_datetime(data_prices.index)

        return data_prices

    def rl_agent(self, env):

        # self.policy = chainer.Sequential(
        # 	L.BatchNormalization(axis=0),
        # 	L.Linear(None, 256),
        # 	# F.dropout(ratio=.5),
        # 	F.tanh,
        # 	L.Linear(None, 128),
        # 	# F.dropout(ratio=.5),
        # 	F.tanh,
        # 	# L.Linear(None, env.action_space.low.size, initialW=winit_last),
        # 	L.Linear(None, env.action_space.low.size),
        # 	# F.sigmoid,
        # 	chainerrl.policies.GaussianHeadWithStateIndependentCovariance(
        # 		action_size=env.action_space.low.size,
        # 		var_type='diagonal',
        # 		var_func=lambda x: F.exp(2 * x),  # Parameterize log std
        # 		# var_param_init=0,  # log std = 0 => std = 1
        # 		))

        self.policy = chainer.Sequential(
            L.BatchNormalization(axis=0),
            L.Linear(None, 256),
            # F.dropout(ratio=.5),
            F.sigmoid,
            # F.relu,
            L.Linear(None, 128),
            # F.dropout(ratio=.5),
            F.sigmoid,
            # L.Linear(None, env.action_space.low.size, initialW=winit_last),
            L.Linear(None, env.action_space.low.size),
            F.sigmoid,
            chainerrl.policies.GaussianHeadWithStateIndependentCovariance(
                action_size=env.action_space.low.size,
                var_type='diagonal',
                var_func=lambda x: F.exp(2 * x),  # Parameterize log std
                # var_param_init=0,  # log std = 0 => std = 1
            ))

        self.vf = chainer.Sequential(
            L.BatchNormalization(axis=0),
            L.Linear(None, 256),
            # F.dropout(ratio=.5),
            F.sigmoid,
            L.Linear(None, 128),
            # F.dropout(ratio=.5),
            F.sigmoid,
            L.Linear(None, 1),
            F.sigmoid,
        )

        # self.vf = chainer.Sequential(
        # 	L.BatchNormalization(axis=0),
        # 	L.Linear(None, 256),
        # 	# F.dropout(ratio=.5),
        # 	F.tanh,
        # 	L.Linear(None, 128),
        # 	# F.dropout(ratio=.5),
        # 	F.tanh,
        # 	L.Linear(None, 1),
        # )

        # Combine a policy and a value function into a single model
        self.model = chainerrl.links.Branched(self.policy, self.vf)

        self.opt = chainer.optimizers.Adam(alpha=3e-3, eps=1e-5)
        self.opt.setup(self.model)

        self.agent = PPO(
            self.model,
            self.opt,
            # obs_normalizer=obs_normalizer,
            gpu=-1,
            update_interval=64,
            minibatch_size=32,
            clip_eps_vf=None,
            entropy_coef=0.001,
            # standardize_advantages=args.standardize_advantages,
        )

        return self.agent

    def monitor_training(self, tb_writer, t, i, done, action, monitor_data,
                         counter):

        if t == 0 or i == 0:

            self.cash_dummy = []
            self.equity_dummy = []
            self.shares_dummy = []
            self.shares_value_dummy = []
            self.action_dummy = []
            self.action_prob_dummy = []

        self.cash_dummy.append(monitor_data['cash'])
        self.equity_dummy.append(monitor_data['equity'])
        self.shares_dummy.append(monitor_data['shares_held'])
        self.shares_value_dummy.append(monitor_data['value_in_shares'])
        self.action_dummy.append(monitor_data['action'])
        self.action_prob_dummy.append(monitor_data['action_prob'])

        # if done:
        # tb_writer.add_scalar('cash', np.mean(self.cash_dummy), counter)
        # tb_writer.add_scalar('equity', np.mean(self.equity_dummy), counter)
        # tb_writer.add_scalar('shares_held', np.mean(self.shares_dummy), counter)
        # tb_writer.add_scalar('shares_value', np.mean(self.shares_value_dummy), counter)
        # tb_writer.add_scalar('action', np.mean(self.action_dummy), counter)
        # tb_writer.add_histogram('action_prob', np.mean(self.action_prob_dummy), counter)

    def plot_validation_figures(self, index, name, test_data_label, benchmark):

        if name in ['mean', 'max', 'final']:
            ylimits = [.75 * np.amin(benchmark), 1.5 * np.amax(benchmark)]
        elif name == 'min':
            ylimits = [0., self.settings['inital_account_balance']]

        plotcolor = 'darkgreen'

        plt.figure(figsize=(18, 18))
        plt.scatter(
            np.asarray(self.validation_scores)[:, 0],
            np.asarray(self.validation_scores)[:, index])
        plt.grid()
        plt.ylim(ylimits[0], ylimits[1])
        plt.title(name + ' equity statistics over 1 year')
        plt.xlabel('trained episodes')
        plt.ylabel('equity [$]')
        plt.savefig(self.outdir + test_data_label + '/scatter_' + name +
                    '_equity.pdf')
        plt.close()

        area_plots = []
        box_data = []
        for j in range(len(np.unique(np.asarray(self.validation_scores)[:,
                                                                        0]))):
            dummy = np.asarray(self.validation_scores)[:, index][np.where(
                np.asarray(self.validation_scores)[:, 0] == np.unique(
                    np.asarray(self.validation_scores)[:, 0])[j])]
            box_data.append(dummy)
            area_plots.append([
                np.percentile(dummy, 5),
                np.percentile(dummy, 25),
                np.percentile(dummy, 50),
                np.percentile(dummy, 75),
                np.percentile(dummy, 95),
            ])
        area_plots = np.asarray(area_plots)

        p05 = area_plots[:, 0]
        p25 = area_plots[:, 1]
        p50 = area_plots[:, 2]
        p75 = area_plots[:, 3]
        p95 = area_plots[:, 4]

        plt.figure(figsize=(18, 18))
        plt.fill_between(np.arange(area_plots.shape[0]),
                         p05,
                         p95,
                         facecolor=plotcolor,
                         alpha=.3)
        plt.fill_between(np.arange(area_plots.shape[0]),
                         p25,
                         p75,
                         facecolor=plotcolor,
                         alpha=.8)
        plt.plot(p50, linewidth=3, color='lightblue')
        plt.ylim(ylimits[0], ylimits[1])
        plt.grid()
        plt.title(name + ' equity statistics over 1 year')
        plt.xlabel('trained episodes')
        plt.ylabel('equity [$]')
        plt.savefig(self.outdir + test_data_label + '/area_' + name +
                    '_equity.pdf')
        plt.close()

        plt.figure(figsize=(18, 18))
        plt.boxplot(
            box_data,
            notch=True,
            labels=None,
            boxprops=dict(color=plotcolor, linewidth=2),
            capprops=dict(color=plotcolor),
            whiskerprops=dict(color=plotcolor),
            flierprops=dict(color=plotcolor,
                            markeredgecolor=plotcolor,
                            markerfacecolor=plotcolor),
            medianprops=dict(color='lightblue', linewidth=2),
        )
        plt.ylim(ylimits[0], ylimits[1])
        plt.grid()
        plt.title('equity statistics over 1 year')
        plt.xlabel('trained episodes')
        plt.ylabel('equity [$]')
        plt.savefig(self.outdir + test_data_label + '/box_' + name +
                    '_equity.pdf')
        plt.close()

    def validate(self, episode, counter, test_data_label):

        try:
            os.mkdir(self.outdir + test_data_label + '/')
        except Exception:
            pass

        test_equity = []
        test_trades_buy = []
        test_trades_sell = []

        test_data = self.data['test_' + test_data_label]
        try:
            benchmark = test_data['Close'].values[self.
                                                  settings['past_horzion']:]
        except KeyError:
            benchmark = test_data['Settle'].values[self.
                                                   settings['past_horzion']:]
        benchmark /= benchmark[0]
        benchmark *= self.settings['inital_account_balance']

        plt.figure(figsize=(18, 18))

        for i in range(0, self.testing_samples):

            if test_data_label == 'gold':
                obs = self.env_test_gold.reset()
            if test_data_label == 'copper':
                obs = self.env_test_copper.reset()
            if test_data_label == 'aluminum':
                obs = self.env_test_aluminum.reset()
            if test_data_label == 'soybean_oil':
                obs = self.env_test_soy_bean.reset()
            if test_data_label == 'dax_futures':
                obs = self.env_test_dax.reset()
            if test_data_label == 'corn':
                obs = self.env_test_corn.reset()
            if test_data_label == 'corn':
                obs = self.env_test_corn.reset()
            if test_data_label == 'canadian_dollar':
                obs = self.env_test_canadian_dollar.reset()

            # obs = self.env_test.reset()

            reward = 0
            done = False
            R = 0
            t = 0

            while not done:

                action = self.agent.act(obs)

                if test_data_label == 'gold':
                    obs, reward, done, _, monitor_data = self.env_test_gold.step(
                        action)
                if test_data_label == 'copper':
                    obs, reward, done, _, monitor_data = self.env_test_copper.step(
                        action)
                if test_data_label == 'aluminum':
                    obs, reward, done, _, monitor_data = self.env_test_aluminum.step(
                        action)
                if test_data_label == 'soybean_oil':
                    obs, reward, done, _, monitor_data = self.env_test_soy_bean.step(
                        action)
                if test_data_label == 'dax_futures':
                    obs, reward, done, _, monitor_data = self.env_test_dax.step(
                        action)
                if test_data_label == 'corn':
                    obs, reward, done, _, monitor_data = self.env_test_corn.step(
                        action)
                if test_data_label == 'canadian_dollar':
                    obs, reward, done, _, monitor_data = self.env_test_canadian_dollar.step(
                        action)

                # obs, reward, done, _, monitor_data = self.env_test.step(action)

                test_equity.append(monitor_data['equity'])

                action_choice = np.argmax(softmax(action))
                action_confidence = softmax(action)[action_choice]
                if action_confidence > .8:
                    if action_choice == 0:
                        test_trades_buy.append([t, monitor_data['equity']])
                    if action_choice == 2:
                        test_trades_sell.append([t, monitor_data['equity']])

                self.monitor_training(self.writer_test, t, i, done, action,
                                      monitor_data, counter)

                R += reward
                t += 1

                if done:
                    test_equity = test_equity[:-1]

                    plt.plot(test_equity[:-1], linewidth=1)
                    # try:
                    # 	plt.scatter(np.asarray(test_trades_buy)[:,0], np.asarray(test_trades_buy)[:,1], marker='X', c='green', s=5)
                    # 	plt.scatter(np.asarray(test_trades_sell)[:,0], np.asarray(test_trades_sell)[:,1], marker='X', c='red', s=5)
                    # except IndexError:
                    # 	pass

                    self.validation_scores.append([
                        counter,
                        np.mean(test_equity),
                        np.amin(test_equity),
                        np.amax(test_equity), test_equity[-1]
                    ])
                    test_equity = []

                    self.agent.stop_episode()

        time_axis = test_data.index[self.settings['past_horzion']:].date
        time_axis_short = time_axis[::10]

        plt.plot(benchmark, linewidth=3, color='k', label='close')
        plt.ylim(.75 * np.amin(benchmark), 1.5 * np.amax(benchmark))
        plt.xticks(np.linspace(0, len(time_axis),
                               len(time_axis_short) - 1),
                   time_axis_short,
                   rotation=90)
        plt.grid()
        plt.title(test_data_label + ' validation runs at episode ' +
                  str(episode))
        plt.xlabel('episode')
        plt.ylabel('equity [$]')
        plt.legend()
        plt.savefig(self.outdir + test_data_label + '/validation_E' +
                    str(episode) + '.pdf')
        plt.close()

        self.plot_validation_figures(1, 'mean', test_data_label, benchmark)
        self.plot_validation_figures(2, 'min', test_data_label, benchmark)
        self.plot_validation_figures(3, 'max', test_data_label, benchmark)
        self.plot_validation_figures(4, 'final', test_data_label, benchmark)

    def train(self):

        print('\nstart training loop\n')

        def check_types(input, inputname):
            if np.isnan(input).any():
                print('----> ', inputname, ' array contains NaN\n',
                      np.isnan(input).shape, '\n')
            if np.isinf(input).any():
                print('----> ', inputname, ' array contains inf\n',
                      np.isinf(input).shape, '\n')

        n_episodes = int(1e5)

        log_data = []
        action_log = []

        debug_printing = False

        for i in range(0, n_episodes + 1):

            obs = self.env_train.reset()

            reward = 0
            done = False
            R = 0  # return (sum of rewards)
            t = 0  # time step

            while not done:

                # self.env.render()
                action = self.agent.act_and_train(obs, reward)

                obs, reward, done, _, monitor_data = self.env_train.step(
                    action)

                self.monitor_training(self.writer_train, t, i, done, action,
                                      monitor_data, self.training_counter)

                R += reward
                t += 1

                if t % 10 == 0 and not done:
                    log_data.append({
                        'equity':
                        int(monitor_data['equity']),
                        'shares_held':
                        int(monitor_data['shares_held']),
                        'shares_value':
                        int(monitor_data['value_in_shares']),
                        'cash':
                        int(monitor_data['cash']),
                        't':
                        int(t),
                    })
                    action_log.append([
                        self.training_counter, action[0], action[1], action[2]
                    ])

                if done:
                    if i % 10 == 0:
                        print('\nrollout ' + str(i) + '\n',
                              pd.DataFrame(log_data).max())
                    log_data = []
                    self.training_scores.append([i, R])
                    self.training_counter += 1

            self.agent.stop_episode()

            if i % self.monitor_freq == 0:

                # self.agent.stop_episode_and_train(obs, reward, done)

                # print('\n\nvalidation...')
                self.validate(i, self.training_counter, 'gold')
                if debug_printing: print('\n\n****************\nSOY BEANS\n\n')
                self.validate(i, self.training_counter, 'soybean_oil')
                if debug_printing: print('\n\n****************\nCORN\n\n')
                self.validate(i, self.training_counter, 'corn')
                # if debug_printing: print('\n\n****************\nCANADIAN DOLLAR\n\n')
                # self.validate(i, self.training_counter, 'canadian_dollar')

                if debug_printing: print('\n****************\n')

                act_probs = softmax(np.asarray(action_log)[:, 1:], axis=1)

                plt.figure()
                plt.scatter(np.asarray(self.training_scores)[:, 0],
                            np.asarray(self.training_scores)[:, 1],
                            s=2,
                            label='reward')
                plt.legend()
                plt.title('reward')
                plt.grid()
                plt.savefig(self.outdir + 'reward.pdf')
                plt.close()

                plt.figure()
                plt.scatter(np.asarray(action_log)[:, 0],
                            act_probs[:, 0],
                            label='action0')
                plt.scatter(np.asarray(action_log)[:, 0],
                            act_probs[:, 1],
                            label='action1')
                plt.scatter(np.asarray(action_log)[:, 0],
                            act_probs[:, 2],
                            label='action2')
                plt.legend()
                plt.title('actions')
                plt.grid()
                plt.savefig(self.outdir + 'actions.pdf')
                plt.close()

                plt.figure()
                plt.plot(np.asarray(action_log)[:, 0],
                         act_probs[:, 0],
                         label='action0')
                plt.plot(np.asarray(action_log)[:, 0],
                         act_probs[:, 1],
                         label='action1')
                plt.plot(np.asarray(action_log)[:, 0],
                         act_probs[:, 2],
                         label='action2')
                plt.legend()
                plt.title('actions')
                plt.grid()
                plt.savefig(self.outdir + 'actions_plot.pdf')
                plt.close()

            if i % 10 == 0 and i > 0:

                self.agent.save(self.outdir)

                serializers.save_npz(self.outdir + 'model.npz', self.model)

            # if i % 1000 == 0:
            #     print('\nepisode:', i, ' | episode length: ', t, '\nreward:', R,
            #           '\nstatistics:', self.agent.get_statistics(), '\n')

        self.agent.stop_episode_and_train(obs, reward, done)
        print('Finished.')