예제 #1
0
  def build(self, name='r2d2'):
    """Build the distributed agent topology."""
    program = lp.Program(name=name)

    with program.group('replay'):
      replay = program.add_node(lp.ReverbNode(self.replay))

    with program.group('counter'):
      counter = program.add_node(lp.CourierNode(self.counter))

    with program.group('learner'):
      learner = program.add_node(lp.CourierNode(self.learner, replay, counter))

    with program.group('cacher'):
      cacher = program.add_node(
          lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000))

    with program.group('evaluator'):
      program.add_node(lp.CourierNode(self.evaluator, cacher, counter))

    # Generate an epsilon for each actor.
    epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0)

    with program.group('actor'):
      for epsilon in epsilons:
        program.add_node(
            lp.CourierNode(self.actor, replay, cacher, counter, epsilon))

    return program
예제 #2
0
    def build(self, name='MCTS'):
        """Builds the distributed agent topology."""
        program = lp.Program(name=name)

        with program.group('replay'):
            replay = program.add_node(lp.ReverbNode(self.replay),
                                      label='replay')

        with program.group('counter'):
            counter = program.add_node(lp.CourierNode(counting.Counter),
                                       label='counter')

        with program.group('learner'):
            learner = program.add_node(lp.CourierNode(self.learner, replay,
                                                      counter),
                                       label='learner')

        with program.group('evaluator'):
            program.add_node(lp.CourierNode(self.evaluator, learner, counter),
                             label='evaluator')

        with program.group('actor'):
            program.add_node(lp.CourierNode(self.actor, replay, learner,
                                            counter),
                             label='actor')

        return program
예제 #3
0
  def build(self, name='impala'):
    """Build the distributed agent topology."""
    program = lp.Program(name=name)

    with program.group('replay'):
      queue = program.add_node(lp.ReverbNode(self.queue))

    with program.group('counter'):
      counter = program.add_node(lp.CourierNode(self.counter))

    with program.group('learner'):
      learner = program.add_node(
          lp.CourierNode(self.learner, queue, counter))

    with program.group('evaluator'):
      program.add_node(lp.CourierNode(self.evaluator, learner, counter))

    with program.group('cacher'):
      cacher = program.add_node(
          lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000))

    with program.group('actor'):
      for _ in range(self._num_actors):
        program.add_node(lp.CourierNode(self.actor, queue, cacher, counter))

    return program
예제 #4
0
파일: system.py 프로젝트: NetColby/DNRL
    def build(self, name: str = "madqn") -> Any:
        """Build the distributed system as a graph program.

        Args:
            name (str, optional): system name. Defaults to "madqn".

        Returns:
            Any: graph program for distributed system training.
        """

        program = lp.Program(name=name)

        with program.group("replay"):
            replay = program.add_node(lp.ReverbNode(self.replay))

        with program.group("counter"):
            counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint))

        if self._max_executor_steps:
            with program.group("coordinator"):
                _ = program.add_node(lp.CourierNode(self.coordinator, counter))

        with program.group("trainer"):
            trainer = program.add_node(lp.CourierNode(self.trainer, replay, counter))

        with program.group("evaluator"):
            program.add_node(lp.CourierNode(self.evaluator, trainer, counter, trainer))

        if not self._num_caches:
            # Use the trainer as a single variable source.
            sources = [trainer]
        else:
            with program.group("cacher"):
                # Create a set of trainer caches.
                sources = []
                for _ in range(self._num_caches):
                    cacher = program.add_node(
                        lp.CacherNode(
                            trainer, refresh_interval_ms=2000, stale_after_ms=4000
                        )
                    )
                    sources.append(cacher)

        with program.group("executor"):
            # Add executors which pull round-robin from our variable sources.
            for executor_id in range(self._num_exectors):
                source = sources[executor_id % len(sources)]
                program.add_node(
                    lp.CourierNode(
                        self.executor,
                        executor_id,
                        replay,
                        source,
                        counter,
                        trainer,
                    )
                )

        return program
예제 #5
0
  def build(self, name='agent', program = None):
    """Build the distributed agent topology."""
    if not program:
      program = lp.Program(name=name)

    key = jax.random.PRNGKey(self._seed)

    replay_node = lp.ReverbNode(self.replay)
    with program.group('replay'):
      if self._multithreading_colocate_learner_and_reverb:
        replay = replay_node.create_handle()
      else:
        replay = program.add_node(replay_node)

    with program.group('counter'):
      counter = program.add_node(lp.CourierNode(self.counter))
      if self._max_number_of_steps is not None:
        _ = program.add_node(
            lp.CourierNode(self.coordinator, counter,
                           self._max_number_of_steps))

    learner_key, key = jax.random.split(key)
    learner_node = lp.CourierNode(self.learner, learner_key, replay, counter)
    with program.group('learner'):
      if self._multithreading_colocate_learner_and_reverb:
        learner = learner_node.create_handle()
        program.add_node(
            lp.MultiThreadingColocation([learner_node, replay_node]))
      else:
        learner = program.add_node(learner_node)

    def make_actor(random_key,
                   policy_network,
                   variable_source):
      return self._builder.make_actor(
          random_key, policy_network, variable_source=variable_source)

    with program.group('evaluator'):
      for evaluator in self._evaluator_factories:
        evaluator_key, key = jax.random.split(key)
        program.add_node(
            lp.CourierNode(evaluator, evaluator_key, learner, counter,
                           make_actor))

    with program.group('actor'):
      for actor_id in range(self._num_actors):
        actor_key, key = jax.random.split(key)
        program.add_node(
            lp.CourierNode(self.actor, actor_key, replay, learner, counter,
                           actor_id))

    return program
예제 #6
0
    def build(self, name='dqn'):
        """Build the distributed agent topology."""
        program = lp.Program(name=name)

        with program.group('replay'):
            replay = program.add_node(lp.ReverbNode(self.replay))

        with program.group('counter'):
            counter = program.add_node(lp.CourierNode(self.counter))

            if self._max_actor_steps:
                program.add_node(
                    lp.CourierNode(self.coordinator, counter,
                                   self._max_actor_steps))

        with program.group('learner'):
            learner = program.add_node(
                lp.CourierNode(self.learner, replay, counter))

        with program.group('evaluator'):
            program.add_node(lp.CourierNode(self.evaluator, learner, counter))

        # Generate an epsilon for each actor.
        epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4),
                           axis=0)

        with program.group('cacher'):
            # Create a set of learner caches.
            sources = []
            for _ in range(self._num_caches):
                cacher = program.add_node(
                    lp.CacherNode(learner,
                                  refresh_interval_ms=2000,
                                  stale_after_ms=4000))
                sources.append(cacher)

        with program.group('actor'):
            # Add actors which pull round-robin from our variable sources.
            for actor_id, epsilon in enumerate(epsilons):
                source = sources[actor_id % len(sources)]
                program.add_node(
                    lp.CourierNode(self.actor, replay, source, counter,
                                   epsilon))

        return program
예제 #7
0
    def build(self, name='dmpo'):
        """Build the distributed agent topology."""
        program = lp.Program(name=name)

        with program.group('replay'):
            replay = program.add_node(lp.ReverbNode(self.replay))

        with program.group('counter'):
            counter = program.add_node(lp.CourierNode(self.counter))

            if self._max_actor_steps:
                _ = program.add_node(
                    lp.CourierNode(self.coordinator, counter,
                                   self._max_actor_steps))

        with program.group('learner'):
            learner = program.add_node(
                lp.CourierNode(self.learner, replay, counter))

        with program.group('evaluator'):
            program.add_node(lp.CourierNode(self.evaluator, learner, counter))

        if not self._num_caches:
            # Use our learner as a single variable source.
            sources = [learner]
        else:
            with program.group('cacher'):
                # Create a set of learner caches.
                sources = []
                for _ in range(self._num_caches):
                    cacher = program.add_node(
                        lp.CacherNode(learner,
                                      refresh_interval_ms=2000,
                                      stale_after_ms=4000))
                    sources.append(cacher)

        with program.group('actor'):
            # Add actors which pull round-robin from our variable sources.
            for actor_id in range(self._num_actors):
                source = sources[actor_id % len(sources)]
                program.add_node(
                    lp.CourierNode(self.actor, replay, source, counter,
                                   actor_id))

        return program
예제 #8
0
def make_distributed_experiment(
        experiment: config.ExperimentConfig,
        num_actors: int,
        *,
        num_learner_nodes: int = 1,
        num_actors_per_node: int = 1,
        multithreading_colocate_learner_and_reverb: bool = False,
        checkpointing_config: Optional[config.CheckpointingConfig] = None,
        make_snapshot_models: Optional[SnapshotModelFactory] = None,
        name='agent',
        program: Optional[lp.Program] = None):
    """Builds distributed agent based on a builder."""

    if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1:
        raise ValueError(
            'Replay and learner colocation is not yet supported when the learner is'
            ' spread across multiple nodes (num_learner_nodes > 1). Please contact'
            ' Acme devs if this is a feature you want. Got:'
            '\tmultithreading_colocate_learner_and_reverb='
            f'{multithreading_colocate_learner_and_reverb}'
            f'\tnum_learner_nodes={num_learner_nodes}.')

    if checkpointing_config is None:
        checkpointing_config = config.CheckpointingConfig()

    def build_replay():
        """The replay storage."""
        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))
        network = experiment.network_factory(spec)
        policy = config.make_policy(experiment=experiment,
                                    networks=network,
                                    environment_spec=spec,
                                    evaluation=False)
        return experiment.builder.make_replay_tables(spec, policy)

    def build_model_saver(variable_source: core.VariableSource):
        environment = experiment.environment_factory(0)
        spec = specs.make_environment_spec(environment)
        networks = experiment.network_factory(spec)
        models = make_snapshot_models(networks, spec)
        # TODO(raveman): Decouple checkpointing and snapshotting configs.
        return snapshotter.JAXSnapshotter(variable_source=variable_source,
                                          models=models,
                                          path=checkpointing_config.directory,
                                          subdirectory='snapshots',
                                          add_uid=checkpointing_config.add_uid)

    def build_counter():
        return savers.CheckpointingRunner(
            counting.Counter(),
            key='counter',
            subdirectory='counter',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

    def build_learner(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: Optional[counting.Counter] = None,
        primary_learner: Optional[core.Learner] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        iterator = experiment.builder.make_dataset_iterator(replay)
        # make_dataset_iterator is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(random_key, networks,
                                                  iterator,
                                                  experiment.logger_factory,
                                                  spec, replay, counter)

        if primary_learner is None:
            learner = savers.CheckpointingRunner(
                learner,
                key='learner',
                subdirectory='learner',
                time_delta_minutes=5,
                directory=checkpointing_config.directory,
                add_uid=checkpointing_config.add_uid,
                max_to_keep=checkpointing_config.max_to_keep)
        else:
            learner.restore(primary_learner.save())
            # NOTE: This initially synchronizes secondary learner states with the
            # primary one. Further synchronization should be handled by the learner
            # properly doing a pmap/pmean on the loss/gradients, respectively.

        return learner

    def build_actor(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        actor_id: ActorId,
    ) -> environment_loop.EnvironmentLoop:
        """The actor process."""
        environment_key, actor_key = jax.random.split(random_key)
        # Create environment and policy core.

        # Environments normally require uint32 as a seed.
        environment = experiment.environment_factory(
            utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)

        networks = experiment.network_factory(environment_spec)
        policy_network = config.make_policy(experiment=experiment,
                                            networks=networks,
                                            environment_spec=environment_spec,
                                            evaluation=False)
        adder = experiment.builder.make_adder(replay, environment_spec,
                                              policy_network)
        actor = experiment.builder.make_actor(actor_key, policy_network,
                                              environment_spec,
                                              variable_source, adder)

        # Create logger and counter.
        counter = counting.Counter(counter, 'actor')
        logger = experiment.logger_factory('actor', counter.get_steps_key(),
                                           actor_id)
        # Create the loop to connect environment and agent.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=experiment.observers)

    if not program:
        program = lp.Program(name=name)

    key = jax.random.PRNGKey(experiment.seed)

    replay_node = lp.ReverbNode(
        build_replay,
        checkpoint_time_delta_minutes=(
            checkpointing_config.replay_checkpointing_time_delta_minutes))
    replay = replay_node.create_handle()

    counter = program.add_node(lp.CourierNode(build_counter), label='counter')

    if experiment.max_num_actor_steps is not None:
        program.add_node(lp.CourierNode(lp_utils.StepsLimiter, counter,
                                        experiment.max_num_actor_steps),
                         label='counter')

    learner_key, key = jax.random.split(key)
    learner_node = lp.CourierNode(build_learner, learner_key, replay, counter)
    learner = learner_node.create_handle()
    variable_sources = [learner]

    if multithreading_colocate_learner_and_reverb:
        program.add_node(lp.MultiThreadingColocation(
            [learner_node, replay_node]),
                         label='learner')
    else:
        program.add_node(replay_node, label='replay')

        with program.group('learner'):
            program.add_node(learner_node)

            # Maybe create secondary learners, necessary when using multi-host
            # accelerators.
            # Warning! If you set num_learner_nodes > 1, make sure the learner class
            # does the appropriate pmap/pmean operations on the loss/gradients,
            # respectively.
            for _ in range(1, num_learner_nodes):
                learner_key, key = jax.random.split(key)
                variable_sources.append(
                    program.add_node(
                        lp.CourierNode(build_learner,
                                       learner_key,
                                       replay,
                                       primary_learner=learner)))
                # NOTE: Secondary learners are used to load-balance get_variables calls,
                # which is why they get added to the list of available variable sources.
                # NOTE: Only the primary learner checkpoints.
                # NOTE: Do not pass the counter to the secondary learners to avoid
                # double counting of learner steps.

    with program.group('actor'):
        # Create all actor threads.
        *actor_keys, key = jax.random.split(key, num_actors + 1)
        variable_sources = itertools.cycle(variable_sources)
        actor_nodes = [
            lp.CourierNode(build_actor, akey, replay, vsource, counter, aid)
            for aid, (akey,
                      vsource) in enumerate(zip(actor_keys, variable_sources))
        ]

        # Create (maybe colocated) actor nodes.
        if num_actors_per_node == 1:
            for actor_node in actor_nodes:
                program.add_node(actor_node)
        else:
            for i in range(0, num_actors, num_actors_per_node):
                program.add_node(
                    lp.MultiThreadingColocation(
                        actor_nodes[i:i + num_actors_per_node]))

    for evaluator in experiment.get_evaluator_factories():
        evaluator_key, key = jax.random.split(key)
        program.add_node(lp.CourierNode(evaluator, evaluator_key, learner,
                                        counter,
                                        experiment.builder.make_actor),
                         label='evaluator')

    if make_snapshot_models and checkpointing_config:
        program.add_node(lp.CourierNode(build_model_saver, learner),
                         label='model_saver')

    return program