def test_buffer_truncate(): b = construct_fake_processing_buffer() update_buffer = AgentBuffer() b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) # Test non-LSTM update_buffer.truncate(2) assert update_buffer.num_experiences == 2 b.append_to_update_buffer(update_buffer, 3, batch_size=None, training_length=2) b.append_to_update_buffer(update_buffer, 2, batch_size=None, training_length=2) # Test LSTM, truncate should be some multiple of sequence_length update_buffer.truncate(4, sequence_length=3) assert update_buffer.num_experiences == 3
def test_buffer_truncate(): agent_1_buffer = construct_fake_buffer(1) agent_2_buffer = construct_fake_buffer(2) update_buffer = AgentBuffer() agent_1_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) agent_2_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) # Test non-LSTM update_buffer.truncate(2) assert update_buffer.num_experiences == 2 agent_1_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) agent_2_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) # Test LSTM, truncate should be some multiple of sequence_length update_buffer.truncate(4, sequence_length=3) assert update_buffer.num_experiences == 3 for buffer_field in update_buffer.values(): assert isinstance(buffer_field, AgentBufferField)