def test_buffer_truncate(): agent_1_buffer = construct_fake_buffer(1) agent_2_buffer = construct_fake_buffer(2) update_buffer = AgentBuffer() agent_1_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) agent_2_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) # Test non-LSTM update_buffer.truncate(2) assert update_buffer.num_experiences == 2 agent_1_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) agent_2_buffer.resequence_and_append(update_buffer, batch_size=None, training_length=2) # Test LSTM, truncate should be some multiple of sequence_length update_buffer.truncate(4, sequence_length=3) assert update_buffer.num_experiences == 3 for buffer_field in update_buffer.values(): assert isinstance(buffer_field, AgentBufferField)
def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult: """ Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward function drawn straight from a Buffer. :param mini_batch: A Dict of numpy arrays (the format used by our Buffer) when drawing from the update buffer. :return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator """ mini_batch_len = len(next(iter(mini_batch.values()))) return RewardSignalResult( self.strength * np.zeros(mini_batch_len, dtype=np.float32), np.zeros(mini_batch_len, dtype=np.float32), )