class Test_env_reset(object): @classmethod def setup_class(cls): with mock.patch('cryptotrader.envs.trading.datetime') as mock_datetime: mock_datetime.now.return_value = datetime.fromtimestamp( index).astimezone(timezone.utc) mock_datetime.fromtimestamp = lambda *args, **kw: datetime.fromtimestamp( *args, **kw) mock_datetime.side_effect = lambda *args, **kw: datetime( *args, **kw) cls.env = PaperTradingEnvironment(period=5, obs_steps=10, tapi=tapi, fiat="USDT", name='env_test') # cls.env.add_pairs("USDT_BTC", "USDT_ETH") # cls.env.fiat = "USDT" @classmethod def teardown_class(cls): shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs')) @mock.patch.object(PaperTradingEnvironment, 'timestamp', floor_datetime( datetime.fromtimestamp(index).astimezone( timezone.utc), 5)) def test_reset(self): obs = self.env.reset() # Assert observation assert isinstance( self.env.obs_df, pd.DataFrame) and self.env.obs_df.shape[0] == self.env.obs_steps assert isinstance(obs, pd.DataFrame) and obs.shape[0] == self.env.obs_steps # Assert taxes assert list(self.env.tax.keys()) == self.env.symbols # Assert portfolio log assert isinstance(self.env.portfolio_df, pd.DataFrame) and self.env.portfolio_df.shape[0] == 1 assert list(self.env.portfolio_df.columns) == list( self.env.symbols) + ['portval'] # Assert action log assert isinstance(self.env.action_df, pd.DataFrame) and self.env.action_df.shape[0] == 1 assert list( self.env.action_df.columns) == list(self.env.symbols) + ['online'] # Assert balance assert list(self.env.balance.keys()) == list(self.env.symbols) for symbol in self.env.balance: assert isinstance(self.env.balance[symbol], Decimal)
def trade(self, env, start_step=0, act_now=False, timeout=None, verbose=False, render=False, db_sock=None, email=False, save_dir="./"): """ TRADE REAL ASSETS WITHIN EXCHANGE. USE AT YOUR OWN RISK! :param env: Livetrading or Papertrading environment instance :param start_step: int: strategy start step :param act_now: bool: Whether to act now or at the next bar start :param timeout: int: Not implemented yet :param verbose: bool: :param render: bool: Not implemented yet :param email: bool: Wheter to send report email or not :param save_dir: str: Save directory for logs :return: """ try: # Fiat symbol self.fiat = env._fiat # Reset env and get initial obs env.reset_status() obs = env.reset() # Set flags can_act = act_now may_report = True status = env.status self.db_sock = False if db_sock: self.db_sock = db_sock # Get initial values prev_portval = init_portval = env.calc_total_portval() init_time = env.timestamp last_action_time = floor_datetime(env.timestamp, env.period) t0 = time() # # TODO: use datetime # Initialize var episode_reward = 0 reward = 0 print( "Executing trading with %d min frequency.\nInitial portfolio value: %f fiat units." % (env.period, init_portval)) Logger.info(Agent.trade, "Starting trade routine...") # Init step counter self.step = start_step if verbose: msg = self.make_report(env, obs, reward, episode_reward, t0, init_time, env.calc_portfolio_vector(), prev_portval, init_portval) print(msg, end="\r", flush=True) if email and may_report: if hasattr(env, 'email'): env.send_email("Trading report " + self.name, msg) while True: try: # Log action time loop_time = env.timestamp # Can act? if loop_time >= last_action_time + timedelta( minutes=env.period): can_act = True try: del self.log["Trade_incomplete"] except Exception: pass # If can act, run strategy and step environment if can_act: # Log action time last_action_time = floor_datetime( env.timestamp, env.period) # Ask oracle for a prediction action = self.rebalance( env.get_observation(True).astype(np.float64)) # Generate report if verbose or email: msg = self.make_report(env, obs, reward, episode_reward, t0, loop_time, action, prev_portval, init_portval) if verbose: print(msg, end="\r", flush=True) if email and may_report: if hasattr(env, 'email'): env.send_email( "Trading report " + self.name, msg) may_report = False # Save portval for report prev_portval = env.calc_total_portval() # Sample environment obs, reward, done, status = env.step(action) # Accumulate reward episode_reward += reward # If action is complete, increment step counter, log action time and allow report if done: # Increase step counter self.step += 1 # You can act just one time per candle can_act = False # If you've acted, report yourself to nerds may_report = True if self.db_sock: self.db_sock.send_string('update') else: self.log[ "Trade_incomplete"] = "Position change was not fully completed." # If can't act, just take an observation and return else: obs = env.get_observation(True).astype(np.float64) # Not implemented yet if render: env.render() # If environment return an error, save data frames and break if status['Error']: # Get error e = status['Error'] # Save data frames for analysis self.save_dfs(env, save_dir, init_time) # Report error if verbose: print( "Env error:", type(e).__name__ + ' in line ' + str(e.__traceback__.tb_lineno) + ': ' + str(e)) if email: if hasattr(env, 'email'): env.send_email("Trading error: %s" % env.name, env.parse_error(e)) # Panic break if not can_act: # When everything is done, wait for the next candle try: sleep( datetime.timestamp(last_action_time + timedelta( minutes=env.period)) - datetime.timestamp(env.timestamp) + np.random.random(1) * 3) except ValueError: sleep(np.random.random(1) * 3) except MaxRetriesException as e: # Tell nerds the delay Logger.error( Agent.trade, "Retries exhausted. Waiting for connection...") try: env.send_email("Trading error: %s" % env.name, env.parse_error(e)) except Exception: pass # Wait for the next candle try: sleep( datetime.timestamp(last_action_time + timedelta(minutes=env.period)) - datetime.timestamp(env.timestamp) + np.random.random(1) * 30) except ValueError: sleep(1 + int(np.random.random(1) * 30)) # Catch exceptions except Exception as e: print(env.timestamp) print(obs) print(env.portfolio_df.iloc[-5:]) print(env.action_df.iloc[-5:]) print("Action taken:", action) print(env.get_reward(prev_portval)) print( "\nAgent Trade Error:", type(e).__name__ + ' in line ' + str(e.__traceback__.tb_lineno) + ': ' + str(e)) # Save dataframes for analysis self.save_dfs(env, save_dir, init_time) if email: env.send_email("Trading error: %s" % env.name, env.parse_error(e)) break # If interrupted, save data and quit except KeyboardInterrupt: # Save dataframes for analysis self.save_dfs(env, save_dir, init_time) print("\nKeyboard Interrupt: Stoping cryptotrader" + \ "\nElapsed steps: {0}\nUptime: {1}\nInitial Portval: {2:.08f}\nFinal Portval: {3:.08f}\n".format(self.step, str(pd.to_timedelta(time() - t0, unit='s')), init_portval, env.calc_total_portval())) # Catch exceptions except Exception as e: print( "\nAgent Trade Error:", type(e).__name__ + ' in line ' + str(e.__traceback__.tb_lineno) + ': ' + str(e)) raise e
class Test_env_step(object): # TODO: CHECK THIS TEST @classmethod def setup_class(cls): with mock.patch('cryptotrader.envs.trading.datetime') as mock_datetime: mock_datetime.now.return_value = datetime.fromtimestamp( index).astimezone(timezone.utc) mock_datetime.fromtimestamp = lambda *args, **kw: datetime.fromtimestamp( *args, **kw) mock_datetime.side_effect = lambda *args, **kw: datetime( *args, **kw) cls.env = PaperTradingEnvironment(period=5, obs_steps=10, tapi=tapi, name='env_test') cls.env.add_pairs("USDT_BTC", "USDT_ETH") cls.env.fiat = "USDT" cls.env.reset() cls.env.fiat = 100 cls.env.reset_status() @classmethod def teardown_class(cls): shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs')) @given( arrays(dtype=np.float32, shape=(3, ), elements=st.floats(allow_nan=False, allow_infinity=False, max_value=1e8, min_value=0))) @settings(max_examples=50) def test_simulate_trade(self, action): # Normalize action vector action = array_normalize(action, False) assert action.sum() - Decimal('1.00000000') < Decimal( '1E-8'), action.sum() - Decimal('1.00000000') # Get timestamp timestamp = self.env.obs_df.index[-1] # Call method self.env.simulate_trade(action, timestamp) # Assert position for i, symbol in enumerate(self.env.symbols): assert self.env.action_df.get_value( timestamp, symbol) - convert_to.decimal( action[i]) <= Decimal('1E-8') # Assert amount for i, symbol in enumerate(self.env.symbols): if symbol not in self.env._fiat: assert self.env.portfolio_df.get_value(self.env.portfolio_df[symbol].last_valid_index(), symbol) - \ self.env.action_df.get_value(timestamp, symbol) * self.env.calc_total_portval(timestamp) / \ self.env.get_open_price(symbol, timestamp) <= convert_to.decimal('1E-4') @mock.patch.object(PaperTradingEnvironment, 'timestamp', floor_datetime( datetime.fromtimestamp(index).astimezone( timezone.utc), 5)) @given( arrays(dtype=np.float32, shape=(3, ), elements=st.floats(allow_nan=False, allow_infinity=False, max_value=1e8, min_value=0))) @settings(max_examples=50) def test_step(self, action): # obs = self.env.reset() action = array_softmax(action) obs, reward, done, status = self.env.step(action) # Assert returned obs assert isinstance(obs, pd.DataFrame) assert obs.shape[0] == self.env.obs_steps assert set(obs.columns.levels[0]) == set( list(self.env.pairs) + [self.env._fiat]) # Assert reward assert isinstance(reward, np.float64) assert reward not in (np.nan, np.inf) # Assert done assert isinstance(done, bool) # Assert status assert status == self.env.status for key in status: assert status[key] == False