示例#1
0
  def stack(cls,
            name: str,
            max_size: int,
            extensions: Sequence[TableExtensionBase] = (),
            signature: Optional[reverb_types.SpecNest] = None):
    """Constructs a Table which acts like a stack.

    Args:
      name: Name of the priority table (aka stack).
      max_size: Maximum number of items in the priority table (aka stack).
      extensions: See documentation in the constructor.
      signature: See documentation in the constructor.

    Returns:
      Table which behaves like a stack of size `max_size`.
    """
    return cls(
        name=name,
        sampler=item_selectors.Lifo(),
        remover=item_selectors.Lifo(),
        max_size=max_size,
        max_times_sampled=1,
        rate_limiter=rate_limiters.Stack(max_size),
        extensions=extensions,
        signature=signature)
示例#2
0
文件: server.py 项目: tomzhang/reverb
    def stack(cls, name: str, max_size: int):
        """Constructs a Table which acts like a stack.

    Args:
      name: Name of the priority table (aka stack).
      max_size: Maximum number of items in the priority table (aka stack).

    Returns:
      Table which behaves like a stack of size `max_size`.
    """
        return cls(name=name,
                   sampler=item_selectors.Lifo(),
                   remover=item_selectors.Lifo(),
                   max_size=max_size,
                   max_times_sampled=1,
                   rate_limiter=rate_limiters.Stack(max_size))
示例#3
0
class TableTest(parameterized.TestCase):
    def _check_selector_proto(self, expected_selector, proto_msg):
        if isinstance(expected_selector, item_selectors.Uniform):
            self.assertTrue(proto_msg.HasField('uniform'))
        elif isinstance(expected_selector, item_selectors.Prioritized):
            self.assertTrue(proto_msg.HasField('prioritized'))
        elif isinstance(expected_selector, pybind.HeapSelector):
            self.assertTrue(proto_msg.HasField('heap'))
        elif isinstance(expected_selector, item_selectors.Fifo):
            self.assertTrue(proto_msg.HasField('fifo'))
        elif isinstance(expected_selector, item_selectors.Lifo):
            self.assertTrue(proto_msg.HasField('lifo'))
        else:
            raise ValueError(f'Unknown selector: {expected_selector}')

    @parameterized.product(
        sampler_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        remover_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        rate_limiter_fn=[
            lambda: rate_limiters.MinSize(10), lambda: rate_limiters.Queue(10),
            lambda: rate_limiters.SampleToInsertRatio(1.0, 10, 1.),
            lambda: rate_limiters.Stack(10)
        ],
    )
    def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn):
        sampler = sampler_fn()
        remover = remover_fn()
        rate_limiter = rate_limiter_fn()
        table = server.Table(name='table',
                             sampler=sampler,
                             remover=remover,
                             max_size=100,
                             rate_limiter=rate_limiter)
        table_info = table.info
        self.assertEqual('table', table_info.name)
        self.assertEqual(100, table_info.max_size)
        self.assertEqual(0, table_info.current_size)
        self.assertEqual(0, table_info.num_episodes)
        self.assertEqual(0, table_info.num_deleted_episodes)
        self.assertIsNone(table_info.signature)
        self._check_selector_proto(sampler, table_info.sampler_options)
        self._check_selector_proto(remover, table_info.remover_options)

    @parameterized.named_parameters(
        (
            'scalar',
            tf.TensorSpec([], tf.float32),
        ),
        (
            'image',
            tf.TensorSpec([3, 64, 64], tf.uint8),
        ),
        ('nested', (tf.TensorSpec([], tf.int32), {
            'a': tf.TensorSpec((1, 1), tf.float64)
        })),
    )
    def test_table_info_signature(self, signature):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10),
                             signature=signature)
        self.assertEqual(signature, table.info.signature)

    def test_replace(self):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10))
        rl_info = table.info.rate_limiter_info
        new_rate_limiter = rate_limiters.RateLimiter(
            samples_per_insert=rl_info.samples_per_insert,
            min_size_to_sample=1,
            min_diff=rl_info.min_diff,
            max_diff=rl_info.max_diff)
        new_table = table.replace(rate_limiter=new_rate_limiter)
        self.assertEqual(new_table.name, table.name)
        self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample,
                         1)