def make_program(num_producers: int) -> lp.Program: """Define the distributed program topology.""" program = lp.Program('consumer_producers') # Use `program.group()` to group homogeneous nodes. with program.group('producer'): # Add a `CourierNode` to the program. `lp.CourierNode()` takes the producer # constructor and its arguments, and exposes it as an RPC server. # `program.add_node(lp.CourierNode(...))` returns a handle to this server. # These handles can then be passed to other nodes. producers = [ program.add_node(lp.CourierNode(Producer)) for _ in range(num_producers) ] # Launch a single consumer that connects to the list of producers. # Note: The use of `label` here actually creates a group with one single node. node = lp.CourierNode( Consumer, producers=producers, stop_fn=lp.make_program_stopper(FLAGS.lp_launch_type)) program.add_node(node, label='consumer') return program
def build(self, name='agent', program: Optional[lp.Program] = None): """Build the distributed agent topology.""" if not program: program = lp.Program(name=name) key = jax.random.PRNGKey(self._seed) replay_node = lp.ReverbNode( self.replay, checkpoint_time_delta_minutes=( self._checkpointing_config. replay_checkpointing_time_delta_minutes)) with program.group('replay'): if self._multithreading_colocate_learner_and_reverb: replay = replay_node.create_handle() else: replay = program.add_node(replay_node) with program.group('counter'): counter = program.add_node(lp.CourierNode(self.counter)) if self._max_number_of_steps is not None: _ = program.add_node( lp.CourierNode(self.coordinator, counter, self._max_number_of_steps)) learner_key, key = jax.random.split(key) learner_node = lp.CourierNode(self.learner, learner_key, replay, counter) with program.group('learner'): if self._multithreading_colocate_learner_and_reverb: learner = learner_node.create_handle() program.add_node( lp.MultiThreadingColocation([learner_node, replay_node])) else: learner = program.add_node(learner_node) def make_actor(random_key: networks_lib.PRNGKey, policy_network: PolicyNetwork, variable_source: core.VariableSource) -> core.Actor: return self._builder.make_actor(random_key, policy_network, variable_source=variable_source) with program.group('evaluator'): for evaluator in self._evaluator_factories: evaluator_key, key = jax.random.split(key) program.add_node( lp.CourierNode(evaluator, evaluator_key, learner, counter, make_actor)) with program.group('actor'): for actor_id in range(self._num_actors): actor_key, key = jax.random.split(key) program.add_node( lp.CourierNode(self.actor, actor_key, replay, learner, counter, actor_id)) if self._make_snapshot_models and self._checkpointing_config: with program.group('model_saver'): program.add_node(lp.CourierNode(self.model_saver, learner)) return program
def make_distributed_experiment( experiment: config.ExperimentConfig, num_actors: int, *, num_learner_nodes: int = 1, num_actors_per_node: int = 1, multithreading_colocate_learner_and_reverb: bool = False, checkpointing_config: Optional[config.CheckpointingConfig] = None, make_snapshot_models: Optional[SnapshotModelFactory] = None, name='agent', program: Optional[lp.Program] = None): """Builds distributed agent based on a builder.""" if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1: raise ValueError( 'Replay and learner colocation is not yet supported when the learner is' ' spread across multiple nodes (num_learner_nodes > 1). Please contact' ' Acme devs if this is a feature you want. Got:' '\tmultithreading_colocate_learner_and_reverb=' f'{multithreading_colocate_learner_and_reverb}' f'\tnum_learner_nodes={num_learner_nodes}.') if checkpointing_config is None: checkpointing_config = config.CheckpointingConfig() def build_replay(): """The replay storage.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) network = experiment.network_factory(spec) policy = config.make_policy(experiment=experiment, networks=network, environment_spec=spec, evaluation=False) return experiment.builder.make_replay_tables(spec, policy) def build_model_saver(variable_source: core.VariableSource): environment = experiment.environment_factory(0) spec = specs.make_environment_spec(environment) networks = experiment.network_factory(spec) models = make_snapshot_models(networks, spec) # TODO(raveman): Decouple checkpointing and snapshotting configs. return snapshotter.JAXSnapshotter(variable_source=variable_source, models=models, path=checkpointing_config.directory, subdirectory='snapshots', add_uid=checkpointing_config.add_uid) def build_counter(): return savers.CheckpointingRunner( counting.Counter(), key='counter', subdirectory='counter', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) def build_learner( random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: Optional[counting.Counter] = None, primary_learner: Optional[core.Learner] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) iterator = experiment.builder.make_dataset_iterator(replay) # make_dataset_iterator is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner(random_key, networks, iterator, experiment.logger_factory, spec, replay, counter) if primary_learner is None: learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) else: learner.restore(primary_learner.save()) # NOTE: This initially synchronizes secondary learner states with the # primary one. Further synchronization should be handled by the learner # properly doing a pmap/pmean on the loss/gradients, respectively. return learner def build_actor( random_key: networks_lib.PRNGKey, replay: reverb.Client, variable_source: core.VariableSource, counter: counting.Counter, actor_id: ActorId, ) -> environment_loop.EnvironmentLoop: """The actor process.""" environment_key, actor_key = jax.random.split(random_key) # Create environment and policy core. # Environments normally require uint32 as a seed. environment = experiment.environment_factory( utils.sample_uint32(environment_key)) environment_spec = specs.make_environment_spec(environment) networks = experiment.network_factory(environment_spec) policy_network = config.make_policy(experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=False) adder = experiment.builder.make_adder(replay, environment_spec, policy_network) actor = experiment.builder.make_actor(actor_key, policy_network, environment_spec, variable_source, adder) # Create logger and counter. counter = counting.Counter(counter, 'actor') logger = experiment.logger_factory('actor', counter.get_steps_key(), actor_id) # Create the loop to connect environment and agent. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=experiment.observers) if not program: program = lp.Program(name=name) key = jax.random.PRNGKey(experiment.seed) replay_node = lp.ReverbNode( build_replay, checkpoint_time_delta_minutes=( checkpointing_config.replay_checkpointing_time_delta_minutes)) replay = replay_node.create_handle() counter = program.add_node(lp.CourierNode(build_counter), label='counter') if experiment.max_num_actor_steps is not None: program.add_node(lp.CourierNode(lp_utils.StepsLimiter, counter, experiment.max_num_actor_steps), label='counter') learner_key, key = jax.random.split(key) learner_node = lp.CourierNode(build_learner, learner_key, replay, counter) learner = learner_node.create_handle() variable_sources = [learner] if multithreading_colocate_learner_and_reverb: program.add_node(lp.MultiThreadingColocation( [learner_node, replay_node]), label='learner') else: program.add_node(replay_node, label='replay') with program.group('learner'): program.add_node(learner_node) # Maybe create secondary learners, necessary when using multi-host # accelerators. # Warning! If you set num_learner_nodes > 1, make sure the learner class # does the appropriate pmap/pmean operations on the loss/gradients, # respectively. for _ in range(1, num_learner_nodes): learner_key, key = jax.random.split(key) variable_sources.append( program.add_node( lp.CourierNode(build_learner, learner_key, replay, primary_learner=learner))) # NOTE: Secondary learners are used to load-balance get_variables calls, # which is why they get added to the list of available variable sources. # NOTE: Only the primary learner checkpoints. # NOTE: Do not pass the counter to the secondary learners to avoid # double counting of learner steps. with program.group('actor'): # Create all actor threads. *actor_keys, key = jax.random.split(key, num_actors + 1) variable_sources = itertools.cycle(variable_sources) actor_nodes = [ lp.CourierNode(build_actor, akey, replay, vsource, counter, aid) for aid, (akey, vsource) in enumerate(zip(actor_keys, variable_sources)) ] # Create (maybe colocated) actor nodes. if num_actors_per_node == 1: for actor_node in actor_nodes: program.add_node(actor_node) else: for i in range(0, num_actors, num_actors_per_node): program.add_node( lp.MultiThreadingColocation( actor_nodes[i:i + num_actors_per_node])) for evaluator in experiment.get_evaluator_factories(): evaluator_key, key = jax.random.split(key) program.add_node(lp.CourierNode(evaluator, evaluator_key, learner, counter, experiment.builder.make_actor), label='evaluator') if make_snapshot_models and checkpointing_config: program.add_node(lp.CourierNode(build_model_saver, learner), label='model_saver') return program
def build(self, name: str = "dial") -> Any: """Build the distributed system as a graph program. Args: name (str, optional): system name. Defaults to "dial". Returns: Any: graph program for distributed system training. """ program = lp.Program(name=name) counter = None with program.group("replay"): replay = program.add_node(lp.ReverbNode(self.replay)) if self._checkpoint: with program.group("counter"): counter = program.add_node(lp.CourierNode(self.counter)) if self._max_executor_steps: with program.group("coordinator"): _ = program.add_node(lp.CourierNode(self.coordinator, counter)) with program.group("trainer"): trainer = program.add_node( lp.CourierNode(self.trainer, replay, counter)) with program.group("evaluator"): program.add_node( lp.CourierNode(self.evaluator, trainer, counter, trainer)) if not self._num_caches: # Use the trainer as a single variable source. sources = [trainer] else: with program.group("cacher"): # Create a set of trainer caches. sources = [] for _ in range(self._num_caches): cacher = program.add_node( lp.CacherNode(trainer, refresh_interval_ms=2000, stale_after_ms=4000)) sources.append(cacher) with program.group("executor"): # Add executors which pull round-robin from our variable sources. for executor_id in range(self._num_exectors): source = sources[executor_id % len(sources)] program.add_node( lp.CourierNode( self.executor, executor_id, replay, source, counter, trainer, )) return program
def make_distributed_offline_experiment( experiment: config.OfflineExperimentConfig, *, checkpointing_config: Optional[config.CheckpointingConfig] = None, make_snapshot_models: Optional[SnapshotModelFactory] = None, name='agent', program: Optional[lp.Program] = None): """Builds distributed agent based on a builder.""" if checkpointing_config is None: checkpointing_config = config.CheckpointingConfig() def build_model_saver(variable_source: core.VariableSource): environment = experiment.environment_factory(0) spec = specs.make_environment_spec(environment) networks = experiment.network_factory(spec) models = make_snapshot_models(networks, spec) # TODO(raveman): Decouple checkpointing and snahpshotting configs. return snapshotter.JAXSnapshotter(variable_source=variable_source, models=models, path=checkpointing_config.directory, add_uid=checkpointing_config.add_uid) def build_counter(): return savers.CheckpointingRunner( counting.Counter(), key='counter', subdirectory='counter', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) def build_learner( random_key: networks_lib.PRNGKey, counter: Optional[counting.Counter] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) dataset_key, random_key = jax.random.split(random_key) iterator = experiment.demonstration_dataset_factory(dataset_key) # make_demonstrations is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner( random_key=random_key, networks=networks, dataset=iterator, logger_fn=experiment.logger_factory, environment_spec=spec, counter=counter) learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) return learner if not program: program = lp.Program(name=name) key = jax.random.PRNGKey(experiment.seed) counter = program.add_node(lp.CourierNode(build_counter), label='counter') if experiment.max_num_learner_steps is not None: program.add_node(lp.CourierNode(lp_utils.StepsLimiter, counter, experiment.max_num_learner_steps, steps_key='learner_steps'), label='counter') learner_key, key = jax.random.split(key) learner_node = lp.CourierNode(build_learner, learner_key, counter) learner = learner_node.create_handle() program.add_node(learner_node, label='learner') for evaluator in experiment.get_evaluator_factories(): evaluator_key, key = jax.random.split(key) program.add_node(lp.CourierNode(evaluator, evaluator_key, learner, counter, experiment.builder.make_actor), label='evaluator') if make_snapshot_models and checkpointing_config: program.add_node(lp.CourierNode(build_model_saver, learner), label='model_saver') return program