Exemplo n.º 1
0
def test_replay_buffer(n_episodes, batch_size, maxlen, gamma):
    env = gym.make('CartPole-v0')

    buffer = ReplayBuffer(maxlen, env, gamma)

    total_step = 0
    for episode in range(n_episodes):
        observation, reward, terminal = env.reset(), 0.0, False
        while not terminal:
            action = env.action_space.sample()
            buffer.append(observation, action, reward, terminal)
            observation, reward, terminal, _ = env.step(action)
            total_step += 1
        buffer.append(observation, action, reward, terminal)
        total_step += 1

    assert len(buffer) == maxlen

    observation_shape = env.observation_space.shape
    batch = buffer.sample(batch_size)
    assert len(batch) == batch_size
    assert batch.observations.shape == (batch_size, ) + observation_shape
    assert batch.actions.shape == (batch_size, 1)
    assert batch.rewards.shape == (batch_size, 1)
    assert batch.next_observations.shape == (batch_size, ) + observation_shape
    assert batch.next_actions.shape == (batch_size, 1)
    assert batch.next_rewards.shape == (batch_size, 1)
    assert batch.terminals.shape == (batch_size, 1)
    assert len(batch.returns) == batch_size
    assert len(batch.consequent_observations) == batch_size
Exemplo n.º 2
0
def test_replay_buffer_with_clip_episode(n_episodes, batch_size, maxlen,
                                         clip_episode_flag):
    env = gym.make("CartPole-v0")

    buffer = ReplayBuffer(maxlen, env)

    observation, reward, terminal = env.reset(), 0.0, False
    clip_episode = False
    while not clip_episode:
        action = env.action_space.sample()
        observation, reward, terminal, _ = env.step(action)
        clip_episode = terminal
        if clip_episode_flag and terminal:
            terminal = False
        buffer.append(
            observation=observation.astype("f4"),
            action=action,
            reward=reward,
            terminal=terminal,
            clip_episode=clip_episode,
        )

    # make a transition for a new episode
    for _ in range(2):
        buffer.append(
            observation=observation.astype("f4"),
            action=action,
            reward=reward,
            terminal=False,
        )

    assert buffer.transitions[-2].terminal != clip_episode_flag
    assert buffer.transitions[-2].next_transition is None
    assert buffer.transitions[-1].prev_transition is None
Exemplo n.º 3
0
def test_replay_buffer(n_episodes, batch_size, maxlen):
    env = gym.make("CartPole-v0")

    buffer = ReplayBuffer(maxlen, env)

    total_step = 0
    for episode in range(n_episodes):
        observation, reward, terminal = env.reset(), 0.0, False
        while not terminal:
            action = env.action_space.sample()
            buffer.append(observation.astype("f4"), action, reward, terminal)
            observation, reward, terminal, _ = env.step(action)
            total_step += 1
        buffer.append(observation.astype("f4"), action, reward, terminal)
        total_step += 1

    assert len(buffer) == maxlen

    observation_shape = env.observation_space.shape
    batch = buffer.sample(batch_size)
    assert len(batch) == batch_size
    assert batch.observations.shape == (batch_size,) + observation_shape
    assert batch.actions.shape == (batch_size,)
    assert batch.rewards.shape == (batch_size, 1)
    assert batch.next_observations.shape == (batch_size,) + observation_shape
    assert batch.next_actions.shape == (batch_size,)
    assert batch.next_rewards.shape == (batch_size, 1)
    assert batch.terminals.shape == (batch_size, 1)
    assert isinstance(batch.observations, np.ndarray)
    assert isinstance(batch.next_observations, np.ndarray)
Exemplo n.º 4
0
def test_replay_buffer(n_episodes, batch_size, maxlen, create_mask, mask_size):
    env = gym.make("CartPole-v0")

    buffer = ReplayBuffer(maxlen,
                          env,
                          create_mask=create_mask,
                          mask_size=mask_size)

    total_step = 0
    for episode in range(n_episodes):
        observation, reward, terminal = env.reset(), 0.0, False
        while not terminal:
            action = env.action_space.sample()
            buffer.append(observation.astype("f4"), action, reward, terminal)
            observation, reward, terminal, _ = env.step(action)
            total_step += 1
        buffer.append(observation.astype("f4"), action, reward, terminal)
        total_step += 1

    assert len(buffer) == maxlen

    # check static dataset conversion
    dataset = buffer.to_mdp_dataset()
    transitions = []
    for episode in dataset:
        transitions += episode.transitions
    assert len(transitions) >= len(buffer)

    observation_shape = env.observation_space.shape
    batch = buffer.sample(batch_size)
    assert len(batch) == batch_size
    assert batch.observations.shape == (batch_size, ) + observation_shape
    assert batch.actions.shape == (batch_size, )
    assert batch.rewards.shape == (batch_size, 1)
    assert batch.next_observations.shape == (batch_size, ) + observation_shape
    assert batch.next_actions.shape == (batch_size, )
    assert batch.next_rewards.shape == (batch_size, 1)
    assert batch.terminals.shape == (batch_size, 1)
    assert isinstance(batch.observations, np.ndarray)
    assert isinstance(batch.next_observations, np.ndarray)
    if create_mask:
        assert batch.masks.shape == (mask_size, batch_size, 1)
    else:
        assert batch.masks is None