Ejemplo n.º 1
0
    def test_sampler_parameter_validation(self, **kwargs):
        if 'max_in_flight_samples_per_worker' not in kwargs:
            kwargs['max_in_flight_samples_per_worker'] = 1

        if 'want_error' in kwargs:
            error = kwargs.pop('want_error')
            with self.assertRaises(error):
                trajectory_dataset.TrajectoryDataset(
                    server_address=self._client.server_address,
                    table=TABLE,
                    dtypes=DTYPES,
                    shapes=SHAPES,
                    **kwargs)
        else:
            trajectory_dataset.TrajectoryDataset(
                server_address=self._client.server_address,
                table=TABLE,
                dtypes=DTYPES,
                shapes=SHAPES,
                **kwargs)
Ejemplo n.º 2
0
  def test_sample_variable_length_trajectory(self):
    with self._client.trajectory_writer(10) as writer:
      for i in range(10):
        writer.append([np.ones([3, 3], np.int32) * i])
        writer.create_item(TABLE, 1.0, {
            'last': writer.history[0][-1],
            'all': writer.history[0][:],
        })

    dataset = trajectory_dataset.TrajectoryDataset(
        tf.constant(self._client.server_address),
        table=tf.constant(TABLE),
        dtypes={
            'last': tf.int32,
            'all': tf.int32,
        },
        shapes={
            'last': tf.TensorShape([3, 3]),
            'all': tf.TensorShape([None, 3, 3]),
        },
        max_in_flight_samples_per_worker=1,
        flexible_batch_size=1)

    # Continue sample until we have observed all the trajectories.
    seen_lengths = set()
    while len(seen_lengths) < 10:
      sample = self._sample_from(dataset, 1)[0]

      # The structure should always be the same.
      tree.assert_same_structure(
          sample,
          replay_sample.ReplaySample(
              info=replay_sample.SampleInfo(
                  key=1,
                  probability=1.0,
                  table_size=10,
                  priority=0.5,
              ),
              data={
                  'last': None,
                  'all': None
              }))

      seen_lengths.add(sample.data['all'].shape[0])

    self.assertEqual(seen_lengths, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
Ejemplo n.º 3
0
    def test_sample_fixed_length_trajectory(self):
        self._populate_replay()

        dataset = trajectory_dataset.TrajectoryDataset(
            tf.constant(self._client.server_address),
            table=tf.constant(TABLE),
            dtypes=DTYPES,
            shapes=SHAPES,
            max_in_flight_samples_per_worker=1)

        tree.assert_same_structure(
            self._sample_from(dataset, 1)[0],
            replay_sample.ReplaySample(info=replay_sample.SampleInfo(
                key=1,
                probability=1.0,
                table_size=10,
                priority=0.5,
                times_sampled=1,
            ),
                                       data=SHAPES))
    def test_max_samples(self, num_workers_per_iterator,
                         max_in_flight_samples_per_worker, max_samples):

        server = reverb_server.Server(
            tables=[reverb_server.Table.queue(_TABLE, 10)])
        client = server.localhost_client()

        with client.trajectory_writer(10) as writer:
            for i in range(10):
                writer.append([np.ones([3, 3], np.int32) * i])
                writer.create_item(_TABLE, 1.0, {
                    'last': writer.history[0][-1],
                    'all': writer.history[0][:],
                })

        dataset = trajectory_dataset.TrajectoryDataset(
            tf.constant(client.server_address),
            table=tf.constant(_TABLE),
            dtypes={
                'last': tf.int32,
                'all': tf.int32,
            },
            shapes={
                'last': tf.TensorShape([3, 3]),
                'all': tf.TensorShape([None, 3, 3]),
            },
            num_workers_per_iterator=num_workers_per_iterator,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            max_samples=max_samples)

        # Check that it fetches exactly `max_samples` samples.
        it = dataset.as_numpy_iterator()
        self.assertLen(list(it), max_samples)
        # Check that no prefetching happen on the queue.
        self.assertEqual(client.server_info()[_TABLE].current_size,
                         10 - max_samples)