Ejemplo n.º 1
0
def test_rollout_buffer_from_state_wrong_type():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        RolloutBuffer.from_state(state=state)
Ejemplo n.º 2
0
 def from_state(state: BufferState) -> BufferBase:
     if state.type == ReplayBuffer.type:
         return ReplayBuffer.from_state(state)
     elif state.type == PERBuffer.type:
         return PERBuffer.from_state(state)
     elif state.type == NStepBuffer.type:
         return NStepBuffer.from_state(state)
     elif state.type == RolloutBuffer.type:
         return RolloutBuffer.from_state(state)
     else:
         raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
Ejemplo n.º 3
0
def test_rollout_buffer_from_state_without_data():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()

    # Act
    new_buffer = RolloutBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert len(new_buffer.data) == 0