def build(self, name='r2d2'): """Build the distributed agent topology.""" program = lp.Program(name=name) with program.group('replay'): replay = program.add_node(lp.ReverbNode(self.replay)) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) with program.group('learner'): learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) with program.group('cacher'): cacher = program.add_node( lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) with program.group('evaluator'): program.add_node(lp.CourierNode(self.evaluator, cacher, counter)) # Generate an epsilon for each actor. epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) with program.group('actor'): for epsilon in epsilons: program.add_node( lp.CourierNode(self.actor, replay, cacher, counter, epsilon)) return program
def build(self, name='MCTS'): """Builds the distributed agent topology.""" program = lp.Program(name=name) with program.group('replay'): replay = program.add_node(lp.ReverbNode(self.replay), label='replay') with program.group('counter'): counter = program.add_node(lp.CourierNode(counting.Counter), label='counter') with program.group('learner'): learner = program.add_node(lp.CourierNode(self.learner, replay, counter), label='learner') with program.group('evaluator'): program.add_node(lp.CourierNode(self.evaluator, learner, counter), label='evaluator') with program.group('actor'): program.add_node(lp.CourierNode(self.actor, replay, learner, counter), label='actor') return program
def build(self, name='impala'): """Build the distributed agent topology.""" program = lp.Program(name=name) with program.group('replay'): queue = program.add_node(lp.ReverbNode(self.queue)) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) with program.group('learner'): learner = program.add_node( lp.CourierNode(self.learner, queue, counter)) with program.group('evaluator'): program.add_node(lp.CourierNode(self.evaluator, learner, counter)) with program.group('cacher'): cacher = program.add_node( lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) with program.group('actor'): for _ in range(self._num_actors): program.add_node(lp.CourierNode(self.actor, queue, cacher, counter)) return program
def build(self, name: str = "madqn") -> Any: """Build the distributed system as a graph program. Args: name (str, optional): system name. Defaults to "madqn". Returns: Any: graph program for distributed system training. """ program = lp.Program(name=name) with program.group("replay"): replay = program.add_node(lp.ReverbNode(self.replay)) with program.group("counter"): counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint)) if self._max_executor_steps: with program.group("coordinator"): _ = program.add_node(lp.CourierNode(self.coordinator, counter)) with program.group("trainer"): trainer = program.add_node(lp.CourierNode(self.trainer, replay, counter)) with program.group("evaluator"): program.add_node(lp.CourierNode(self.evaluator, trainer, counter, trainer)) if not self._num_caches: # Use the trainer as a single variable source. sources = [trainer] else: with program.group("cacher"): # Create a set of trainer caches. sources = [] for _ in range(self._num_caches): cacher = program.add_node( lp.CacherNode( trainer, refresh_interval_ms=2000, stale_after_ms=4000 ) ) sources.append(cacher) with program.group("executor"): # Add executors which pull round-robin from our variable sources. for executor_id in range(self._num_exectors): source = sources[executor_id % len(sources)] program.add_node( lp.CourierNode( self.executor, executor_id, replay, source, counter, trainer, ) ) return program
def build(self, name='agent', program = None): """Build the distributed agent topology.""" if not program: program = lp.Program(name=name) key = jax.random.PRNGKey(self._seed) replay_node = lp.ReverbNode(self.replay) with program.group('replay'): if self._multithreading_colocate_learner_and_reverb: replay = replay_node.create_handle() else: replay = program.add_node(replay_node) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) if self._max_number_of_steps is not None: _ = program.add_node( lp.CourierNode(self.coordinator, counter, self._max_number_of_steps)) learner_key, key = jax.random.split(key) learner_node = lp.CourierNode(self.learner, learner_key, replay, counter) with program.group('learner'): if self._multithreading_colocate_learner_and_reverb: learner = learner_node.create_handle() program.add_node( lp.MultiThreadingColocation([learner_node, replay_node])) else: learner = program.add_node(learner_node) def make_actor(random_key, policy_network, variable_source): return self._builder.make_actor( random_key, policy_network, variable_source=variable_source) with program.group('evaluator'): for evaluator in self._evaluator_factories: evaluator_key, key = jax.random.split(key) program.add_node( lp.CourierNode(evaluator, evaluator_key, learner, counter, make_actor)) with program.group('actor'): for actor_id in range(self._num_actors): actor_key, key = jax.random.split(key) program.add_node( lp.CourierNode(self.actor, actor_key, replay, learner, counter, actor_id)) return program
def build(self, name='dqn'): """Build the distributed agent topology.""" program = lp.Program(name=name) with program.group('replay'): replay = program.add_node(lp.ReverbNode(self.replay)) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) if self._max_actor_steps: program.add_node( lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) with program.group('learner'): learner = program.add_node( lp.CourierNode(self.learner, replay, counter)) with program.group('evaluator'): program.add_node(lp.CourierNode(self.evaluator, learner, counter)) # Generate an epsilon for each actor. epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) with program.group('cacher'): # Create a set of learner caches. sources = [] for _ in range(self._num_caches): cacher = program.add_node( lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) sources.append(cacher) with program.group('actor'): # Add actors which pull round-robin from our variable sources. for actor_id, epsilon in enumerate(epsilons): source = sources[actor_id % len(sources)] program.add_node( lp.CourierNode(self.actor, replay, source, counter, epsilon)) return program
def build(self, name='dmpo'): """Build the distributed agent topology.""" program = lp.Program(name=name) with program.group('replay'): replay = program.add_node(lp.ReverbNode(self.replay)) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) if self._max_actor_steps: _ = program.add_node( lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) with program.group('learner'): learner = program.add_node( lp.CourierNode(self.learner, replay, counter)) with program.group('evaluator'): program.add_node(lp.CourierNode(self.evaluator, learner, counter)) if not self._num_caches: # Use our learner as a single variable source. sources = [learner] else: with program.group('cacher'): # Create a set of learner caches. sources = [] for _ in range(self._num_caches): cacher = program.add_node( lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) sources.append(cacher) with program.group('actor'): # Add actors which pull round-robin from our variable sources. for actor_id in range(self._num_actors): source = sources[actor_id % len(sources)] program.add_node( lp.CourierNode(self.actor, replay, source, counter, actor_id)) return program
def make_distributed_experiment( experiment: config.ExperimentConfig, num_actors: int, *, num_learner_nodes: int = 1, num_actors_per_node: int = 1, multithreading_colocate_learner_and_reverb: bool = False, checkpointing_config: Optional[config.CheckpointingConfig] = None, make_snapshot_models: Optional[SnapshotModelFactory] = None, name='agent', program: Optional[lp.Program] = None): """Builds distributed agent based on a builder.""" if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1: raise ValueError( 'Replay and learner colocation is not yet supported when the learner is' ' spread across multiple nodes (num_learner_nodes > 1). Please contact' ' Acme devs if this is a feature you want. Got:' '\tmultithreading_colocate_learner_and_reverb=' f'{multithreading_colocate_learner_and_reverb}' f'\tnum_learner_nodes={num_learner_nodes}.') if checkpointing_config is None: checkpointing_config = config.CheckpointingConfig() def build_replay(): """The replay storage.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) network = experiment.network_factory(spec) policy = config.make_policy(experiment=experiment, networks=network, environment_spec=spec, evaluation=False) return experiment.builder.make_replay_tables(spec, policy) def build_model_saver(variable_source: core.VariableSource): environment = experiment.environment_factory(0) spec = specs.make_environment_spec(environment) networks = experiment.network_factory(spec) models = make_snapshot_models(networks, spec) # TODO(raveman): Decouple checkpointing and snapshotting configs. return snapshotter.JAXSnapshotter(variable_source=variable_source, models=models, path=checkpointing_config.directory, subdirectory='snapshots', add_uid=checkpointing_config.add_uid) def build_counter(): return savers.CheckpointingRunner( counting.Counter(), key='counter', subdirectory='counter', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) def build_learner( random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: Optional[counting.Counter] = None, primary_learner: Optional[core.Learner] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) iterator = experiment.builder.make_dataset_iterator(replay) # make_dataset_iterator is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner(random_key, networks, iterator, experiment.logger_factory, spec, replay, counter) if primary_learner is None: learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) else: learner.restore(primary_learner.save()) # NOTE: This initially synchronizes secondary learner states with the # primary one. Further synchronization should be handled by the learner # properly doing a pmap/pmean on the loss/gradients, respectively. return learner def build_actor( random_key: networks_lib.PRNGKey, replay: reverb.Client, variable_source: core.VariableSource, counter: counting.Counter, actor_id: ActorId, ) -> environment_loop.EnvironmentLoop: """The actor process.""" environment_key, actor_key = jax.random.split(random_key) # Create environment and policy core. # Environments normally require uint32 as a seed. environment = experiment.environment_factory( utils.sample_uint32(environment_key)) environment_spec = specs.make_environment_spec(environment) networks = experiment.network_factory(environment_spec) policy_network = config.make_policy(experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=False) adder = experiment.builder.make_adder(replay, environment_spec, policy_network) actor = experiment.builder.make_actor(actor_key, policy_network, environment_spec, variable_source, adder) # Create logger and counter. counter = counting.Counter(counter, 'actor') logger = experiment.logger_factory('actor', counter.get_steps_key(), actor_id) # Create the loop to connect environment and agent. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=experiment.observers) if not program: program = lp.Program(name=name) key = jax.random.PRNGKey(experiment.seed) replay_node = lp.ReverbNode( build_replay, checkpoint_time_delta_minutes=( checkpointing_config.replay_checkpointing_time_delta_minutes)) replay = replay_node.create_handle() counter = program.add_node(lp.CourierNode(build_counter), label='counter') if experiment.max_num_actor_steps is not None: program.add_node(lp.CourierNode(lp_utils.StepsLimiter, counter, experiment.max_num_actor_steps), label='counter') learner_key, key = jax.random.split(key) learner_node = lp.CourierNode(build_learner, learner_key, replay, counter) learner = learner_node.create_handle() variable_sources = [learner] if multithreading_colocate_learner_and_reverb: program.add_node(lp.MultiThreadingColocation( [learner_node, replay_node]), label='learner') else: program.add_node(replay_node, label='replay') with program.group('learner'): program.add_node(learner_node) # Maybe create secondary learners, necessary when using multi-host # accelerators. # Warning! If you set num_learner_nodes > 1, make sure the learner class # does the appropriate pmap/pmean operations on the loss/gradients, # respectively. for _ in range(1, num_learner_nodes): learner_key, key = jax.random.split(key) variable_sources.append( program.add_node( lp.CourierNode(build_learner, learner_key, replay, primary_learner=learner))) # NOTE: Secondary learners are used to load-balance get_variables calls, # which is why they get added to the list of available variable sources. # NOTE: Only the primary learner checkpoints. # NOTE: Do not pass the counter to the secondary learners to avoid # double counting of learner steps. with program.group('actor'): # Create all actor threads. *actor_keys, key = jax.random.split(key, num_actors + 1) variable_sources = itertools.cycle(variable_sources) actor_nodes = [ lp.CourierNode(build_actor, akey, replay, vsource, counter, aid) for aid, (akey, vsource) in enumerate(zip(actor_keys, variable_sources)) ] # Create (maybe colocated) actor nodes. if num_actors_per_node == 1: for actor_node in actor_nodes: program.add_node(actor_node) else: for i in range(0, num_actors, num_actors_per_node): program.add_node( lp.MultiThreadingColocation( actor_nodes[i:i + num_actors_per_node])) for evaluator in experiment.get_evaluator_factories(): evaluator_key, key = jax.random.split(key) program.add_node(lp.CourierNode(evaluator, evaluator_key, learner, counter, experiment.builder.make_actor), label='evaluator') if make_snapshot_models and checkpointing_config: program.add_node(lp.CourierNode(build_model_saver, learner), label='model_saver') return program