def make_server(): return server.Server( tables=[ server.Table( 'dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table( 'dist2', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table( 'signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))), ], port=None, )
def test_max_samples(self, num_workers_per_iterator, max_in_flight_samples_per_worker, max_samples): s = reverb_server.Server([reverb_server.Table.queue('q', 10)]) c = s.localhost_client() for i in range(10): c.insert(i, {'q': 1}) ds = timestep_dataset.TimestepDataset( server_address=c.server_address, table='q', dtypes=tf.int64, shapes=tf.TensorShape([]), max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=num_workers_per_iterator, max_samples=max_samples) iterator = ds.make_one_shot_iterator() item = iterator.get_next() # Check that it fetches exactly 5 samples. with self.session() as sess: for _ in range(max_samples): sess.run(item) with self.assertRaises(tf.errors.OutOfRangeError): sess.run(item) # Check that no prefetching happened; Check that there are 5 items left in # the queue. self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples) np.testing.assert_array_equal( next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
def test_max_samples(self, num_workers_per_iterator, max_in_flight_samples_per_worker, max_samples): s = reverb_server.Server([reverb_server.Table.queue('q', 10)]) c = s.localhost_client() for i in range(10): c.insert(i, {'q': 1}) ds = timestep_dataset.TimestepDataset( server_address=c.server_address, table='q', dtypes=tf.int64, shapes=tf.TensorShape([]), max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=num_workers_per_iterator, max_samples=max_samples) # Check that it fetches exactly `max_samples` samples. it = ds.as_numpy_iterator() self.assertLen(list(it), max_samples) # Check that no prefetching happened in the queue. self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples) np.testing.assert_array_equal( next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
def make_server(): return reverb_server.Server( tables=[ reverb_server.Table( 'dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), reverb_server.Table( 'signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))), reverb_server.Table( 'bounded_spec_signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), # Currently only the `shape` and `dtype` of the bounded spec # is considered during signature check. # TODO(b/158033101): Check the boundaries as well. signature=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(None, None), minimum=(0.0, 0.0), maximum=(10.0, 10.)), ), ], port=None, )
def test_combines_sequence_length_with_signature_if_not_emit_timestamps( self): server = reverb_server.Server([ reverb_server.Table.queue('queue', 10, signature={ 'a': { 'b': tf.TensorSpec([3, 3], tf.float32), 'c': tf.TensorSpec([], tf.int64), }, }) ]) dataset = reverb_dataset.ReplayDataset.from_table_signature( f'localhost:{server.port}', 'queue', 100, emit_timesteps=False, sequence_length=5) self.assertDictEqual( dataset.element_spec.data, { 'a': { 'b': tf.TensorSpec([5, 3, 3], tf.float32), 'c': tf.TensorSpec([5], tf.int64), }, })
def make_server(): return reverb_server.Server(tables=[ reverb_server.Table(name=TABLE, sampler=item_selectors.Prioritized( priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(1)), ])
def test_in_process_client(self): my_server = server.Server(tables=[ server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(2)), ]) my_client = my_server.localhost_client() my_client.reset(TABLE_NAME) del my_client my_server.stop()
def setUpClass(cls): super().setUpClass() cls.server = server.Server(tables=[ server.Table( name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(3), signature=tf.TensorSpec(dtype=tf.int64, shape=()), ), ], port=None) cls.client = client.Client(f'localhost:{cls.server.port}')
def make_tables_and_server(): tables = [ server.Table('dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table('dist2', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), ] return tables, server.Server(tables=tables)
def test_duplicate_priority_table_name(self): with self.assertRaises(ValueError): server.Server(tables=[ server.Table(name='test', sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(2)), server.Table(name='test', sampler=item_selectors.Prioritized(2), remover=item_selectors.Fifo(), max_size=200, rate_limiter=rate_limiters.MinSize(1)) ])
def test_table_not_found(self): server = reverb_server.Server([ reverb_server.Table.queue('table_a', 10), reverb_server.Table.queue('table_c', 10), reverb_server.Table.queue('table_b', 10), ]) address = f'localhost:{server.port}' with self.assertRaisesWithPredicateMatch( ValueError, f'Server at {address} does not contain any table named not_found. ' f'Found: table_a, table_b, table_c.'): reverb_dataset.ReplayDataset.from_table_signature( address, 'not_found', 100)
def test_sets_dtypes_from_signature(self): signature = { 'a': { 'b': tf.TensorSpec([3, 3], tf.float32), 'c': tf.TensorSpec([], tf.int64), }, 'x': tf.TensorSpec([None], tf.uint64), } server = reverb_server.Server( [reverb_server.Table.queue('queue', 10, signature=signature)]) dataset = reverb_dataset.ReplayDataset.from_table_signature( f'localhost:{server.port}', 'queue', 100) self.assertDictEqual(dataset.element_spec.data, signature)
def test_episode_steps_reset_on_end_episode(self, clear_buffers: bool): server = server_lib.Server([server_lib.Table.queue('queue', 1)]) client = client_lib.Client(f'localhost:{server.port}') # Create a writer and check that the counter starts at 0. writer = client.trajectory_writer(num_keep_alive_refs=1) self.assertEqual(writer.episode_steps, 0) # Append a step and check that the counter is incremented. writer.append([1]) self.assertEqual(writer.episode_steps, 1) # End the episode and check the counter is reset. writer.end_episode(clear_buffers=clear_buffers) self.assertEqual(writer.episode_steps, 0)
def test_episode_steps(self): server = server_lib.Server([server_lib.Table.queue('queue', 1)]) client = client_lib.Client(f'localhost:{server.port}') writer = client.trajectory_writer(num_keep_alive_refs=1) for _ in range(10): # Every episode, including the first, should start at zero. self.assertEqual(writer.episode_steps, 0) for i in range(1, 21): writer.append({'x': 3, 'y': 2}) # Step count should increment with each append call. self.assertEqual(writer.episode_steps, i) # Ending the episode should reset the step count to zero. writer.end_episode()
def test_timeout_on_end_episode(self): server = server_lib.Server([server_lib.Table.queue('queue', 1)]) client = client_lib.Client(f'localhost:{server.port}') writer = client.trajectory_writer(num_keep_alive_refs=1) writer.append([1]) # Table has space for one item, up to 2 more items can be queued in # table worker queues. # Since there isn't space for all 4 items end_episode should time out. with self.assertRaises(errors.DeadlineExceededError): for _ in range(4): writer.create_item('queue', 1.0, writer.history[0][:]) writer.end_episode(clear_buffers=False, timeout_ms=1) writer.close() server.stop()
def test_can_sample(self): table = server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, max_times_sampled=1, rate_limiter=rate_limiters.MinSize(2)) my_server = server.Server(tables=[table], port=None) my_client = my_server.in_process_client() self.assertFalse(table.can_sample(1)) self.assertTrue(table.can_insert(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertFalse(table.can_sample(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertTrue(table.can_sample(2)) # TODO(b/153258711): This should return False since max_times_sampled=1. self.assertTrue(table.can_sample(3)) del my_client my_server.stop()
def setUpClass(cls): super().setUpClass() cls.tables = [ server.Table( name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(3), signature=tf.TensorSpec(dtype=tf.int64, shape=[]), ), server.Table.queue( name=NESTED_SIGNATURE_TABLE_NAME, max_size=10, signature=QUEUE_SIGNATURE, ), server.Table.queue(SIMPLE_QUEUE_NAME, 10), ] cls.server = server.Server(tables=cls.tables) cls.client = cls.server.localhost_client()
def test_sets_dtypes_from_bounded_spec_signature(self): bounded_spec_signature = { 'a': { 'b': tensor_spec.BoundedTensorSpec([3, 3], tf.float32, 0, 3), 'c': tensor_spec.BoundedTensorSpec([], tf.int64, 0, 5), }, } server = reverb_server.Server([ reverb_server.Table.queue( 'queue', 10, signature=bounded_spec_signature) ]) dataset = reverb_dataset.ReplayDataset.from_table_signature( f'localhost:{server.port}', 'queue', 100) self.assertDictEqual( dataset.element_spec.data, { 'a': { 'b': tf.TensorSpec([3, 3], tf.float32), 'c': tf.TensorSpec([], tf.int64), }, })
def test_episode_steps_partial_step(self): server = server_lib.Server([server_lib.Table.queue('queue', 1)]) client = client_lib.Client(f'localhost:{server.port}') writer = client.trajectory_writer(num_keep_alive_refs=1) for _ in range(3): # Every episode, including the first, should start at zero. self.assertEqual(writer.episode_steps, 0) for i in range(1, 4): writer.append({'x': 3}, partial_step=True) # Step count should not increment on partial append calls. self.assertEqual(writer.episode_steps, i - 1) writer.append({'y': 2}) # Step count should increment after the unqualified append call. self.assertEqual(writer.episode_steps, i) # Ending the episode should reset the step count to zero. writer.end_episode()
def test_max_samples(self, num_workers_per_iterator, max_in_flight_samples_per_worker, max_samples): server = reverb_server.Server( tables=[reverb_server.Table.queue(_TABLE, 10)]) client = server.localhost_client() with client.trajectory_writer(10) as writer: for i in range(10): writer.append([np.ones([3, 3], np.int32) * i]) writer.create_item(_TABLE, 1.0, { 'last': writer.history[0][-1], 'all': writer.history[0][:], }) dataset = trajectory_dataset.TrajectoryDataset( tf.constant(client.server_address), table=tf.constant(_TABLE), dtypes={ 'last': tf.int32, 'all': tf.int32, }, shapes={ 'last': tf.TensorShape([3, 3]), 'all': tf.TensorShape([None, 3, 3]), }, num_workers_per_iterator=num_workers_per_iterator, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, max_samples=max_samples) # Check that it fetches exactly `max_samples` samples. it = dataset.as_numpy_iterator() self.assertLen(list(it), max_samples) # Check that no prefetching happen on the queue. self.assertEqual(client.server_info()[_TABLE].current_size, 10 - max_samples)
def test_can_sample(self): table = server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, max_times_sampled=1, rate_limiter=rate_limiters.MinSize(2)) my_server = server.Server(tables=[table]) my_client = my_server.localhost_client() self.assertFalse(table.can_sample(1)) self.assertTrue(table.can_insert(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertFalse(table.can_sample(1)) my_client.insert(1, {TABLE_NAME: 1.0}) for _ in range(100): if table.info.current_size == 2: break time.sleep(0.01) self.assertEqual(table.info.current_size, 2) self.assertTrue(table.can_sample(2)) # TODO(b/153258711): This should return False since max_times_sampled=1. self.assertTrue(table.can_sample(3)) del my_client my_server.stop()
def test_no_priority_table_provided(self): with self.assertRaises(ValueError): server.Server(tables=[], port=None)
def setUpClass(cls): super().setUpClass() cls._server = server_lib.Server([server_lib.Table.queue('queue', 100)])
def setUpClass(cls): super().setUpClass() cls._server = server_lib.Server( [server_lib.Table.queue(table, 100) for table in TABLES])