def test_test_async_requests_task_doesnt_buffering(self): """Tests that the async manager drops""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) for i in range(8): scheduled = manager.call(lambda w: w.task()) if i < 4: assert scheduled, "We should have scheduled the task" else: assert not scheduled, ( "We should not have scheduled the task because" " all workers are busy." ) assert len(manager._pending_remotes) == 4, "We should have 4 pending requests" time.sleep(3) ready_requests = manager.get_ready() for worker in workers: if not len(ready_requests[worker]) == 2: raise Exception( "We should return the 2 ready requests in this case from each " "actors." ) for _ in range(4): manager.call(lambda w: w.task()) # new tasks scheduled from the buffer time.sleep(3) ready_requests = manager.get_ready() for worker in workers: if not len(ready_requests[worker]) == 2: raise Exception( "We should return the 2 ready requests in this case from each " "actors" )
def test_add_remove_actors(self): """Tests that the async manager can properly add and remove actors""" workers = [] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) if not ( ( len(manager._all_workers) == len(manager._remote_requests_in_flight) == len(manager._pending_to_actor) == len(manager._pending_remotes) == 0 ) ): raise ValueError("We should have no workers in this case.") assert not manager.call(lambda w: w.task()), ( "Task shouldn't have been " "launched since there are no " "workers in the manager." ) worker = RemoteRLlibActor.remote(sleep_time=0.1) manager.add_workers(worker) manager.call(lambda w: w.task()) if not ( len(manager._remote_requests_in_flight[worker]) == len(manager._pending_to_actor) == len(manager._all_workers) == len(manager._pending_remotes) == 1 ): raise ValueError("We should have 1 worker and 1 pending request") time.sleep(3) manager.get_ready() # test worker removal for i in range(2): manager.call(lambda w: w.task()) assert len(manager._pending_remotes) == i + 1 manager.remove_workers(worker) if not ((len(manager._all_workers) == 0)): raise ValueError("We should have no workers that we can schedule tasks to") if not ( (len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2) ): raise ValueError( "We should still have 2 pending requests in flight from the worker" ) time.sleep(3) result = manager.get_ready() if not ( len(result) == 1 and len(result[worker]) == 2 and len(manager._pending_remotes) == 0 and len(manager._pending_to_actor) == 0 ): raise ValueError( "We should have 2 ready results from the worker and no pending requests" )
def test_call_to_actor(self): workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] worker_not_in_manager = RemoteRLlibActor.remote(sleep_time=0.1) manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2) manager.call(lambda w: w.task(), actor=workers[0]) time.sleep(3) results = manager.get_ready() if not len(results) == 1 and workers[0] not in results: raise Exception( "We should return the 1 ready requests in this case from the worker we " "called to") with pytest.raises(ValueError, match=".*has not been added to the manager.*"): manager.call(lambda w: w.task(), actor=worker_not_in_manager)
def test_round_robin_scheduling(self): """Test that the async manager schedules actors in a round robin fashion""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) for i in range(4): scheduled_actor = workers[i % len(workers)] manager.call(lambda w: w.task()) if i < 2: assert len(manager._remote_requests_in_flight[scheduled_actor]) == 1, ( "We should have 1 request in flight for the actor that we just " "scheduled on" ) else: assert len(manager._remote_requests_in_flight[scheduled_actor]) == 2, ( "We should have 2 request in flight for the actor that we just " "scheduled on" )
def test_async_requests_manager_num_returns(self): """Tests that an async manager can properly handle actors with tasks that vary in the amount of time that they take to run""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] workers += [RemoteRLlibActor.remote(sleep_time=5) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=1) for _ in range(4): manager.call(lambda w: w.task()) time.sleep(3) if not len(manager.get_ready()) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have shorter tasks") time.sleep(7) if not len(manager.get_ready()) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have longer tasks")
def test_args_kwargs(self): """Tests that the async manager can properly handle actors with tasks that vary in the amount of time that they take to run""" workers = [RemoteRLlibActor.remote(sleep_time=0.1)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2) for _ in range(2): manager.call(lambda w, a, b: w.task2(a, b), fn_args=[1, 2]) time.sleep(3) if not len(manager.get_ready()[workers[0]]) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have shorter tasks") for _ in range(2): manager.call(lambda w, a, b: w.task2(a, b), fn_kwargs=dict(a=1, b=2)) time.sleep(3) if not len(manager.get_ready()[workers[0]]) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have longer tasks")
class Impala(Algorithm): """Importance weighted actor/learner architecture (IMPALA) Algorithm == Overview of data flow in IMPALA == 1. Policy evaluation in parallel across `num_workers` actors produces batches of size `rollout_fragment_length * num_envs_per_worker`. 2. If enabled, the replay buffer stores and produces batches of size `rollout_fragment_length * num_envs_per_worker`. 3. If enabled, the minibatch ring buffer stores and replays batches of size `train_batch_size` up to `num_sgd_iter` times per batch. 4. The learner thread executes data parallel SGD across `num_gpus` GPUs on batches of size `train_batch_size`. """ @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfigDict: return ImpalaConfig().to_dict() @override(Algorithm) def get_default_policy_class( self, config: PartialAlgorithmConfigDict) -> Optional[Type[Policy]]: if config["framework"] == "torch": if config["vtrace"]: from ray.rllib.algorithms.impala.impala_torch_policy import ( ImpalaTorchPolicy, ) return ImpalaTorchPolicy else: from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy return A3CTorchPolicy elif config["framework"] == "tf": if config["vtrace"]: from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy return ImpalaTF1Policy else: from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy return A3CTFPolicy else: if config["vtrace"]: from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF2Policy return ImpalaTF2Policy else: from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy return A3CTFPolicy @override(Algorithm) def validate_config(self, config): # Call the super class' validation method first. super().validate_config(config) # Check the IMPALA specific config. if config["num_data_loader_buffers"] != DEPRECATED_VALUE: deprecation_warning("num_data_loader_buffers", "num_multi_gpu_tower_stacks", error=False) config["num_multi_gpu_tower_stacks"] = config[ "num_data_loader_buffers"] if config["entropy_coeff"] < 0.0: raise ValueError("`entropy_coeff` must be >= 0.0!") # Check whether worker to aggregation-worker ratio makes sense. if config["num_aggregation_workers"] > config["num_workers"]: raise ValueError( "`num_aggregation_workers` must be smaller than or equal " "`num_workers`! Aggregation makes no sense otherwise.") elif config["num_aggregation_workers"] > config["num_workers"] / 2: logger.warning( "`num_aggregation_workers` should be significantly smaller " "than `num_workers`! Try setting it to 0.5*`num_workers` or " "less.") # If two separate optimizers/loss terms used for tf, must also set # `_tf_policy_handles_more_than_one_loss` to True. if config["_separate_vf_optimizer"] is True: # Only supported to tf so far. # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the # models to return separate sets of weights in order to create # the different torch optimizers). if config["framework"] not in ["tf", "tf2", "tfe"]: raise ValueError( "`_separate_vf_optimizer` only supported to tf so far!") if config["_tf_policy_handles_more_than_one_loss"] is False: logger.warning( "`_tf_policy_handles_more_than_one_loss` must be set to " "True, for TFPolicy to support more than one loss " "term/optimizer! Auto-setting it to True.") config["_tf_policy_handles_more_than_one_loss"] = True @override(Algorithm) def setup(self, config: PartialAlgorithmConfigDict): super().setup(config) if self.config["_disable_execution_plan_api"]: # Create extra aggregation workers and assign each rollout worker to # one of them. self.batches_to_place_on_learner = [] self.batch_being_built = [] if self.config["num_aggregation_workers"] > 0: # This spawns `num_aggregation_workers` actors that aggregate # experiences coming from RolloutWorkers in parallel. We force # colocation on the same node (localhost) to maximize data bandwidth # between them and the learner. localhost = platform.node() assert localhost != "", ( "ERROR: Cannot determine local node name! " "`platform.node()` returned empty string.") all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) ( AggregatorWorker, [ self.config, ], {}, self.config["num_aggregation_workers"], ) ], node=localhost, ) self._aggregator_workers = [ actor for actor_groups in all_co_located for actor in actor_groups ] self._aggregator_actor_manager = AsyncRequestsManager( self._aggregator_workers, max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_aggregator_worker"], ray_wait_timeout_s=self. config["timeout_s_aggregator_manager"], ) else: # Create our local mixin buffer if the num of aggregation workers is 0. self.local_mixin_buffer = MixInMultiAgentReplayBuffer( capacity=(self.config["replay_buffer_num_slots"] if self.config["replay_buffer_num_slots"] > 0 else 1), replay_ratio=self.config["replay_ratio"], replay_mode=ReplayMode.LOCKSTEP, ) self._sampling_actor_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_sampler_worker"], return_object_refs=True, ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], ) # Create and start the learner thread. self._learner_thread = make_learner_thread( self.workers.local_worker(), self.config) self._learner_thread.start() self.workers_that_need_updates = set() @override(Algorithm) def training_step(self) -> ResultDict: unprocessed_sample_batches = self.get_samples_from_workers() self.workers_that_need_updates |= unprocessed_sample_batches.keys() if self.config["num_aggregation_workers"] > 0: batch = self.process_experiences_tree_aggregation( unprocessed_sample_batches) else: batch = self.process_experiences_directly( unprocessed_sample_batches) self.concatenate_batches_and_pre_queue(batch) self.place_processed_samples_on_learner_queue() train_results = self.process_trained_results() self.update_workers_if_necessary() return train_results @staticmethod @override(Algorithm) def execution_plan(workers, config, **kwargs): assert ( len(kwargs) == 0 ), "IMPALA execution_plan does NOT take any additional parameters" if config["num_aggregation_workers"] > 0: train_batches = gather_experiences_tree_aggregation( workers, config) else: train_batches = gather_experiences_directly(workers, config) # Start the learner thread. learner_thread = make_learner_thread(workers.local_worker(), config) learner_thread.start() # This sub-flow sends experiences to the learner. enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue)) # Only need to update workers if there are remote workers. if workers.remote_workers(): enqueue_op = enqueue_op.zip_with_source_actor().for_each( BroadcastUpdateLearnerWeights( learner_thread, workers, broadcast_interval=config["broadcast_interval"], )) def record_steps_trained(item): count, fetches, _ = item metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count metrics.counters[STEPS_TRAINED_COUNTER] += count return item # This sub-flow updates the steps trained counter based on learner # output. dequeue_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive).for_each(record_steps_trained) merged_op = Concurrently([enqueue_op, dequeue_op], mode="async", output_indexes=[1]) # Callback for APPO to use to update KL, target network periodically. # The input to the callback is the learner fetches dict. if config["after_train_step"]: merged_op = merged_op.for_each(lambda t: t[1]).for_each( config["after_train_step"](workers, config)) return StandardMetricsReporting(merged_op, workers, config).for_each( learner_thread.add_learner_metrics) @classmethod @override(Algorithm) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Driver + Aggregation Workers: # Force to be on same node to maximize data bandwidth # between aggregation workers and the learner (driver). # Aggregation workers tree-aggregate experiences collected # from RolloutWorkers (n rollout workers map to m # aggregation workers, where m < n) and always use 1 CPU # each. "CPU": cf["num_cpus_for_driver"] + cf["num_aggregation_workers"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], **cf["custom_resources_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), **eval_config.get( "custom_resources_per_worker", cf["custom_resources_per_worker"], ), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), ) def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]): """Concatenate batches that are being returned from rollout workers Args: batches: batches of experiences from rollout workers """ def aggregate_into_larger_batch(): if (sum(b.count for b in self.batch_being_built) >= self.config["train_batch_size"]): batch_to_add = SampleBatch.concat_samples( self.batch_being_built) self.batches_to_place_on_learner.append(batch_to_add) self.batch_being_built = [] for batch in batches: self.batch_being_built.append(batch) aggregate_into_larger_batch() def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]: # Perform asynchronous sampling on all (remote) rollout workers. if self.workers.remote_workers(): self._sampling_actor_manager.call_on_all_available( lambda worker: worker.sample()) sample_batches: Dict[ ActorHandle, List[ObjectRef]] = self._sampling_actor_manager.get_ready() else: # only sampling on the local worker sample_batches = { self.workers.local_worker(): [self.workers.local_worker().sample()] } return sample_batches def place_processed_samples_on_learner_queue(self) -> None: self._counters["num_samples_added_to_queue"] = 0 while self.batches_to_place_on_learner: batch = self.batches_to_place_on_learner[0] try: self._learner_thread.inqueue.put(batch, block=False) self.batches_to_place_on_learner.pop(0) self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters["num_samples_added_to_queue"] = batch.count except queue.Full: self._counters["num_times_learner_queue_full"] += 1 def process_trained_results(self) -> ResultDict: # Get learner outputs/stats from output queue. final_learner_info = {} learner_infos = [] num_env_steps_trained = 0 num_agent_steps_trained = 0 for _ in range(self._learner_thread.outqueue.qsize()): if self._learner_thread.is_alive(): ( env_steps, agent_steps, learner_results, ) = self._learner_thread.outqueue.get(timeout=0.001) num_env_steps_trained += env_steps num_agent_steps_trained += agent_steps if learner_results: learner_infos.append(learner_results) else: raise RuntimeError("The learner thread died in while training") if not learner_infos: final_learner_info = copy.deepcopy( self._learner_thread.learner_info) else: builder = LearnerInfoBuilder() for info in learner_infos: builder.add_learn_on_batch_results_multi_agent(info) final_learner_info = builder.finalize() # Update the steps trained counters. self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained return final_learner_info def process_experiences_directly( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: processed_batches = [] batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] if not batches: return processed_batches if batches and isinstance(batches[0], ray.ObjectRef): batches = ray.get(batches) for batch in batches: batch = batch.decompress_if_needed() self.local_mixin_buffer.add_batch(batch) batch = self.local_mixin_buffer.replay(_ALL_POLICIES) if batch: processed_batches.append(batch) return processed_batches def process_experiences_tree_aggregation( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] ready_processed_batches = [] for batch in batches: self._aggregator_actor_manager.call( lambda actor, b: actor.process_episodes(b), fn_kwargs={"b": batch}) waiting_processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = self._aggregator_actor_manager.get_ready() for ready_sub_batches in waiting_processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) return ready_processed_batches def update_workers_if_necessary(self) -> None: # Only need to update workers if there are remote workers. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]} self._counters["steps_since_broadcast"] += 1 if (self.workers.remote_workers() and self._counters["steps_since_broadcast"] >= self.config["broadcast_interval"] and self.workers_that_need_updates): weights = ray.put(self.workers.local_worker().get_weights()) self._counters["steps_since_broadcast"] = 0 self._learner_thread.weights_updated = False self._counters["num_weight_broadcasts"] += 1 for worker in self.workers_that_need_updates: worker.set_weights.remote(weights, global_vars) self.workers_that_need_updates = set() # Update global vars of the local worker. self.workers.local_worker().set_global_vars(global_vars) @override(Algorithm) def on_worker_failures(self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]): """Handle the failures of remote sampling workers Args: removed_workers: removed worker ids. new_workers: ids of newly created workers. """ if self.config["_disable_execution_plan_api"]: self._sampling_actor_manager.remove_workers( removed_workers, remove_in_flight_requests=True) self._sampling_actor_manager.add_workers(new_workers) @override(Algorithm) def _compile_iteration_results(self, *, step_ctx, iteration_results=None): result = super()._compile_iteration_results( step_ctx=step_ctx, iteration_results=iteration_results) result = self._learner_thread.add_learner_metrics( result, overwrite_learner_info=False) return result
class ApexDQN(DQN): @override(Trainable) def setup(self, config: PartialAlgorithmConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] # Create copy here so that we can modify without breaking other logic replay_actor_config = copy.deepcopy( self.config["replay_buffer_config"]) replay_actor_config["capacity"] = ( self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"]) # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if replay_actor_config["replay_buffer_shards_colocated_with_driver"]: self._replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) ( ReplayActor, None, replay_actor_config, num_replay_buffer_shards, ) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self._replay_actors = [ ReplayActor.remote(*replay_actor_config) for _ in range(num_replay_buffer_shards) ] self._replay_actor_manager = AsyncRequestsManager( self._replay_actors, max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_replay_worker"], ray_wait_timeout_s=self.config["timeout_s_replay_manager"], ) self._sampling_actor_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_sampler_worker"], ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], ) self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @classmethod @override(DQN) def get_default_config(cls) -> AlgorithmConfigDict: return ApexDQNConfig().to_dict() @override(DQN) def validate_config(self, config): if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!") # Call DQN's validation method. super().validate_config(config) @override(DQN) def training_step(self) -> ResultDict: num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers() worker_samples_collected = defaultdict(int) for worker, samples_infos in num_samples_ready_dict.items(): for samples_info in samples_infos: self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info[ "agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info[ "env_steps"] worker_samples_collected[worker] += samples_info["agent_steps"] # update the weights of the workers that returned samples # only do this if there are remote workers (config["num_workers"] > 1) if self.workers.remote_workers(): self.update_workers(worker_samples_collected) # trigger a sample from the replay actors and enqueue operation to the # learner thread. self.sample_from_replay_buffer_place_on_learner_queue_non_blocking( worker_samples_collected) self.update_replay_sample_priority() return copy.deepcopy(self.learner_thread.learner_info) def get_samples_and_store_to_replay_buffers(self): # in the case the num_workers = 0 if not self.workers.remote_workers(): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() actor = random.choice(self._replay_actors) ray.get(actor.add.remote(batch)) batch_statistics = { local_sampling_worker: [{ "agent_steps": batch.agent_steps(), "env_steps": batch.env_steps(), }] } return batch_statistics def remote_worker_sample_and_store(worker: RolloutWorker, replay_actors: List[ActorHandle]): # This function is run as a remote function on sampling workers, # and should only be used with the RolloutWorker's apply function ever. # It is used to gather samples, and trigger the operation to store them to # replay actors from the rollout worker instead of returning the obj # refs for the samples to the driver process and doing the sampling # operation on there. _batch = worker.sample() _actor = random.choice(replay_actors) _actor.add.remote(_batch) _batch_statistics = { "agent_steps": _batch.agent_steps(), "env_steps": _batch.env_steps(), } return _batch_statistics # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: self._sampling_actor_manager.call_on_all_available( remote_worker_sample_and_store, fn_kwargs={"replay_actors": self._replay_actors}, ) num_samples_ready_dict = self._sampling_actor_manager.get_ready() return num_samples_ready_dict def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: """Update the remote workers that have samples ready. Args: _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to the number of samples returned by the remote worker. Returns: The number of remote workers whose weights were updated. """ max_steps_weight_sync_delay = self.config["optimizer"][ "max_weight_sync_delay"] # Update our local copy of the weights if the learner thread has updated # the learner worker's weights if self.learner_thread.weights_updated: self.learner_thread.weights_updated = False weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: for ( remote_sampler_worker, num_samples_collected, ) in _num_samples_ready.items(): self.steps_since_update[ remote_sampler_worker] += num_samples_collected if (self.steps_since_update[remote_sampler_worker] >= max_steps_weight_sync_delay): remote_sampler_worker.set_weights.remote( self.curr_learner_weights, {"timestep": self._counters[STEPS_TRAINED_COUNTER]}, ) self.steps_since_update[remote_sampler_worker] = 0 self._counters["num_weight_syncs"] += 1 def sample_from_replay_buffer_place_on_learner_queue_non_blocking( self, num_samples_collected: Dict[ActorHandle, int]) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: num_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have been collected since the last time that training was triggered. """ def wait_on_replay_actors() -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ _replay_samples_ready = self._replay_actor_manager.get_ready() for _replay_actor, _sample_batches in _replay_samples_ready.items( ): for _sample_batch in _sample_batches: self.replay_sample_batches.append( (_replay_actor, _sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected wait_on_replay_actors() if self.curr_num_samples_collected >= self.config["train_batch_size"]: training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"]) * training_intensity num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): self._replay_actor_manager.call( lambda actor, num_items: actor.sample(num_items), fn_args=[self.config["train_batch_size"]], ) wait_on_replay_actors() # add the sample batches to the learner queue while self.replay_sample_batches: try: item = self.replay_sample_batches[0] # the replay buffer returns none if it has not been filled to # the minimum threshold yet. if item: self.learner_thread.inqueue.put( self.replay_sample_batches[0], timeout=0.001) self.replay_sample_batches.pop(0) except queue.Full: break def update_replay_sample_priority(self) -> None: """Update the priorities of the sample batches with new priorities that are computed by the learner thread. """ num_samples_trained_this_itr = 0 for _ in range(self.learner_thread.outqueue.qsize()): if self.learner_thread.is_alive(): ( replay_actor, priority_dict, env_steps, agent_steps, ) = self.learner_thread.outqueue.get(timeout=0.001) if (self.config["replay_buffer_config"].get( "prioritized_replay_alpha") > 0): replay_actor.update_priorities.remote(priority_dict) num_samples_trained_this_itr += env_steps self.update_target_networks(env_steps) self._counters[NUM_ENV_STEPS_TRAINED] += env_steps self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps self.workers.local_worker().set_global_vars( {"timestep": self._counters[NUM_ENV_STEPS_TRAINED]}) else: raise RuntimeError("The learner thread died in while training") self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr self._timers["learner_dequeue"] = self.learner_thread.queue_timer self._timers["learner_grad"] = self.learner_thread.grad_timer self._timers["learner_overall"] = self.learner_thread.overall_timer def update_target_networks(self, num_new_trained_samples) -> None: """Update the target networks.""" self._num_ts_trained_since_last_target_update += num_new_trained_samples if (self._num_ts_trained_since_last_target_update >= self.config["target_network_update_freq"]): self._num_ts_trained_since_last_target_update = 0 with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ STEPS_TRAINED_COUNTER] @override(Algorithm) def on_worker_failures(self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]): """Handle the failures of remote sampling workers Args: removed_workers: removed worker ids. new_workers: ids of newly created workers. """ if self.config["_disable_execution_plan_api"]: self._sampling_actor_manager.remove_workers( removed_workers, remove_in_flight_requests=True) self._sampling_actor_manager.add_workers(new_workers) @override(Algorithm) def _compile_iteration_results(self, *, step_ctx, iteration_results=None): result = super()._compile_iteration_results( step_ctx=step_ctx, iteration_results=iteration_results) replay_stats = ray.get(self._replay_actors[0].stats.remote( self.config["optimizer"].get("debug"))) exploration_infos_list = self.workers.foreach_policy_to_train( lambda p, pid: {pid: p.get_exploration_state()}) exploration_infos = {} for info in exploration_infos_list: # we're guaranteed that each info has policy ids that are unique exploration_infos.update(info) other_results = { "exploration_infos": exploration_infos, "learner_queue": self.learner_thread.learner_queue_size.stats(), "replay_shard_0": replay_stats, } result["info"].update(other_results) return result @classmethod @override(Algorithm) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Local worker + replay buffer actors. # Force replay buffers to be on same node to maximize # data bandwidth between buffers and the learner (driver). # Replay buffer actors each contain one shard of the total # replay buffer and use 1 CPU each. "CPU": cf["num_cpus_for_driver"] + cf["optimizer"]["num_replay_buffer_shards"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], **cf["custom_resources_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation workers. # Note: The local eval worker is located on the driver # CPU. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), **eval_config.get( "custom_resources_per_worker", cf["custom_resources_per_worker"], ), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), )
class AlphaStar(appo.APPO): _allow_unknown_subkeys = appo.APPO._allow_unknown_subkeys + [ "league_builder_config", ] _override_all_subkeys_if_type_changes = ( appo.APPO._override_all_subkeys_if_type_changes + [ "league_builder_config", ]) @classmethod @override(Algorithm) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) # Construct a dummy LeagueBuilder, such that it gets the opportunity to # adjust the multiagent config, according to its setup, and we can then # properly infer the resources to allocate. from_config(cf["league_builder_config"], trainer=None, trainer_config=cf) max_num_policies_to_train = cf["max_num_policies_to_train"] or len( cf["multiagent"].get("policies_to_train") or cf["multiagent"]["policies"]) num_learner_shards = min(cf["num_gpus"] or max_num_policies_to_train, max_num_policies_to_train) num_gpus_per_shard = cf["num_gpus"] / num_learner_shards num_policies_per_shard = max_num_policies_to_train / num_learner_shards fake_gpus = cf["_fake_gpus"] eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Driver (no GPUs). "CPU": cf["num_cpus_for_driver"], }] + [ { # RolloutWorkers (no GPUs). "CPU": cf["num_cpus_per_worker"], } for _ in range(cf["num_workers"]) ] + [ { # Policy learners (and Replay buffer shards). # 1 CPU for the replay buffer. # 1 CPU (or fractional GPU) for each learning policy. "CPU": 1 + (num_policies_per_shard if fake_gpus else 0), "GPU": 0 if fake_gpus else num_gpus_per_shard, } for _ in range(num_learner_shards) ] + ([ { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), ) @classmethod @override(appo.APPO) def get_default_config(cls) -> AlgorithmConfigDict: return AlphaStarConfig().to_dict() @override(appo.APPO) def validate_config(self, config: AlgorithmConfigDict): # Create the LeagueBuilder object, allowing it to build the multiagent # config as well. self.league_builder = from_config(config["league_builder_config"], trainer=self, trainer_config=config) super().validate_config(config) @override(appo.APPO) def setup(self, config: PartialAlgorithmConfigDict): # Call super's setup to validate config, create RolloutWorkers # (train and eval), etc.. num_gpus_saved = config["num_gpus"] config["num_gpus"] = min(config["num_gpus"], 1) super().setup(config) self.config["num_gpus"] = num_gpus_saved # - Create n policy learner actors (@ray.remote-converted Policies) on # one or more GPU nodes. # - On each such node, also locate one replay buffer shard. ma_cfg = self.config["multiagent"] # By default, set max_num_policies_to_train to the number of policy IDs # provided in the multiagent config. if self.config["max_num_policies_to_train"] is None: self.config["max_num_policies_to_train"] = len( self.workers.local_worker().get_policies_to_train()) # Single CPU replay shard (co-located with GPUs so we can place the # policies on the same machine(s)). num_gpus = (0.01 if (self.config["num_gpus"] and not self.config["_fake_gpus"]) else 0) ReplayActor = ray.remote( num_cpus=1, num_gpus=num_gpus, )(MixInMultiAgentReplayBuffer) # Setup remote replay buffer shards and policy learner actors # (located on any GPU machine in the cluster): replay_actor_args = [ self.config["replay_buffer_capacity"], self.config["replay_buffer_replay_ratio"], ] # Create a DistributedLearners utility object and set it up with # the initial first n learnable policies (found in the config). distributed_learners = DistributedLearners( config=self.config, max_num_policies_to_train=self.config["max_num_policies_to_train"], replay_actor_class=ReplayActor, replay_actor_args=replay_actor_args, ) for pid, policy_spec in ma_cfg["policies"].items(): if pid in self.workers.local_worker().get_policies_to_train(): distributed_learners.add_policy(pid, policy_spec) # Store distributed_learners on all RolloutWorkers # so they know, to which replay shard to send samples to. def _set_policy_learners(worker): worker._distributed_learners = distributed_learners ray.get([ w.apply.remote(_set_policy_learners) for w in self.workers.remote_workers() ]) self.distributed_learners = distributed_learners self._sampling_actor_manager = AsyncRequestsManager( self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_sampler_worker"], ray_wait_timeout_s=self.config["timeout_s_sampler_manager"], ) policy_actors = [ policy_actor for _, policy_actor, _ in distributed_learners ] self._learner_worker_manager = AsyncRequestsManager( workers=policy_actors, max_remote_requests_in_flight_per_worker=self. config["max_requests_in_flight_per_learner_worker"], ray_wait_timeout_s=self.config["timeout_s_learner_manager"], ) @override(Algorithm) def step(self) -> ResultDict: # Perform a full step (including evaluation). result = super().step() # Based on the (train + evaluate) results, perform a step of # league building. self.league_builder.build_league(result=result) return result @override(Algorithm) def training_step(self) -> ResultDict: # Trigger asynchronous rollouts on all RolloutWorkers. # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: # if there are no remote workers (e.g. num_workers=0) if not self.workers.remote_workers(): worker = self.workers.local_worker() statistics = worker.apply(self._sample_and_send_to_buffer) sample_results = {worker: [statistics]} else: self._sampling_actor_manager.call_on_all_available( self._sample_and_send_to_buffer) sample_results = self._sampling_actor_manager.get_ready() # Update sample counters. for sample_result in sample_results.values(): for (env_steps, agent_steps) in sample_result: self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: for pid, pol_actor, repl_actor in self.distributed_learners: if pol_actor not in self._learner_worker_manager.workers: self._learner_worker_manager.add_workers(pol_actor) self._learner_worker_manager.call(self._update_policy, actor=pol_actor, fn_args=[repl_actor, pid]) train_results = self._learner_worker_manager.get_ready() # Update sample counters. for train_result in train_results.values(): for result in train_result: if NUM_AGENT_STEPS_TRAINED in result: self._counters[NUM_AGENT_STEPS_TRAINED] += result[ NUM_AGENT_STEPS_TRAINED] # For those policies that have been updated in this iteration # (not all policies may have undergone an updated as we are # requesting updates asynchronously): # - Gather train infos. # - Update weights to those remote rollout workers that contain # the respective policy. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: train_infos = {} policy_weights = {} for pol_actor, policy_results in train_results.items(): results_have_same_structure = True for result1, result2 in zip(policy_results, policy_results[1:]): try: tree.assert_same_structure(result1, result2) except (ValueError, TypeError): results_have_same_structure = False break if len(policy_results) > 1 and results_have_same_structure: policy_result = tree.map_structure( lambda *_args: sum(_args) / len(policy_results), *policy_results) else: policy_result = policy_results[-1] if policy_result: pid = self.distributed_learners.get_policy_id(pol_actor) train_infos[pid] = policy_result policy_weights[pid] = pol_actor.get_weights.remote() policy_weights_ref = ray.put(policy_weights) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "league_builder": self.league_builder.__getstate__(), } for worker in self.workers.remote_workers(): worker.set_weights.remote(policy_weights_ref, global_vars) return train_infos @override(Algorithm) def add_policy( self, policy_id: PolicyID, policy_cls: Type[Policy], *, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, config: Optional[PartialAlgorithmConfigDict] = None, policy_state: Optional[PolicyState] = None, **kwargs, ) -> Policy: # Add the new policy to all our train- and eval RolloutWorkers # (including the local worker). new_policy = super().add_policy( policy_id, policy_cls, observation_space=observation_space, action_space=action_space, config=config, policy_state=policy_state, **kwargs, ) # Do we have to create a policy-learner actor from it as well? if policy_id in kwargs.get("policies_to_train", []): new_policy_actor = self.distributed_learners.add_policy( policy_id, PolicySpec( policy_cls, new_policy.observation_space, new_policy.action_space, self.config, ), ) # Set state of new policy actor, if provided. if policy_state is not None: ray.get(new_policy_actor.set_state.remote(policy_state)) return new_policy @override(Algorithm) def cleanup(self) -> None: super().cleanup() # Stop all policy- and replay actors. self.distributed_learners.stop() @staticmethod def _sample_and_send_to_buffer(worker: RolloutWorker): # Generate a sample. sample = worker.sample() # Send the per-agent SampleBatches to the correct buffer(s), # depending on which policies participated in the episode. assert isinstance(sample, MultiAgentBatch) for pid, batch in sample.policy_batches.items(): # Don't send data, if policy is not trainable. replay_actor, _ = worker._distributed_learners.get_replay_and_policy_actors( pid) if replay_actor is not None: ma_batch = MultiAgentBatch({pid: batch}, batch.count) replay_actor.add_batch.remote(ma_batch) # Return counts (env-steps, agent-steps). return sample.count, sample.agent_steps() @staticmethod def _update_policy(policy: Policy, replay_actor: ActorHandle, pid: PolicyID): if not hasattr(policy, "_target_and_kl_stats"): policy._target_and_kl_stats = { LAST_TARGET_UPDATE_TS: 0, NUM_TARGET_UPDATES: 0, NUM_AGENT_STEPS_TRAINED: 0, TARGET_NET_UPDATE_TIMER: _Timer(), } train_results = policy.learn_on_batch_from_replay_buffer( replay_actor=replay_actor, policy_id=pid) if not train_results: return train_results # Update target net and KL. with policy._target_and_kl_stats[TARGET_NET_UPDATE_TIMER]: policy._target_and_kl_stats[ NUM_AGENT_STEPS_TRAINED] += train_results[ NUM_AGENT_STEPS_TRAINED] target_update_freq = (policy.config["num_sgd_iter"] * policy.config["replay_buffer_capacity"] * policy.config["train_batch_size"]) cur_ts = policy._target_and_kl_stats[NUM_AGENT_STEPS_TRAINED] last_update = policy._target_and_kl_stats[LAST_TARGET_UPDATE_TS] # Update target networks on all policy learners. if cur_ts - last_update > target_update_freq: policy._target_and_kl_stats[NUM_TARGET_UPDATES] += 1 policy._target_and_kl_stats[LAST_TARGET_UPDATE_TS] = cur_ts policy.update_target() # Also update Policy's current KL coeff. if policy.config["use_kl_loss"]: kl = train_results[LEARNER_STATS_KEY].get("kl") assert kl is not None, train_results # Make the actual `Policy.update_kl()` call. policy.update_kl(kl) return train_results @override(appo.APPO) def __getstate__(self) -> dict: state = super().__getstate__() state.update({ "league_builder": self.league_builder.__getstate__(), }) return state @override(appo.APPO) def __setstate__(self, state: dict) -> None: state_copy = state.copy() self.league_builder.__setstate__(state.pop("league_builder", {})) super().__setstate__(state_copy)