def make_server(): return server.Server( tables=[ server.Table( 'dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table( 'dist2', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table( 'signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))), ], port=None, )
def make_server(): return reverb_server.Server( tables=[ reverb_server.Table( 'dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), reverb_server.Table( 'signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))), reverb_server.Table( 'bounded_spec_signatured', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), # Currently only the `shape` and `dtype` of the bounded spec # is considered during signature check. # TODO(b/158033101): Check the boundaries as well. signature=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(None, None), minimum=(0.0, 0.0), maximum=(10.0, 10.)), ), ], port=None, )
def test_duplicate_priority_table_name(self): with self.assertRaises(ValueError): server.Server(tables=[ server.Table(name='test', sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(2)), server.Table(name='test', sampler=item_selectors.Prioritized(2), remover=item_selectors.Fifo(), max_size=200, rate_limiter=rate_limiters.MinSize(1)) ])
def make_tables_and_server(): tables = [ server.Table('dist', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), server.Table('dist2', sampler=item_selectors.Prioritized(priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1)), ] return tables, server.Server(tables=tables)
def test_table_info_signature(self, signature): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10), signature=signature) self.assertEqual(signature, table.info.signature)
def make_server(): return reverb_server.Server(tables=[ reverb_server.Table(name=TABLE, sampler=item_selectors.Prioritized( priority_exponent=1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(1)), ])
def test_in_process_client(self): my_server = server.Server(tables=[ server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(2)), ]) my_client = my_server.localhost_client() my_client.reset(TABLE_NAME) del my_client my_server.stop()
def setUpClass(cls): super().setUpClass() cls.server = server.Server(tables=[ server.Table( name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(3), signature=tf.TensorSpec(dtype=tf.int64, shape=()), ), ], port=None) cls.client = client.Client(f'localhost:{cls.server.port}')
def test_replace(self): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10)) rl_info = table.info.rate_limiter_info new_rate_limiter = rate_limiters.RateLimiter( samples_per_insert=rl_info.samples_per_insert, min_size_to_sample=1, min_diff=rl_info.min_diff, max_diff=rl_info.max_diff) new_table = table.replace(rate_limiter=new_rate_limiter) self.assertEqual(new_table.name, table.name) self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample, 1)
def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn): sampler = sampler_fn() remover = remover_fn() rate_limiter = rate_limiter_fn() table = server.Table(name='table', sampler=sampler, remover=remover, max_size=100, rate_limiter=rate_limiter) table_info = table.info self.assertEqual('table', table_info.name) self.assertEqual(100, table_info.max_size) self.assertEqual(0, table_info.current_size) self.assertEqual(0, table_info.num_episodes) self.assertEqual(0, table_info.num_deleted_episodes) self.assertIsNone(table_info.signature) self._check_selector_proto(sampler, table_info.sampler_options) self._check_selector_proto(remover, table_info.remover_options)
def test_can_sample(self): table = server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, max_times_sampled=1, rate_limiter=rate_limiters.MinSize(2)) my_server = server.Server(tables=[table], port=None) my_client = my_server.in_process_client() self.assertFalse(table.can_sample(1)) self.assertTrue(table.can_insert(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertFalse(table.can_sample(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertTrue(table.can_sample(2)) # TODO(b/153258711): This should return False since max_times_sampled=1. self.assertTrue(table.can_sample(3)) del my_client my_server.stop()
def setUpClass(cls): super().setUpClass() cls.tables = [ server.Table( name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=1000, rate_limiter=rate_limiters.MinSize(3), signature=tf.TensorSpec(dtype=tf.int64, shape=[]), ), server.Table.queue( name=NESTED_SIGNATURE_TABLE_NAME, max_size=10, signature=QUEUE_SIGNATURE, ), server.Table.queue(SIMPLE_QUEUE_NAME, 10), ] cls.server = server.Server(tables=cls.tables) cls.client = cls.server.localhost_client()
def test_can_sample(self): table = server.Table(name=TABLE_NAME, sampler=item_selectors.Prioritized(1), remover=item_selectors.Fifo(), max_size=100, max_times_sampled=1, rate_limiter=rate_limiters.MinSize(2)) my_server = server.Server(tables=[table]) my_client = my_server.localhost_client() self.assertFalse(table.can_sample(1)) self.assertTrue(table.can_insert(1)) my_client.insert(1, {TABLE_NAME: 1.0}) self.assertFalse(table.can_sample(1)) my_client.insert(1, {TABLE_NAME: 1.0}) for _ in range(100): if table.info.current_size == 2: break time.sleep(0.01) self.assertEqual(table.info.current_size, 2) self.assertTrue(table.can_sample(2)) # TODO(b/153258711): This should return False since max_times_sampled=1. self.assertTrue(table.can_sample(3)) del my_client my_server.stop()