示例#1
0
def test_replay_buffer_from_state_wrong_type():
    # Assign
    buffer = ReplayBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        ReplayBuffer.from_state(state=state)
 def from_state(state: BufferState) -> BufferBase:
     if state.type == ReplayBuffer.type:
         return ReplayBuffer.from_state(state)
     elif state.type == PERBuffer.type:
         return PERBuffer.from_state(state)
     elif state.type == NStepBuffer.type:
         return NStepBuffer.from_state(state)
     elif state.type == RolloutBuffer.type:
         return RolloutBuffer.from_state(state)
     else:
         raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
示例#3
0
def test_replay_buffer_from_state_without_data():
    # Assign
    buffer = ReplayBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()

    # Act
    new_buffer = ReplayBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert new_buffer.data == []
示例#4
0
def test_replay_buffer_from_state_with_data():
    # Assign
    buffer = ReplayBuffer(batch_size=5, buffer_size=20)
    buffer = populate_buffer(buffer, 30)
    state = buffer.get_state()

    # Act
    new_buffer = ReplayBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert new_buffer.data == state.data
    assert len(buffer.data) == state.buffer_size