def test_timeout(self):

        dataset_0s = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=0,
            max_in_flight_samples_per_worker=100)

        dataset_1s = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=1000,
            max_in_flight_samples_per_worker=100)

        dataset_2s = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            rate_limiter_timeout_ms=2000,
            max_in_flight_samples_per_worker=100)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_0s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 0)
        self.assertLess(duration, 5)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_1s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 1)
        self.assertLess(duration, 10)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_2s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 2)
        self.assertLess(duration, 10)

        # If we insert some data, and the rate limiter doesn't force any waiting,
        # then we can ask for a timeout of 0s and still get data back.
        self._populate_replay()
        got = self._sample_from(dataset_0s, 2)
        self.assertLen(got, 2)
Beispiel #2
0
  def test_sampler_parameter_validation(self, **kwargs):
    dtypes = (tf.float32,)
    shapes = (tf.TensorShape([3, 3]),)

    if 'max_in_flight_samples_per_worker' not in kwargs:
      kwargs['max_in_flight_samples_per_worker'] = 100

    if 'want_error' in kwargs:
      error = kwargs.pop('want_error')
      with self.assertRaises(error):
        timestep_dataset.TimestepDataset(self._client.server_address, 'dist',
                                         dtypes, shapes, **kwargs)
    else:
      timestep_dataset.TimestepDataset(self._client.server_address, 'dist',
                                       dtypes, shapes, **kwargs)
Beispiel #3
0
  def test_session_is_closed_while_op_pending(self):
    dataset = timestep_dataset.TimestepDataset(
        self._client.server_address,
        table='dist',
        dtypes=tf.float32,
        shapes=tf.TensorShape([]),
        max_in_flight_samples_per_worker=100)

    iterator = dataset.make_initializable_iterator()
    item = iterator.get_next()

    def _session_closer(sess, wait_time_secs):

      def _fn():
        time.sleep(wait_time_secs)
        sess.close()

      return _fn

    with self.session() as sess:
      sess.run(iterator.initializer)
      thread = threading.Thread(target=_session_closer(sess, 3))
      thread.start()
      with self.assertRaises(tf.errors.CancelledError):
        sess.run(item)
Beispiel #4
0
  def test_max_samples(self, num_workers_per_iterator,
                       max_in_flight_samples_per_worker, max_samples):
    s = reverb_server.Server([reverb_server.Table.queue('q', 10)])
    c = s.localhost_client()

    for i in range(10):
      c.insert(i, {'q': 1})

    ds = timestep_dataset.TimestepDataset(
        server_address=c.server_address,
        table='q',
        dtypes=tf.int64,
        shapes=tf.TensorShape([]),
        max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
        num_workers_per_iterator=num_workers_per_iterator,
        max_samples=max_samples)

    iterator = ds.make_one_shot_iterator()
    item = iterator.get_next()
    # Check that it fetches exactly 5 samples.
    with self.session() as sess:
      for _ in range(max_samples):
        sess.run(item)
      with self.assertRaises(tf.errors.OutOfRangeError):
        sess.run(item)

    # Check that no prefetching happened; Check that there are 5 items left in
    # the queue.
    self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples)
    np.testing.assert_array_equal(
        next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
    def test_max_samples(self, num_workers_per_iterator,
                         max_in_flight_samples_per_worker, max_samples):
        s = reverb_server.Server([reverb_server.Table.queue('q', 10)])
        c = s.localhost_client()

        for i in range(10):
            c.insert(i, {'q': 1})

        ds = timestep_dataset.TimestepDataset(
            server_address=c.server_address,
            table='q',
            dtypes=tf.int64,
            shapes=tf.TensorShape([]),
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=num_workers_per_iterator,
            max_samples=max_samples)

        # Check that it fetches exactly `max_samples` samples.
        it = ds.as_numpy_iterator()
        self.assertLen(list(it), max_samples)

        # Check that no prefetching happened in the queue.
        self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples)
        np.testing.assert_array_equal(
            next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
    def test_multiple_iterators(self):
        with self._client.writer(100) as writer:
            for i in range(10):
                writer.append([np.ones((81, 81), dtype=np.float32) * i])
            writer.create_item(table='dist', num_timesteps=10, priority=1)

        trajectory_length = 5
        batch_size = 3

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([81, 81]), ),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(trajectory_length)

        iterators = [
            dataset.make_initializable_iterator() for _ in range(batch_size)
        ]
        items = tf.stack(
            [tf.squeeze(iterator.get_next().data) for iterator in iterators])

        with self.session() as session:
            session.run([iterator.initializer for iterator in iterators])
            got = session.run(items)
            self.assertEqual(got.shape,
                             (batch_size, trajectory_length, 81, 81))

            want = np.array(
                [[np.ones([81, 81]) * i
                  for i in range(trajectory_length)]] * batch_size)
            np.testing.assert_array_equal(got, want)
 def timestep_dataset_fn(i):
     tf.print('Creating dataset for replica; index:', i)
     return timestep_dataset.TimestepDataset(
         self._client.server_address,
         table=tf.constant('dist'),
         dtypes=(tf.float32, ),
         shapes=(tf.TensorShape([3, 3]), ),
         max_in_flight_samples_per_worker=100).take(2)
    def test_iterate_nested_and_batched(self):
        with self._client.writer(100) as writer:
            for i in range(1000):
                writer.append({
                    'observation': {
                        'data': np.zeros((3, 3), dtype=np.float32),
                        'extras': [
                            np.int64(10),
                            np.ones([1], dtype=np.int32),
                        ],
                    },
                    'reward': np.zeros((10, 10), dtype=np.float32),
                })
                if i % 5 == 0 and i >= 100:
                    writer.create_item(table='dist',
                                       num_timesteps=100,
                                       priority=1)

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32),
            shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None),
                                              tf.TensorShape([1]))),
                    tf.TensorShape([10, 10])),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(3)

        structure = {
            'observation': {
                'data':
                tf.TensorSpec([3, 3], tf.float32),
                'extras': [
                    tf.TensorSpec([], tf.int64),
                    tf.TensorSpec([1], tf.int32),
                ],
            },
            'reward': tf.TensorSpec([], tf.int64),
        }

        got = self._sample_from(dataset, 10)
        self.assertLen(got, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            transition = tree.unflatten_as(structure,
                                           tree.flatten(sample.data))
            np.testing.assert_array_equal(
                transition['observation']['data'],
                np.zeros([3, 3, 3], dtype=np.float32))
            np.testing.assert_array_equal(
                transition['observation']['extras'][0],
                np.ones([3], dtype=np.int64) * 10)
            np.testing.assert_array_equal(
                transition['observation']['extras'][1],
                np.ones([3, 1], dtype=np.int32))
            np.testing.assert_array_equal(
                transition['reward'], np.zeros([3, 10, 10], dtype=np.float32))
 def test_timeout_invalid_arguments(self):
     with self.assertRaisesRegex(ValueError, r'must be an integer >= -1'):
         timestep_dataset.TimestepDataset(
             self._client.server_address,
             table='dist',
             dtypes=(tf.float32, ),
             shapes=(tf.TensorShape([3, 3]), ),
             rate_limiter_timeout_ms=-2,
             max_in_flight_samples_per_worker=100)
Beispiel #10
0
 def test_incompatible_dataset_shapes_and_types_without_signature(self):
     self._populate_replay()
     ds_wrong_shape = timestep_dataset.TimestepDataset(
         self._client.server_address,
         table='dist',
         dtypes=(tf.float32, ),
         shapes=(tf.TensorShape([]), ),
         max_in_flight_samples_per_worker=100)
     with self.assertRaisesRegex(
             tf.errors.InvalidArgumentError,
             r'Specification has \(dtype, shape\): \(float, \[\]\).  '
             r'Tensor has \(dtype, shape\): \(float, \[3,3\]\).'):
         self._sample_from(ds_wrong_shape, 1)
Beispiel #11
0
 def test_incompatible_signature_dtype(self, table_name):
   self._populate_replay()
   dataset = timestep_dataset.TimestepDataset(
       self._client.server_address,
       table=table_name,
       dtypes=(tf.int64,),
       shapes=(tf.TensorShape([3, 3]),),
       max_in_flight_samples_per_worker=100)
   with self.assertRaisesWithPredicateMatch(
       tf.errors.InvalidArgumentError,
       r'Requested incompatible tensor at flattened index 0 from table '
       r'\'{}\'.  Requested \(dtype, shape\): \(int64, \[3,3\]\).  '
       r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'.format(table_name)):
     self._sample_from(dataset, 10)
Beispiel #12
0
    def test_inconsistent_signature_size(self, table_name):
        self._populate_replay()

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, tf.float64),
            shapes=(tf.TensorShape([3, 3]), tf.TensorShape([])),
            max_in_flight_samples_per_worker=100)
        with self.assertRaisesWithPredicateMatch(
                tf.errors.InvalidArgumentError,
                r'Inconsistent number of tensors requested from table \'{}\'.  '
                r'Requested 6 tensors, but table signature shows 5 tensors.'.
                format(table_name)):
            self._sample_from(dataset, 10)
Beispiel #13
0
  def test_iterate(self):
    self._populate_replay()

    dataset = timestep_dataset.TimestepDataset(
        tf.constant(self._client.server_address),
        table=tf.constant('dist'),
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([3, 3]),),
        max_in_flight_samples_per_worker=100)
    got = self._sample_from(dataset, 10)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)
      # A single sample is returned so the key should be a scalar int64.
      self.assertIsInstance(sample.info.key, np.uint64)
      np.testing.assert_array_equal(sample.data[0],
                                    np.zeros((3, 3), dtype=np.float32))
Beispiel #14
0
    def test_iterate_over_blobs(self):
        for _ in range(10):
            self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1})

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.int32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100)

        got = self._sample_from(dataset, 20)
        self.assertLen(got, 20)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            self.assertIsInstance(sample.info.key, np.uint64)
            self.assertIsInstance(sample.info.probability, np.float64)
            np.testing.assert_array_equal(sample.data[0],
                                          np.ones((3, 3), dtype=np.int32))
Beispiel #15
0
    def test_converts_spec_lists_into_tuples(self):
        for _ in range(10):
            data = [
                (np.ones([1, 1], dtype=np.int32), ),
                [
                    np.ones([3, 3], dtype=np.int8),
                    (np.ones([2, 2], dtype=np.float64), )
                ],
            ]
            self._client.insert(data, {'dist': 1})

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=[
                (tf.int32, ),
                [
                    tf.int8,
                    (tf.float64, ),
                ],
            ],
            shapes=[
                (tf.TensorShape([1, 1]), ),
                [
                    tf.TensorShape([3, 3]),
                    (tf.TensorShape([2, 2]), ),
                ],
            ],
            max_in_flight_samples_per_worker=100)

        got = self._sample_from(dataset, 10)

        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            self.assertIsInstance(sample.info.key, np.uint64)
            tree.assert_same_structure(sample.data, (
                (None, ),
                (
                    None,
                    (None, ),
                ),
            ))
Beispiel #16
0
    def test_iterate_batched(self, table_name):
        self._populate_replay()

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table=table_name,
            dtypes=(tf.float32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100)
        dataset = dataset.batch(2, True)

        got = self._sample_from(dataset, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            # The keys should be batched up like the data.
            self.assertEqual(sample.info.key.shape, (2, ))

            np.testing.assert_array_equal(
                sample.data[0], np.zeros((2, 3, 3), dtype=np.float32))
Beispiel #17
0
    def test_respects_flexible_batch_size(self, flexible_batch_size):
        for _ in range(10):
            self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1})

        dataset = timestep_dataset.TimestepDataset(
            self._client.server_address,
            table='dist',
            dtypes=(tf.int32, ),
            shapes=(tf.TensorShape([3, 3]), ),
            max_in_flight_samples_per_worker=100,
            flexible_batch_size=flexible_batch_size)

        iterator = dataset.make_initializable_iterator()
        dataset_item = iterator.get_next()
        self.evaluate(iterator.initializer)

        for _ in range(100):
            self.evaluate(dataset_item)

            # Check that the buffer is incremented by steps of flexible_batch_size.
            self.assertEqual(
                self._get_num_samples('dist') % flexible_batch_size, 0)