Exemplo n.º 1
0
  def test_adder(self, n_step, additional_discount, first, steps,
                 expected_transitions):
    # Create a fake client to record our writes and use it in the adder.
    client = test_utils.FakeClient()
    adder = adders.NStepTransitionAdder(client, n_step, additional_discount)

    # Add all the data up to the final step.
    adder.add_first(first)
    for step in steps[:-1]:
      adder.add(*step)

    # Make sure the writer has been created but not closed.
    self.assertLen(client.writers, 1)
    self.assertFalse(client.writers[0].closed)

    # Add the final step.
    adder.add(*steps[-1])

    # Ending the episode should close the writer. No new writer should yet have
    # been created as it is constructed lazily.
    self.assertLen(client.writers, 1)
    self.assertTrue(client.writers[0].closed)

    # Make sure our expected and observed transitions match.
    observed_transitions = list(p[1][0] for p in client.writers[0].priorities)
    for exp, obs in zip(expected_transitions, observed_transitions):
      tree.map_structure(np.testing.assert_array_almost_equal, exp, obs)

    # Add the start of a second trajectory.
    adder.add_first(first)
    adder.add(*steps[0])

    # Make sure this creates an open writer.
    self.assertLen(client.writers, 2)
    self.assertFalse(client.writers[1].closed)
Exemplo n.º 2
0
 def test_adder(self, n_step, additional_discount, first, steps,
                expected_transitions):
     adder = adders.NStepTransitionAdder(self.client, n_step,
                                         additional_discount)
     super().run_test_adder(adder=adder,
                            first=first,
                            steps=steps,
                            expected_items=expected_transitions)
Exemplo n.º 3
0
 def test_adder(self, n_step, additional_discount, first, steps,
                expected_transitions):
     adder = adders.NStepTransitionAdder(self.client, n_step,
                                         additional_discount)
     super().run_test_adder(
         adder=adder,
         first=first,
         steps=steps,
         expected_items=expected_transitions,
         stack_sequence_fields=False,
         signature=adder.signature(*test_utils.get_specs(steps[0])))