def test_replay_buffer(n_episodes, batch_size, maxlen, gamma): env = gym.make('CartPole-v0') buffer = ReplayBuffer(maxlen, env, gamma) total_step = 0 for episode in range(n_episodes): observation, reward, terminal = env.reset(), 0.0, False while not terminal: action = env.action_space.sample() buffer.append(observation, action, reward, terminal) observation, reward, terminal, _ = env.step(action) total_step += 1 buffer.append(observation, action, reward, terminal) total_step += 1 assert len(buffer) == maxlen observation_shape = env.observation_space.shape batch = buffer.sample(batch_size) assert len(batch) == batch_size assert batch.observations.shape == (batch_size, ) + observation_shape assert batch.actions.shape == (batch_size, 1) assert batch.rewards.shape == (batch_size, 1) assert batch.next_observations.shape == (batch_size, ) + observation_shape assert batch.next_actions.shape == (batch_size, 1) assert batch.next_rewards.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert len(batch.returns) == batch_size assert len(batch.consequent_observations) == batch_size
def test_replay_buffer_with_clip_episode(n_episodes, batch_size, maxlen, clip_episode_flag): env = gym.make("CartPole-v0") buffer = ReplayBuffer(maxlen, env) observation, reward, terminal = env.reset(), 0.0, False clip_episode = False while not clip_episode: action = env.action_space.sample() observation, reward, terminal, _ = env.step(action) clip_episode = terminal if clip_episode_flag and terminal: terminal = False buffer.append( observation=observation.astype("f4"), action=action, reward=reward, terminal=terminal, clip_episode=clip_episode, ) # make a transition for a new episode for _ in range(2): buffer.append( observation=observation.astype("f4"), action=action, reward=reward, terminal=False, ) assert buffer.transitions[-2].terminal != clip_episode_flag assert buffer.transitions[-2].next_transition is None assert buffer.transitions[-1].prev_transition is None
def test_replay_buffer(n_episodes, batch_size, maxlen): env = gym.make("CartPole-v0") buffer = ReplayBuffer(maxlen, env) total_step = 0 for episode in range(n_episodes): observation, reward, terminal = env.reset(), 0.0, False while not terminal: action = env.action_space.sample() buffer.append(observation.astype("f4"), action, reward, terminal) observation, reward, terminal, _ = env.step(action) total_step += 1 buffer.append(observation.astype("f4"), action, reward, terminal) total_step += 1 assert len(buffer) == maxlen observation_shape = env.observation_space.shape batch = buffer.sample(batch_size) assert len(batch) == batch_size assert batch.observations.shape == (batch_size,) + observation_shape assert batch.actions.shape == (batch_size,) assert batch.rewards.shape == (batch_size, 1) assert batch.next_observations.shape == (batch_size,) + observation_shape assert batch.next_actions.shape == (batch_size,) assert batch.next_rewards.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert isinstance(batch.observations, np.ndarray) assert isinstance(batch.next_observations, np.ndarray)
def test_replay_buffer(n_episodes, batch_size, maxlen, create_mask, mask_size): env = gym.make("CartPole-v0") buffer = ReplayBuffer(maxlen, env, create_mask=create_mask, mask_size=mask_size) total_step = 0 for episode in range(n_episodes): observation, reward, terminal = env.reset(), 0.0, False while not terminal: action = env.action_space.sample() buffer.append(observation.astype("f4"), action, reward, terminal) observation, reward, terminal, _ = env.step(action) total_step += 1 buffer.append(observation.astype("f4"), action, reward, terminal) total_step += 1 assert len(buffer) == maxlen # check static dataset conversion dataset = buffer.to_mdp_dataset() transitions = [] for episode in dataset: transitions += episode.transitions assert len(transitions) >= len(buffer) observation_shape = env.observation_space.shape batch = buffer.sample(batch_size) assert len(batch) == batch_size assert batch.observations.shape == (batch_size, ) + observation_shape assert batch.actions.shape == (batch_size, ) assert batch.rewards.shape == (batch_size, 1) assert batch.next_observations.shape == (batch_size, ) + observation_shape assert batch.next_actions.shape == (batch_size, ) assert batch.next_rewards.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert isinstance(batch.observations, np.ndarray) assert isinstance(batch.next_observations, np.ndarray) if create_mask: assert batch.masks.shape == (mask_size, batch_size, 1) else: assert batch.masks is None