コード例 #1
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_too_few_samples():
    # Assign
    batch_size = 5
    per_buffer = PERBuffer(batch_size, 10)

    # Act & Assert
    for _ in range(batch_size):
        assert per_buffer.sample_list() is None
        per_buffer.add(priority=0.1, reward=0.1)

    assert per_buffer.sample_list() is not None
コード例 #2
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_reset_alpha():
    # Assign
    per_buffer = PERBuffer(10, 10, alpha=0.1)
    for _ in range(30):
        per_buffer.add(reward=np.random.randint(0, 1e5),
                       priority=np.random.random())

    # Act
    old_experiences = per_buffer.sample_list()
    per_buffer.reset_alpha(0.5)
    new_experiences = per_buffer.sample_list()

    # Assert
    assert old_experiences is not None and new_experiences is not None
    sorted_new_experiences = sorted(new_experiences, key=lambda k: k.index)
    sorted_old_experiences = sorted(old_experiences, key=lambda k: k.index)
    for (new_sample, old_sample) in zip(sorted_new_experiences,
                                        sorted_old_experiences):
        assert new_sample.index == old_sample.index
        assert new_sample.weight != old_sample.weight
        assert new_sample.reward == old_sample.reward
コード例 #3
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_add_one_sample_one():
    # Assign
    per_buffer = PERBuffer(1, 20)

    # Act
    per_buffer.add(priority=0.5, state=range(5))

    # Assert
    raw_samples = per_buffer.sample_list()
    assert raw_samples is not None
    experience = raw_samples[0]
    assert experience.state == range(5)
    assert experience.weight == 1.  # max scale
    assert experience.index == 0
コード例 #4
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_priority_update():
    """Update all priorities to the same value makes them all to be 1."""
    # Assign
    batch_size = 5
    buffer_size = 10
    per_buffer = PERBuffer(batch_size, buffer_size)
    for _ in range(2 * buffer_size):  # Make sure we fill the whole buffer
        per_buffer.add(priority=np.random.randint(10),
                       state=np.random.random(10))
    per_buffer.add(priority=100,
                   state=np.random.random(10))  # Make sure there's one highest

    # Act & Assert
    experiences = per_buffer.sample_list(beta=0.5)
    assert experiences is not None
    assert sum([exp.weight for exp in experiences]) < batch_size

    per_buffer.priority_update(indices=range(buffer_size),
                               priorities=np.ones(buffer_size))
    experiences = per_buffer.sample_list(beta=0.9)
    assert experiences is not None
    weights = [exp.weight for exp in experiences]
    assert sum(weights) == batch_size
    assert all([w == 1 for w in weights])
コード例 #5
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_add_two_sample_two_beta():
    # Assign
    per_buffer = PERBuffer(2, 20)

    # Act
    per_buffer.add(state=range(5), priority=0.9)
    per_buffer.add(state=range(3, 8), priority=0.1)

    # Assert
    experiences = per_buffer.sample_list(beta=0.6)
    assert experiences is not None
    for experience in experiences:
        if experience.index == 0:
            assert experience.state == range(5)
            # assert 0.936 < experience.weight < 0.937
            assert 0.946 < experience.weight < 0.947
        else:
            assert experience.state == range(3, 8)
            assert experience.weight == 1.
コード例 #6
0
ファイル: test_buffers.py プロジェクト: fokx/ai-traineree
def test_per_buffer_sample():
    # Assign
    buffer_size = 5
    per_buffer = PERBuffer(buffer_size)

    # Act
    for priority in range(buffer_size):
        state = np.arange(priority, priority + 10)
        per_buffer.add(priority=priority + 0.01, state=state)

    # Assert
    experiences = per_buffer.sample_list()
    assert experiences is not None
    assert len(experiences) == buffer_size
    zipped_exp = [(exp.state, exp.reward, exp.weight, exp.index)
                  for exp in experiences]
    states, rewards, weights, indices = zip(*zipped_exp)
    assert len(weights) == len(indices) == buffer_size
    assert all([s is not None for s in states])
    assert all([r is None for r in rewards])