def setup(self, config: PartialTrainerConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] # Create copy here so that we can modify without breaking other logic replay_actor_config = copy.deepcopy( self.config["replay_buffer_config"]) replay_actor_config["capacity"] = ( self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"]) # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if replay_actor_config["replay_buffer_shards_colocated_with_driver"]: self.replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) ( ReplayActor, None, replay_actor_config, num_replay_buffer_shards, ) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self.replay_actors = [ ReplayActor.remote(*replay_actor_config) for _ in range(num_replay_buffer_shards) ] self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.remote_sampling_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.remote_replay_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0
def apex_execution_plan(workers: WorkerSet, config: dict) -> LocalIterator[dict]: # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], config["train_batch_size"], config["prioritized_replay_alpha"], config["prioritized_replay_beta"], config["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config.get("replay_sequence_length", 1), ], num_replay_buffer_shards) # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None: actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts \ .for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor() \ .for_each(UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]) )) # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(actors=replay_actors, num_async=4) \ .for_each(lambda x: post_fn(x, workers, config)) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(update_prio_and_stats) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True)) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only output # items from (3) since metrics aren't available before then. merged_op = Concurrently( [store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_trainable_policy( lambda p, _: p.get_exploration_info()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), "learner": copy.deepcopy(learner_thread.stats), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of epsilons. selected_workers = workers.remote_workers()[ -len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
def execution_plan(workers: WorkerSet, config: dict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "Apex execution_plan does NOT take any additional parameters" # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], buffer_size, config["train_batch_size"], config["replay_buffer_config"]["prioritized_replay_alpha"], config["replay_buffer_config"]["prioritized_replay_beta"], config["replay_buffer_config"]["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if config["replay_buffer_shards_colocated_with_driver"]: replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: replay_actors = [ ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats( item: Tuple[ActorHandle, dict, int, int]) -> None: actor, prio_dict, env_count, agent_count = item if config.get("prioritized_replay"): actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count metrics.counters[STEPS_TRAINED_COUNTER] += env_count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in one of our replay buffer # actors. Update the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor().for_each( UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]), )) # (2) Read experiences from one of the replay buffer actors and send # to the learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = (Replay( actors=replay_actors, num_async=4).for_each(lambda x: post_fn( x, workers, config)).zip_with_source_actor().for_each( Enqueue(learner_thread.inqueue))) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = (Dequeue(learner_thread.outqueue, check=learner_thread.is_alive).for_each( update_prio_and_stats).for_each( UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True))) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights, ) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only # output items from (3) since metrics aren't available before # then. merged_op = Concurrently([store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_policy_to_train( lambda p, _: p.get_exploration_state()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of # epsilons. selected_workers = workers.remote_workers( )[-len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
class ApexTrainer(DQNTrainer): @override(Trainable) def setup(self, config: PartialTrainerConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, self.config["learning_starts"], buffer_size, self.config["train_batch_size"], self.config["replay_buffer_config"]["prioritized_replay_alpha"], self.config["replay_buffer_config"]["prioritized_replay_beta"], self.config["replay_buffer_config"]["prioritized_replay_eps"], self.config["multiagent"]["replay_mode"], self.config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if self.config["replay_buffer_shards_colocated_with_driver"]: self.replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self.replay_actors = [ ReplayActor.remote(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.remote_sampling_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.remote_replay_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @classmethod @override(DQNTrainer) def get_default_config(cls) -> TrainerConfigDict: return APEX_DEFAULT_CONFIG @override(DQNTrainer) def validate_config(self, config): if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!") # Call DQN's validation method. super().validate_config(config) # if config["_disable_execution_plan_api"]: # if not config.get("training_intensity", 1.0) > 0: # raise ValueError("training_intensity must be > 0") @override(Trainable) def training_iteration(self) -> ResultDict: num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers() worker_samples_collected = defaultdict(int) for worker, samples_infos in num_samples_ready_dict.items(): for samples_info in samples_infos: self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info[ "agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info[ "env_steps"] worker_samples_collected[worker] += samples_info["agent_steps"] # update the weights of the workers that returned samples # only do this if there are remote workers (config["num_workers"] > 1) if self.workers.remote_workers(): self.update_workers(worker_samples_collected) # trigger a sample from the replay actors and enqueue operation to the # learner thread. self.sample_from_replay_buffer_place_on_learner_queue_non_blocking( worker_samples_collected) self.update_replay_sample_priority() return copy.deepcopy(self.learner_thread.learner_info) @staticmethod @override(DQNTrainer) def execution_plan(workers: WorkerSet, config: dict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "Apex execution_plan does NOT take any additional parameters" # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], buffer_size, config["train_batch_size"], config["replay_buffer_config"]["prioritized_replay_alpha"], config["replay_buffer_config"]["prioritized_replay_beta"], config["replay_buffer_config"]["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if config["replay_buffer_shards_colocated_with_driver"]: replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: replay_actors = [ ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats( item: Tuple[ActorHandle, dict, int, int]) -> None: actor, prio_dict, env_count, agent_count = item if config.get("prioritized_replay"): actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count metrics.counters[STEPS_TRAINED_COUNTER] += env_count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in one of our replay buffer # actors. Update the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor().for_each( UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]), )) # (2) Read experiences from one of the replay buffer actors and send # to the learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = (Replay( actors=replay_actors, num_async=4).for_each(lambda x: post_fn( x, workers, config)).zip_with_source_actor().for_each( Enqueue(learner_thread.inqueue))) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = (Dequeue(learner_thread.outqueue, check=learner_thread.is_alive).for_each( update_prio_and_stats).for_each( UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True))) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights, ) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only # output items from (3) since metrics aren't available before # then. merged_op = Concurrently([store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_policy_to_train( lambda p, _: p.get_exploration_state()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of # epsilons. selected_workers = workers.remote_workers( )[-len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics) def get_samples_and_store_to_replay_buffers(self): # in the case the num_workers = 0 if not self.workers.remote_workers(): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() actor = random.choice(self.replay_actors) ray.get(actor.add_batch.remote(batch)) batch_statistics = { local_sampling_worker: [{ "agent_steps": batch.agent_steps(), "env_steps": batch.env_steps(), }] } return batch_statistics def remote_worker_sample_and_store(worker: RolloutWorker, replay_actors: List[ReplayActor]): # This function is run as a remote function on sampling workers, # and should only be used with the RolloutWorker's apply function ever. # It is used to gather samples, and trigger the operation to store them to # replay actors from the rollout worker instead of returning the obj # refs for the samples to the driver process and doing the sampling # operation on there. _batch = worker.sample() _actor = random.choice(replay_actors) _actor.add_batch.remote(_batch) _batch_statistics = { "agent_steps": _batch.agent_steps(), "env_steps": _batch.env_steps(), } return _batch_statistics # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. num_samples_ready_dict: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_sampling_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor=4, remote_fn=remote_worker_sample_and_store, remote_kwargs=[{ "replay_actors": self.replay_actors }] * len(self.workers.remote_workers()), ) return num_samples_ready_dict def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: """Update the remote workers that have samples ready. Args: _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to the number of samples returned by the remote worker. Returns: The number of remote workers whose weights were updated. """ max_steps_weight_sync_delay = self.config["optimizer"][ "max_weight_sync_delay"] # Update our local copy of the weights if the learner thread has updated # the learner worker's weights if self.learner_thread.weights_updated: self.learner_thread.weights_updated = False weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: for ( remote_sampler_worker, num_samples_collected, ) in _num_samples_ready.items(): self.steps_since_update[ remote_sampler_worker] += num_samples_collected if (self.steps_since_update[remote_sampler_worker] >= max_steps_weight_sync_delay): remote_sampler_worker.set_weights.remote( self.curr_learner_weights, {"timestep": self._counters[STEPS_TRAINED_COUNTER]}, ) self.steps_since_update[remote_sampler_worker] = 0 self._counters["num_weight_syncs"] += 1 def sample_from_replay_buffer_place_on_learner_queue_non_blocking( self, num_samples_collected: Dict[ActorHandle, int]) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: num_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have been collected since the last time that training was triggered. """ def wait_on_replay_actors(timeout: float) -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, ray_wait_timeout_s=timeout, ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected if self.curr_num_samples_collected >= self.config["train_batch_size"]: wait_on_replay_actors(None) training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"]) * training_intensity num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): rand_actor = random.choice(self.replay_actors) replay_samples_ready: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, actors=[rand_actor], ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor= num_requests_to_launch, remote_fn=lambda actor: actor.replay(), ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) wait_on_replay_actors(0.1) # add the sample batches to the learner queue while self.replay_sample_batches: try: item = self.replay_sample_batches[0] # the replay buffer returns none if it has not been filled to # the minimum threshold yet. if item: self.learner_thread.inqueue.put( self.replay_sample_batches[0], timeout=0.001) self.replay_sample_batches.pop(0) except queue.Full: break def update_replay_sample_priority(self) -> int: """Update the priorities of the sample batches with new priorities that are computed by the learner thread. Returns: The number of samples trained by the learner thread since the last training iteration. """ num_samples_trained_this_itr = 0 for _ in range(self.learner_thread.outqueue.qsize()): if self.learner_thread.is_alive(): ( replay_actor, priority_dict, env_steps, agent_steps, ) = self.learner_thread.outqueue.get(timeout=0.001) if self.config["prioritized_replay"]: replay_actor.update_priorities.remote(priority_dict) num_samples_trained_this_itr += env_steps self.update_target_networks(env_steps) self._counters[NUM_ENV_STEPS_TRAINED] += env_steps self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps self.workers.local_worker().set_global_vars( {"timestep": self._counters[NUM_ENV_STEPS_TRAINED]}) else: raise RuntimeError("The learner thread died in while training") self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr self._timers["learner_dequeue"] = self.learner_thread.queue_timer self._timers["learner_grad"] = self.learner_thread.grad_timer self._timers["learner_overall"] = self.learner_thread.overall_timer def update_target_networks(self, num_new_trained_samples) -> None: """Update the target networks.""" self._num_ts_trained_since_last_target_update += num_new_trained_samples if (self._num_ts_trained_since_last_target_update >= self.config["target_network_update_freq"]): self._num_ts_trained_since_last_target_update = 0 to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ STEPS_TRAINED_COUNTER] @override(Trainer) def _compile_step_results(self, *, step_ctx, step_attempt_results=None): result = super()._compile_step_results( step_ctx=step_ctx, step_attempt_results=step_attempt_results) replay_stats = ray.get(self.replay_actors[0].stats.remote( self.config["optimizer"].get("debug"))) exploration_infos_list = self.workers.foreach_policy_to_train( lambda p, pid: {pid: p.get_exploration_state()}) exploration_infos = {} for info in exploration_infos_list: # we're guaranteed that each info has policy ids that are unique exploration_infos.update(info) other_results = { "exploration_infos": exploration_infos, "learner_queue": self.learner_thread.learner_queue_size.stats(), "replay_shard_0": replay_stats, } result["info"].update(other_results) return result @classmethod @override(Trainable) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Local worker + replay buffer actors. # Force replay buffers to be on same node to maximize # data bandwidth between buffers and the learner (driver). # Replay buffer actors each contain one shard of the total # replay buffer and use 1 CPU each. "CPU": cf["num_cpus_for_driver"] + cf["optimizer"]["num_replay_buffer_shards"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation workers. # Note: The local eval worker is located on the driver # CPU. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), )
class ApexTrainer(DQNTrainer): @override(Trainable) def setup(self, config: PartialTrainerConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] # Create copy here so that we can modify without breaking other logic replay_actor_config = copy.deepcopy( self.config["replay_buffer_config"]) replay_actor_config["capacity"] = ( self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"]) # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if replay_actor_config["replay_buffer_shards_colocated_with_driver"]: self.replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) ( ReplayActor, None, replay_actor_config, num_replay_buffer_shards, ) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self.replay_actors = [ ReplayActor.remote(*replay_actor_config) for _ in range(num_replay_buffer_shards) ] self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.remote_sampling_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.remote_replay_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @classmethod @override(DQNTrainer) def get_default_config(cls) -> TrainerConfigDict: return APEX_DEFAULT_CONFIG @override(DQNTrainer) def validate_config(self, config): if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!") # Call DQN's validation method. super().validate_config(config) # if config["_disable_execution_plan_api"]: # if not config.get("training_intensity", 1.0) > 0: # raise ValueError("training_intensity must be > 0") @override(Trainable) def training_iteration(self) -> ResultDict: num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers() worker_samples_collected = defaultdict(int) for worker, samples_infos in num_samples_ready_dict.items(): for samples_info in samples_infos: self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info[ "agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info[ "env_steps"] worker_samples_collected[worker] += samples_info["agent_steps"] # update the weights of the workers that returned samples # only do this if there are remote workers (config["num_workers"] > 1) if self.workers.remote_workers(): self.update_workers(worker_samples_collected) # trigger a sample from the replay actors and enqueue operation to the # learner thread. self.sample_from_replay_buffer_place_on_learner_queue_non_blocking( worker_samples_collected) self.update_replay_sample_priority() return copy.deepcopy(self.learner_thread.learner_info) def get_samples_and_store_to_replay_buffers(self): # in the case the num_workers = 0 if not self.workers.remote_workers(): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() actor = random.choice(self.replay_actors) ray.get(actor.add.remote(batch)) batch_statistics = { local_sampling_worker: [{ "agent_steps": batch.agent_steps(), "env_steps": batch.env_steps(), }] } return batch_statistics def remote_worker_sample_and_store(worker: RolloutWorker, replay_actors: List[ActorHandle]): # This function is run as a remote function on sampling workers, # and should only be used with the RolloutWorker's apply function ever. # It is used to gather samples, and trigger the operation to store them to # replay actors from the rollout worker instead of returning the obj # refs for the samples to the driver process and doing the sampling # operation on there. _batch = worker.sample() _actor = random.choice(replay_actors) _actor.add.remote(_batch) _batch_statistics = { "agent_steps": _batch.agent_steps(), "env_steps": _batch.env_steps(), } return _batch_statistics # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. num_samples_ready_dict: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_sampling_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor=4, remote_fn=remote_worker_sample_and_store, remote_kwargs=[{ "replay_actors": self.replay_actors }] * len(self.workers.remote_workers()), ) return num_samples_ready_dict def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: """Update the remote workers that have samples ready. Args: _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to the number of samples returned by the remote worker. Returns: The number of remote workers whose weights were updated. """ max_steps_weight_sync_delay = self.config["optimizer"][ "max_weight_sync_delay"] # Update our local copy of the weights if the learner thread has updated # the learner worker's weights if self.learner_thread.weights_updated: self.learner_thread.weights_updated = False weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: for ( remote_sampler_worker, num_samples_collected, ) in _num_samples_ready.items(): self.steps_since_update[ remote_sampler_worker] += num_samples_collected if (self.steps_since_update[remote_sampler_worker] >= max_steps_weight_sync_delay): remote_sampler_worker.set_weights.remote( self.curr_learner_weights, {"timestep": self._counters[STEPS_TRAINED_COUNTER]}, ) self.steps_since_update[remote_sampler_worker] = 0 self._counters["num_weight_syncs"] += 1 def sample_from_replay_buffer_place_on_learner_queue_non_blocking( self, num_samples_collected: Dict[ActorHandle, int]) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: num_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have been collected since the last time that training was triggered. """ def wait_on_replay_actors(timeout: float) -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, ray_wait_timeout_s=timeout, ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected if self.curr_num_samples_collected >= self.config["train_batch_size"]: wait_on_replay_actors(None) training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"]) * training_intensity num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): rand_actor = random.choice(self.replay_actors) replay_samples_ready: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, actors=[rand_actor], ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor= num_requests_to_launch, remote_args=[[self.config["train_batch_size"]]], remote_fn=lambda actor, num_items: actor.sample( num_items), ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) wait_on_replay_actors(0.1) # add the sample batches to the learner queue while self.replay_sample_batches: try: item = self.replay_sample_batches[0] # the replay buffer returns none if it has not been filled to # the minimum threshold yet. if item: self.learner_thread.inqueue.put( self.replay_sample_batches[0], timeout=0.001) self.replay_sample_batches.pop(0) except queue.Full: break def update_replay_sample_priority(self) -> None: """Update the priorities of the sample batches with new priorities that are computed by the learner thread. """ num_samples_trained_this_itr = 0 for _ in range(self.learner_thread.outqueue.qsize()): if self.learner_thread.is_alive(): ( replay_actor, priority_dict, env_steps, agent_steps, ) = self.learner_thread.outqueue.get(timeout=0.001) if (self.config["replay_buffer_config"].get( "prioritized_replay_alpha") > 0): replay_actor.update_priorities.remote(priority_dict) num_samples_trained_this_itr += env_steps self.update_target_networks(env_steps) self._counters[NUM_ENV_STEPS_TRAINED] += env_steps self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps self.workers.local_worker().set_global_vars( {"timestep": self._counters[NUM_ENV_STEPS_TRAINED]}) else: raise RuntimeError("The learner thread died in while training") self._counters[ STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr self._timers["learner_dequeue"] = self.learner_thread.queue_timer self._timers["learner_grad"] = self.learner_thread.grad_timer self._timers["learner_overall"] = self.learner_thread.overall_timer def update_target_networks(self, num_new_trained_samples) -> None: """Update the target networks.""" self._num_ts_trained_since_last_target_update += num_new_trained_samples if (self._num_ts_trained_since_last_target_update >= self.config["target_network_update_freq"]): self._num_ts_trained_since_last_target_update = 0 with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ STEPS_TRAINED_COUNTER] @override(Trainer) def _compile_step_results(self, *, step_ctx, step_attempt_results=None): result = super()._compile_step_results( step_ctx=step_ctx, step_attempt_results=step_attempt_results) replay_stats = ray.get(self.replay_actors[0].stats.remote( self.config["optimizer"].get("debug"))) exploration_infos_list = self.workers.foreach_policy_to_train( lambda p, pid: {pid: p.get_exploration_state()}) exploration_infos = {} for info in exploration_infos_list: # we're guaranteed that each info has policy ids that are unique exploration_infos.update(info) other_results = { "exploration_infos": exploration_infos, "learner_queue": self.learner_thread.learner_queue_size.stats(), "replay_shard_0": replay_stats, } result["info"].update(other_results) return result @classmethod @override(Trainable) def default_resource_request(cls, config): cf = dict(cls.get_default_config(), **config) eval_config = cf["evaluation_config"] # Return PlacementGroupFactory containing all needed resources # (already properly defined as device bundles). return PlacementGroupFactory( bundles=[{ # Local worker + replay buffer actors. # Force replay buffers to be on same node to maximize # data bandwidth between buffers and the learner (driver). # Replay buffer actors each contain one shard of the total # replay buffer and use 1 CPU each. "CPU": cf["num_cpus_for_driver"] + cf["optimizer"]["num_replay_buffer_shards"], "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], }] + [ { # RolloutWorkers. "CPU": cf["num_cpus_per_worker"], "GPU": cf["num_gpus_per_worker"], } for _ in range(cf["num_workers"]) ] + ([ { # Evaluation workers. # Note: The local eval worker is located on the driver # CPU. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", cf["num_gpus_per_worker"]), } for _ in range(cf["evaluation_num_workers"]) ] if cf["evaluation_interval"] else []), strategy=config.get("placement_strategy", "PACK"), )