Ejemplo n.º 1
0
def test_per_from_state_wrong_batch_size():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)
    state = buffer.get_state()
    state.batch_size = 5

    # Act
    with pytest.raises(ValueError):
        NStepBuffer.from_state(state=state)
Ejemplo n.º 2
0
def test_per_from_state_wrong_type():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        NStepBuffer.from_state(state=state)
Ejemplo n.º 3
0
 def from_state(state: BufferState) -> BufferBase:
     if state.type == ReplayBuffer.type:
         return ReplayBuffer.from_state(state)
     elif state.type == PERBuffer.type:
         return PERBuffer.from_state(state)
     elif state.type == NStepBuffer.type:
         return NStepBuffer.from_state(state)
     elif state.type == RolloutBuffer.type:
         return RolloutBuffer.from_state(state)
     else:
         raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
Ejemplo n.º 4
0
def test_nstep_buffer_from_state_without_data():
    # Assign
    buffer_size, gamma = 5, 0.9
    buffer = NStepBuffer(n_steps=buffer_size, gamma=gamma)
    state = buffer.get_state()

    # Act
    new_buffer = NStepBuffer.from_state(state)

    # Assert
    assert new_buffer.type == NStepBuffer.type
    assert new_buffer.gamma == gamma
    assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps
    assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1
    assert len(new_buffer.data) == 0
Ejemplo n.º 5
0
def test_nstep_buffer_from_state_with_data():
    # Assign
    buffer_size = 5
    buffer = NStepBuffer(n_steps=buffer_size, gamma=1.)
    buffer = populate_buffer(buffer, 10)  # in-place
    last_samples = [
        sars for sars in generate_sample_SARS(buffer_size, dict_type=True)
    ]
    for sample in last_samples:
        buffer.add(**sample)
    state = buffer.get_state()

    # Act
    new_buffer = NStepBuffer.from_state(state)

    # Assert
    assert new_buffer.type == NStepBuffer.type
    assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps
    assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1
    assert len(new_buffer.data) == state.buffer_size

    for sample in last_samples:
        assert Experience(**sample) in new_buffer.data