def make_program(num_producers: int) -> lp.Program:
  """Define the distributed program topology."""
  program = lp.Program('consumer_producers')

  # Use `program.group()` to group homogeneous nodes.
  with program.group('producer'):
    # Add a `CourierNode` to the program. `lp.CourierNode()` takes the producer
    # constructor and its arguments, and exposes it as an RPC server.
    # `program.add_node(lp.CourierNode(...))` returns a handle to this server.
    # These handles can then be passed to other nodes.
    producers = [
        program.add_node(lp.CourierNode(Producer)) for _ in range(num_producers)
    ]

  # Launch a single consumer that connects to the list of producers.
  # Note: The use of `label` here actually creates a group with one single node.
  node = lp.CourierNode(
      Consumer,
      producers=producers,
      stop_fn=lp.make_program_stopper(FLAGS.lp_launch_type))
  program.add_node(node, label='consumer')

  return program
Beispiel #2
0
    def build(self, name='agent', program: Optional[lp.Program] = None):
        """Build the distributed agent topology."""
        if not program:
            program = lp.Program(name=name)

        key = jax.random.PRNGKey(self._seed)

        replay_node = lp.ReverbNode(
            self.replay,
            checkpoint_time_delta_minutes=(
                self._checkpointing_config.
                replay_checkpointing_time_delta_minutes))

        with program.group('replay'):
            if self._multithreading_colocate_learner_and_reverb:
                replay = replay_node.create_handle()
            else:
                replay = program.add_node(replay_node)

        with program.group('counter'):
            counter = program.add_node(lp.CourierNode(self.counter))
            if self._max_number_of_steps is not None:
                _ = program.add_node(
                    lp.CourierNode(self.coordinator, counter,
                                   self._max_number_of_steps))

        learner_key, key = jax.random.split(key)
        learner_node = lp.CourierNode(self.learner, learner_key, replay,
                                      counter)
        with program.group('learner'):
            if self._multithreading_colocate_learner_and_reverb:
                learner = learner_node.create_handle()
                program.add_node(
                    lp.MultiThreadingColocation([learner_node, replay_node]))
            else:
                learner = program.add_node(learner_node)

        def make_actor(random_key: networks_lib.PRNGKey,
                       policy_network: PolicyNetwork,
                       variable_source: core.VariableSource) -> core.Actor:
            return self._builder.make_actor(random_key,
                                            policy_network,
                                            variable_source=variable_source)

        with program.group('evaluator'):
            for evaluator in self._evaluator_factories:
                evaluator_key, key = jax.random.split(key)
                program.add_node(
                    lp.CourierNode(evaluator, evaluator_key, learner, counter,
                                   make_actor))

        with program.group('actor'):
            for actor_id in range(self._num_actors):
                actor_key, key = jax.random.split(key)
                program.add_node(
                    lp.CourierNode(self.actor, actor_key, replay, learner,
                                   counter, actor_id))

        if self._make_snapshot_models and self._checkpointing_config:
            with program.group('model_saver'):
                program.add_node(lp.CourierNode(self.model_saver, learner))

        return program
def make_distributed_experiment(
        experiment: config.ExperimentConfig,
        num_actors: int,
        *,
        num_learner_nodes: int = 1,
        num_actors_per_node: int = 1,
        multithreading_colocate_learner_and_reverb: bool = False,
        checkpointing_config: Optional[config.CheckpointingConfig] = None,
        make_snapshot_models: Optional[SnapshotModelFactory] = None,
        name='agent',
        program: Optional[lp.Program] = None):
    """Builds distributed agent based on a builder."""

    if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1:
        raise ValueError(
            'Replay and learner colocation is not yet supported when the learner is'
            ' spread across multiple nodes (num_learner_nodes > 1). Please contact'
            ' Acme devs if this is a feature you want. Got:'
            '\tmultithreading_colocate_learner_and_reverb='
            f'{multithreading_colocate_learner_and_reverb}'
            f'\tnum_learner_nodes={num_learner_nodes}.')

    if checkpointing_config is None:
        checkpointing_config = config.CheckpointingConfig()

    def build_replay():
        """The replay storage."""
        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))
        network = experiment.network_factory(spec)
        policy = config.make_policy(experiment=experiment,
                                    networks=network,
                                    environment_spec=spec,
                                    evaluation=False)
        return experiment.builder.make_replay_tables(spec, policy)

    def build_model_saver(variable_source: core.VariableSource):
        environment = experiment.environment_factory(0)
        spec = specs.make_environment_spec(environment)
        networks = experiment.network_factory(spec)
        models = make_snapshot_models(networks, spec)
        # TODO(raveman): Decouple checkpointing and snapshotting configs.
        return snapshotter.JAXSnapshotter(variable_source=variable_source,
                                          models=models,
                                          path=checkpointing_config.directory,
                                          subdirectory='snapshots',
                                          add_uid=checkpointing_config.add_uid)

    def build_counter():
        return savers.CheckpointingRunner(
            counting.Counter(),
            key='counter',
            subdirectory='counter',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

    def build_learner(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: Optional[counting.Counter] = None,
        primary_learner: Optional[core.Learner] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        iterator = experiment.builder.make_dataset_iterator(replay)
        # make_dataset_iterator is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(random_key, networks,
                                                  iterator,
                                                  experiment.logger_factory,
                                                  spec, replay, counter)

        if primary_learner is None:
            learner = savers.CheckpointingRunner(
                learner,
                key='learner',
                subdirectory='learner',
                time_delta_minutes=5,
                directory=checkpointing_config.directory,
                add_uid=checkpointing_config.add_uid,
                max_to_keep=checkpointing_config.max_to_keep)
        else:
            learner.restore(primary_learner.save())
            # NOTE: This initially synchronizes secondary learner states with the
            # primary one. Further synchronization should be handled by the learner
            # properly doing a pmap/pmean on the loss/gradients, respectively.

        return learner

    def build_actor(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        actor_id: ActorId,
    ) -> environment_loop.EnvironmentLoop:
        """The actor process."""
        environment_key, actor_key = jax.random.split(random_key)
        # Create environment and policy core.

        # Environments normally require uint32 as a seed.
        environment = experiment.environment_factory(
            utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)

        networks = experiment.network_factory(environment_spec)
        policy_network = config.make_policy(experiment=experiment,
                                            networks=networks,
                                            environment_spec=environment_spec,
                                            evaluation=False)
        adder = experiment.builder.make_adder(replay, environment_spec,
                                              policy_network)
        actor = experiment.builder.make_actor(actor_key, policy_network,
                                              environment_spec,
                                              variable_source, adder)

        # Create logger and counter.
        counter = counting.Counter(counter, 'actor')
        logger = experiment.logger_factory('actor', counter.get_steps_key(),
                                           actor_id)
        # Create the loop to connect environment and agent.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=experiment.observers)

    if not program:
        program = lp.Program(name=name)

    key = jax.random.PRNGKey(experiment.seed)

    replay_node = lp.ReverbNode(
        build_replay,
        checkpoint_time_delta_minutes=(
            checkpointing_config.replay_checkpointing_time_delta_minutes))
    replay = replay_node.create_handle()

    counter = program.add_node(lp.CourierNode(build_counter), label='counter')

    if experiment.max_num_actor_steps is not None:
        program.add_node(lp.CourierNode(lp_utils.StepsLimiter, counter,
                                        experiment.max_num_actor_steps),
                         label='counter')

    learner_key, key = jax.random.split(key)
    learner_node = lp.CourierNode(build_learner, learner_key, replay, counter)
    learner = learner_node.create_handle()
    variable_sources = [learner]

    if multithreading_colocate_learner_and_reverb:
        program.add_node(lp.MultiThreadingColocation(
            [learner_node, replay_node]),
                         label='learner')
    else:
        program.add_node(replay_node, label='replay')

        with program.group('learner'):
            program.add_node(learner_node)

            # Maybe create secondary learners, necessary when using multi-host
            # accelerators.
            # Warning! If you set num_learner_nodes > 1, make sure the learner class
            # does the appropriate pmap/pmean operations on the loss/gradients,
            # respectively.
            for _ in range(1, num_learner_nodes):
                learner_key, key = jax.random.split(key)
                variable_sources.append(
                    program.add_node(
                        lp.CourierNode(build_learner,
                                       learner_key,
                                       replay,
                                       primary_learner=learner)))
                # NOTE: Secondary learners are used to load-balance get_variables calls,
                # which is why they get added to the list of available variable sources.
                # NOTE: Only the primary learner checkpoints.
                # NOTE: Do not pass the counter to the secondary learners to avoid
                # double counting of learner steps.

    with program.group('actor'):
        # Create all actor threads.
        *actor_keys, key = jax.random.split(key, num_actors + 1)
        variable_sources = itertools.cycle(variable_sources)
        actor_nodes = [
            lp.CourierNode(build_actor, akey, replay, vsource, counter, aid)
            for aid, (akey,
                      vsource) in enumerate(zip(actor_keys, variable_sources))
        ]

        # Create (maybe colocated) actor nodes.
        if num_actors_per_node == 1:
            for actor_node in actor_nodes:
                program.add_node(actor_node)
        else:
            for i in range(0, num_actors, num_actors_per_node):
                program.add_node(
                    lp.MultiThreadingColocation(
                        actor_nodes[i:i + num_actors_per_node]))

    for evaluator in experiment.get_evaluator_factories():
        evaluator_key, key = jax.random.split(key)
        program.add_node(lp.CourierNode(evaluator, evaluator_key, learner,
                                        counter,
                                        experiment.builder.make_actor),
                         label='evaluator')

    if make_snapshot_models and checkpointing_config:
        program.add_node(lp.CourierNode(build_model_saver, learner),
                         label='model_saver')

    return program
Beispiel #4
0
    def build(self, name: str = "dial") -> Any:
        """Build the distributed system as a graph program.

        Args:
            name (str, optional): system name. Defaults to "dial".

        Returns:
            Any: graph program for distributed system training.
        """

        program = lp.Program(name=name)
        counter = None

        with program.group("replay"):
            replay = program.add_node(lp.ReverbNode(self.replay))

        if self._checkpoint:
            with program.group("counter"):
                counter = program.add_node(lp.CourierNode(self.counter))

        if self._max_executor_steps:
            with program.group("coordinator"):
                _ = program.add_node(lp.CourierNode(self.coordinator, counter))

        with program.group("trainer"):
            trainer = program.add_node(
                lp.CourierNode(self.trainer, replay, counter))

        with program.group("evaluator"):
            program.add_node(
                lp.CourierNode(self.evaluator, trainer, counter, trainer))

        if not self._num_caches:
            # Use the trainer as a single variable source.
            sources = [trainer]
        else:
            with program.group("cacher"):
                # Create a set of trainer caches.
                sources = []
                for _ in range(self._num_caches):
                    cacher = program.add_node(
                        lp.CacherNode(trainer,
                                      refresh_interval_ms=2000,
                                      stale_after_ms=4000))
                    sources.append(cacher)

        with program.group("executor"):
            # Add executors which pull round-robin from our variable sources.
            for executor_id in range(self._num_exectors):
                source = sources[executor_id % len(sources)]
                program.add_node(
                    lp.CourierNode(
                        self.executor,
                        executor_id,
                        replay,
                        source,
                        counter,
                        trainer,
                    ))

        return program
def make_distributed_offline_experiment(
        experiment: config.OfflineExperimentConfig,
        *,
        checkpointing_config: Optional[config.CheckpointingConfig] = None,
        make_snapshot_models: Optional[SnapshotModelFactory] = None,
        name='agent',
        program: Optional[lp.Program] = None):
    """Builds distributed agent based on a builder."""

    if checkpointing_config is None:
        checkpointing_config = config.CheckpointingConfig()

    def build_model_saver(variable_source: core.VariableSource):
        environment = experiment.environment_factory(0)
        spec = specs.make_environment_spec(environment)
        networks = experiment.network_factory(spec)
        models = make_snapshot_models(networks, spec)
        # TODO(raveman): Decouple checkpointing and snahpshotting configs.
        return snapshotter.JAXSnapshotter(variable_source=variable_source,
                                          models=models,
                                          path=checkpointing_config.directory,
                                          add_uid=checkpointing_config.add_uid)

    def build_counter():
        return savers.CheckpointingRunner(
            counting.Counter(),
            key='counter',
            subdirectory='counter',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

    def build_learner(
        random_key: networks_lib.PRNGKey,
        counter: Optional[counting.Counter] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        dataset_key, random_key = jax.random.split(random_key)
        iterator = experiment.demonstration_dataset_factory(dataset_key)
        # make_demonstrations is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(
            random_key=random_key,
            networks=networks,
            dataset=iterator,
            logger_fn=experiment.logger_factory,
            environment_spec=spec,
            counter=counter)

        learner = savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

        return learner

    if not program:
        program = lp.Program(name=name)

    key = jax.random.PRNGKey(experiment.seed)

    counter = program.add_node(lp.CourierNode(build_counter), label='counter')

    if experiment.max_num_learner_steps is not None:
        program.add_node(lp.CourierNode(lp_utils.StepsLimiter,
                                        counter,
                                        experiment.max_num_learner_steps,
                                        steps_key='learner_steps'),
                         label='counter')

    learner_key, key = jax.random.split(key)
    learner_node = lp.CourierNode(build_learner, learner_key, counter)
    learner = learner_node.create_handle()
    program.add_node(learner_node, label='learner')

    for evaluator in experiment.get_evaluator_factories():
        evaluator_key, key = jax.random.split(key)
        program.add_node(lp.CourierNode(evaluator, evaluator_key, learner,
                                        counter,
                                        experiment.builder.make_actor),
                         label='evaluator')

    if make_snapshot_models and checkpointing_config:
        program.add_node(lp.CourierNode(build_model_saver, learner),
                         label='model_saver')

    return program