示例#1
0
 def test_end_of_episode_behavior_set_correctly(self, pad_end_of_episode,
                                                break_end_of_episode,
                                                expected_behavior):
     adder = adders.SequenceAdder(self.client,
                                  sequence_length=5,
                                  period=3,
                                  pad_end_of_episode=pad_end_of_episode,
                                  break_end_of_episode=break_end_of_episode)
     self.assertEqual(adder._end_of_episode_behavior, expected_behavior)
示例#2
0
 def test_adder(self, sequence_length: int, period: int, first, steps,
                expected_sequences, pad_end_of_episode: bool = True):
   adder = adders.SequenceAdder(
       self.client,
       sequence_length=sequence_length,
       period=period,
       pad_end_of_episode=pad_end_of_episode)
   super().run_test_adder(
       adder=adder,
       first=first,
       steps=steps,
       expected_items=expected_sequences)
示例#3
0
 def test_adder(
         self,
         sequence_length: int,
         period: int,
         first,
         steps,
         expected_sequences,
         end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD,
         repeat_episode_times: int = 1):
     adder = adders.SequenceAdder(self.client,
                                  sequence_length=sequence_length,
                                  period=period,
                                  end_of_episode_behavior=end_behavior)
     super().run_test_adder(
         adder=adder,
         first=first,
         steps=steps,
         expected_items=expected_sequences,
         repeat_episode_times=repeat_episode_times,
         end_behavior=end_behavior,
         signature=adder.signature(*test_utils.get_specs(steps[0])))
示例#4
0
  def test_adder(self, sequence_length: int, period: int, first, steps,
                 expected_sequences, pad_end_of_episode: bool = True):
    client = test_utils.FakeClient()
    adder = adders.SequenceAdder(
        client,
        sequence_length=sequence_length,
        period=period,
        pad_end_of_episode=pad_end_of_episode)

    # Add all the data up to the final step.
    adder.add_first(first)
    for step in steps[:-1]:
      adder.add(*step)

    # Make sure the writer has been created but not closed.
    self.assertLen(client.writers, 1)
    self.assertFalse(client.writers[0].closed)

    # Add the final step.
    adder.add(*steps[-1])

    # Ending the episode should close the writer. No new writer should yet have
    # been created as it is constructed lazily.
    self.assertLen(client.writers, 1)
    self.assertTrue(client.writers[0].closed)

    # Make sure our expected and observed transitions match.
    observed_sequences = list(p[1] for p in client.writers[0].priorities)
    for exp, obs in zip(expected_sequences, observed_sequences):
      np.testing.assert_array_equal(exp, obs)

    # Add the start of a second trajectory.
    adder.add_first(first)
    adder.add(*steps[0])

    # Make sure this creates an open writer.
    self.assertLen(client.writers, 2)
    self.assertFalse(client.writers[1].closed)