def test_uniform_samples_all_transitions_eventually_one_add(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10) buf.add(_TestTransition(np.array([0, 1]))) sampled_transitions = set() for _ in range(100): sampled_transitions.add(buf.sample(batch_size=1).test_field.item()) assert sampled_transitions == {0, 1}
def test_uniform_overwrites_old_transitions_multiple_times(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10) for i in range(0, 99, 3): buf.add(_TestTransition(np.arange(i, i + 3))) # Only [89, 99] should remain. x = buf.sample(batch_size=100).test_field print(list(sorted(list(set(x))))) assert set(buf.sample(batch_size=100).test_field) == set(range(89, 99))
def test_uniform_samples_consistent_transitions(): buf = replay_buffers.UniformReplayBuffer(_test_pair_datapoint_sig, capacity=10) for i in range(0, 30, 2): buf.add(_TestPairTransition(np.array([i]), np.array([i + 1]))) # Only [290, 299] should remain. batch = buf.sample(batch_size=10) # Assert that the two fields are not mixed up between transitions. np.testing.assert_array_equal(batch.a + 1, batch.b)
def test_uniform_overwrites_old_transitions(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=4) buf.add(_TestTransition(np.arange(3))) buf.add(_TestTransition(np.arange(3, 6))) # 0, 1 should get overwritten. assert set(buf.sample(batch_size=100).test_field) == {2, 3, 4, 5}
def test_uniform_oversamples_transitions(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10) stacked_transitions = _TestTransition(np.array([0, 1])) buf.add(stacked_transitions) assert set(buf.sample(batch_size=100).test_field) == {0, 1}
def test_uniform_samples_different_transitions(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=100) buf.add(_TestTransition(np.arange(100))) assert len(set(buf.sample(batch_size=3).test_field)) > 1
def test_uniform_raises_when_sampling_from_an_empty_buffer(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10) with pytest.raises(ValueError): buf.sample(batch_size=1)
def test_uniform_samples_added_transition(): buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10) stacked_transitions = _TestTransition(np.array([123])) buf.add(stacked_transitions) assert buf.sample(batch_size=1) == stacked_transitions