def test_add_remove_actors(self): """Tests that the async manager can properly add and remove actors""" workers = [] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) if not ( ( len(manager._all_workers) == len(manager._remote_requests_in_flight) == len(manager._pending_to_actor) == len(manager._pending_remotes) == 0 ) ): raise ValueError("We should have no workers in this case.") assert not manager.call(lambda w: w.task()), ( "Task shouldn't have been " "launched since there are no " "workers in the manager." ) worker = RemoteRLlibActor.remote(sleep_time=0.1) manager.add_workers(worker) manager.call(lambda w: w.task()) if not ( len(manager._remote_requests_in_flight[worker]) == len(manager._pending_to_actor) == len(manager._all_workers) == len(manager._pending_remotes) == 1 ): raise ValueError("We should have 1 worker and 1 pending request") time.sleep(3) manager.get_ready() # test worker removal for i in range(2): manager.call(lambda w: w.task()) assert len(manager._pending_remotes) == i + 1 manager.remove_workers(worker) if not ((len(manager._all_workers) == 0)): raise ValueError("We should have no workers that we can schedule tasks to") if not ( (len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2) ): raise ValueError( "We should still have 2 pending requests in flight from the worker" ) time.sleep(3) result = manager.get_ready() if not ( len(result) == 1 and len(result[worker]) == 2 and len(manager._pending_remotes) == 0 and len(manager._pending_to_actor) == 0 ): raise ValueError( "We should have 2 ready results from the worker and no pending requests" )
class Impala(Algorithm): """Importance weighted actor/learner architecture (IMPALA) Algorithm == Overview of data flow in IMPALA == 1. Policy evaluation in parallel across `num_workers` actors produces batches of size `rollout_fragment_length * num_envs_per_worker`. 2. If enabled, the replay buffer stores and produces batches of size `rollout_fragment_length * num_envs_per_worker`. 3. If enabled, the minibatch ring buffer stores and replays batches of size `train_batch_size` up to `num_sgd_iter` times per batch. 4. The learner thread executes data parallel SGD across `num_gpus` GPUs on batches of size `train_batch_size`. """ @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfigDict: return ImpalaConfig().to_dict() @override(Algorithm) def get_default_policy_class( self, config: PartialAlgorithmConfigDict) -> Optional[Type[Policy]]: if config["framework"] == "torch": if config["vtrace"]: from ray.rllib.algorithms.impala.impala_torch_policy import ( ImpalaTorchPolicy, ) return ImpalaTorchPolicy else: from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy return A3CTorchPolicy elif config["framework"] == "tf": if config["vtrace"]: from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy return ImpalaTF1Policy else: from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy return A3CTFPolicy else: if config["vtrace"]: from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF2Policy return ImpalaTF2Policy else: from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy return A3CTFPolicy @override(Algorithm) def validate_config(self, config): # Call the super class' validation method first. super().validate_config(config) # Check the IMPALA specific config. if config["num_data_loader_buffers"] != DEPRECATED_VALUE: deprecation_warning("num_data_loader_buffers", "num_multi_gpu_tower_stacks", error=False) config["num_multi_gpu_tower_stacks"] = config[ "num_data_loader_buffers"] if config["entropy_coeff"] < 0.0: raise ValueError("`entropy_coeff` must be >= 0.0!") # Check whether worker to aggregation-worker ratio makes sense. if config["num_aggregation_workers"] > config["num_workers"]: raise ValueError( "`num_aggregation_workers` must be smaller than or equal " "`num_workers`! Aggregation makes no sense otherwise.") elif config["num_aggregation_workers"] > config["num_workers"] / 2: logger.warning( "`num_aggregation_workers` should be significantly smaller " "than `num_workers`! Try setting it to 0.5*`num_workers` or " "less.") # If two separate optimizers/loss terms used for tf, must also set # `_tf_policy_handles_more_than_one_loss` to True. if config["_separate_vf_optimizer"] is True: # Only supported to tf so far. # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the # models to return separate sets of weights in order to create # the different torch optimizers). if config["framework"] not in ["tf", "tf2", "tfe"]: raise ValueError( "`_separate_vf_optimizer` only supported to tf so far!") if config["_tf_policy_handles_more_than_one_loss"] is False: logger.warning( "`_tf_policy_handles_more_than_one_loss` must be set to " "True, for TFPolicy to support more than one loss " "term/optimizer! Auto-setting it to True.") config["_tf_policy_handles_more_than_one_loss"] = True @override(Algorithm) def setup(self, config: PartialAlgorithmConfigDict): super().setup(config) if self.config["_disable_execution_plan_api"]: # Create extra aggregation workers and assign each rollout worker to # one of them. self.batches_to_place_on_learner = [] self.batch_being_built = [] if self.config["num_aggregation_workers"] > 0: # This spawns `num_aggregation_workers` actors that aggregate # experiences coming from RolloutWorkers in parallel. We force # colocation on the same node (localhost) to maximize data bandwidth # between them and the learner. localhost = platform.node() assert localhost != "", ( "ERROR: Cannot determine local node name! " "`platform.node()` returned empty string.") all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) ( AggregatorWorker, [ self.config, ], {}, self.config["num_aggregation_workers"], ) ], node=localhost, ) self._aggregator_workers = [ actor for actor_groups in all_co_located for actor in actor_groups ] self._aggregator_actor_manager = AsyncRequestsManager( self._aggregator_workers, max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_aggregator_worker"], ray_wait_timeout_s=self. config["timeout_s_aggregator_manager"], ) else: # Create our local mixin buffer if the num of aggregation workers is 0. self.local_mixin_buffer = MixInMultiAgentReplayBuffer( capacity=(self.config["replay_buffer_num_slots"] if self.config["replay_buffer_num_slots"] > 0 else 1), replay_ratio=self.config["replay_ratio"], replay_mode=ReplayMode.LOCKSTEP, ) self._sampling_actor_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_sampler_worker"], return_object_refs=True, ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], ) # Create and start the learner thread. self._learner_thread = make_learner_thread( self.workers.local_worker(), self.config) self._learner_thread.start() self.workers_that_need_updates = set() @override(Algorithm) def training_step(self) -> ResultDict: unprocessed_sample_batches = self.get_samples_from_workers() self.workers_that_need_updates |= unprocessed_sample_batches.keys() if self.config["num_aggregation_workers"] > 0: batch = self.process_experiences_tree_aggregation( unprocessed_sample_batches) else: batch = self.process_experiences_directly( unprocessed_sample_batches) self.concatenate_batches_and_pre_queue(batch) self.place_processed_samples_on_learner_queue() train_results = self.process_trained_results() self.update_workers_if_necessary() return train_results @staticmethod @override(Algorithm) def execution_plan(workers, config, **kwargs): assert ( len(kwargs) == 0 ), "IMPALA execution_plan does NOT take any additional parameters" if config["num_aggregation_workers"] > 0: train_batches = gather_experiences_tree_aggregation( workers, config) else: train_batches = gather_experiences_directly(workers, config) # Start the learner thread. learner_thread = make_learner_thread(workers.local_worker(), config) learner_thread.start() # This sub-flow sends experiences to the learner. enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue)) # Only need to update workers if there are remote workers. if workers.remote_workers(): enqueue_op = enqueue_op.zip_with_source_actor().for_each( BroadcastUpdateLearnerWeights( learner_thread, workers, broadcast_interval=config["broadcast_interval"], )) def record_steps_trained(item): count, fetches, _ = item metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count metrics.counters[STEPS_TRAINED_COUNTER] += count return item # This sub-flow updates the steps trained counter based on learner # output. dequeue_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive).for_each(record_steps_trained) merged_op = Concurrently([enqueue_op, dequeue_op], mode="async", output_indexes=[1]) # Callback for APPO to use to update KL, target network periodically. # The input to the callback is the learner fetches dict. if config["after_train_step"]: merged_op = merged_op.for_each(lambda t: t[1]).for_each( config["after_train_step"](workers, config)) return StandardMetricsReporting(merged_op, workers, config).for_each( learner_thread.add_learner_metrics) @classmethod @override(Algorithm) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Driver + Aggregation Workers: # Force to be on same node to maximize data bandwidth # between aggregation workers and the learner (driver). # Aggregation workers tree-aggregate experiences collected # from RolloutWorkers (n rollout workers map to m # aggregation workers, where m < n) and always use 1 CPU # each. "CPU": cf["num_cpus_for_driver"] + cf["num_aggregation_workers"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], **cf["custom_resources_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), **eval_config.get( "custom_resources_per_worker", cf["custom_resources_per_worker"], ), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), ) def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]): """Concatenate batches that are being returned from rollout workers Args: batches: batches of experiences from rollout workers """ def aggregate_into_larger_batch(): if (sum(b.count for b in self.batch_being_built) >= self.config["train_batch_size"]): batch_to_add = SampleBatch.concat_samples( self.batch_being_built) self.batches_to_place_on_learner.append(batch_to_add) self.batch_being_built = [] for batch in batches: self.batch_being_built.append(batch) aggregate_into_larger_batch() def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]: # Perform asynchronous sampling on all (remote) rollout workers. if self.workers.remote_workers(): self._sampling_actor_manager.call_on_all_available( lambda worker: worker.sample()) sample_batches: Dict[ ActorHandle, List[ObjectRef]] = self._sampling_actor_manager.get_ready() else: # only sampling on the local worker sample_batches = { self.workers.local_worker(): [self.workers.local_worker().sample()] } return sample_batches def place_processed_samples_on_learner_queue(self) -> None: self._counters["num_samples_added_to_queue"] = 0 while self.batches_to_place_on_learner: batch = self.batches_to_place_on_learner[0] try: self._learner_thread.inqueue.put(batch, block=False) self.batches_to_place_on_learner.pop(0) self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters["num_samples_added_to_queue"] = batch.count except queue.Full: self._counters["num_times_learner_queue_full"] += 1 def process_trained_results(self) -> ResultDict: # Get learner outputs/stats from output queue. final_learner_info = {} learner_infos = [] num_env_steps_trained = 0 num_agent_steps_trained = 0 for _ in range(self._learner_thread.outqueue.qsize()): if self._learner_thread.is_alive(): ( env_steps, agent_steps, learner_results, ) = self._learner_thread.outqueue.get(timeout=0.001) num_env_steps_trained += env_steps num_agent_steps_trained += agent_steps if learner_results: learner_infos.append(learner_results) else: raise RuntimeError("The learner thread died in while training") if not learner_infos: final_learner_info = copy.deepcopy( self._learner_thread.learner_info) else: builder = LearnerInfoBuilder() for info in learner_infos: builder.add_learn_on_batch_results_multi_agent(info) final_learner_info = builder.finalize() # Update the steps trained counters. self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained return final_learner_info def process_experiences_directly( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: processed_batches = [] batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] if not batches: return processed_batches if batches and isinstance(batches[0], ray.ObjectRef): batches = ray.get(batches) for batch in batches: batch = batch.decompress_if_needed() self.local_mixin_buffer.add_batch(batch) batch = self.local_mixin_buffer.replay(_ALL_POLICIES) if batch: processed_batches.append(batch) return processed_batches def process_experiences_tree_aggregation( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] ready_processed_batches = [] for batch in batches: self._aggregator_actor_manager.call( lambda actor, b: actor.process_episodes(b), fn_kwargs={"b": batch}) waiting_processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = self._aggregator_actor_manager.get_ready() for ready_sub_batches in waiting_processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) return ready_processed_batches def update_workers_if_necessary(self) -> None: # Only need to update workers if there are remote workers. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]} self._counters["steps_since_broadcast"] += 1 if (self.workers.remote_workers() and self._counters["steps_since_broadcast"] >= self.config["broadcast_interval"] and self.workers_that_need_updates): weights = ray.put(self.workers.local_worker().get_weights()) self._counters["steps_since_broadcast"] = 0 self._learner_thread.weights_updated = False self._counters["num_weight_broadcasts"] += 1 for worker in self.workers_that_need_updates: worker.set_weights.remote(weights, global_vars) self.workers_that_need_updates = set() # Update global vars of the local worker. self.workers.local_worker().set_global_vars(global_vars) @override(Algorithm) def on_worker_failures(self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]): """Handle the failures of remote sampling workers Args: removed_workers: removed worker ids. new_workers: ids of newly created workers. """ if self.config["_disable_execution_plan_api"]: self._sampling_actor_manager.remove_workers( removed_workers, remove_in_flight_requests=True) self._sampling_actor_manager.add_workers(new_workers) @override(Algorithm) def _compile_iteration_results(self, *, step_ctx, iteration_results=None): result = super()._compile_iteration_results( step_ctx=step_ctx, iteration_results=iteration_results) result = self._learner_thread.add_learner_metrics( result, overwrite_learner_info=False) return result
class ApexDQN(DQN): @override(Trainable) def setup(self, config: PartialAlgorithmConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] # Create copy here so that we can modify without breaking other logic replay_actor_config = copy.deepcopy( self.config["replay_buffer_config"]) replay_actor_config["capacity"] = ( self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"]) # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if replay_actor_config["replay_buffer_shards_colocated_with_driver"]: self._replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) ( ReplayActor, None, replay_actor_config, num_replay_buffer_shards, ) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self._replay_actors = [ ReplayActor.remote(*replay_actor_config) for _ in range(num_replay_buffer_shards) ] self._replay_actor_manager = AsyncRequestsManager( self._replay_actors, max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_replay_worker"], ray_wait_timeout_s=self.config["timeout_s_replay_manager"], ) self._sampling_actor_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_sampler_worker"], ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], ) self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @classmethod @override(DQN) def get_default_config(cls) -> AlgorithmConfigDict: return ApexDQNConfig().to_dict() @override(DQN) def validate_config(self, config): if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!") # Call DQN's validation method. super().validate_config(config) @override(DQN) def training_step(self) -> ResultDict: num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers() worker_samples_collected = defaultdict(int) for worker, samples_infos in num_samples_ready_dict.items(): for samples_info in samples_infos: self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info[ "agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info[ "env_steps"] worker_samples_collected[worker] += samples_info["agent_steps"] # update the weights of the workers that returned samples # only do this if there are remote workers (config["num_workers"] > 1) if self.workers.remote_workers(): self.update_workers(worker_samples_collected) # trigger a sample from the replay actors and enqueue operation to the # learner thread. self.sample_from_replay_buffer_place_on_learner_queue_non_blocking( worker_samples_collected) self.update_replay_sample_priority() return copy.deepcopy(self.learner_thread.learner_info) def get_samples_and_store_to_replay_buffers(self): # in the case the num_workers = 0 if not self.workers.remote_workers(): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() actor = random.choice(self._replay_actors) ray.get(actor.add.remote(batch)) batch_statistics = { local_sampling_worker: [{ "agent_steps": batch.agent_steps(), "env_steps": batch.env_steps(), }] } return batch_statistics def remote_worker_sample_and_store(worker: RolloutWorker, replay_actors: List[ActorHandle]): # This function is run as a remote function on sampling workers, # and should only be used with the RolloutWorker's apply function ever. # It is used to gather samples, and trigger the operation to store them to # replay actors from the rollout worker instead of returning the obj # refs for the samples to the driver process and doing the sampling # operation on there. _batch = worker.sample() _actor = random.choice(replay_actors) _actor.add.remote(_batch) _batch_statistics = { "agent_steps": _batch.agent_steps(), "env_steps": _batch.env_steps(), } return _batch_statistics # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: self._sampling_actor_manager.call_on_all_available( remote_worker_sample_and_store, fn_kwargs={"replay_actors": self._replay_actors}, ) num_samples_ready_dict = self._sampling_actor_manager.get_ready() return num_samples_ready_dict def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: """Update the remote workers that have samples ready. Args: _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to the number of samples returned by the remote worker. Returns: The number of remote workers whose weights were updated. """ max_steps_weight_sync_delay = self.config["optimizer"][ "max_weight_sync_delay"] # Update our local copy of the weights if the learner thread has updated # the learner worker's weights if self.learner_thread.weights_updated: self.learner_thread.weights_updated = False weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: for ( remote_sampler_worker, num_samples_collected, ) in _num_samples_ready.items(): self.steps_since_update[ remote_sampler_worker] += num_samples_collected if (self.steps_since_update[remote_sampler_worker] >= max_steps_weight_sync_delay): remote_sampler_worker.set_weights.remote( self.curr_learner_weights, {"timestep": self._counters[STEPS_TRAINED_COUNTER]}, ) self.steps_since_update[remote_sampler_worker] = 0 self._counters["num_weight_syncs"] += 1 def sample_from_replay_buffer_place_on_learner_queue_non_blocking( self, num_samples_collected: Dict[ActorHandle, int]) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: num_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have been collected since the last time that training was triggered. """ def wait_on_replay_actors() -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ _replay_samples_ready = self._replay_actor_manager.get_ready() for _replay_actor, _sample_batches in _replay_samples_ready.items( ): for _sample_batch in _sample_batches: self.replay_sample_batches.append( (_replay_actor, _sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected wait_on_replay_actors() if self.curr_num_samples_collected >= self.config["train_batch_size"]: training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"]) * training_intensity num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): self._replay_actor_manager.call( lambda actor, num_items: actor.sample(num_items), fn_args=[self.config["train_batch_size"]], ) wait_on_replay_actors() # add the sample batches to the learner queue while self.replay_sample_batches: try: item = self.replay_sample_batches[0] # the replay buffer returns none if it has not been filled to # the minimum threshold yet. if item: self.learner_thread.inqueue.put( self.replay_sample_batches[0], timeout=0.001) self.replay_sample_batches.pop(0) except queue.Full: break def update_replay_sample_priority(self) -> None: """Update the priorities of the sample batches with new priorities that are computed by the learner thread. """ num_samples_trained_this_itr = 0 for _ in range(self.learner_thread.outqueue.qsize()): if self.learner_thread.is_alive(): ( replay_actor, priority_dict, env_steps, agent_steps, ) = self.learner_thread.outqueue.get(timeout=0.001) if (self.config["replay_buffer_config"].get( "prioritized_replay_alpha") > 0): replay_actor.update_priorities.remote(priority_dict) num_samples_trained_this_itr += env_steps self.update_target_networks(env_steps) self._counters[NUM_ENV_STEPS_TRAINED] += env_steps self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps self.workers.local_worker().set_global_vars( {"timestep": self._counters[NUM_ENV_STEPS_TRAINED]}) else: raise RuntimeError("The learner thread died in while training") self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr self._timers["learner_dequeue"] = self.learner_thread.queue_timer self._timers["learner_grad"] = self.learner_thread.grad_timer self._timers["learner_overall"] = self.learner_thread.overall_timer def update_target_networks(self, num_new_trained_samples) -> None: """Update the target networks.""" self._num_ts_trained_since_last_target_update += num_new_trained_samples if (self._num_ts_trained_since_last_target_update >= self.config["target_network_update_freq"]): self._num_ts_trained_since_last_target_update = 0 with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ STEPS_TRAINED_COUNTER] @override(Algorithm) def on_worker_failures(self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]): """Handle the failures of remote sampling workers Args: removed_workers: removed worker ids. new_workers: ids of newly created workers. """ if self.config["_disable_execution_plan_api"]: self._sampling_actor_manager.remove_workers( removed_workers, remove_in_flight_requests=True) self._sampling_actor_manager.add_workers(new_workers) @override(Algorithm) def _compile_iteration_results(self, *, step_ctx, iteration_results=None): result = super()._compile_iteration_results( step_ctx=step_ctx, iteration_results=iteration_results) replay_stats = ray.get(self._replay_actors[0].stats.remote( self.config["optimizer"].get("debug"))) exploration_infos_list = self.workers.foreach_policy_to_train( lambda p, pid: {pid: p.get_exploration_state()}) exploration_infos = {} for info in exploration_infos_list: # we're guaranteed that each info has policy ids that are unique exploration_infos.update(info) other_results = { "exploration_infos": exploration_infos, "learner_queue": self.learner_thread.learner_queue_size.stats(), "replay_shard_0": replay_stats, } result["info"].update(other_results) return result @classmethod @override(Algorithm) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Local worker + replay buffer actors. # Force replay buffers to be on same node to maximize # data bandwidth between buffers and the learner (driver). # Replay buffer actors each contain one shard of the total # replay buffer and use 1 CPU each. "CPU": cf["num_cpus_for_driver"] + cf["optimizer"]["num_replay_buffer_shards"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], **cf["custom_resources_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation workers. # Note: The local eval worker is located on the driver # CPU. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), **eval_config.get( "custom_resources_per_worker", cf["custom_resources_per_worker"], ), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), )
class A3C(Trainer): @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: return A3CConfig().to_dict() @override(Trainer) def setup(self, config: PartialTrainerConfigDict): super().setup(config) self._worker_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=1) @override(Trainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["entropy_coeff"] < 0: raise ValueError("`entropy_coeff` must be >= 0.0!") if config["num_workers"] <= 0 and config["sample_async"]: raise ValueError("`num_workers` for A3C must be >= 1!") @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: if config["framework"] == "torch": from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy return A3CTorchPolicy elif config["framework"] == "tf": from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CStaticGraphTFPolicy return A3CStaticGraphTFPolicy else: from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CEagerTFPolicy return A3CEagerTFPolicy def training_step(self) -> ResultDict: # Shortcut. local_worker = self.workers.local_worker() # Define the function executed in parallel by all RolloutWorkers to collect # samples + compute and return gradients (and other information). def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: """Call sample() and compute_gradients() remotely on workers.""" samples = worker.sample() grads, infos = worker.compute_gradients(samples) return { "grads": grads, "infos": infos, "agent_steps": samples.agent_steps(), "env_steps": samples.env_steps(), } # Perform rollouts and gradient calculations asynchronously. with self._timers[GRAD_WAIT_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. self._worker_manager.call_on_all_available( sample_and_compute_grads) async_results = self._worker_manager.get_ready() # Loop through all fetched worker-computed gradients (if any) # and apply them - one by one - to the local worker's model. # After each apply step (one step per worker that returned some gradients), # update that particular worker's weights. global_vars = None learner_info_builder = LearnerInfoBuilder(num_devices=1) for worker, results in async_results.items(): for result in results: # Apply gradients to local worker. with self._timers[APPLY_GRADS_TIMER]: local_worker.apply_gradients(result["grads"]) self._timers[APPLY_GRADS_TIMER].push_units_processed( result["agent_steps"]) # Update all step counters. self._counters[NUM_AGENT_STEPS_SAMPLED] += result[ "agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result[ "agent_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] learner_info_builder.add_learn_on_batch_results_multi_agent( result["infos"]) # Create current global vars. global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Synch updated weights back to the particular worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: weights = local_worker.get_weights( local_worker.get_policies_to_train()) worker.set_weights.remote(weights, global_vars) # Update global vars of the local worker. if global_vars: local_worker.set_global_vars(global_vars) return learner_info_builder.finalize() @override(Trainer) def on_worker_failures(self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]): """Handle failures on remote A3C workers. Args: removed_workers: removed worker ids. new_workers: ids of newly created workers. """ self._worker_manager.remove_workers(removed_workers) self._worker_manager.add_workers(new_workers)