def test_buffer_sample(): agent_1_buffer = construct_fake_buffer(1) agent_2_buffer = construct_fake_buffer(2) update_buffer = AgentBuffer() agent_1_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) agent_2_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) # Test non-LSTM mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1) assert mb.keys() == update_buffer.keys() assert np.array(mb[BufferKey.CONTINUOUS_ACTION]).shape == (4, 2) # Test LSTM # We need to check if we ever get a breaking start - this will maximize the probability mb = update_buffer.sample_mini_batch(batch_size=20, sequence_length=19) assert mb.keys() == update_buffer.keys() # Should only return one sequence assert np.array(mb[BufferKey.CONTINUOUS_ACTION]).shape == (19, 2)
def test_buffer_sample(): b = construct_fake_processing_buffer() update_buffer = AgentBuffer() b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) # Test non-LSTM mb = update_buffer.sample_mini_batch(batch_size=4, sequence_length=1) assert mb.keys() == update_buffer.keys() assert np.array(mb["action"]).shape == (4, 2) # Test LSTM # We need to check if we ever get a breaking start - this will maximize the probability mb = update_buffer.sample_mini_batch(batch_size=20, sequence_length=19) assert mb.keys() == update_buffer.keys() # Should only return one sequence assert np.array(mb["action"]).shape == (19, 2)