def test_adder(self, max_sequence_length): adder = adders.EpisodeAdder(self.client, max_sequence_length) # Create a simple trajectory to add. observations = range(max_sequence_length) first, steps = test_utils.make_trajectory(observations) expected_episode = test_utils.make_sequence(observations) self.run_test_adder(adder=adder, first=first, steps=steps, expected_items=[expected_episode])
def test_nonzero_padding(self, max_sequence_length, padding): adder = adders.EpisodeAdder(self.client, max_sequence_length + padding, padding_fn=lambda s, d: np.zeros(s, d) - 1) # Create a simple trajectory to add. observations = range(max_sequence_length) first, steps = test_utils.make_trajectory(observations) expected_episode = test_utils.make_sequence(observations) for _ in range(padding): expected_episode.append((-1, -1, -1.0, -1.0, False, ())) self.run_test_adder( adder=adder, first=first, steps=steps, expected_items=[expected_episode], signature=adder.signature(*test_utils.get_specs(steps[0])))
def test_adder(self, max_sequence_length): client = test_utils.FakeClient() adder = adders.EpisodeAdder(client, max_sequence_length) # Create a simple trajectory to add. observations = range(max_sequence_length) first, steps = test_utils.make_trajectory(observations) # Add everything up to the final transition. adder.add_first(first) for action, step in steps[:-1]: adder.add(action, step) if max_sequence_length == 2: # Nothing has been written since we only have an initial step and a # final step (corresponding to a single transition). self.assertEmpty(client.writers) else: # No priorities should have been written so far but all timesteps (all # but the last one) should have been sent to the writer. self.assertEmpty(client.writers[0].priorities) self.assertLen(client.writers[0].timesteps, max_sequence_length - 2) # Adding a terminal timestep should close the writer and insert a # prioritized sample referencing all the timesteps (including the padded # terminating observation). action, step = steps[-1] adder.add(action, step) # The writer should be closed and should have max_sequence_length timesteps. self.assertTrue(client.writers[0].closed) self.assertLen(client.writers[0].timesteps, max_sequence_length) # Make the sequence of data and the priority table entry we expect. expected_sequence = test_utils.make_sequence(observations) expected_entry = (base.DEFAULT_PRIORITY_TABLE, expected_sequence, 1.0) self.assertEqual(client.writers[0].priorities, [expected_entry])