def _create_replay_buffer_and_insert(env: gym.Env):
    env.seed(1)
    replay_buffer = ReplayBuffer.create_from_env(
        env, replay_memory_size=10, batch_size=1
    )
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "terminal": terminal,
            }
        )
        log_prob = 0.0
        replay_buffer_inserter(replay_buffer, obs, action, reward, terminal, log_prob)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Ejemplo n.º 2
0
 def test_cartpole(self):
     env = gym.make("CartPole-v0")
     replay_buffer = ReplayBuffer.create_from_env(
         env, replay_memory_size=10, batch_size=5
     )
     replay_buffer_inserter = make_replay_buffer_inserter(env)
     obs = env.reset()
     terminal = False
     i = 0
     while not terminal and i < 5:
         action = env.action_space.sample()
         next_obs, reward, terminal, _ = env.step(action)
         replay_buffer_inserter(
             replay_buffer, obs, action, reward, terminal, log_prob=0.0
         )
         obs = next_obs