示例#1
0
    def test_incompatible_signature_shape(self, table_name):
        self._populate_replay()

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3]), ),
            max_in_flight_samples_per_worker=100)
        with self.assertRaisesWithPredicateMatch(
                tf.errors.InvalidArgumentError,
                r'Requested incompatible tensor at flattened index 4 from table '
                r'\'{}\'.  Requested \(dtype, shape\): \(float, \[3\]\).  '
                r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'.format(
                    table_name)):
            self._sample_from(dataset, 10)

        dataset_emit_sequences = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([None, 3]), ),
            emit_timesteps=False,
            max_in_flight_samples_per_worker=100)
        with self.assertRaisesWithPredicateMatch(
                tf.errors.InvalidArgumentError,
                r'Requested incompatible tensor at flattened index 4 from table '
                r'\'{}\'.  Requested \(dtype, shape\): \(float, \[3\]\).  '
                r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'.format(
                    table_name)):
            self._sample_from(dataset_emit_sequences, 10)
示例#2
0
    def test_incompatible_dataset_shapes_and_types_without_signature(self):
        self._populate_replay()
        ds_wrong_shape = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([]), ),
            max_in_flight_samples_per_worker=100)
        with self.assertRaisesRegex(
                tf.errors.InvalidArgumentError,
                r'Specification has \(dtype, shape\): \(float, \[\]\).  '
                r'Tensor has \(dtype, shape\): \(float, \[3,3\]\).'):
            self._sample_from(ds_wrong_shape, 1)

        ds_full_sequences_wrong_shape = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([None]), ),
            emit_timesteps=False,
            max_in_flight_samples_per_worker=100)

        with self.assertRaisesRegex(
                tf.errors.InvalidArgumentError,
                r'Specification has \(dtype, shape\): \(float, \[\]\).  '
                r'Tensor has \(dtype, shape\): \(float, \[3,3\]\).'):
            self._sample_from(ds_full_sequences_wrong_shape, 1)
示例#3
0
    def test_timeout(self):

        dataset_0s = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=0,
            max_in_flight_samples_per_worker=100)

        dataset_1s = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=1000,
            max_in_flight_samples_per_worker=100)

        dataset_2s = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=2000,
            max_in_flight_samples_per_worker=100)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_0s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 0)
        self.assertLess(duration, 5)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_1s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 1)
        self.assertLess(duration, 10)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_2s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 2)
        self.assertLess(duration, 10)

        # If we insert some data, and the rate limiter doesn't force any waiting,
        # then we can ask for a timeout of 0s and still get data back.
        self._populate_replay()
        got = self._sample_from(dataset_0s, 2)
        self.assertLen(got, 2)
示例#4
0
    def test_sampler_parameter_validation(self, **kwargs):
        dtypes = (tf.float32, )
        shapes = (tf.TensorShape([3, 3]), )

        if 'max_in_flight_samples_per_worker' not in kwargs:
            kwargs['max_in_flight_samples_per_worker'] = 100

        if 'want_error' in kwargs:
            error = kwargs.pop('want_error')
            with self.assertRaises(error):
                reverb_dataset.ReplayDataset(self._client.server_address,
                                             'dist', dtypes, shapes, **kwargs)
        else:
            reverb_dataset.ReplayDataset(self._client.server_address, 'dist',
                                         dtypes, shapes, **kwargs)
示例#5
0
    def test_iterate_with_unknown_sequence_length(self, table_name,
                                                  sequence_length):
        self._populate_replay(sequence_length)

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([None, 3, 3]), ),
            emit_timesteps=False,
            sequence_length=None,
            max_in_flight_samples_per_worker=100)

        # Check the shape of the items.
        iterator = dataset.make_initializable_iterator()
        dataset_item = iterator.get_next()
        self.assertIsNone(dataset_item.info.key.shape.as_list()[0], None)
        self.assertIsNone(dataset_item.data[0].shape.as_list()[0], None)

        # Verify that once evaluated, the samples has the expected length.
        got = self._sample_from(dataset, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            # The keys and data should be batched up by the sequence length.
            self.assertEqual(sample.info.key.shape, (sequence_length, ))
            np.testing.assert_array_equal(
                sample.data[0],
                np.zeros((sequence_length, 3, 3), dtype=np.float32))
示例#6
0
    def test_multiple_iterators(self):
        with self._client.writer(100) as writer:
            for i in range(10):
                writer.append([np.ones((81, 81), dtype=np.float32) * i])
            writer.create_item(table='dist', num_timesteps=10, priority=1)

        trajectory_length = 5
        batch_size = 3

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([81, 81]), ),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(trajectory_length)

        iterators = [
            dataset.make_initializable_iterator() for _ in range(batch_size)
        ]
        items = tf.stack(
            [tf.squeeze(iterator.get_next().data) for iterator in iterators])

        with self.session() as session:
            session.run([iterator.initializer for iterator in iterators])
            got = session.run(items)
            self.assertEqual(got.shape,
                             (batch_size, trajectory_length, 81, 81))

            want = np.array(
                [[np.ones([81, 81]) * i
                  for i in range(trajectory_length)]] * batch_size)
            np.testing.assert_array_equal(got, want)
示例#7
0
 def test_timeout_invalid_arguments(self):
     with self.assertRaisesRegex(ValueError, r'must be an integer >= -1'):
         reverb_dataset.ReplayDataset(self._client.server_address,
                                      table='dist',
                                      dtypes=(tf.float32, ),
                                      shapes=(tf.TensorShape([3, 3]), ),
                                      rate_limiter_timeout_ms=-2,
                                      max_in_flight_samples_per_worker=100)
 def reverb_dataset_fn(i):
   tf.print('Creating dataset for replica; index:', i)
   return reverb_dataset.ReplayDataset(
       self._client.server_address,
       table=tf.constant('dist'),
       dtypes=(tf.float32,),
       shapes=(tf.TensorShape([3, 3]),),
       max_in_flight_samples_per_worker=100).take(2)
示例#9
0
    def test_iterate_nested_and_batched(self):
        with self._client.writer(100) as writer:
            for i in range(1000):
                writer.append({
                    'observation': {
                        'data': np.zeros((3, 3), dtype=np.float32),
                        'extras': [
                            np.int64(10),
                            np.ones([1], dtype=np.int32),
                        ],
                    },
                    'reward': np.zeros((10, 10), dtype=np.float32),
                })
                if i % 5 == 0 and i >= 100:
                    writer.create_item(table='dist',
                                       num_timesteps=100,
                                       priority=1)

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32),
            shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None),
                                              tf.TensorShape([1]))),
                    tf.TensorShape([10, 10])),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(3)

        structure = {
            'observation': {
                'data':
                tf.TensorSpec([3, 3], tf.float32),
                'extras': [
                    tf.TensorSpec([], tf.int64),
                    tf.TensorSpec([1], tf.int32),
                ],
            },
            'reward': tf.TensorSpec([], tf.int64),
        }

        got = self._sample_from(dataset, 10)
        self.assertLen(got, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            transition = tree.unflatten_as(structure,
                                           tree.flatten(sample.data))
            np.testing.assert_array_equal(
                transition['observation']['data'],
                np.zeros([3, 3, 3], dtype=np.float32))
            np.testing.assert_array_equal(
                transition['observation']['extras'][0],
                np.ones([3], dtype=np.int64) * 10)
            np.testing.assert_array_equal(
                transition['observation']['extras'][1],
                np.ones([3, 1], dtype=np.int32))
            np.testing.assert_array_equal(
                transition['reward'], np.zeros([3, 10, 10], dtype=np.float32))
 def test_incompatible_shape_when_using_sequence_length(self, sequence_length):
   with self.assertRaises(ValueError):
     reverb_dataset.ReplayDataset(
         self._client.server_address,
         table='dist',
         dtypes=(tf.float32,),
         shapes=(tf.TensorShape([sequence_length + 1, 3, 3]),),
         emit_timesteps=False,
         sequence_length=sequence_length,
         max_in_flight_samples_per_worker=100)
  def test_checks_sequence_length_when_timesteps_emitted(
      self, table_name, actual_sequence_length, provided_sequence_length):
    self._populate_replay(actual_sequence_length)

    dataset = reverb_dataset.ReplayDataset(
        self._client.server_address,
        table=table_name,
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([provided_sequence_length, 3, 3]),),
        emit_timesteps=True,
        sequence_length=provided_sequence_length,
        max_in_flight_samples_per_worker=100)

    with self.assertRaises(tf.errors.InvalidArgumentError):
      self._sample_from(dataset, 10)
示例#12
0
    def test_inconsistent_signature_size(self, table_name):
        self._populate_replay()

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, tf.float64),
            shapes=(tf.TensorShape([3, 3]), tf.TensorShape([])),
            max_in_flight_samples_per_worker=100)
        with self.assertRaisesWithPredicateMatch(
                tf.errors.InvalidArgumentError,
                r'Inconsistent number of tensors requested from table \'{}\'.  '
                r'Requested 6 tensors, but table signature shows 5 tensors.'.
                format(table_name)):
            self._sample_from(dataset, 10)
示例#13
0
    def test_iterate(self):
        self._populate_replay()

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100)
        got = self._sample_from(dataset, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            # A single sample is returned so the key should be a scalar int64.
            self.assertIsInstance(sample.info.key, np.uint64)
            np.testing.assert_array_equal(sample.data[0],
                                          np.zeros((3, 3), dtype=np.float32))
示例#14
0
    def test_iterate_over_blobs(self):
        for _ in range(10):
            self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1})

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.int32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100)

        got = self._sample_from(dataset, 20)
        self.assertLen(got, 20)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            self.assertIsInstance(sample.info.key, np.uint64)
            self.assertIsInstance(sample.info.probability, np.float64)
            np.testing.assert_array_equal(sample.data[0],
                                          np.ones((3, 3), dtype=np.int32))
示例#15
0
    def test_converts_spec_lists_into_tuples(self):
        for _ in range(10):
            data = [
                (np.ones([1, 1], dtype=np.int32), ),
                [
                    np.ones([3, 3], dtype=np.int8),
                    (np.ones([2, 2], dtype=np.float64), )
                ],
            ]
            self._client.insert(data, {'dist': 1})

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=[
                (tf.int32, ),
                [
                    tf.int8,
                    (tf.float64, ),
                ],
            ],
            shapes=[
                (tf.TensorShape([1, 1]), ),
                [
                    tf.TensorShape([3, 3]),
                    (tf.TensorShape([2, 2]), ),
                ],
            ],
            max_in_flight_samples_per_worker=100)

        got = self._sample_from(dataset, 10)

        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            self.assertIsInstance(sample.info.key, np.uint64)
            tree.assert_same_structure(sample.data, (
                (None, ),
                (
                    None,
                    (None, ),
                ),
            ))
示例#16
0
    def test_iterate_batched(self, table_name):
        self._populate_replay()

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(2, True)

        got = self._sample_from(dataset, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            # The keys should be batched up like the data.
            self.assertEqual(sample.info.key.shape, (2, ))

            np.testing.assert_array_equal(
                sample.data[0], np.zeros((2, 3, 3), dtype=np.float32))
示例#17
0
  def test_respects_flexible_batch_size(self, flexible_batch_size):
    for _ in range(10):
      self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1})

    dataset = reverb_dataset.ReplayDataset(
        self._client.server_address,
        table='dist',
        dtypes=(tf.int32,),
        shapes=(tf.TensorShape([3, 3]),),
        max_in_flight_samples_per_worker=100,
        flexible_batch_size=flexible_batch_size)

    iterator = dataset.make_initializable_iterator()
    dataset_item = iterator.get_next()
    self.evaluate(iterator.initializer)

    for _ in range(100):
      self.evaluate(dataset_item)

      # Check that the buffer is incremented by steps of flexible_batch_size.
      self.assertEqual(self._get_num_samples('dist') % flexible_batch_size, 0)
  def test_iterate_with_sequence_length(self, table_name, sequence_length,
                                        max_time_steps):
    # Also ensure we get sequence_length-shaped outputs when
    # writers' max_time_steps != sequence_length.
    self._populate_replay(sequence_length, max_time_steps=max_time_steps)

    dataset = reverb_dataset.ReplayDataset(
        self._client.server_address,
        table=table_name,
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([sequence_length, 3, 3]),),
        emit_timesteps=False,
        sequence_length=sequence_length,
        max_in_flight_samples_per_worker=100)

    got = self._sample_from(dataset, 10)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)

      # The keys and data should be batched up by the sequence length.
      self.assertEqual(sample.info.key.shape, (sequence_length,))
      np.testing.assert_array_equal(
          sample.data[0], np.zeros((sequence_length, 3, 3), dtype=np.float32))
示例#19
0
    def test_session_is_closed_while_op_pending(self):
        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=tf.float32,
            shapes=tf.TensorShape([]),
            max_in_flight_samples_per_worker=100)

        iterator = dataset.make_initializable_iterator()
        item = iterator.get_next()

        def _session_closer(sess, wait_time_secs):
            def _fn():
                time.sleep(wait_time_secs)
                sess.close()

            return _fn

        with self.session() as sess:
            sess.run(iterator.initializer)
            thread = threading.Thread(target=_session_closer(sess, 3))
            thread.start()
            with self.assertRaises(tf.errors.CancelledError):
                sess.run(item)
示例#20
0
    def dataset(self,
                table: str,
                dtypes: Sequence[Any],
                shapes: Sequence[Any],
                capacity: int = 100,
                num_workers_per_iterator: int = -1,
                max_samples_per_stream: int = -1,
                sequence_length: Optional[int] = None,
                emit_timesteps: bool = True,
                rate_limiter_timeout_ms: int = -1) -> dataset.ReplayDataset:
        """DEPRECATED, please use dataset.ReplayDataset instead.

    A tf.data.Dataset which samples timesteps from the ReverbService.

    Note: Uses of Python lists are converted into tuples as nest used by the
    tf.data API doesn't have good support for lists.

    See dataset.ReplayDataset for detailed documentation.

    Args:
      table: Probability table to sample from.
      dtypes: Dtypes of the data output. Can be nested.
      shapes: Shapes of the data output. Can be nested. When `emit_timesteps` is
        True this is the shape of a single timestep in the sampled items; when
        it is False shapes must include `sequence_length`.
      capacity: (Defaults to 100) Maximum number of samples requested by the
        workers with each request. Higher values give higher throughput but too
        big values can result in skewed sampling distributions as large number
        of samples are fetched from single snapshot of the replay (followed by a
        period of lower activity as the samples are consumed). A good rule of
        thumb is to set this value to 2-3x times the batch size used.
      num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number
        of worker threads to create per dataset iterator. When the selected
        table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be
        used to avoid races causing invalid ordering of items. For all other
        samplers, this value should be roughly equal to the number of threads
        available on the CPU.
      max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum
        number of samples to fetch from a stream before a new call is made.
        Keeping this number low ensures that the data is fetched uniformly from
        all server.
      sequence_length: (Defaults to None, i.e unknown) The number of timesteps
        that each sample consists of. If set then the length of samples received
        from the server will be validated against this number.
      emit_timesteps: (Defaults to True) If set, timesteps instead of full
        sequences are retturned from the dataset. Returning sequences instead of
        timesteps can be more efficient as the memcopies caused by the splitting
        and batching of tensor can be avoided. Note that if set to False then
        then all `shapes` must have dim[0] equal to `sequence_length`.
      rate_limiter_timeout_ms: (Defaults to -1: infinite).  Timeout (in
        milliseconds) to wait on the rate limiter when sampling from the table.
        If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to
        `Table::Sample` describing how long to wait for the rate limiter to
        allow sampling. The first time that a request times out (across any of
        the workers), the Dataset iterator is closed and the sequence is
        considered finished.

    Returns:
      A ReplayDataset with the above specification.
    """
        logging.warning(
            'TFClient.dataset is DEPRECATED! Please use ReplayDataset (see '
            './dataset.py) instead.')
        return dataset.ReplayDataset(
            server_address=self._server_address,
            table=table,
            dtypes=dtypes,
            shapes=shapes,
            max_in_flight_samples_per_worker=capacity,
            num_workers_per_iterator=num_workers_per_iterator,
            max_samples_per_stream=max_samples_per_stream,
            sequence_length=sequence_length,
            emit_timesteps=emit_timesteps,
            rate_limiter_timeout_ms=rate_limiter_timeout_ms)