Пример #1
0
  def test_sample_single_episode(self):
    num_episodes = 1
    sequence_length = 100
    batch_size = 10
    num_steps = 5
    env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length)
    # Insert only one episode in the RB.
    self._insert_random_data(
        env,
        num_steps=num_episodes * sequence_length,
        sequence_length=sequence_length)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        sequence_length=sequence_length,
        local_server=self._server)

    dataset = replay.as_dataset(batch_size, num_steps=num_steps)
    n_samples = 0
    for sample, _ in dataset.take(10):
      n_samples += 1
      episode, step = sample.observation
      # The episode should always be 0.
      episode_id = tf.constant(0, dtype=episode.dtype, shape=episode.shape)
      # All elements in the same batch should belong to the same episode.
      self.assertAllEqual(episode_id, episode)
      for n in range(num_steps):
        # All elements in the batch should have consecutive steps.
        self.assertAllEqual(step[:, 0] + n, step[:, n])
    # Ensure we can actually sampled 10 times.
    self.assertEqual(10, n_samples)
Пример #2
0
  def setUp(self):
    super(ReverbReplayBufferTest, self).setUp()

    # Prepare the environment (and the corresponding specs).
    self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3)
    tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec,
                                                  self._env.time_step_spec())
    tensor_action_spec = tensor_spec.from_spec(self._env.action_spec())
    self._data_spec = trajectory.Trajectory(
        step_type=tensor_time_step_spec.step_type,
        observation=tensor_time_step_spec.observation,
        action=tensor_action_spec,
        policy_info=(),
        next_step_type=tensor_time_step_spec.step_type,
        reward=tensor_time_step_spec.reward,
        discount=tensor_time_step_spec.discount,
    )
    table_spec = tf.nest.map_structure(
        lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape),
        self._data_spec)
    self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec)

    # Initialize and start a Reverb server (and set up a client to it).
    self._table_name = 'test_table'
    uniform_table = reverb.Table(
        self._table_name,
        max_size=100,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=table_spec,
    )
    self._server = reverb.Server([uniform_table])
    self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
Пример #3
0
  def test_sequential_ordering(self, num_steps):
    sequence_length = 10
    batch_size = 5
    env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length)
    self._insert_random_data(
        env,
        num_steps=batch_size * sequence_length,
        sequence_length=sequence_length)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        sequence_length=sequence_length,
        local_server=self._server)

    dataset = replay.as_dataset(batch_size, num_steps=num_steps)
    num_steps = num_steps or sequence_length
    for sample, _ in dataset.take(10):
      episode, step = sample.observation
      self.assertEqual((batch_size, num_steps), episode.shape)
      self.assertEqual((batch_size, num_steps), step.shape)
      for n in range(num_steps):
        # All elements in the same batch should belong to the same episode.
        self.assertAllEqual(episode[:, 0], episode[:, n])
        # All elements in the batch should have consecutive steps.
        self.assertAllEqual(step[:, 0] + n, step[:, n])
Пример #4
0
  def test_batched_episodes_dataset(self, sequence_length):
    # Observe batch_size * sequence_length steps to have at least 3 episodes
    batch_size = 3
    env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length)
    self._insert_random_data(
        env,
        num_steps=batch_size * sequence_length,
        sequence_length=sequence_length)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        sequence_length=None,
        local_server=self._server)

    dataset = replay.as_dataset(batch_size)
    for sample, _ in dataset.take(5):
      episode, step = sample.observation
      self.assertEqual((batch_size, sequence_length), episode.shape)
      self.assertEqual((batch_size, sequence_length), step.shape)
      for n in range(sequence_length):
        # All elements in the same batch should belong to the same episode.
        self.assertAllEqual(episode[:, 0], episode[:, n])
        # All elements in the same batch should have consecutive steps.
        self.assertAllEqual(step[:, 0] + n, step[:, n])
Пример #5
0
  def test_variable_length_episodes_dataset(self):
    # Add one episode of each length.
    for sequence_length in range(1, 10):
      env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length)
      self._insert_random_data(
          env,
          num_steps=sequence_length,
          sequence_length=sequence_length)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        sequence_length=None,
        local_server=self._server)

    # Make sure observations are off by 1 given we are counting transitions in
    # the env observations.
    dataset = replay.as_dataset(sample_batch_size=1)
    for sample, _ in dataset.take(5):
      episode, step = sample.observation
      self.assertIn(episode.shape[1], range(1, 10))
      self.assertIn(step.shape[1], range(1, 10))
      length = episode.shape[1]
      # All Episode id are 0.
      self.assertAllEqual([[0] * length], episode)
      # Steps id is sequential up its length.
      self.assertAllEqual([list(range(length))], step)
Пример #6
0
  def test_sequential(self):
    num_episodes = 3
    steps_per_episode = 4
    env = test_envs.EpisodeCountingEnv(steps_per_episode=steps_per_episode)

    for episode in range(num_episodes):
      step = 0
      time_step = env.reset()
      self.assertAllEqual((episode, step), time_step.observation)
      while not time_step.is_last():
        time_step = env.step(0)
        step += 1
        self.assertAllEqual((episode, step), time_step.observation)
      self.assertAllEqual((episode, steps_per_episode), time_step.observation)
Пример #7
0
 def test_validate_specs(self):
     env = test_envs.EpisodeCountingEnv(steps_per_episode=15)
     env_utils.validate_py_environment(env, episodes=10)