예제 #1
0
파일: a3c_app.py 프로젝트: yt7589/iching
    def make_env():
        run_name = "yt1"
        saves_path = AppConfig.SAVES_DIR / f"simple-{run_name}"
        saves_path.mkdir(parents=True, exist_ok=True)

        data_path = pathlib.Path(AppConfig.STOCKS)
        val_path = pathlib.Path(AppConfig.VAL_STOCKS)
        year = 2016

        if year is not None or data_path.is_file():
            if year is not None:
                print('load stock data...')
                stock_data = BarData.load_year_data(year)
            else:
                stock_data = {"YNDX": BarData.load_relative(data_path)}
            env = MinuteBarEnv(stock_data,
                               bars_count=AppConfig.BARS_COUNT,
                               volumes=True)
            env_tst = MinuteBarEnv(stock_data,
                                   bars_count=AppConfig.BARS_COUNT,
                                   volumes=True)
        elif data_path.is_dir():
            env = MinuteBarEnv.from_dir(data_path,
                                        bars_count=AppConfig.BARS_COUNT)
            env_tst = MinuteBarEnv.from_dir(data_path,
                                            bars_count=AppConfig.BARS_COUNT)
        else:
            raise RuntimeError("No data to train on")

        env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
        val_data = {"YNDX": BarData.load_relative(val_path)}
        env_val = MinuteBarEnv(val_data,
                               bars_count=AppConfig.BARS_COUNT,
                               volumes=True)
        return env, env_val, env_tst
예제 #2
0
 def test_State_main(self):
     print('生成环境状态类')
     year = 2016
     instrument = 'data\\YNDX_160101_161231.csv'
     stock_data = BarData.load_year_data(year)
     print('stock_data: {0};'.format(stock_data[instrument]))
     st = State(bars_count=10,
                commission_perc=0.1,
                reset_on_close=True,
                reward_on_close=True,
                volumes=True)
     st.reset(stock_data[instrument], offset=AppConfig.BARS_COUNT + 1)
     obs = st.encode()
     print('initial observation: type:{0}; shape:{1};'.format(
         type(obs), obs))
     # 购买股票
     action = AssetActions.Buy
     reward, done = st.step(action=action)
     obs = st.encode()
     info = {'instrument': 'YNDX', 'offset': st._offset}
     self._print_State_step_result(reward, done, obs, info)
     # 持有
     action = AssetActions.Skip
     reward, done = st.step(action=action)
     obs = st.encode()
     info = {'instrument': 'YNDX', 'offset': st._offset}
     self._print_State_step_result(reward, done, obs, info)
     # 卖出
     action = AssetActions.Sell
     reward, done = st.step(action=action)
     obs = st.encode()
     info = {'instrument': 'YNDX', 'offset': st._offset}
     self._print_State_step_result(reward, done, obs, info)
     print('^_^')
     self.assertTrue(1 > 0)
예제 #3
0
 def test_MinuteBarEnv_main(self):
     '''
     研究市场环境类
     '''
     year = 2016
     instrument = 'data\\YNDX_160101_161231.csv'
     stock_data = BarData.load_year_data(year)
     print('stock_data: {0};'.format(stock_data[instrument]))
     env = MinuteBarEnv(stock_data,
                        bars_count=AppConfig.BARS_COUNT,
                        volumes=True)
     obs = env.reset()
     seq = 1
     while True:
         if seq > 3:
             done = True
         if 1 == seq:
             action = AssetActions.Buy
         elif 2 == seq:
             action = AssetActions.Sell
         else:
             action = AssetActions.Skip
         obs, reward, done, info = env.step(action)
         if done:
             break
         env.render(mode='human', obs=obs, reward=reward, info=info)
         seq += 1
     print('observation: {0};'.format(obs))
예제 #4
0
 def test_exp(self):
     #
     device = torch.device("cuda:0")
     year = 2016
     stock_data = BarData.load_year_data(year)
     env = MinuteBarEnv(
             stock_data, bars_count=AppConfig.BARS_COUNT)
     env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
     net = SimpleFFDQN(env.observation_space.shape[0],
                             env.action_space.n).to(device)
     selector = rll.actions.EpsilonGreedyActionSelector(AppConfig.EPS_START)
     agt = DQNAgent(net, selector, device=device)
     obs = env.reset()
     '''
예제 #5
0
    def test_exp(self):
        #
        device = torch.device("cuda:0")
        year = 2016
        stock_data = BarData.load_year_data(year)
        env = MinuteBarEnv(stock_data,
                           bars_count=AppConfig.BARS_COUNT,
                           volumes=True)
        env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
        net = SimpleFFDQN(env.observation_space.shape[0],
                          env.action_space.n).to(device)
        selector = rll.actions.EpsilonGreedyActionSelector(AppConfig.EPS_START)
        agent = DQNAgent(net, selector, device=device)

        exp_source = rll.experience.ExperienceSourceFirstLast(
            env, agent, AppConfig.GAMMA, steps_count=AppConfig.REWARD_STEPS)
        src_itr = iter(exp_source)
        v1 = next(src_itr)
예제 #6
0
 def test_exp(self):
     #
     device = torch.device("cuda:0")
     year = 2016
     stock_data = BarData.load_year_data(year)
     env = MinuteBarEnv(
             stock_data, bars_count=AppConfig.BARS_COUNT, volumes=True)
     env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
     net = SimpleFFDQN(env.observation_space.shape[0],
                             env.action_space.n).to(device)
     selector = rll.actions.EpsilonGreedyActionSelector(AppConfig.EPS_START)
     agent = DQNAgent(net, selector, device=device)
     exp_source = rll.experience.ExperienceSourceFirstLast(
         env, agent, AppConfig.GAMMA, steps_count=AppConfig.REWARD_STEPS)
     replay_buffer = ExperienceReplayBuffer(
         exp_source, AppConfig.REPLAY_SIZE)
     replay_buffer.populate(1000)
     batch_size = 16
     X = replay_buffer.sample(batch_size)
     print('X: {0}; {1};'.format(type(X), X))
예제 #7
0
 def from_dir(cls, data_dir, **kwargs):
     prices = {
         file: BarData.load_relative(file)
         for file in BarData.price_files(data_dir)
     }
     return MinuteBarEnv(prices, **kwargs)