def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, ) -> List[reverb.Table]: """Create tables to insert data into.""" if self._config.samples_per_insert is None: # We will take a samples_per_insert ratio of None to mean that there is # no limit, i.e. this only implies a min size limit. limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) else: # Create enough of an error buffer to give a 10% tolerance in rate. samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) replay_table = reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=reverb_adders.NStepTransitionAdder.signature( environment_spec)) return [replay_table]
def replay(self) -> List[reverb.Table]: """The replay storage.""" network = self._network_factory(self._environment_spec.actions) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) if self._samples_per_insert: limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._min_replay_size, samples_per_insert=self._samples_per_insert, error_buffer=self._batch_size) else: limiter = reverb.rate_limiters.MinSize(self._min_replay_size) table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(self._priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._max_replay_size, rate_limiter=limiter, signature=adders.SequenceAdder.signature( self._environment_spec, extra_spec, sequence_length=self._burn_in_length + self._trace_length + 1)) return [table]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, ) -> List[reverb.Table]: """Create tables to insert data into.""" #''' if self._config.samples_per_insert: samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) else: limiter = reverb.rate_limiters.MinSize( self._config.min_replay_size) return [ reverb.Table(name=self._config.replay_table_name, sampler=reverb.selectors.Prioritized( self._config.priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.SequenceAdder.signature( environment_spec, self._extra_spec)) ]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: dqn_actor.EpsilonPolicy, ) -> List[reverb.Table]: """Creates reverb tables for the algorithm.""" del policy samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) return [ reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Prioritized( self._config.priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: r2d2_actor.R2D2Policy, ) -> List[reverb.Table]: """Create tables to insert data into.""" dummy_actor_state = policy.init(jax.random.PRNGKey(0)) extras_spec = policy.get_extras(dummy_actor_state) if self._config.samples_per_insert: samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) else: limiter = reverb.rate_limiters.MinSize( self._config.min_replay_size) return [ reverb.Table(name=self._config.replay_table_name, sampler=reverb.selectors.Prioritized( self._config.priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.SequenceAdder.signature( environment_spec, extras_spec)) ]
def setUp(self): super(ReverbReplayBufferTest, self).setUp() # Prepare the environment (and the corresponding specs). self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3) tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec, self._env.time_step_spec()) tensor_action_spec = tensor_spec.from_spec(self._env.action_spec()) self._data_spec = trajectory.Trajectory( step_type=tensor_time_step_spec.step_type, observation=tensor_time_step_spec.observation, action=tensor_action_spec, policy_info=(), next_step_type=tensor_time_step_spec.step_type, reward=tensor_time_step_spec.reward, discount=tensor_time_step_spec.discount, ) table_spec = tf.nest.map_structure( lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape), self._data_spec) self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec) # Initialize and start a Reverb server (and set up a client to it). self._table_name = 'test_table' uniform_table = reverb.Table( self._table_name, max_size=100, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=table_spec, ) self._server = reverb.Server([uniform_table]) self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, sequence_length: int, ) -> List[reverb.Table]: """Create tables to insert data into.""" if self._config.samples_per_insert is None: # We will take a samples_per_insert ratio of None to mean that there is # no limit, i.e. this only implies a min size limit. limiter = reverb.rate_limiters.MinSize( self._config.min_replay_size) else: error_buffer = max(1, self._config.samples_per_insert) limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) extras_spec = {'log_prob': tf.ones(shape=(), dtype=tf.float32)} replay_table = reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=reverb_adders.SequenceAdder.signature( environment_spec, extras_spec=extras_spec, sequence_length=sequence_length + 1)) return [replay_table]
def test_prioritized_table(self): table_name = 'test_prioritized_table' queue_table = reverb.Table( table_name, sampler=reverb.selectors.Prioritized(1.0), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=3) reverb_server = reverb.Server([queue_table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, sequence_length=1, local_server=reverb_server, dataset_buffer_size=1) with replay.py_client.trajectory_writer(1) as writer: for i in range(3): writer.append(i) writer.create_item(table_name, trajectory=writer.history[-1:], priority=i) dataset = replay.as_dataset( sample_batch_size=1, num_steps=None, num_parallel_calls=1) iterator = iter(dataset) counts = [0] * 3 for i in range(1000): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x1. counts[int(item_0)] += 1 self.assertEqual(counts[0], 0) # priority 0 self.assertGreater(counts[1], 250) # priority 1 self.assertGreater(counts[2], 600) # priority 2
def __init__(self, data_spec, batch_size=1, n_steps=2): self.data_spec = data_spec self.batch_size = batch_size self.n_steps = n_steps self.replay_buffer_capacity = 100000 self.name = 'PER' self.server = reverb.Server([ reverb.Table(name=self.name, max_size=self.replay_buffer_capacity, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) ]) self.buffer = reverb_replay_buffer.ReverbReplayBuffer( data_spec=self.data_spec, sequence_length=self.n_steps, table_name=self.name, local_server=self.server) self.writer = reverb_utils.ReverbAddTrajectoryObserver( self.buffer.py_client, table_name=self.name, sequence_length=self.n_steps, stride_length=1, priority=1, ) self.states = None
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: actor_core_lib.FeedForwardPolicy, ) -> List[reverb.Table]: """Create tables to insert data into.""" del policy # Create the rate limiter. if self._config.samples_per_insert: samples_per_insert_tolerance = ( self._config.samples_per_insert_tolerance_rate * self._config.samples_per_insert) error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer) else: limiter = rate_limiters.MinSize(self._config.min_replay_size) return [ reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def _create_reverb_server(table_name: str) -> reverb.Server: table = reverb.Table(table_name, max_size=_replay_buffer_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) return reverb.Server([table])
def priority_tables_fn(): return [ reverb.Table(name=_TABLE_NAME, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=100, rate_limiter=rate_limiters.MinSize(100)) ]
def create_reverb_server_for_replay_buffer_and_variable_container( collect_policy, train_step, replay_buffer_capacity, port): """Sets up one reverb server for replay buffer and variable container.""" # Create the signature for the variable container holding the policy weights. variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container_signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), variables) # Create the signature for the replay buffer holding observed experience. replay_buffer_signature = tensor_spec.from_spec( collect_policy.collect_data_spec) # Prefix a time axis for trajectories. replay_buffer_signature = tf.nest.map_structure( lambda s: tf.TensorSpec((None,) + s.shape, s.dtype, s.name), replay_buffer_signature) # Crete and start the replay buffer and variable container server. server = reverb.Server( tables=[ reverb.Table( # Replay buffer storing experience. name=reverb_replay_buffer.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), # TODO(b/159073060): Set rate limiter for SAC properly. rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_buffer_capacity, max_times_sampled=0, signature=replay_buffer_signature, ), reverb.Table( # Variable container storing policy parameters. name=reverb_variable_container.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=variable_container_signature, ), ], port=port) return server
def _reverb_server(): return reverb.Server(tables=[ reverb.Table('test_table', reverb.selectors.Uniform(), reverb.selectors.Fifo(), max_size=100, rate_limiter=reverb.rate_limiters.MinSize(95)) ], port=None)
def make_replay_tables( self, environment_spec: specs.MAEnvironmentSpec, ) -> List[reverb.Table]: """Create tables to insert data into. Args: environment_spec (specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. Raises: NotImplementedError: unknown executor type. Returns: List[reverb.Table]: a list of data tables for inserting data. """ # Select adder if issubclass(self._executor_fn, executors.FeedForwardExecutor): # Check if we should use fingerprints if self._replay_stabiliser_fn is not None: self._extra_specs.update({"fingerprint": np.array([1.0, 1.0])}) adder_sig = reverb_adders.ParallelNStepTransitionAdder.signature( environment_spec, self._extra_specs ) elif issubclass(self._executor_fn, executors.RecurrentExecutor): adder_sig = reverb_adders.ParallelSequenceAdder.signature( environment_spec, self._extra_specs ) else: raise NotImplementedError("Unknown executor type: ", self._executor_fn) if self._config.samples_per_insert is None: # We will take a samples_per_insert ratio of None to mean that there is # no limit, i.e. this only implies a min size limit. limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) else: # Create enough of an error buffer to give a 10% tolerance in rate. samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert error_buffer = self._config.min_replay_size * samples_per_insert_tolerance limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._config.min_replay_size, samples_per_insert=self._config.samples_per_insert, error_buffer=error_buffer, ) replay_table = reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=limiter, signature=adder_sig, ) return [replay_table]
def test_dataset_with_variable_sequence_length_truncates(self): spec = tf.TensorSpec((), tf.int64) table_spec = tf.TensorSpec((None, ), tf.int64) table = reverb.Table( name=self._table_name, sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), max_times_sampled=1, max_size=100, rate_limiter=reverb.rate_limiters.MinSize(1), signature=table_spec, ) server = reverb.Server([table]) py_client = reverb.Client('localhost:{}'.format(server.port)) # Insert two episodes: one of length 3 and one of length 5 with py_client.trajectory_writer(10) as writer: writer.append(1) writer.append(2) writer.append(3) writer.create_item(self._table_name, trajectory=writer.history[-3:], priority=5) with py_client.trajectory_writer(10) as writer: writer.append(10) writer.append(20) writer.append(30) writer.append(40) writer.append(50) writer.create_item(self._table_name, trajectory=writer.history[-5:], priority=5) replay = reverb_replay_buffer.ReverbReplayBuffer( spec, self._table_name, local_server=server, sequence_length=None, rate_limiter_timeout_ms=100) ds = replay.as_dataset(single_deterministic_pass=True, num_steps=2) it = iter(ds) # Expect [1, 2] data, _ = next(it) self.assertAllEqual(data, [1, 2]) # Expect [10, 20] data, _ = next(it) self.assertAllEqual(data, [10, 20]) # Expect [30, 40] data, _ = next(it) self.assertAllEqual(data, [30, 40]) with self.assertRaises(StopIteration): next(it)
def __init__(self): # Initialize the reverb server buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64) batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64) reverb_server = reverb.Server(tables=[ reverb.Table( name=Params.BUFFER_TYPE, sampler=reverb.selectors.Prioritized( priority_exponent=Params.BUFFER_PRIORITY_ALPHA), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(batch_size), ), reverb.Table( name=Params.BUFFER_TYPE + "_max", sampler=reverb.selectors.MaxHeap(), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(tf.constant(1)), ), reverb.Table( name=Params.BUFFER_TYPE + "_min", sampler=reverb.selectors.MinHeap(), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(tf.constant(1)), ), ], ) super().__init__(name="ReverbPrioritizedReplayBuffer", reverb_server=reverb_server) # Init client for updating priorities self.client = self.get_client() # Insert dummy trajectory self.client.insert([ tf.zeros(spec.shape, dtype=spec.dtype) for spec in Params.BUFFER_DATA_SPEC ], tables=Params.BUFFER_PRIORITY_TABLE_NAMES, priorities=tf.constant([1., 1., 1.], dtype=tf.float64))
def test_make_replay_table_preserves_table_info(self): limiter = reverb.rate_limiters.SampleToInsertRatio( samples_per_insert=1, min_size_to_sample=2, error_buffer=(0, 10)) table = reverb.Table(name='test', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=10, rate_limiter=limiter) table_from_info = reverb_utils.make_replay_table_from_info(table.info) self.assertEqual(table_from_info.info, table.info)
def setUpClass(cls): super().setUpClass() replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1000, rate_limiter=reverb.rate_limiters.MinSize(1), ) cls.server = reverb.Server([replay_table])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') eval_table_size = _CONFIG.value.num_eval_points # TODO(joshgreaves): Choose an appropriate rate_limiter, max_size. server = reverb.Server(tables=[ reverb.Table(name='successor_table', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1_000_000, rate_limiter=reverb.rate_limiters.MinSize(20_000)), reverb.Table( name='eval_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), max_size=eval_table_size, rate_limiter=reverb.rate_limiters.MinSize(eval_table_size)), ], port=FLAGS.port) server.wait()
def make_replay_tables(environment_spec: specs.EnvironmentSpec ) -> Sequence[reverb.Table]: """Create tables to insert data into.""" return [ reverb.Table( name='default', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1000000, rate_limiter=rate_limiters.MinSize(1), signature=adders_reverb.NStepTransitionAdder.signature( environment_spec)) ]
def __init__(self, min_size: int = 64, max_size: int = 40000): self._min_size = min_size self._table_name = 'priority_table' self._server = reverb.Server( tables=[ reverb.Table( name=self._table_name, sampler=reverb.selectors.Prioritized( priority_exponent=0.8), remover=reverb.selectors.Fifo(), max_size=int(max_size), rate_limiter=reverb.rate_limiters.MinSize(min_size)), ], # Sets the port to None to make the server pick one automatically. port=None)
def tables_from_proto( configs: Sequence[checkpoint_pb2.PriorityTableCheckpoint] ) -> Sequence[reverb.Table]: """Convert protobuf to reverb.Table.""" tables = [] for config in configs: tables.append( reverb.Table( name=config.table_name, sampler=selector_from_proto(config.sampler), remover=selector_from_proto(config.remover), max_size=config.max_size, rate_limiter=rate_limiter_from_proto(config.rate_limiter), max_times_sampled=config.max_times_sampled, )) return tables
def test_deterministic_dataset_from_heap_sampler_remover(self): uniform_sampler_min_heap_remover_table = reverb.Table( name=self._table_name, sampler=reverb.selectors.MaxHeap(), remover=reverb.selectors.MinHeap(), max_size=100, max_times_sampled=0, rate_limiter=reverb.rate_limiters.MinSize(1)) server = reverb.Server([uniform_sampler_min_heap_remover_table]) replay = reverb_replay_buffer.ReverbReplayBuffer(self._data_spec, self._table_name, local_server=server, sequence_length=None) replay.as_dataset(single_deterministic_pass=True) server.stop()
def setUp(self): super().setUp() self._mock_client = mock.MagicMock() self._mock_writer = mock.MagicMock() self._mock_client.trajectory_writer = self._mock_writer self._mock_writer.return_value = self._mock_writer self._table_name = 'uniform_table' self._table = reverb.Table( self._table_name, max_size=100, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) self._reverb_server = reverb.Server([self._table], port=None) self._reverb_client = self._reverb_server.localhost_client()
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec, policy: impala_networks.IMPALANetworks, ) -> List[reverb.Table]: """The queue; use XData or INFO log.""" del policy num_actions = environment_spec.actions.num_values extra_spec = { 'core_state': self._core_state_spec, 'logits': jnp.ones(shape=(num_actions,), dtype=jnp.float32) } signature = reverb_adders.SequenceAdder.signature( environment_spec, extra_spec, sequence_length=self._config.sequence_length) # Maybe create rate limiter. # Setting the samples_per_insert ratio less than the default of 1.0, allows # the agent to drop data for the benefit of using data from most up-to-date # policies to compute its learner updates. samples_per_insert = self._config.samples_per_insert if samples_per_insert: if samples_per_insert > 1.0 or samples_per_insert <= 0.0: raise ValueError( 'Impala requires a samples_per_insert ratio in the range (0, 1],' f' but received {samples_per_insert}.') limiter = reverb.rate_limiters.SampleToInsertRatio( samples_per_insert=samples_per_insert, min_size_to_sample=1, error_buffer=self._config.batch_size) else: limiter = reverb.rate_limiters.MinSize(1) table_extensions = [] if self._table_extension is not None: table_extensions = [self._table_extension()] queue = reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_queue_size, max_times_sampled=1, rate_limiter=limiter, extensions=table_extensions, signature=signature) return [queue]
def make_replay_tables( self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]: replay_tables = self._rl_agent.make_replay_tables(environment_spec) if self._config.share_iterator: return replay_tables replay_tables.append( reverb.Table( name=self._config.replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._config.max_replay_size, rate_limiter=rate_limiters.MinSize( self._config.min_replay_size), signature=adders_reverb.NStepTransitionAdder.signature( environment_spec))) return replay_tables
def test_capacity_set(self): table_name = 'test_table' capacity = 100 uniform_table = reverb.Table( table_name, max_size=capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(3)) server = reverb.Server([uniform_table]) data_spec = tensor_spec.TensorSpec((), tf.float32) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, local_server=server, sequence_length=None) self.assertEqual(capacity, replay.capacity) server.stop()
def replay(self): """The replay storage.""" if self._samples_per_insert: limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._min_replay_size, samples_per_insert=self._samples_per_insert, error_buffer=self._batch_size) else: limiter = reverb.rate_limiters.MinSize(self._min_replay_size) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(self._priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._max_replay_size, rate_limiter=limiter, signature=adders.NStepTransitionAdder.signature(self._env_spec)) return [replay_table]
def main(): mp.set_start_method("spawn", force=True) config_path = path.join(path.dirname(__file__), "config.yaml") with open(config_path, "r") as f: config = yaml.safe_load(f) # init perwez pwz_proc, pwz_config = perwez.start_server() # init reverb_server reverb.Server( tables=[ reverb.Table( name=TABLE_NAME, sampler=reverb.selectors.Prioritized(0.6), remover=reverb.selectors.Fifo(), max_size=config["replay_buffer"]["capacity"], rate_limiter=reverb.rate_limiters.MinSize(1000), ) ], port=PORT, ) # worker subprocesses worker_processes = [] num_workers = config["common"]["num_workers"] for idx in range(num_workers): p = mp.Process( name=f"apex-worker-{idx}", target=worker_main, args=(pwz_config["url"], config, idx), daemon=True, ) p.start() worker_processes.append(p) # trainer process should be the main process try: trainer_main_tf_dataset(pwz_config["url"], config) finally: print("exiting...") for p in worker_processes: p.terminate() p.join() pwz_proc.terminate() pwz_proc.join()