def queue(cls, name: str, max_size: int, extensions: Sequence[TableExtensionBase] = (), signature: Optional[reverb_types.SpecNest] = None): """Constructs a Table which acts like a queue. Args: name: Name of the priority table (aka queue). max_size: Maximum number of items in the priority table (aka queue). extensions: See documentation in the constructor. signature: See documentation in the constructor. Returns: Table which behaves like a queue of size `max_size`. """ return cls( name=name, sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=max_size, max_times_sampled=1, rate_limiter=rate_limiters.Queue(max_size), extensions=extensions, signature=signature)
def queue(cls, name: str, max_size: int): """Constructs a Table which acts like a queue. Args: name: Name of the priority table (aka queue). max_size: Maximum number of items in the priority table (aka queue). Returns: Table which behaves like a queue of size `max_size`. """ return cls(name=name, sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=max_size, max_times_sampled=1, rate_limiter=rate_limiters.Queue(max_size))
class TableTest(parameterized.TestCase): def _check_selector_proto(self, expected_selector, proto_msg): if isinstance(expected_selector, item_selectors.Uniform): self.assertTrue(proto_msg.HasField('uniform')) elif isinstance(expected_selector, item_selectors.Prioritized): self.assertTrue(proto_msg.HasField('prioritized')) elif isinstance(expected_selector, pybind.HeapSelector): self.assertTrue(proto_msg.HasField('heap')) elif isinstance(expected_selector, item_selectors.Fifo): self.assertTrue(proto_msg.HasField('fifo')) elif isinstance(expected_selector, item_selectors.Lifo): self.assertTrue(proto_msg.HasField('lifo')) else: raise ValueError(f'Unknown selector: {expected_selector}') @parameterized.product( sampler_fn=[ item_selectors.Uniform, lambda: item_selectors.Prioritized(1.), item_selectors.MinHeap, item_selectors.MaxHeap, item_selectors.Fifo, item_selectors.Lifo ], remover_fn=[ item_selectors.Uniform, lambda: item_selectors.Prioritized(1.), item_selectors.MinHeap, item_selectors.MaxHeap, item_selectors.Fifo, item_selectors.Lifo ], rate_limiter_fn=[ lambda: rate_limiters.MinSize(10), lambda: rate_limiters.Queue(10), lambda: rate_limiters.SampleToInsertRatio(1.0, 10, 1.), lambda: rate_limiters.Stack(10) ], ) def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn): sampler = sampler_fn() remover = remover_fn() rate_limiter = rate_limiter_fn() table = server.Table(name='table', sampler=sampler, remover=remover, max_size=100, rate_limiter=rate_limiter) table_info = table.info self.assertEqual('table', table_info.name) self.assertEqual(100, table_info.max_size) self.assertEqual(0, table_info.current_size) self.assertEqual(0, table_info.num_episodes) self.assertEqual(0, table_info.num_deleted_episodes) self.assertIsNone(table_info.signature) self._check_selector_proto(sampler, table_info.sampler_options) self._check_selector_proto(remover, table_info.remover_options) @parameterized.named_parameters( ( 'scalar', tf.TensorSpec([], tf.float32), ), ( 'image', tf.TensorSpec([3, 64, 64], tf.uint8), ), ('nested', (tf.TensorSpec([], tf.int32), { 'a': tf.TensorSpec((1, 1), tf.float64) })), ) def test_table_info_signature(self, signature): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10), signature=signature) self.assertEqual(signature, table.info.signature) def test_replace(self): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10)) rl_info = table.info.rate_limiter_info new_rate_limiter = rate_limiters.RateLimiter( samples_per_insert=rl_info.samples_per_insert, min_size_to_sample=1, min_diff=rl_info.min_diff, max_diff=rl_info.max_diff) new_table = table.replace(rate_limiter=new_rate_limiter) self.assertEqual(new_table.name, table.name) self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample, 1)