Esempio n. 1
0
def test_per_from_state_wrong_batch_size():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)
    state = buffer.get_state()
    state.batch_size = 5

    # Act
    with pytest.raises(ValueError):
        NStepBuffer.from_state(state=state)
Esempio n. 2
0
def test_per_from_state_wrong_type():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        NStepBuffer.from_state(state=state)
Esempio n. 3
0
def test_nstep_buffer_get_state_without_data():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == NStepBuffer.type
    assert state.buffer_size == 5
    assert state.batch_size == 1
    assert state.data is None
Esempio n. 4
0
def test_nstep_buffer_get_state_with_data():
    # Assign
    buffer = NStepBuffer(n_steps=5, gamma=1.)
    populate_buffer(buffer, 10)  # in-place

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == NStepBuffer.type
    assert state.buffer_size == 5
    assert state.batch_size == 1
    assert len(state.data) == state.buffer_size
def test_factory_nstep_buffer_from_state_without_data():
    # Assign
    buffer_size, gamma = 5, 0.9
    buffer = NStepBuffer(n_steps=buffer_size, gamma=gamma)
    state = buffer.get_state()

    # Act
    new_buffer = BufferFactory.from_state(state)

    # Assert
    assert new_buffer.type == NStepBuffer.type
    assert new_buffer.gamma == gamma
    assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps
    assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1
    assert len(new_buffer.data) == 0
def test_factory_nstep_buffer_from_state_with_data():
    # Assign
    buffer_size = 5
    buffer = NStepBuffer(n_steps=buffer_size, gamma=1.)
    buffer = populate_buffer(buffer, 10)  # in-place
    last_samples = [
        sars for sars in generate_sample_SARS(buffer_size, dict_type=True)
    ]
    for sample in last_samples:
        buffer.add(**sample)
    state = buffer.get_state()

    # Act
    new_buffer = BufferFactory.from_state(state)

    # Assert
    assert new_buffer.type == NStepBuffer.type
    assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps
    assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1
    assert len(new_buffer.data) == state.buffer_size

    for sample in last_samples:
        assert Experience(**sample) in new_buffer.data