Example #1
0
def test_buffer_truncate():
    agent_1_buffer = construct_fake_buffer(1)
    agent_2_buffer = construct_fake_buffer(2)
    update_buffer = AgentBuffer()
    agent_1_buffer.resequence_and_append(update_buffer,
                                         batch_size=None,
                                         training_length=2)
    agent_2_buffer.resequence_and_append(update_buffer,
                                         batch_size=None,
                                         training_length=2)
    # Test non-LSTM
    update_buffer.truncate(2)
    assert update_buffer.num_experiences == 2

    agent_1_buffer.resequence_and_append(update_buffer,
                                         batch_size=None,
                                         training_length=2)
    agent_2_buffer.resequence_and_append(update_buffer,
                                         batch_size=None,
                                         training_length=2)
    # Test LSTM, truncate should be some multiple of sequence_length
    update_buffer.truncate(4, sequence_length=3)
    assert update_buffer.num_experiences == 3
    for buffer_field in update_buffer.values():
        assert isinstance(buffer_field, AgentBufferField)
 def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult:
     """
     Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward
     function drawn straight from a Buffer.
     :param mini_batch: A Dict of numpy arrays (the format used by our Buffer)
         when drawing from the update buffer.
     :return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
     """
     mini_batch_len = len(next(iter(mini_batch.values())))
     return RewardSignalResult(
         self.strength * np.zeros(mini_batch_len, dtype=np.float32),
         np.zeros(mini_batch_len, dtype=np.float32),
     )