Example #1
0
def test_uniform_samples_all_transitions_eventually_one_add():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    buf.add(_TestTransition(np.array([0, 1])))
    sampled_transitions = set()
    for _ in range(100):
        sampled_transitions.add(buf.sample(batch_size=1).test_field.item())
    assert sampled_transitions == {0, 1}
Example #2
0
def test_uniform_overwrites_old_transitions_multiple_times():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    for i in range(0, 99, 3):
        buf.add(_TestTransition(np.arange(i, i + 3)))
    # Only [89, 99] should remain.
    x = buf.sample(batch_size=100).test_field
    print(list(sorted(list(set(x)))))
    assert set(buf.sample(batch_size=100).test_field) == set(range(89, 99))
Example #3
0
def test_uniform_samples_consistent_transitions():
    buf = replay_buffers.UniformReplayBuffer(_test_pair_datapoint_sig,
                                             capacity=10)
    for i in range(0, 30, 2):
        buf.add(_TestPairTransition(np.array([i]), np.array([i + 1])))
    # Only [290, 299] should remain.
    batch = buf.sample(batch_size=10)
    # Assert that the two fields are not mixed up between transitions.
    np.testing.assert_array_equal(batch.a + 1, batch.b)
Example #4
0
def test_uniform_overwrites_old_transitions():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=4)
    buf.add(_TestTransition(np.arange(3)))
    buf.add(_TestTransition(np.arange(3, 6)))
    # 0, 1 should get overwritten.
    assert set(buf.sample(batch_size=100).test_field) == {2, 3, 4, 5}
Example #5
0
def test_uniform_oversamples_transitions():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    stacked_transitions = _TestTransition(np.array([0, 1]))
    buf.add(stacked_transitions)
    assert set(buf.sample(batch_size=100).test_field) == {0, 1}
Example #6
0
def test_uniform_samples_different_transitions():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=100)
    buf.add(_TestTransition(np.arange(100)))
    assert len(set(buf.sample(batch_size=3).test_field)) > 1
Example #7
0
def test_uniform_raises_when_sampling_from_an_empty_buffer():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    with pytest.raises(ValueError):
        buf.sample(batch_size=1)
Example #8
0
def test_uniform_samples_added_transition():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    stacked_transitions = _TestTransition(np.array([123]))
    buf.add(stacked_transitions)
    assert buf.sample(batch_size=1) == stacked_transitions