예제 #1
0
    def test_adder(self, max_sequence_length):
        adder = adders.EpisodeAdder(self.client, max_sequence_length)

        # Create a simple trajectory to add.
        observations = range(max_sequence_length)
        first, steps = test_utils.make_trajectory(observations)

        expected_episode = test_utils.make_sequence(observations)
        self.run_test_adder(adder=adder,
                            first=first,
                            steps=steps,
                            expected_items=[expected_episode])
예제 #2
0
    def test_nonzero_padding(self, max_sequence_length, padding):
        adder = adders.EpisodeAdder(self.client,
                                    max_sequence_length + padding,
                                    padding_fn=lambda s, d: np.zeros(s, d) - 1)

        # Create a simple trajectory to add.
        observations = range(max_sequence_length)
        first, steps = test_utils.make_trajectory(observations)

        expected_episode = test_utils.make_sequence(observations)
        for _ in range(padding):
            expected_episode.append((-1, -1, -1.0, -1.0, False, ()))

        self.run_test_adder(
            adder=adder,
            first=first,
            steps=steps,
            expected_items=[expected_episode],
            signature=adder.signature(*test_utils.get_specs(steps[0])))
예제 #3
0
파일: episode_test.py 프로젝트: zzp110/acme
  def test_adder(self, max_sequence_length):
    client = test_utils.FakeClient()
    adder = adders.EpisodeAdder(client, max_sequence_length)

    # Create a simple trajectory to add.
    observations = range(max_sequence_length)
    first, steps = test_utils.make_trajectory(observations)

    # Add everything up to the final transition.
    adder.add_first(first)
    for action, step in steps[:-1]:
      adder.add(action, step)

    if max_sequence_length == 2:
      # Nothing has been written since we only have an initial step and a
      # final step (corresponding to a single transition).
      self.assertEmpty(client.writers)

    else:
      # No priorities should have been written so far but all timesteps (all
      # but the last one) should have been sent to the writer.
      self.assertEmpty(client.writers[0].priorities)
      self.assertLen(client.writers[0].timesteps, max_sequence_length - 2)

    # Adding a terminal timestep should close the writer and insert a
    # prioritized sample referencing all the timesteps (including the padded
    # terminating observation).
    action, step = steps[-1]
    adder.add(action, step)

    # The writer should be closed and should have max_sequence_length timesteps.
    self.assertTrue(client.writers[0].closed)
    self.assertLen(client.writers[0].timesteps, max_sequence_length)

    # Make the sequence of data and the priority table entry we expect.
    expected_sequence = test_utils.make_sequence(observations)
    expected_entry = (base.DEFAULT_PRIORITY_TABLE, expected_sequence, 1.0)

    self.assertEqual(client.writers[0].priorities, [expected_entry])