Ejemplo n.º 1
0
def check_stock_trading_env():
    if_eval = True  # False

    env = StockTradingEnv(if_eval=if_eval)
    action_dim = env.action_dim

    state = env.reset()
    print('state_dim', len(state))

    from time import time
    timer = time()

    step = 1
    done = False
    reward = None
    while not done:
        action = rd.rand(action_dim) * 2 - 1
        next_state, reward, done, _ = env.step(action)
        # print(';', len(next_state), env.day, reward)
        step += 1

    print(f"| Random action: step {step}, UsedTime {time() - timer:.3f}")
    print(f"| Random action: terminal reward {reward:.3f}")
    print(f"| Random action: episode return {env.episode_return:.3f}")
    '''draw_cumulative_return'''
    from elegantrl.agent import AgentPPO
    from elegantrl.run import Arguments
    args = Arguments(if_on_policy=True)
    args.agent = AgentPPO()
    args.env = StockTradingEnv(if_eval=True)
    args.if_remove = False
    args.cwd = './AgentPPO/StockTradingEnv-v1_0'
    args.init_before_training()

    env.draw_cumulative_return(args, torch)
Ejemplo n.º 2
0
    def __init__(self,
                 ticker_list,
                 time_interval,
                 drl_lib,
                 agent,
                 cwd,
                 net_dim,
                 state_dim,
                 action_dim,
                 API_KEY,
                 API_SECRET,
                 APCA_API_BASE_URL,
                 tech_indicator_list,
                 turbulence_thresh=30,
                 max_stock=1e2,
                 latency=None):
        #load agent
        self.drl_lib = drl_lib
        if agent == 'ppo':
            if drl_lib == 'elegantrl':
                from elegantrl.agent import AgentPPO
                from elegantrl.run import Arguments, init_agent
                #load agent
                config = {
                    'state_dim': state_dim,
                    'action_dim': action_dim,
                }
                args = Arguments(agent=AgentPPO, env=StockEnvEmpty(config))
                args.cwd = cwd
                args.net_dim = net_dim
                # load agent
                try:
                    agent = init_agent(args, gpu_id=0)
                    self.act = agent.act
                    self.device = agent.device
                except BaseException:
                    raise ValueError("Fail to load agent!")

            elif drl_lib == 'rllib':
                from ray.rllib.agents import ppo
                from ray.rllib.agents.ppo.ppo import PPOTrainer

                config = ppo.DEFAULT_CONFIG.copy()
                config['env'] = StockEnvEmpty
                config["log_level"] = "WARN"
                config['env_config'] = {
                    'state_dim': state_dim,
                    'action_dim': action_dim,
                }
                trainer = PPOTrainer(env=StockEnvEmpty, config=config)
                trainer.restore(cwd)
                try:
                    trainer.restore(cwd)
                    self.agent = trainer
                    print("Restoring from checkpoint path", cwd)
                except:
                    raise ValueError('Fail to load agent!')

            elif drl_lib == 'stable_baselines3':
                from stable_baselines3 import PPO

                try:
                    #load agent
                    self.model = PPO.load(cwd)
                    print("Successfully load model", cwd)
                except:
                    raise ValueError('Fail to load agent!')

            else:
                raise ValueError(
                    'The DRL library input is NOT supported yet. Please check your input.'
                )

        else:
            raise ValueError('Agent input is NOT supported yet.')

        #connect to Alpaca trading API
        try:
            self.alpaca = tradeapi.REST(API_KEY, API_SECRET, APCA_API_BASE_URL,
                                        'v2')
        except:
            raise ValueError(
                'Fail to connect Alpaca. Please check account info and internet connection.'
            )

        #read trading time interval
        if time_interval == '1s':
            self.time_interval = 1
        elif time_interval == '5s':
            self.time_interval = 5
        elif time_interval == '1Min':
            self.time_interval = 60
        elif time_interval == '5Min':
            self.time_interval = 60 * 5
        elif time_interval == '15Min':
            self.time_interval = 60 * 15
        else:
            raise ValueError('Time interval input is NOT supported yet.')

        #read trading settings
        self.tech_indicator_list = tech_indicator_list
        self.turbulence_thresh = turbulence_thresh
        self.max_stock = max_stock

        #initialize account
        self.stocks = np.asarray([0] * len(ticker_list))  #stocks holding
        self.stocks_cd = np.zeros_like(self.stocks)
        self.cash = None  #cash record
        self.stocks_df = pd.DataFrame(self.stocks,
                                      columns=['stocks'],
                                      index=ticker_list)
        self.asset_list = []
        self.price = np.asarray([0] * len(ticker_list))
        self.stockUniverse = ticker_list
        self.turbulence_bool = 0
        self.equities = []