Exemple #1
0
def make_server():
    return server.Server(
        tables=[
            server.Table(
                'dist',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            server.Table(
                'dist2',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            server.Table(
                'signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))),
        ],
        port=None,
    )
Exemple #2
0
def make_server():
    return reverb_server.Server(
        tables=[
            reverb_server.Table(
                'dist',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            reverb_server.Table(
                'signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))),
            reverb_server.Table(
                'bounded_spec_signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                # Currently only the `shape` and `dtype` of the bounded spec
                # is considered during signature check.
                # TODO(b/158033101): Check the boundaries as well.
                signature=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                        shape=(None, None),
                                                        minimum=(0.0, 0.0),
                                                        maximum=(10.0, 10.)),
            ),
        ],
        port=None,
    )
Exemple #3
0
def make_tables_and_server():
    tables = [
        server.Table('dist',
                     sampler=item_selectors.Prioritized(priority_exponent=1),
                     remover=item_selectors.Fifo(),
                     max_size=1000000,
                     rate_limiter=rate_limiters.MinSize(1)),
        server.Table('dist2',
                     sampler=item_selectors.Prioritized(priority_exponent=1),
                     remover=item_selectors.Fifo(),
                     max_size=1000000,
                     rate_limiter=rate_limiters.MinSize(1)),
    ]
    return tables, server.Server(tables=tables)
Exemple #4
0
 def test_duplicate_priority_table_name(self):
     with self.assertRaises(ValueError):
         server.Server(tables=[
             server.Table(name='test',
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          rate_limiter=rate_limiters.MinSize(2)),
             server.Table(name='test',
                          sampler=item_selectors.Prioritized(2),
                          remover=item_selectors.Fifo(),
                          max_size=200,
                          rate_limiter=rate_limiters.MinSize(1))
         ])
Exemple #5
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
     policy: actor_core_lib.FeedForwardPolicy,
 ) -> List[reverb.Table]:
     """Create tables to insert data into."""
     del policy
     # Create the rate limiter.
     if self._config.samples_per_insert:
         samples_per_insert_tolerance = (
             self._config.samples_per_insert_tolerance_rate *
             self._config.samples_per_insert)
         error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
         limiter = rate_limiters.SampleToInsertRatio(
             min_size_to_sample=self._config.min_replay_size,
             samples_per_insert=self._config.samples_per_insert,
             error_buffer=error_buffer)
     else:
         limiter = rate_limiters.MinSize(self._config.min_replay_size)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Exemple #6
0
 def test_table_info_signature(self, signature):
     table = server.Table(name='table',
                          sampler=item_selectors.Fifo(),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          rate_limiter=rate_limiters.MinSize(10),
                          signature=signature)
     self.assertEqual(signature, table.info.signature)
Exemple #7
0
def priority_tables_fn():
    return [
        reverb.Table(name=_TABLE_NAME,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     max_size=100,
                     rate_limiter=rate_limiters.MinSize(100))
    ]
def make_server():
    return reverb_server.Server(tables=[
        reverb_server.Table(name=TABLE,
                            sampler=item_selectors.Prioritized(
                                priority_exponent=1),
                            remover=item_selectors.Fifo(),
                            max_size=1000,
                            rate_limiter=rate_limiters.MinSize(1)),
    ])
Exemple #9
0
 def test_in_process_client(self):
     my_server = server.Server(tables=[
         server.Table(name=TABLE_NAME,
                      sampler=item_selectors.Prioritized(1),
                      remover=item_selectors.Fifo(),
                      max_size=100,
                      rate_limiter=rate_limiters.MinSize(2)),
     ])
     my_client = my_server.localhost_client()
     my_client.reset(TABLE_NAME)
     del my_client
     my_server.stop()
def make_replay_tables(environment_spec: specs.EnvironmentSpec
                      ) -> Sequence[reverb.Table]:
  """Create tables to insert data into."""
  return [
      reverb.Table(
          name='default',
          sampler=reverb.selectors.Uniform(),
          remover=reverb.selectors.Fifo(),
          max_size=1000000,
          rate_limiter=rate_limiters.MinSize(1),
          signature=adders_reverb.NStepTransitionAdder.signature(
              environment_spec))
  ]
Exemple #11
0
 def setUpClass(cls):
     super().setUpClass()
     cls.server = server.Server(tables=[
         server.Table(
             name=TABLE_NAME,
             sampler=item_selectors.Prioritized(1),
             remover=item_selectors.Fifo(),
             max_size=1000,
             rate_limiter=rate_limiters.MinSize(3),
             signature=tf.TensorSpec(dtype=tf.int64, shape=()),
         ),
     ],
                                port=None)
     cls.client = client.Client(f'localhost:{cls.server.port}')
Exemple #12
0
 def test_replace(self):
     table = server.Table(name='table',
                          sampler=item_selectors.Fifo(),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          rate_limiter=rate_limiters.MinSize(10))
     rl_info = table.info.rate_limiter_info
     new_rate_limiter = rate_limiters.RateLimiter(
         samples_per_insert=rl_info.samples_per_insert,
         min_size_to_sample=1,
         min_diff=rl_info.min_diff,
         max_diff=rl_info.max_diff)
     new_table = table.replace(rate_limiter=new_rate_limiter)
     self.assertEqual(new_table.name, table.name)
     self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample,
                      1)
Exemple #13
0
 def make_replay_tables(
         self,
         environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
     replay_tables = self._rl_agent.make_replay_tables(environment_spec)
     if self._config.share_iterator:
         return replay_tables
     replay_tables.append(
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=rate_limiters.MinSize(
                 self._config.min_replay_size),
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec)))
     return replay_tables
Exemple #14
0
 def test_can_sample(self):
     table = server.Table(name=TABLE_NAME,
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          max_times_sampled=1,
                          rate_limiter=rate_limiters.MinSize(2))
     my_server = server.Server(tables=[table], port=None)
     my_client = my_server.in_process_client()
     self.assertFalse(table.can_sample(1))
     self.assertTrue(table.can_insert(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertFalse(table.can_sample(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertTrue(table.can_sample(2))
     # TODO(b/153258711): This should return False since max_times_sampled=1.
     self.assertTrue(table.can_sample(3))
     del my_client
     my_server.stop()
Exemple #15
0
 def setUpClass(cls):
   super().setUpClass()
   cls.tables = [
       server.Table(
           name=TABLE_NAME,
           sampler=item_selectors.Prioritized(1),
           remover=item_selectors.Fifo(),
           max_size=1000,
           rate_limiter=rate_limiters.MinSize(3),
           signature=tf.TensorSpec(dtype=tf.int64, shape=[]),
       ),
       server.Table.queue(
           name=NESTED_SIGNATURE_TABLE_NAME,
           max_size=10,
           signature=QUEUE_SIGNATURE,
       ),
       server.Table.queue(SIMPLE_QUEUE_NAME, 10),
   ]
   cls.server = server.Server(tables=cls.tables)
   cls.client = cls.server.localhost_client()
Exemple #16
0
 def test_can_sample(self):
     table = server.Table(name=TABLE_NAME,
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          max_times_sampled=1,
                          rate_limiter=rate_limiters.MinSize(2))
     my_server = server.Server(tables=[table])
     my_client = my_server.localhost_client()
     self.assertFalse(table.can_sample(1))
     self.assertTrue(table.can_insert(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertFalse(table.can_sample(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     for _ in range(100):
         if table.info.current_size == 2:
             break
         time.sleep(0.01)
     self.assertEqual(table.info.current_size, 2)
     self.assertTrue(table.can_sample(2))
     # TODO(b/153258711): This should return False since max_times_sampled=1.
     self.assertTrue(table.can_sample(3))
     del my_client
     my_server.stop()
Exemple #17
0
class TableTest(parameterized.TestCase):
    def _check_selector_proto(self, expected_selector, proto_msg):
        if isinstance(expected_selector, item_selectors.Uniform):
            self.assertTrue(proto_msg.HasField('uniform'))
        elif isinstance(expected_selector, item_selectors.Prioritized):
            self.assertTrue(proto_msg.HasField('prioritized'))
        elif isinstance(expected_selector, pybind.HeapSelector):
            self.assertTrue(proto_msg.HasField('heap'))
        elif isinstance(expected_selector, item_selectors.Fifo):
            self.assertTrue(proto_msg.HasField('fifo'))
        elif isinstance(expected_selector, item_selectors.Lifo):
            self.assertTrue(proto_msg.HasField('lifo'))
        else:
            raise ValueError(f'Unknown selector: {expected_selector}')

    @parameterized.product(
        sampler_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        remover_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        rate_limiter_fn=[
            lambda: rate_limiters.MinSize(10), lambda: rate_limiters.Queue(10),
            lambda: rate_limiters.SampleToInsertRatio(1.0, 10, 1.),
            lambda: rate_limiters.Stack(10)
        ],
    )
    def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn):
        sampler = sampler_fn()
        remover = remover_fn()
        rate_limiter = rate_limiter_fn()
        table = server.Table(name='table',
                             sampler=sampler,
                             remover=remover,
                             max_size=100,
                             rate_limiter=rate_limiter)
        table_info = table.info
        self.assertEqual('table', table_info.name)
        self.assertEqual(100, table_info.max_size)
        self.assertEqual(0, table_info.current_size)
        self.assertEqual(0, table_info.num_episodes)
        self.assertEqual(0, table_info.num_deleted_episodes)
        self.assertIsNone(table_info.signature)
        self._check_selector_proto(sampler, table_info.sampler_options)
        self._check_selector_proto(remover, table_info.remover_options)

    @parameterized.named_parameters(
        (
            'scalar',
            tf.TensorSpec([], tf.float32),
        ),
        (
            'image',
            tf.TensorSpec([3, 64, 64], tf.uint8),
        ),
        ('nested', (tf.TensorSpec([], tf.int32), {
            'a': tf.TensorSpec((1, 1), tf.float64)
        })),
    )
    def test_table_info_signature(self, signature):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10),
                             signature=signature)
        self.assertEqual(signature, table.info.signature)

    def test_replace(self):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10))
        rl_info = table.info.rate_limiter_info
        new_rate_limiter = rate_limiters.RateLimiter(
            samples_per_insert=rl_info.samples_per_insert,
            min_size_to_sample=1,
            min_diff=rl_info.min_diff,
            max_diff=rl_info.max_diff)
        new_table = table.replace(rate_limiter=new_rate_limiter)
        self.assertEqual(new_table.name, table.name)
        self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample,
                         1)
Exemple #18
0
 def test_raises_if_min_size_lt_1(self, min_size_to_sample, want_error):
     if want_error:
         with self.assertRaises(ValueError):
             rate_limiters.MinSize(min_size_to_sample)
     else:
         rate_limiters.MinSize(min_size_to_sample)