示例#1
0
    def test_add_remove_actors(self):
        """Tests that the async manager can properly add and remove actors"""

        workers = []
        manager = AsyncRequestsManager(
            workers, max_remote_requests_in_flight_per_worker=2
        )
        if not (
            (
                len(manager._all_workers)
                == len(manager._remote_requests_in_flight)
                == len(manager._pending_to_actor)
                == len(manager._pending_remotes)
                == 0
            )
        ):
            raise ValueError("We should have no workers in this case.")

        assert not manager.call(lambda w: w.task()), (
            "Task shouldn't have been "
            "launched since there are no "
            "workers in the manager."
        )
        worker = RemoteRLlibActor.remote(sleep_time=0.1)
        manager.add_workers(worker)
        manager.call(lambda w: w.task())
        if not (
            len(manager._remote_requests_in_flight[worker])
            == len(manager._pending_to_actor)
            == len(manager._all_workers)
            == len(manager._pending_remotes)
            == 1
        ):
            raise ValueError("We should have 1 worker and 1 pending request")
        time.sleep(3)
        manager.get_ready()
        # test worker removal
        for i in range(2):
            manager.call(lambda w: w.task())
            assert len(manager._pending_remotes) == i + 1
        manager.remove_workers(worker)
        if not ((len(manager._all_workers) == 0)):
            raise ValueError("We should have no workers that we can schedule tasks to")
        if not (
            (len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2)
        ):
            raise ValueError(
                "We should still have 2 pending requests in flight from the worker"
            )
        time.sleep(3)
        result = manager.get_ready()
        if not (
            len(result) == 1
            and len(result[worker]) == 2
            and len(manager._pending_remotes) == 0
            and len(manager._pending_to_actor) == 0
        ):
            raise ValueError(
                "We should have 2 ready results from the worker and no pending requests"
            )
示例#2
0
class Impala(Algorithm):
    """Importance weighted actor/learner architecture (IMPALA) Algorithm

    == Overview of data flow in IMPALA ==
    1. Policy evaluation in parallel across `num_workers` actors produces
       batches of size `rollout_fragment_length * num_envs_per_worker`.
    2. If enabled, the replay buffer stores and produces batches of size
       `rollout_fragment_length * num_envs_per_worker`.
    3. If enabled, the minibatch ring buffer stores and replays batches of
       size `train_batch_size` up to `num_sgd_iter` times per batch.
    4. The learner thread executes data parallel SGD across `num_gpus` GPUs
       on batches of size `train_batch_size`.
    """
    @classmethod
    @override(Algorithm)
    def get_default_config(cls) -> AlgorithmConfigDict:
        return ImpalaConfig().to_dict()

    @override(Algorithm)
    def get_default_policy_class(
            self,
            config: PartialAlgorithmConfigDict) -> Optional[Type[Policy]]:
        if config["framework"] == "torch":
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_torch_policy import (
                    ImpalaTorchPolicy, )

                return ImpalaTorchPolicy
            else:
                from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy

                return A3CTorchPolicy
        elif config["framework"] == "tf":
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy

                return ImpalaTF1Policy
            else:
                from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy

                return A3CTFPolicy
        else:
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF2Policy

                return ImpalaTF2Policy
            else:
                from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy

                return A3CTFPolicy

    @override(Algorithm)
    def validate_config(self, config):
        # Call the super class' validation method first.
        super().validate_config(config)

        # Check the IMPALA specific config.

        if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
            deprecation_warning("num_data_loader_buffers",
                                "num_multi_gpu_tower_stacks",
                                error=False)
            config["num_multi_gpu_tower_stacks"] = config[
                "num_data_loader_buffers"]

        if config["entropy_coeff"] < 0.0:
            raise ValueError("`entropy_coeff` must be >= 0.0!")

        # Check whether worker to aggregation-worker ratio makes sense.
        if config["num_aggregation_workers"] > config["num_workers"]:
            raise ValueError(
                "`num_aggregation_workers` must be smaller than or equal "
                "`num_workers`! Aggregation makes no sense otherwise.")
        elif config["num_aggregation_workers"] > config["num_workers"] / 2:
            logger.warning(
                "`num_aggregation_workers` should be significantly smaller "
                "than `num_workers`! Try setting it to 0.5*`num_workers` or "
                "less.")

        # If two separate optimizers/loss terms used for tf, must also set
        # `_tf_policy_handles_more_than_one_loss` to True.
        if config["_separate_vf_optimizer"] is True:
            # Only supported to tf so far.
            # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
            #  models to return separate sets of weights in order to create
            #  the different torch optimizers).
            if config["framework"] not in ["tf", "tf2", "tfe"]:
                raise ValueError(
                    "`_separate_vf_optimizer` only supported to tf so far!")
            if config["_tf_policy_handles_more_than_one_loss"] is False:
                logger.warning(
                    "`_tf_policy_handles_more_than_one_loss` must be set to "
                    "True, for TFPolicy to support more than one loss "
                    "term/optimizer! Auto-setting it to True.")
                config["_tf_policy_handles_more_than_one_loss"] = True

    @override(Algorithm)
    def setup(self, config: PartialAlgorithmConfigDict):
        super().setup(config)

        if self.config["_disable_execution_plan_api"]:
            # Create extra aggregation workers and assign each rollout worker to
            # one of them.
            self.batches_to_place_on_learner = []
            self.batch_being_built = []
            if self.config["num_aggregation_workers"] > 0:
                # This spawns `num_aggregation_workers` actors that aggregate
                # experiences coming from RolloutWorkers in parallel. We force
                # colocation on the same node (localhost) to maximize data bandwidth
                # between them and the learner.
                localhost = platform.node()
                assert localhost != "", (
                    "ERROR: Cannot determine local node name! "
                    "`platform.node()` returned empty string.")
                all_co_located = create_colocated_actors(
                    actor_specs=[
                        # (class, args, kwargs={}, count=1)
                        (
                            AggregatorWorker,
                            [
                                self.config,
                            ],
                            {},
                            self.config["num_aggregation_workers"],
                        )
                    ],
                    node=localhost,
                )
                self._aggregator_workers = [
                    actor for actor_groups in all_co_located
                    for actor in actor_groups
                ]
                self._aggregator_actor_manager = AsyncRequestsManager(
                    self._aggregator_workers,
                    max_remote_requests_in_flight_per_worker=self.
                    config["max_requests_in_flight_per_aggregator_worker"],
                    ray_wait_timeout_s=self.
                    config["timeout_s_aggregator_manager"],
                )

            else:
                # Create our local mixin buffer if the num of aggregation workers is 0.
                self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
                    capacity=(self.config["replay_buffer_num_slots"]
                              if self.config["replay_buffer_num_slots"] > 0
                              else 1),
                    replay_ratio=self.config["replay_ratio"],
                    replay_mode=ReplayMode.LOCKSTEP,
                )

            self._sampling_actor_manager = AsyncRequestsManager(
                self.workers.remote_workers(),
                max_remote_requests_in_flight_per_worker=self.
                config["max_requests_in_flight_per_sampler_worker"],
                return_object_refs=True,
                ray_wait_timeout_s=self.config["timeout_s_sampler_manager"],
            )

            # Create and start the learner thread.
            self._learner_thread = make_learner_thread(
                self.workers.local_worker(), self.config)
            self._learner_thread.start()
            self.workers_that_need_updates = set()

    @override(Algorithm)
    def training_step(self) -> ResultDict:
        unprocessed_sample_batches = self.get_samples_from_workers()

        self.workers_that_need_updates |= unprocessed_sample_batches.keys()

        if self.config["num_aggregation_workers"] > 0:
            batch = self.process_experiences_tree_aggregation(
                unprocessed_sample_batches)
        else:
            batch = self.process_experiences_directly(
                unprocessed_sample_batches)

        self.concatenate_batches_and_pre_queue(batch)
        self.place_processed_samples_on_learner_queue()
        train_results = self.process_trained_results()

        self.update_workers_if_necessary()

        return train_results

    @staticmethod
    @override(Algorithm)
    def execution_plan(workers, config, **kwargs):
        assert (
            len(kwargs) == 0
        ), "IMPALA execution_plan does NOT take any additional parameters"

        if config["num_aggregation_workers"] > 0:
            train_batches = gather_experiences_tree_aggregation(
                workers, config)
        else:
            train_batches = gather_experiences_directly(workers, config)

        # Start the learner thread.
        learner_thread = make_learner_thread(workers.local_worker(), config)
        learner_thread.start()

        # This sub-flow sends experiences to the learner.
        enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue))
        # Only need to update workers if there are remote workers.
        if workers.remote_workers():
            enqueue_op = enqueue_op.zip_with_source_actor().for_each(
                BroadcastUpdateLearnerWeights(
                    learner_thread,
                    workers,
                    broadcast_interval=config["broadcast_interval"],
                ))

        def record_steps_trained(item):
            count, fetches, _ = item
            metrics = _get_shared_metrics()
            # Manually update the steps trained counter since the learner
            # thread is executing outside the pipeline.
            metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
            metrics.counters[STEPS_TRAINED_COUNTER] += count
            return item

        # This sub-flow updates the steps trained counter based on learner
        # output.
        dequeue_op = Dequeue(
            learner_thread.outqueue,
            check=learner_thread.is_alive).for_each(record_steps_trained)

        merged_op = Concurrently([enqueue_op, dequeue_op],
                                 mode="async",
                                 output_indexes=[1])

        # Callback for APPO to use to update KL, target network periodically.
        # The input to the callback is the learner fetches dict.
        if config["after_train_step"]:
            merged_op = merged_op.for_each(lambda t: t[1]).for_each(
                config["after_train_step"](workers, config))

        return StandardMetricsReporting(merged_op, workers, config).for_each(
            learner_thread.add_learner_metrics)

    @classmethod
    @override(Algorithm)
    def default_resource_request(cls, config):
        cf = dict(cls.get_default_config(), **config)

        eval_config = cf["evaluation_config"]

        # Return PlacementGroupFactory containing all needed resources
        # (already properly defined as device bundles).
        return PlacementGroupFactory(
            bundles=[{
                # Driver + Aggregation Workers:
                # Force to be on same node to maximize data bandwidth
                # between aggregation workers and the learner (driver).
                # Aggregation workers tree-aggregate experiences collected
                # from RolloutWorkers (n rollout workers map to m
                # aggregation workers, where m < n) and always use 1 CPU
                # each.
                "CPU":
                cf["num_cpus_for_driver"] + cf["num_aggregation_workers"],
                "GPU":
                0 if cf["_fake_gpus"] else cf["num_gpus"],
            }] + [
                {
                    # RolloutWorkers.
                    "CPU": cf["num_cpus_per_worker"],
                    "GPU": cf["num_gpus_per_worker"],
                    **cf["custom_resources_per_worker"],
                } for _ in range(cf["num_workers"])
            ] + ([
                {
                    # Evaluation (remote) workers.
                    # Note: The local eval worker is located on the driver
                    # CPU or not even created iff >0 eval workers.
                    "CPU":
                    eval_config.get("num_cpus_per_worker",
                                    cf["num_cpus_per_worker"]),
                    "GPU":
                    eval_config.get("num_gpus_per_worker",
                                    cf["num_gpus_per_worker"]),
                    **eval_config.get(
                        "custom_resources_per_worker",
                        cf["custom_resources_per_worker"],
                    ),
                } for _ in range(cf["evaluation_num_workers"])
            ] if cf["evaluation_interval"] else []),
            strategy=config.get("placement_strategy", "PACK"),
        )

    def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]):
        """Concatenate batches that are being returned from rollout workers

        Args:
            batches: batches of experiences from rollout workers

        """
        def aggregate_into_larger_batch():
            if (sum(b.count for b in self.batch_being_built) >=
                    self.config["train_batch_size"]):
                batch_to_add = SampleBatch.concat_samples(
                    self.batch_being_built)
                self.batches_to_place_on_learner.append(batch_to_add)
                self.batch_being_built = []

        for batch in batches:
            self.batch_being_built.append(batch)
            aggregate_into_larger_batch()

    def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]:
        # Perform asynchronous sampling on all (remote) rollout workers.
        if self.workers.remote_workers():
            self._sampling_actor_manager.call_on_all_available(
                lambda worker: worker.sample())
            sample_batches: Dict[
                ActorHandle,
                List[ObjectRef]] = self._sampling_actor_manager.get_ready()
        else:
            # only sampling on the local worker
            sample_batches = {
                self.workers.local_worker():
                [self.workers.local_worker().sample()]
            }
        return sample_batches

    def place_processed_samples_on_learner_queue(self) -> None:
        self._counters["num_samples_added_to_queue"] = 0

        while self.batches_to_place_on_learner:
            batch = self.batches_to_place_on_learner[0]
            try:
                self._learner_thread.inqueue.put(batch, block=False)
                self.batches_to_place_on_learner.pop(0)
                self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count
                self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
                self._counters["num_samples_added_to_queue"] = batch.count
            except queue.Full:
                self._counters["num_times_learner_queue_full"] += 1

    def process_trained_results(self) -> ResultDict:
        # Get learner outputs/stats from output queue.
        final_learner_info = {}
        learner_infos = []
        num_env_steps_trained = 0
        num_agent_steps_trained = 0

        for _ in range(self._learner_thread.outqueue.qsize()):
            if self._learner_thread.is_alive():
                (
                    env_steps,
                    agent_steps,
                    learner_results,
                ) = self._learner_thread.outqueue.get(timeout=0.001)
                num_env_steps_trained += env_steps
                num_agent_steps_trained += agent_steps
                if learner_results:
                    learner_infos.append(learner_results)
            else:
                raise RuntimeError("The learner thread died in while training")
        if not learner_infos:
            final_learner_info = copy.deepcopy(
                self._learner_thread.learner_info)
        else:
            builder = LearnerInfoBuilder()
            for info in learner_infos:
                builder.add_learn_on_batch_results_multi_agent(info)
            final_learner_info = builder.finalize()

        # Update the steps trained counters.
        self._counters[
            STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
        self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
        self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained

        return final_learner_info

    def process_experiences_directly(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        processed_batches = []
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        if not batches:
            return processed_batches
        if batches and isinstance(batches[0], ray.ObjectRef):
            batches = ray.get(batches)
        for batch in batches:
            batch = batch.decompress_if_needed()
            self.local_mixin_buffer.add_batch(batch)
            batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
            if batch:
                processed_batches.append(batch)
        return processed_batches

    def process_experiences_tree_aggregation(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        ready_processed_batches = []
        for batch in batches:
            self._aggregator_actor_manager.call(
                lambda actor, b: actor.process_episodes(b),
                fn_kwargs={"b": batch})

        waiting_processed_sample_batches: Dict[
            ActorHandle,
            List[ObjectRef]] = self._aggregator_actor_manager.get_ready()
        for ready_sub_batches in waiting_processed_sample_batches.values():
            ready_processed_batches.extend(ready_sub_batches)

        return ready_processed_batches

    def update_workers_if_necessary(self) -> None:
        # Only need to update workers if there are remote workers.
        global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]}
        self._counters["steps_since_broadcast"] += 1
        if (self.workers.remote_workers()
                and self._counters["steps_since_broadcast"] >=
                self.config["broadcast_interval"]
                and self.workers_that_need_updates):
            weights = ray.put(self.workers.local_worker().get_weights())
            self._counters["steps_since_broadcast"] = 0
            self._learner_thread.weights_updated = False
            self._counters["num_weight_broadcasts"] += 1

            for worker in self.workers_that_need_updates:
                worker.set_weights.remote(weights, global_vars)
            self.workers_that_need_updates = set()

        # Update global vars of the local worker.
        self.workers.local_worker().set_global_vars(global_vars)

    @override(Algorithm)
    def on_worker_failures(self, removed_workers: List[ActorHandle],
                           new_workers: List[ActorHandle]):
        """Handle the failures of remote sampling workers

        Args:
            removed_workers: removed worker ids.
            new_workers: ids of newly created workers.
        """
        if self.config["_disable_execution_plan_api"]:
            self._sampling_actor_manager.remove_workers(
                removed_workers, remove_in_flight_requests=True)
            self._sampling_actor_manager.add_workers(new_workers)

    @override(Algorithm)
    def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
        result = super()._compile_iteration_results(
            step_ctx=step_ctx, iteration_results=iteration_results)
        result = self._learner_thread.add_learner_metrics(
            result, overwrite_learner_info=False)
        return result
示例#3
0
class ApexDQN(DQN):
    @override(Trainable)
    def setup(self, config: PartialAlgorithmConfigDict):
        super().setup(config)

        # Shortcut: If execution_plan, thread and buffer will be created in there.
        if self.config["_disable_execution_plan_api"] is False:
            return

        # Tag those workers (top 1/3rd indices) that we should collect episodes from
        # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy.
        if self.workers.remote_workers():
            self._remote_workers_for_metrics = self.workers.remote_workers(
            )[-len(self.workers.remote_workers()) // 3:]

        num_replay_buffer_shards = self.config["optimizer"][
            "num_replay_buffer_shards"]

        # Create copy here so that we can modify without breaking other logic
        replay_actor_config = copy.deepcopy(
            self.config["replay_buffer_config"])

        replay_actor_config["capacity"] = (
            self.config["replay_buffer_config"]["capacity"] //
            num_replay_buffer_shards)

        ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"])

        # Place all replay buffer shards on the same node as the learner
        # (driver process that runs this execution plan).
        if replay_actor_config["replay_buffer_shards_colocated_with_driver"]:
            self._replay_actors = create_colocated_actors(
                actor_specs=[  # (class, args, kwargs={}, count)
                    (
                        ReplayActor,
                        None,
                        replay_actor_config,
                        num_replay_buffer_shards,
                    )
                ],
                node=platform.node(),  # localhost
            )[0]  # [0]=only one item in `actor_specs`.
        # Place replay buffer shards on any node(s).
        else:
            self._replay_actors = [
                ReplayActor.remote(*replay_actor_config)
                for _ in range(num_replay_buffer_shards)
            ]
        self._replay_actor_manager = AsyncRequestsManager(
            self._replay_actors,
            max_remote_requests_in_flight_per_worker=self.
            config["max_requests_in_flight_per_replay_worker"],
            ray_wait_timeout_s=self.config["timeout_s_replay_manager"],
        )
        self._sampling_actor_manager = AsyncRequestsManager(
            self.workers.remote_workers(),
            max_remote_requests_in_flight_per_worker=self.
            config["max_requests_in_flight_per_sampler_worker"],
            ray_wait_timeout_s=self.config["timeout_s_sampler_manager"],
        )
        self.learner_thread = LearnerThread(self.workers.local_worker())
        self.learner_thread.start()
        self.steps_since_update = defaultdict(int)
        weights = self.workers.local_worker().get_weights()
        self.curr_learner_weights = ray.put(weights)
        self.curr_num_samples_collected = 0
        self.replay_sample_batches = []
        self._num_ts_trained_since_last_target_update = 0

    @classmethod
    @override(DQN)
    def get_default_config(cls) -> AlgorithmConfigDict:
        return ApexDQNConfig().to_dict()

    @override(DQN)
    def validate_config(self, config):
        if config["num_gpus"] > 1:
            raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!")
        # Call DQN's validation method.
        super().validate_config(config)

    @override(DQN)
    def training_step(self) -> ResultDict:
        num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers()
        worker_samples_collected = defaultdict(int)

        for worker, samples_infos in num_samples_ready_dict.items():
            for samples_info in samples_infos:
                self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info[
                    "agent_steps"]
                self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info[
                    "env_steps"]
                worker_samples_collected[worker] += samples_info["agent_steps"]

        # update the weights of the workers that returned samples
        # only do this if there are remote workers (config["num_workers"] > 1)
        if self.workers.remote_workers():
            self.update_workers(worker_samples_collected)
        # trigger a sample from the replay actors and enqueue operation to the
        # learner thread.
        self.sample_from_replay_buffer_place_on_learner_queue_non_blocking(
            worker_samples_collected)
        self.update_replay_sample_priority()

        return copy.deepcopy(self.learner_thread.learner_info)

    def get_samples_and_store_to_replay_buffers(self):
        # in the case the num_workers = 0
        if not self.workers.remote_workers():
            with self._timers[SAMPLE_TIMER]:
                local_sampling_worker = self.workers.local_worker()
                batch = local_sampling_worker.sample()
                actor = random.choice(self._replay_actors)
                ray.get(actor.add.remote(batch))
                batch_statistics = {
                    local_sampling_worker: [{
                        "agent_steps": batch.agent_steps(),
                        "env_steps": batch.env_steps(),
                    }]
                }
                return batch_statistics

        def remote_worker_sample_and_store(worker: RolloutWorker,
                                           replay_actors: List[ActorHandle]):
            # This function is run as a remote function on sampling workers,
            # and should only be used with the RolloutWorker's apply function ever.
            # It is used to gather samples, and trigger the operation to store them to
            # replay actors from the rollout worker instead of returning the obj
            # refs for the samples to the driver process and doing the sampling
            # operation on there.
            _batch = worker.sample()
            _actor = random.choice(replay_actors)
            _actor.add.remote(_batch)
            _batch_statistics = {
                "agent_steps": _batch.agent_steps(),
                "env_steps": _batch.env_steps(),
            }
            return _batch_statistics

        # Sample and Store in the Replay Actors on the sampling workers.
        with self._timers[SAMPLE_TIMER]:
            self._sampling_actor_manager.call_on_all_available(
                remote_worker_sample_and_store,
                fn_kwargs={"replay_actors": self._replay_actors},
            )
            num_samples_ready_dict = self._sampling_actor_manager.get_ready()
        return num_samples_ready_dict

    def update_workers(self, _num_samples_ready: Dict[ActorHandle,
                                                      int]) -> int:
        """Update the remote workers that have samples ready.

        Args:
            _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to
                the number of samples returned by the remote worker.
        Returns:
            The number of remote workers whose weights were updated.
        """
        max_steps_weight_sync_delay = self.config["optimizer"][
            "max_weight_sync_delay"]
        # Update our local copy of the weights if the learner thread has updated
        # the learner worker's weights
        if self.learner_thread.weights_updated:
            self.learner_thread.weights_updated = False
            weights = self.workers.local_worker().get_weights()
            self.curr_learner_weights = ray.put(weights)
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            for (
                    remote_sampler_worker,
                    num_samples_collected,
            ) in _num_samples_ready.items():
                self.steps_since_update[
                    remote_sampler_worker] += num_samples_collected
                if (self.steps_since_update[remote_sampler_worker] >=
                        max_steps_weight_sync_delay):
                    remote_sampler_worker.set_weights.remote(
                        self.curr_learner_weights,
                        {"timestep": self._counters[STEPS_TRAINED_COUNTER]},
                    )
                    self.steps_since_update[remote_sampler_worker] = 0
                self._counters["num_weight_syncs"] += 1

    def sample_from_replay_buffer_place_on_learner_queue_non_blocking(
            self, num_samples_collected: Dict[ActorHandle, int]) -> None:
        """Get samples from the replay buffer and place them on the learner queue.

        Args:
            num_samples_collected: A mapping from ActorHandle (RolloutWorker) to
                number of samples returned by the remote worker. This is used to
                implement training intensity which is the concept of triggering a
                certain amount of training based on the number of samples that have
                been collected since the last time that training was triggered.

        """
        def wait_on_replay_actors() -> None:
            """Wait for the replay actors to finish sampling for timeout seconds.
            If the timeout is None, then block on the actors indefinitely.
            """
            _replay_samples_ready = self._replay_actor_manager.get_ready()

            for _replay_actor, _sample_batches in _replay_samples_ready.items(
            ):
                for _sample_batch in _sample_batches:
                    self.replay_sample_batches.append(
                        (_replay_actor, _sample_batch))

        num_samples_collected = sum(num_samples_collected.values())
        self.curr_num_samples_collected += num_samples_collected
        wait_on_replay_actors()
        if self.curr_num_samples_collected >= self.config["train_batch_size"]:
            training_intensity = int(self.config["training_intensity"] or 1)
            num_requests_to_launch = (
                self.curr_num_samples_collected /
                self.config["train_batch_size"]) * training_intensity
            num_requests_to_launch = max(1, round(num_requests_to_launch))
            self.curr_num_samples_collected = 0
            for _ in range(num_requests_to_launch):
                self._replay_actor_manager.call(
                    lambda actor, num_items: actor.sample(num_items),
                    fn_args=[self.config["train_batch_size"]],
                )
            wait_on_replay_actors()

        # add the sample batches to the learner queue
        while self.replay_sample_batches:
            try:
                item = self.replay_sample_batches[0]
                # the replay buffer returns none if it has not been filled to
                # the minimum threshold yet.
                if item:
                    self.learner_thread.inqueue.put(
                        self.replay_sample_batches[0], timeout=0.001)
                    self.replay_sample_batches.pop(0)
            except queue.Full:
                break

    def update_replay_sample_priority(self) -> None:
        """Update the priorities of the sample batches with new priorities that are
        computed by the learner thread.
        """
        num_samples_trained_this_itr = 0
        for _ in range(self.learner_thread.outqueue.qsize()):
            if self.learner_thread.is_alive():
                (
                    replay_actor,
                    priority_dict,
                    env_steps,
                    agent_steps,
                ) = self.learner_thread.outqueue.get(timeout=0.001)
                if (self.config["replay_buffer_config"].get(
                        "prioritized_replay_alpha") > 0):
                    replay_actor.update_priorities.remote(priority_dict)
                num_samples_trained_this_itr += env_steps
                self.update_target_networks(env_steps)
                self._counters[NUM_ENV_STEPS_TRAINED] += env_steps
                self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps
                self.workers.local_worker().set_global_vars(
                    {"timestep": self._counters[NUM_ENV_STEPS_TRAINED]})
            else:
                raise RuntimeError("The learner thread died in while training")

        self._counters[
            STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr
        self._timers["learner_dequeue"] = self.learner_thread.queue_timer
        self._timers["learner_grad"] = self.learner_thread.grad_timer
        self._timers["learner_overall"] = self.learner_thread.overall_timer

    def update_target_networks(self, num_new_trained_samples) -> None:
        """Update the target networks."""
        self._num_ts_trained_since_last_target_update += num_new_trained_samples
        if (self._num_ts_trained_since_last_target_update >=
                self.config["target_network_update_freq"]):
            self._num_ts_trained_since_last_target_update = 0
            with self._timers[TARGET_NET_UPDATE_TIMER]:
                to_update = self.workers.local_worker().get_policies_to_train()
                self.workers.local_worker().foreach_policy_to_train(
                    lambda p, pid: pid in to_update and p.update_target())
            self._counters[NUM_TARGET_UPDATES] += 1
            self._counters[LAST_TARGET_UPDATE_TS] = self._counters[
                STEPS_TRAINED_COUNTER]

    @override(Algorithm)
    def on_worker_failures(self, removed_workers: List[ActorHandle],
                           new_workers: List[ActorHandle]):
        """Handle the failures of remote sampling workers

        Args:
            removed_workers: removed worker ids.
            new_workers: ids of newly created workers.
        """
        if self.config["_disable_execution_plan_api"]:
            self._sampling_actor_manager.remove_workers(
                removed_workers, remove_in_flight_requests=True)
            self._sampling_actor_manager.add_workers(new_workers)

    @override(Algorithm)
    def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
        result = super()._compile_iteration_results(
            step_ctx=step_ctx, iteration_results=iteration_results)
        replay_stats = ray.get(self._replay_actors[0].stats.remote(
            self.config["optimizer"].get("debug")))
        exploration_infos_list = self.workers.foreach_policy_to_train(
            lambda p, pid: {pid: p.get_exploration_state()})
        exploration_infos = {}
        for info in exploration_infos_list:
            # we're guaranteed that each info has policy ids that are unique
            exploration_infos.update(info)
        other_results = {
            "exploration_infos": exploration_infos,
            "learner_queue": self.learner_thread.learner_queue_size.stats(),
            "replay_shard_0": replay_stats,
        }

        result["info"].update(other_results)
        return result

    @classmethod
    @override(Algorithm)
    def default_resource_request(cls, config):
        cf = dict(cls.get_default_config(), **config)

        eval_config = cf["evaluation_config"]

        # Return PlacementGroupFactory containing all needed resources
        # (already properly defined as device bundles).
        return PlacementGroupFactory(
            bundles=[{
                # Local worker + replay buffer actors.
                # Force replay buffers to be on same node to maximize
                # data bandwidth between buffers and the learner (driver).
                # Replay buffer actors each contain one shard of the total
                # replay buffer and use 1 CPU each.
                "CPU":
                cf["num_cpus_for_driver"] +
                cf["optimizer"]["num_replay_buffer_shards"],
                "GPU":
                0 if cf["_fake_gpus"] else cf["num_gpus"],
            }] + [
                {
                    # RolloutWorkers.
                    "CPU": cf["num_cpus_per_worker"],
                    "GPU": cf["num_gpus_per_worker"],
                    **cf["custom_resources_per_worker"],
                } for _ in range(cf["num_workers"])
            ] + ([
                {
                    # Evaluation workers.
                    # Note: The local eval worker is located on the driver
                    # CPU.
                    "CPU":
                    eval_config.get("num_cpus_per_worker",
                                    cf["num_cpus_per_worker"]),
                    "GPU":
                    eval_config.get("num_gpus_per_worker",
                                    cf["num_gpus_per_worker"]),
                    **eval_config.get(
                        "custom_resources_per_worker",
                        cf["custom_resources_per_worker"],
                    ),
                } for _ in range(cf["evaluation_num_workers"])
            ] if cf["evaluation_interval"] else []),
            strategy=config.get("placement_strategy", "PACK"),
        )
示例#4
0
class A3C(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return A3CConfig().to_dict()

    @override(Trainer)
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)
        self._worker_manager = AsyncRequestsManager(
            self.workers.remote_workers(),
            max_remote_requests_in_flight_per_worker=1)

    @override(Trainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        # Call super's validation method.
        super().validate_config(config)

        if config["entropy_coeff"] < 0:
            raise ValueError("`entropy_coeff` must be >= 0.0!")
        if config["num_workers"] <= 0 and config["sample_async"]:
            raise ValueError("`num_workers` for A3C must be >= 1!")

    @override(Trainer)
    def get_default_policy_class(self,
                                 config: TrainerConfigDict) -> Type[Policy]:
        if config["framework"] == "torch":
            from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy

            return A3CTorchPolicy
        elif config["framework"] == "tf":
            from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CStaticGraphTFPolicy

            return A3CStaticGraphTFPolicy
        else:
            from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CEagerTFPolicy

            return A3CEagerTFPolicy

    def training_step(self) -> ResultDict:
        # Shortcut.
        local_worker = self.workers.local_worker()

        # Define the function executed in parallel by all RolloutWorkers to collect
        # samples + compute and return gradients (and other information).

        def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
            """Call sample() and compute_gradients() remotely on workers."""
            samples = worker.sample()
            grads, infos = worker.compute_gradients(samples)
            return {
                "grads": grads,
                "infos": infos,
                "agent_steps": samples.agent_steps(),
                "env_steps": samples.env_steps(),
            }

        # Perform rollouts and gradient calculations asynchronously.
        with self._timers[GRAD_WAIT_TIMER]:
            # Results are a mapping from ActorHandle (RolloutWorker) to their
            # returned gradient calculation results.
            self._worker_manager.call_on_all_available(
                sample_and_compute_grads)
            async_results = self._worker_manager.get_ready()

        # Loop through all fetched worker-computed gradients (if any)
        # and apply them - one by one - to the local worker's model.
        # After each apply step (one step per worker that returned some gradients),
        # update that particular worker's weights.
        global_vars = None
        learner_info_builder = LearnerInfoBuilder(num_devices=1)
        for worker, results in async_results.items():
            for result in results:
                # Apply gradients to local worker.
                with self._timers[APPLY_GRADS_TIMER]:
                    local_worker.apply_gradients(result["grads"])
                self._timers[APPLY_GRADS_TIMER].push_units_processed(
                    result["agent_steps"])

                # Update all step counters.
                self._counters[NUM_AGENT_STEPS_SAMPLED] += result[
                    "agent_steps"]
                self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
                self._counters[NUM_AGENT_STEPS_TRAINED] += result[
                    "agent_steps"]
                self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]

                learner_info_builder.add_learn_on_batch_results_multi_agent(
                    result["infos"])

            # Create current global vars.
            global_vars = {
                "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
            }

            # Synch updated weights back to the particular worker.
            with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
                weights = local_worker.get_weights(
                    local_worker.get_policies_to_train())
                worker.set_weights.remote(weights, global_vars)

        # Update global vars of the local worker.
        if global_vars:
            local_worker.set_global_vars(global_vars)

        return learner_info_builder.finalize()

    @override(Trainer)
    def on_worker_failures(self, removed_workers: List[ActorHandle],
                           new_workers: List[ActorHandle]):
        """Handle failures on remote A3C workers.

        Args:
            removed_workers: removed worker ids.
            new_workers: ids of newly created workers.
        """
        self._worker_manager.remove_workers(removed_workers)
        self._worker_manager.add_workers(new_workers)