def test_validates_single_number_error_buffer(self, samples_per_insert, error_buffer, want): if want: with self.assertRaises(want): rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer) else: # Should not raise any error. rate_limiters.SampleToInsertRatio(samples_per_insert, 10, error_buffer)
def test_validates_explicit_range_error_buffer(self, min_size_to_sample, samples_per_insert, error_buffer, want): if want: with self.assertRaises(want): rate_limiters.SampleToInsertRatio(samples_per_insert, min_size_to_sample, error_buffer) else: # Should not raise any error. rate_limiters.SampleToInsertRatio(samples_per_insert, min_size_to_sample, error_buffer)
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: actor_core_lib.FeedForwardPolicy, ) -> List[reverb.Table]: """Create tables to insert data into.""" del policy # Create the rate limiter. if self._config.samples_per_insert: samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) else: limiter = rate_limiters.MinSize(self._config.min_replay_size) return [ reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: dqn_actor.EpsilonPolicy, ) -> List[reverb.Table]: """Creates reverb tables for the algorithm.""" del policy samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) return [ reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Prioritized( self._config.priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]: samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) return [ reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def make_replay_tables( self, environment_spec, ): """Create tables to insert data into.""" samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) min_replay_traj = self._config.min_replay_size // self._config.max_episode_steps # pylint: disable=line-too-long max_replay_traj = self._config.max_replay_size // self._config.max_episode_steps # pylint: disable=line-too-long error_buffer = min_replay_traj * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=min_replay_traj, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) return [ reverb.Table(name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_traj, rate_limiter=limiter, signature=adders_reverb.EpisodeAdder.signature( environment_spec, {})) # pylint: disable=line-too-long ]
def _make_rate_limiter_from_rate_limiter_info( info) -> rate_limiters.RateLimiter: return rate_limiters.SampleToInsertRatio( samples_per_insert=info.samples_per_insert, min_size_to_sample=info.min_size_to_sample, error_buffer=(info.min_diff, info.max_diff))
class TableTest(parameterized.TestCase): def _check_selector_proto(self, expected_selector, proto_msg): if isinstance(expected_selector, item_selectors.Uniform): self.assertTrue(proto_msg.HasField('uniform')) elif isinstance(expected_selector, item_selectors.Prioritized): self.assertTrue(proto_msg.HasField('prioritized')) elif isinstance(expected_selector, pybind.HeapSelector): self.assertTrue(proto_msg.HasField('heap')) elif isinstance(expected_selector, item_selectors.Fifo): self.assertTrue(proto_msg.HasField('fifo')) elif isinstance(expected_selector, item_selectors.Lifo): self.assertTrue(proto_msg.HasField('lifo')) else: raise ValueError(f'Unknown selector: {expected_selector}') @parameterized.product( sampler_fn=[ item_selectors.Uniform, lambda: item_selectors.Prioritized(1.), item_selectors.MinHeap, item_selectors.MaxHeap, item_selectors.Fifo, item_selectors.Lifo ], remover_fn=[ item_selectors.Uniform, lambda: item_selectors.Prioritized(1.), item_selectors.MinHeap, item_selectors.MaxHeap, item_selectors.Fifo, item_selectors.Lifo ], rate_limiter_fn=[ lambda: rate_limiters.MinSize(10), lambda: rate_limiters.Queue(10), lambda: rate_limiters.SampleToInsertRatio(1.0, 10, 1.), lambda: rate_limiters.Stack(10) ], ) def test_table_info(self, sampler_fn, remover_fn, rate_limiter_fn): sampler = sampler_fn() remover = remover_fn() rate_limiter = rate_limiter_fn() table = server.Table(name='table', sampler=sampler, remover=remover, max_size=100, rate_limiter=rate_limiter) table_info = table.info self.assertEqual('table', table_info.name) self.assertEqual(100, table_info.max_size) self.assertEqual(0, table_info.current_size) self.assertEqual(0, table_info.num_episodes) self.assertEqual(0, table_info.num_deleted_episodes) self.assertIsNone(table_info.signature) self._check_selector_proto(sampler, table_info.sampler_options) self._check_selector_proto(remover, table_info.remover_options) @parameterized.named_parameters( ( 'scalar', tf.TensorSpec([], tf.float32), ), ( 'image', tf.TensorSpec([3, 64, 64], tf.uint8), ), ('nested', (tf.TensorSpec([], tf.int32), { 'a': tf.TensorSpec((1, 1), tf.float64) })), ) def test_table_info_signature(self, signature): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10), signature=signature) self.assertEqual(signature, table.info.signature) def test_replace(self): table = server.Table(name='table', sampler=item_selectors.Fifo(), remover=item_selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(10)) rl_info = table.info.rate_limiter_info new_rate_limiter = rate_limiters.RateLimiter( samples_per_insert=rl_info.samples_per_insert, min_size_to_sample=1, min_diff=rl_info.min_diff, max_diff=rl_info.max_diff) new_table = table.replace(rate_limiter=new_rate_limiter) self.assertEqual(new_table.name, table.name) self.assertEqual(new_table.info.rate_limiter_info.min_size_to_sample, 1)