Esempio n. 1
0
 def test_validates_single_number_error_buffer(self, samples_per_insert,
                                               error_buffer, want):
   if want:
     with self.assertRaises(want):
       rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer)
   else:  # Should not raise any error.
     rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer)
Esempio n. 2
0
 def test_validates_explicit_range_error_buffer(self, min_size_to_sample,
                                                samples_per_insert,
                                                error_buffer, want):
   if want:
     with self.assertRaises(want):
       rate_limiters.SampleToInsertRatio(samples_per_insert,
                                         min_size_to_sample, error_buffer)
   else:  # Should not raise any error.
     rate_limiters.SampleToInsertRatio(samples_per_insert, min_size_to_sample,
                                       error_buffer)
Esempio n. 3
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
     policy: actor_core_lib.FeedForwardPolicy,
 ) -> List[reverb.Table]:
     """Create tables to insert data into."""
     del policy
     # Create the rate limiter.
     if self._config.samples_per_insert:
         samples_per_insert_tolerance = (
             self._config.samples_per_insert_tolerance_rate *
             self._config.samples_per_insert)
         error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
         limiter = rate_limiters.SampleToInsertRatio(
             min_size_to_sample=self._config.min_replay_size,
             samples_per_insert=self._config.samples_per_insert,
             error_buffer=error_buffer)
     else:
         limiter = rate_limiters.MinSize(self._config.min_replay_size)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Esempio n. 4
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
     policy: dqn_actor.EpsilonPolicy,
 ) -> List[reverb.Table]:
     """Creates reverb tables for the algorithm."""
     del policy
     samples_per_insert_tolerance = (
         self._config.samples_per_insert_tolerance_rate *
         self._config.samples_per_insert)
     error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
     limiter = rate_limiters.SampleToInsertRatio(
         min_size_to_sample=self._config.min_replay_size,
         samples_per_insert=self._config.samples_per_insert,
         error_buffer=error_buffer)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Prioritized(
                 self._config.priority_exponent),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Esempio n. 5
0
 def make_replay_tables(
         self,
         environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
     samples_per_insert_tolerance = (
         self._config.samples_per_insert_tolerance_rate *
         self._config.samples_per_insert)
     error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
     limiter = rate_limiters.SampleToInsertRatio(
         min_size_to_sample=self._config.min_replay_size,
         samples_per_insert=self._config.samples_per_insert,
         error_buffer=error_buffer)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Esempio n. 6
0
 def make_replay_tables(
     self,
     environment_spec,
 ):
     """Create tables to insert data into."""
     samples_per_insert_tolerance = (
         self._config.samples_per_insert_tolerance_rate *
         self._config.samples_per_insert)
     min_replay_traj = self._config.min_replay_size // self._config.max_episode_steps  # pylint: disable=line-too-long
     max_replay_traj = self._config.max_replay_size // self._config.max_episode_steps  # pylint: disable=line-too-long
     error_buffer = min_replay_traj * samples_per_insert_tolerance
     limiter = rate_limiters.SampleToInsertRatio(
         min_size_to_sample=min_replay_traj,
         samples_per_insert=self._config.samples_per_insert,
         error_buffer=error_buffer)
     return [
         reverb.Table(name=self._config.replay_table_name,
                      sampler=reverb.selectors.Uniform(),
                      remover=reverb.selectors.Fifo(),
                      max_size=max_replay_traj,
                      rate_limiter=limiter,
                      signature=adders_reverb.EpisodeAdder.signature(
                          environment_spec, {}))  # pylint: disable=line-too-long
     ]
Esempio n. 7
0
def _make_rate_limiter_from_rate_limiter_info(
        info) -> rate_limiters.RateLimiter:
    return rate_limiters.SampleToInsertRatio(
        samples_per_insert=info.samples_per_insert,
        min_size_to_sample=info.min_size_to_sample,
        error_buffer=(info.min_diff, info.max_diff))
Esempio n. 8
0
class TableTest(parameterized.TestCase):
    def _check_selector_proto(self, expected_selector, proto_msg):
        if isinstance(expected_selector, item_selectors.Uniform):
            self.assertTrue(proto_msg.HasField('uniform'))
        elif isinstance(expected_selector, item_selectors.Prioritized):
            self.assertTrue(proto_msg.HasField('prioritized'))
        elif isinstance(expected_selector, pybind.HeapSelector):
            self.assertTrue(proto_msg.HasField('heap'))
        elif isinstance(expected_selector, item_selectors.Fifo):
            self.assertTrue(proto_msg.HasField('fifo'))
        elif isinstance(expected_selector, item_selectors.Lifo):
            self.assertTrue(proto_msg.HasField('lifo'))
        else:
            raise ValueError(f'Unknown selector: {expected_selector}')

    @parameterized.product(
        sampler_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        remover_fn=[
            item_selectors.Uniform, lambda: item_selectors.Prioritized(1.),
            item_selectors.MinHeap, item_selectors.MaxHeap,
            item_selectors.Fifo, item_selectors.Lifo
        ],
        rate_limiter_fn=[
            lambda: rate_limiters.MinSize(10), lambda: rate_limiters.Queue(10),
            lambda: rate_limiters.SampleToInsertRatio(1.0, 10, 1.),
            lambda: rate_limiters.Stack(10)
        ],
    )
    def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn):
        sampler = sampler_fn()
        remover = remover_fn()
        rate_limiter = rate_limiter_fn()
        table = server.Table(name='table',
                             sampler=sampler,
                             remover=remover,
                             max_size=100,
                             rate_limiter=rate_limiter)
        table_info = table.info
        self.assertEqual('table', table_info.name)
        self.assertEqual(100, table_info.max_size)
        self.assertEqual(0, table_info.current_size)
        self.assertEqual(0, table_info.num_episodes)
        self.assertEqual(0, table_info.num_deleted_episodes)
        self.assertIsNone(table_info.signature)
        self._check_selector_proto(sampler, table_info.sampler_options)
        self._check_selector_proto(remover, table_info.remover_options)

    @parameterized.named_parameters(
        (
            'scalar',
            tf.TensorSpec([], tf.float32),
        ),
        (
            'image',
            tf.TensorSpec([3, 64, 64], tf.uint8),
        ),
        ('nested', (tf.TensorSpec([], tf.int32), {
            'a': tf.TensorSpec((1, 1), tf.float64)
        })),
    )
    def test_table_info_signature(self, signature):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10),
                             signature=signature)
        self.assertEqual(signature, table.info.signature)

    def test_replace(self):
        table = server.Table(name='table',
                             sampler=item_selectors.Fifo(),
                             remover=item_selectors.Fifo(),
                             max_size=100,
                             rate_limiter=rate_limiters.MinSize(10))
        rl_info = table.info.rate_limiter_info
        new_rate_limiter = rate_limiters.RateLimiter(
            samples_per_insert=rl_info.samples_per_insert,
            min_size_to_sample=1,
            min_diff=rl_info.min_diff,
            max_diff=rl_info.max_diff)
        new_table = table.replace(rate_limiter=new_rate_limiter)
        self.assertEqual(new_table.name, table.name)
        self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample,
                         1)