def test_timeout(self): dataset_0s = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), rate_limiter_timeout_ms=0, max_in_flight_samples_per_worker=100) dataset_1s = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), rate_limiter_timeout_ms=1000, max_in_flight_samples_per_worker=100) dataset_2s = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), rate_limiter_timeout_ms=2000, max_in_flight_samples_per_worker=100) start_time = time.time() with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError, r'End of sequence'): self._sample_from(dataset_0s, 1) duration = time.time() - start_time self.assertGreaterEqual(duration, 0) self.assertLess(duration, 5) start_time = time.time() with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError, r'End of sequence'): self._sample_from(dataset_1s, 1) duration = time.time() - start_time self.assertGreaterEqual(duration, 1) self.assertLess(duration, 10) start_time = time.time() with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError, r'End of sequence'): self._sample_from(dataset_2s, 1) duration = time.time() - start_time self.assertGreaterEqual(duration, 2) self.assertLess(duration, 10) # If we insert some data, and the rate limiter doesn't force any waiting, # then we can ask for a timeout of 0s and still get data back. self._populate_replay() got = self._sample_from(dataset_0s, 2) self.assertLen(got, 2)
def test_sampler_parameter_validation(self, **kwargs): dtypes = (tf.float32,) shapes = (tf.TensorShape([3, 3]),) if 'max_in_flight_samples_per_worker' not in kwargs: kwargs['max_in_flight_samples_per_worker'] = 100 if 'want_error' in kwargs: error = kwargs.pop('want_error') with self.assertRaises(error): timestep_dataset.TimestepDataset(self._client.server_address, 'dist', dtypes, shapes, **kwargs) else: timestep_dataset.TimestepDataset(self._client.server_address, 'dist', dtypes, shapes, **kwargs)
def test_session_is_closed_while_op_pending(self): dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=tf.float32, shapes=tf.TensorShape([]), max_in_flight_samples_per_worker=100) iterator = dataset.make_initializable_iterator() item = iterator.get_next() def _session_closer(sess, wait_time_secs): def _fn(): time.sleep(wait_time_secs) sess.close() return _fn with self.session() as sess: sess.run(iterator.initializer) thread = threading.Thread(target=_session_closer(sess, 3)) thread.start() with self.assertRaises(tf.errors.CancelledError): sess.run(item)
def test_max_samples(self, num_workers_per_iterator, max_in_flight_samples_per_worker, max_samples): s = reverb_server.Server([reverb_server.Table.queue('q', 10)]) c = s.localhost_client() for i in range(10): c.insert(i, {'q': 1}) ds = timestep_dataset.TimestepDataset( server_address=c.server_address, table='q', dtypes=tf.int64, shapes=tf.TensorShape([]), max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=num_workers_per_iterator, max_samples=max_samples) iterator = ds.make_one_shot_iterator() item = iterator.get_next() # Check that it fetches exactly 5 samples. with self.session() as sess: for _ in range(max_samples): sess.run(item) with self.assertRaises(tf.errors.OutOfRangeError): sess.run(item) # Check that no prefetching happened; Check that there are 5 items left in # the queue. self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples) np.testing.assert_array_equal( next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
def test_max_samples(self, num_workers_per_iterator, max_in_flight_samples_per_worker, max_samples): s = reverb_server.Server([reverb_server.Table.queue('q', 10)]) c = s.localhost_client() for i in range(10): c.insert(i, {'q': 1}) ds = timestep_dataset.TimestepDataset( server_address=c.server_address, table='q', dtypes=tf.int64, shapes=tf.TensorShape([]), max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=num_workers_per_iterator, max_samples=max_samples) # Check that it fetches exactly `max_samples` samples. it = ds.as_numpy_iterator() self.assertLen(list(it), max_samples) # Check that no prefetching happened in the queue. self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples) np.testing.assert_array_equal( next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
def test_multiple_iterators(self): with self._client.writer(100) as writer: for i in range(10): writer.append([np.ones((81, 81), dtype=np.float32) * i]) writer.create_item(table='dist', num_timesteps=10, priority=1) trajectory_length = 5 batch_size = 3 dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([81, 81]), ), max_in_flight_samples_per_worker=100) dataset = dataset.batch(trajectory_length) iterators = [ dataset.make_initializable_iterator() for _ in range(batch_size) ] items = tf.stack( [tf.squeeze(iterator.get_next().data) for iterator in iterators]) with self.session() as session: session.run([iterator.initializer for iterator in iterators]) got = session.run(items) self.assertEqual(got.shape, (batch_size, trajectory_length, 81, 81)) want = np.array( [[np.ones([81, 81]) * i for i in range(trajectory_length)]] * batch_size) np.testing.assert_array_equal(got, want)
def timestep_dataset_fn(i): tf.print('Creating dataset for replica; index:', i) return timestep_dataset.TimestepDataset( self._client.server_address, table=tf.constant('dist'), dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), max_in_flight_samples_per_worker=100).take(2)
def test_iterate_nested_and_batched(self): with self._client.writer(100) as writer: for i in range(1000): writer.append({ 'observation': { 'data': np.zeros((3, 3), dtype=np.float32), 'extras': [ np.int64(10), np.ones([1], dtype=np.int32), ], }, 'reward': np.zeros((10, 10), dtype=np.float32), }) if i % 5 == 0 and i >= 100: writer.create_item(table='dist', num_timesteps=100, priority=1) dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32), shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None), tf.TensorShape([1]))), tf.TensorShape([10, 10])), max_in_flight_samples_per_worker=100) dataset = dataset.batch(3) structure = { 'observation': { 'data': tf.TensorSpec([3, 3], tf.float32), 'extras': [ tf.TensorSpec([], tf.int64), tf.TensorSpec([1], tf.int32), ], }, 'reward': tf.TensorSpec([], tf.int64), } got = self._sample_from(dataset, 10) self.assertLen(got, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) transition = tree.unflatten_as(structure, tree.flatten(sample.data)) np.testing.assert_array_equal( transition['observation']['data'], np.zeros([3, 3, 3], dtype=np.float32)) np.testing.assert_array_equal( transition['observation']['extras'][0], np.ones([3], dtype=np.int64) * 10) np.testing.assert_array_equal( transition['observation']['extras'][1], np.ones([3, 1], dtype=np.int32)) np.testing.assert_array_equal( transition['reward'], np.zeros([3, 10, 10], dtype=np.float32))
def test_timeout_invalid_arguments(self): with self.assertRaisesRegex(ValueError, r'must be an integer >= -1'): timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), rate_limiter_timeout_ms=-2, max_in_flight_samples_per_worker=100)
def test_incompatible_dataset_shapes_and_types_without_signature(self): self._populate_replay() ds_wrong_shape = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.float32, ), shapes=(tf.TensorShape([]), ), max_in_flight_samples_per_worker=100) with self.assertRaisesRegex( tf.errors.InvalidArgumentError, r'Specification has \(dtype, shape\): \(float, \[\]\). ' r'Tensor has \(dtype, shape\): \(float, \[3,3\]\).'): self._sample_from(ds_wrong_shape, 1)
def test_incompatible_signature_dtype(self, table_name): self._populate_replay() dataset = timestep_dataset.TimestepDataset( self._client.server_address, table=table_name, dtypes=(tf.int64,), shapes=(tf.TensorShape([3, 3]),), max_in_flight_samples_per_worker=100) with self.assertRaisesWithPredicateMatch( tf.errors.InvalidArgumentError, r'Requested incompatible tensor at flattened index 0 from table ' r'\'{}\'. Requested \(dtype, shape\): \(int64, \[3,3\]\). ' r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'.format(table_name)): self._sample_from(dataset, 10)
def test_inconsistent_signature_size(self, table_name): self._populate_replay() dataset = timestep_dataset.TimestepDataset( self._client.server_address, table=table_name, dtypes=(tf.float32, tf.float64), shapes=(tf.TensorShape([3, 3]), tf.TensorShape([])), max_in_flight_samples_per_worker=100) with self.assertRaisesWithPredicateMatch( tf.errors.InvalidArgumentError, r'Inconsistent number of tensors requested from table \'{}\'. ' r'Requested 6 tensors, but table signature shows 5 tensors.'. format(table_name)): self._sample_from(dataset, 10)
def test_iterate(self): self._populate_replay() dataset = timestep_dataset.TimestepDataset( tf.constant(self._client.server_address), table=tf.constant('dist'), dtypes=(tf.float32,), shapes=(tf.TensorShape([3, 3]),), max_in_flight_samples_per_worker=100) got = self._sample_from(dataset, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) # A single sample is returned so the key should be a scalar int64. self.assertIsInstance(sample.info.key, np.uint64) np.testing.assert_array_equal(sample.data[0], np.zeros((3, 3), dtype=np.float32))
def test_iterate_over_blobs(self): for _ in range(10): self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1}) dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.int32, ), shapes=(tf.TensorShape([3, 3]), ), max_in_flight_samples_per_worker=100) got = self._sample_from(dataset, 20) self.assertLen(got, 20) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) self.assertIsInstance(sample.info.key, np.uint64) self.assertIsInstance(sample.info.probability, np.float64) np.testing.assert_array_equal(sample.data[0], np.ones((3, 3), dtype=np.int32))
def test_converts_spec_lists_into_tuples(self): for _ in range(10): data = [ (np.ones([1, 1], dtype=np.int32), ), [ np.ones([3, 3], dtype=np.int8), (np.ones([2, 2], dtype=np.float64), ) ], ] self._client.insert(data, {'dist': 1}) dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=[ (tf.int32, ), [ tf.int8, (tf.float64, ), ], ], shapes=[ (tf.TensorShape([1, 1]), ), [ tf.TensorShape([3, 3]), (tf.TensorShape([2, 2]), ), ], ], max_in_flight_samples_per_worker=100) got = self._sample_from(dataset, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) self.assertIsInstance(sample.info.key, np.uint64) tree.assert_same_structure(sample.data, ( (None, ), ( None, (None, ), ), ))
def test_iterate_batched(self, table_name): self._populate_replay() dataset = timestep_dataset.TimestepDataset( self._client.server_address, table=table_name, dtypes=(tf.float32, ), shapes=(tf.TensorShape([3, 3]), ), max_in_flight_samples_per_worker=100) dataset = dataset.batch(2, True) got = self._sample_from(dataset, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) # The keys should be batched up like the data. self.assertEqual(sample.info.key.shape, (2, )) np.testing.assert_array_equal( sample.data[0], np.zeros((2, 3, 3), dtype=np.float32))
def test_respects_flexible_batch_size(self, flexible_batch_size): for _ in range(10): self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1}) dataset = timestep_dataset.TimestepDataset( self._client.server_address, table='dist', dtypes=(tf.int32, ), shapes=(tf.TensorShape([3, 3]), ), max_in_flight_samples_per_worker=100, flexible_batch_size=flexible_batch_size) iterator = dataset.make_initializable_iterator() dataset_item = iterator.get_next() self.evaluate(iterator.initializer) for _ in range(100): self.evaluate(dataset_item) # Check that the buffer is incremented by steps of flexible_batch_size. self.assertEqual( self._get_num_samples('dist') % flexible_batch_size, 0)