def test_per_from_state_wrong_batch_size(): # Assign buffer = NStepBuffer(n_steps=5, gamma=1.) state = buffer.get_state() state.batch_size = 5 # Act with pytest.raises(ValueError): NStepBuffer.from_state(state=state)
def test_per_from_state_wrong_type(): # Assign buffer = NStepBuffer(n_steps=5, gamma=1.) state = buffer.get_state() state.type = "WrongType" # Act with pytest.raises(ValueError): NStepBuffer.from_state(state=state)
def from_state(state: BufferState) -> BufferBase: if state.type == ReplayBuffer.type: return ReplayBuffer.from_state(state) elif state.type == PERBuffer.type: return PERBuffer.from_state(state) elif state.type == NStepBuffer.type: return NStepBuffer.from_state(state) elif state.type == RolloutBuffer.type: return RolloutBuffer.from_state(state) else: raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
def test_nstep_buffer_from_state_without_data(): # Assign buffer_size, gamma = 5, 0.9 buffer = NStepBuffer(n_steps=buffer_size, gamma=gamma) state = buffer.get_state() # Act new_buffer = NStepBuffer.from_state(state) # Assert assert new_buffer.type == NStepBuffer.type assert new_buffer.gamma == gamma assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1 assert len(new_buffer.data) == 0
def test_nstep_buffer_from_state_with_data(): # Assign buffer_size = 5 buffer = NStepBuffer(n_steps=buffer_size, gamma=1.) buffer = populate_buffer(buffer, 10) # in-place last_samples = [ sars for sars in generate_sample_SARS(buffer_size, dict_type=True) ] for sample in last_samples: buffer.add(**sample) state = buffer.get_state() # Act new_buffer = NStepBuffer.from_state(state) # Assert assert new_buffer.type == NStepBuffer.type assert new_buffer.buffer_size == state.buffer_size == buffer.n_steps assert new_buffer.batch_size == state.batch_size == buffer.batch_size == 1 assert len(new_buffer.data) == state.buffer_size for sample in last_samples: assert Experience(**sample) in new_buffer.data