def check_stock_trading_env(): if_eval = True # False env = StockTradingEnv(if_eval=if_eval) action_dim = env.action_dim state = env.reset() print('state_dim', len(state)) from time import time timer = time() step = 1 done = False reward = None while not done: action = rd.rand(action_dim) * 2 - 1 next_state, reward, done, _ = env.step(action) # print(';', len(next_state), env.day, reward) step += 1 print(f"| Random action: step {step}, UsedTime {time() - timer:.3f}") print(f"| Random action: terminal reward {reward:.3f}") print(f"| Random action: episode return {env.episode_return:.3f}") '''draw_cumulative_return''' from elegantrl.agent import AgentPPO from elegantrl.run import Arguments args = Arguments(if_on_policy=True) args.agent = AgentPPO() args.env = StockTradingEnv(if_eval=True) args.if_remove = False args.cwd = './AgentPPO/StockTradingEnv-v1_0' args.init_before_training() env.draw_cumulative_return(args, torch)
def __init__(self, ticker_list, time_interval, drl_lib, agent, cwd, net_dim, state_dim, action_dim, API_KEY, API_SECRET, APCA_API_BASE_URL, tech_indicator_list, turbulence_thresh=30, max_stock=1e2, latency=None): #load agent self.drl_lib = drl_lib if agent == 'ppo': if drl_lib == 'elegantrl': from elegantrl.agent import AgentPPO from elegantrl.run import Arguments, init_agent #load agent config = { 'state_dim': state_dim, 'action_dim': action_dim, } args = Arguments(agent=AgentPPO, env=StockEnvEmpty(config)) args.cwd = cwd args.net_dim = net_dim # load agent try: agent = init_agent(args, gpu_id=0) self.act = agent.act self.device = agent.device except BaseException: raise ValueError("Fail to load agent!") elif drl_lib == 'rllib': from ray.rllib.agents import ppo from ray.rllib.agents.ppo.ppo import PPOTrainer config = ppo.DEFAULT_CONFIG.copy() config['env'] = StockEnvEmpty config["log_level"] = "WARN" config['env_config'] = { 'state_dim': state_dim, 'action_dim': action_dim, } trainer = PPOTrainer(env=StockEnvEmpty, config=config) trainer.restore(cwd) try: trainer.restore(cwd) self.agent = trainer print("Restoring from checkpoint path", cwd) except: raise ValueError('Fail to load agent!') elif drl_lib == 'stable_baselines3': from stable_baselines3 import PPO try: #load agent self.model = PPO.load(cwd) print("Successfully load model", cwd) except: raise ValueError('Fail to load agent!') else: raise ValueError( 'The DRL library input is NOT supported yet. Please check your input.' ) else: raise ValueError('Agent input is NOT supported yet.') #connect to Alpaca trading API try: self.alpaca = tradeapi.REST(API_KEY, API_SECRET, APCA_API_BASE_URL, 'v2') except: raise ValueError( 'Fail to connect Alpaca. Please check account info and internet connection.' ) #read trading time interval if time_interval == '1s': self.time_interval = 1 elif time_interval == '5s': self.time_interval = 5 elif time_interval == '1Min': self.time_interval = 60 elif time_interval == '5Min': self.time_interval = 60 * 5 elif time_interval == '15Min': self.time_interval = 60 * 15 else: raise ValueError('Time interval input is NOT supported yet.') #read trading settings self.tech_indicator_list = tech_indicator_list self.turbulence_thresh = turbulence_thresh self.max_stock = max_stock #initialize account self.stocks = np.asarray([0] * len(ticker_list)) #stocks holding self.stocks_cd = np.zeros_like(self.stocks) self.cash = None #cash record self.stocks_df = pd.DataFrame(self.stocks, columns=['stocks'], index=ticker_list) self.asset_list = [] self.price = np.asarray([0] * len(ticker_list)) self.stockUniverse = ticker_list self.turbulence_bool = 0 self.equities = []