Пример #1
0
    def test_add_transitions_dtype(self):
        env = DummyDiscreteEnv()
        obs = env.reset()
        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=3,
                                           time_horizon=1)
        replay_buffer.add_transitions(observation=[obs],
                                      action=[env.action_space.sample()])
        sample = replay_buffer.sample(1)
        sample_obs = sample['observation']
        sample_action = sample['action']

        assert sample_obs.dtype == env.observation_space.dtype
        assert sample_action.dtype == env.action_space.dtype
Пример #2
0
    def test_pickleable(self):
        env = DummyDiscreteEnv()
        obs = env.reset()

        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=100,
                                           time_horizon=1)
        for _ in range(0, 100):
            replay_buffer.add_transitions(observation=[obs], action=[1])
        replay_buffer_pickled = pickle.loads(pickle.dumps(replay_buffer))
        assert replay_buffer_pickled._buffer.keys(
        ) == replay_buffer._buffer.keys()
        for k in replay_buffer_pickled._buffer:
            assert replay_buffer_pickled._buffer[
                k].shape == replay_buffer._buffer[k].shape
Пример #3
0
    def test_eviction_policy(self):
        env = DummyDiscreteEnv()
        obs = env.reset()

        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=3,
                                           time_horizon=1)
        replay_buffer.add_transitions(observation=[obs, obs], action=[1, 2])
        assert not replay_buffer.full
        replay_buffer.add_transitions(observation=[obs, obs], action=[3, 4])
        assert replay_buffer.full
        replay_buffer.add_transitions(observation=[obs, obs], action=[5, 6])
        replay_buffer.add_transitions(observation=[obs, obs], action=[7, 8])

        assert np.array_equal(replay_buffer._buffer['action'], [[7], [8], [6]])
        assert replay_buffer.n_transitions_stored == 3