def test_end_of_episode_behavior_set_correctly(self, pad_end_of_episode, break_end_of_episode, expected_behavior): adder = adders.SequenceAdder(self.client, sequence_length=5, period=3, pad_end_of_episode=pad_end_of_episode, break_end_of_episode=break_end_of_episode) self.assertEqual(adder._end_of_episode_behavior, expected_behavior)
def test_adder(self, sequence_length: int, period: int, first, steps, expected_sequences, pad_end_of_episode: bool = True): adder = adders.SequenceAdder( self.client, sequence_length=sequence_length, period=period, pad_end_of_episode=pad_end_of_episode) super().run_test_adder( adder=adder, first=first, steps=steps, expected_items=expected_sequences)
def test_adder( self, sequence_length: int, period: int, first, steps, expected_sequences, end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, repeat_episode_times: int = 1): adder = adders.SequenceAdder(self.client, sequence_length=sequence_length, period=period, end_of_episode_behavior=end_behavior) super().run_test_adder( adder=adder, first=first, steps=steps, expected_items=expected_sequences, repeat_episode_times=repeat_episode_times, end_behavior=end_behavior, signature=adder.signature(*test_utils.get_specs(steps[0])))
def test_adder(self, sequence_length: int, period: int, first, steps, expected_sequences, pad_end_of_episode: bool = True): client = test_utils.FakeClient() adder = adders.SequenceAdder( client, sequence_length=sequence_length, period=period, pad_end_of_episode=pad_end_of_episode) # Add all the data up to the final step. adder.add_first(first) for step in steps[:-1]: adder.add(*step) # Make sure the writer has been created but not closed. self.assertLen(client.writers, 1) self.assertFalse(client.writers[0].closed) # Add the final step. adder.add(*steps[-1]) # Ending the episode should close the writer. No new writer should yet have # been created as it is constructed lazily. self.assertLen(client.writers, 1) self.assertTrue(client.writers[0].closed) # Make sure our expected and observed transitions match. observed_sequences = list(p[1] for p in client.writers[0].priorities) for exp, obs in zip(expected_sequences, observed_sequences): np.testing.assert_array_equal(exp, obs) # Add the start of a second trajectory. adder.add_first(first) adder.add(*steps[0]) # Make sure this creates an open writer. self.assertLen(client.writers, 2) self.assertFalse(client.writers[1].closed)