def training_iteration(self) -> ResultDict: # Shortcut. first_worker = self.workers.remote_workers()[0] # Run sampling and update steps on each worker in asynchronous fashion. sample_and_update_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.0, max_remote_requests_in_flight_per_actor=1, # 2 remote_fn=self._sample_and_train_torch_distributed, ) # For all results collected: # - Update our counters and timers. # - Update the worker's global_vars. # - Build info dict using a LearnerInfoBuilder object. learner_info_builder = LearnerInfoBuilder(num_devices=1) steps_this_iter = 0 for worker, results in sample_and_update_results.items(): for result in results: steps_this_iter += result["env_steps"] self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"]) self._timers[SAMPLE_TIMER].push(result["sample_time"]) # Add partial learner info to builder object. learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"]) # Broadcast the local set of global vars. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]} for worker in self.workers.remote_workers(): worker.set_global_vars.remote(global_vars) self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter # Sync down the weights from 1st remote worker (only if we have received # some results from it). # As with the sync up, this is not really needed unless the user is # reading the local weights. if ( self.config["keep_local_weights_in_sync"] and first_worker in sample_and_update_results ): self.workers.local_worker().set_weights( ray.get(first_worker.get_weights.remote()) ) # Return merged laarner into results. new_learner_info = learner_info_builder.finalize() if new_learner_info: self._curr_learner_info = new_learner_info return self._curr_learner_info
def get_samples_and_store_to_replay_buffers(self): # in the case the num_workers = 0 if not self.workers.remote_workers(): with self._timers[SAMPLE_TIMER]: local_sampling_worker = self.workers.local_worker() batch = local_sampling_worker.sample() actor = random.choice(self.replay_actors) ray.get(actor.add_batch.remote(batch)) batch_statistics = { local_sampling_worker: [{ "agent_steps": batch.agent_steps(), "env_steps": batch.env_steps(), }] } return batch_statistics def remote_worker_sample_and_store(worker: RolloutWorker, replay_actors: List[ReplayActor]): # This function is run as a remote function on sampling workers, # and should only be used with the RolloutWorker's apply function ever. # It is used to gather samples, and trigger the operation to store them to # replay actors from the rollout worker instead of returning the obj # refs for the samples to the driver process and doing the sampling # operation on there. _batch = worker.sample() _actor = random.choice(replay_actors) _actor.add_batch.remote(_batch) _batch_statistics = { "agent_steps": _batch.agent_steps(), "env_steps": _batch.env_steps(), } return _batch_statistics # Sample and Store in the Replay Actors on the sampling workers. with self._timers[SAMPLE_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. num_samples_ready_dict: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_sampling_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor=4, remote_fn=remote_worker_sample_and_store, remote_kwargs=[{ "replay_actors": self.replay_actors }] * len(self.workers.remote_workers()), ) return num_samples_ready_dict
def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]: # Perform asynchronous sampling on all (remote) rollout workers. if self.workers.remote_workers(): sample_batches: Dict[ ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=self.config["sample_wait_timeout"], max_remote_requests_in_flight_per_actor=self. config["max_sample_requests_in_flight_per_worker"], return_result_obj_ref_ids=True, ) else: # only sampling on the local worker sample_batches = { self.workers.local_worker(): [self.workers.local_worker().sample()] } return sample_batches
def process_experiences_tree_aggregation( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] ready_processed_batches = [] for batch in batches: aggregator = random.choice(self.aggregator_workers) processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_aggregator_requests_in_flight, actors=[aggregator], remote_fn=lambda actor, b: actor.process_episodes(b), remote_kwargs=[{ "b": batch }], ray_wait_timeout_s=self.config["aggregator_wait_timeout"], max_remote_requests_in_flight_per_actor=float("inf"), ) for ready_sub_batches in processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) waiting_processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_aggregator_requests_in_flight, ray_wait_timeout_s=self.config["aggregator_wait_timeout"], ) for ready_sub_batches in waiting_processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) return ready_processed_batches
def training_iteration(self) -> ResultDict: # Trigger asynchronous rollouts on all RolloutWorkers. # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: sample_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers() or [self.workers.local_worker()], ray_wait_timeout_s=self.config["sample_wait_timeout"], max_remote_requests_in_flight_per_actor=2, remote_fn=self._sample_and_send_to_buffer, ) # Update sample counters. for sample_result in sample_results.values(): for (env_steps, agent_steps) in sample_result: self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: pol_actors = [] args = [] for pid, pol_actor, repl_actor in self.distributed_learners: pol_actors.append(pol_actor) args.append([repl_actor, pid]) train_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=pol_actors, ray_wait_timeout_s=self.config["learn_wait_timeout"], max_remote_requests_in_flight_per_actor=2, remote_fn=self._update_policy, remote_args=args, ) # Update sample counters. for train_result in train_results.values(): for result in train_result: if NUM_AGENT_STEPS_TRAINED in result: self._counters[NUM_AGENT_STEPS_TRAINED] += result[ NUM_AGENT_STEPS_TRAINED] # For those policies that have been updated in this iteration # (not all policies may have undergone an updated as we are # requesting updates asynchronously): # - Gather train infos. # - Update weights to those remote rollout workers that contain # the respective policy. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: train_infos = {} policy_weights = {} for pol_actor, policy_results in train_results.items(): results_have_same_structure = True for result1, result2 in zip(policy_results, policy_results[1:]): try: tree.assert_same_structure(result1, result2) except (ValueError, TypeError): results_have_same_structure = False break if len(policy_results) > 1 and results_have_same_structure: policy_result = tree.map_structure( lambda *_args: sum(_args) / len(policy_results), *policy_results) else: policy_result = policy_results[-1] if policy_result: pid = self.distributed_learners.get_policy_id(pol_actor) train_infos[pid] = policy_result policy_weights[pid] = pol_actor.get_weights.remote() policy_weights_ref = ray.put(policy_weights) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "league_builder": self.league_builder.__getstate__(), } for worker in self.workers.remote_workers(): worker.set_weights.remote(policy_weights_ref, global_vars) return train_infos
def training_iteration(self) -> ResultDict: # Shortcut. local_worker = self.workers.local_worker() # Define the function executed in parallel by all RolloutWorkers to collect # samples + compute and return gradients (and other information). def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: """Call sample() and compute_gradients() remotely on workers.""" samples = worker.sample() grads, infos = worker.compute_gradients(samples) return { "grads": grads, "infos": infos, "agent_steps": samples.agent_steps(), "env_steps": samples.env_steps(), } # Perform rollouts and gradient calculations asynchronously. with self._timers[GRAD_WAIT_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. async_results: Dict[ ActorHandle, Dict] = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.0, max_remote_requests_in_flight_per_actor=1, remote_fn=sample_and_compute_grads, ) # Loop through all fetched worker-computed gradients (if any) # and apply them - one by one - to the local worker's model. # After each apply step (one step per worker that returned some gradients), # update that particular worker's weights. global_vars = None learner_info_builder = LearnerInfoBuilder(num_devices=1) for worker, result in async_results.items(): # Apply gradients to local worker. with self._timers[APPLY_GRADS_TIMER]: local_worker.apply_gradients(result["grads"]) self._timers[APPLY_GRADS_TIMER].push_units_processed( result["agent_steps"]) # Update all step counters. self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] # Create current global vars. global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Synch updated weights back to the particular worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: weights = local_worker.get_weights( local_worker.get_policies_to_train()) worker.set_weights.remote(weights, global_vars) learner_info_builder.add_learn_on_batch_results_multi_agent( result["infos"]) # Update global vars of the local worker. if global_vars: local_worker.set_global_vars(global_vars) return learner_info_builder.finalize()
def training_iteration(self) -> ResultDict: # Trigger asynchronous rollouts on all RolloutWorkers. # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: sample_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers() or [self.workers.local_worker()], ray_wait_timeout_s=0.01, max_remote_requests_in_flight_per_actor=2, remote_fn=self._sample_and_send_to_buffer, ) # Update sample counters. for (env_steps, agent_steps) in sample_results.values(): self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: pol_actors = [] args = [] for pid, pol_actor, repl_actor in self.distributed_learners: pol_actors.append(pol_actor) args.append([repl_actor, pid]) train_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=pol_actors, ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor=2, remote_fn=self._update_policy, remote_args=args, ) # Update sample counters. for result in train_results.values(): if NUM_AGENT_STEPS_TRAINED in result: self._counters[NUM_AGENT_STEPS_TRAINED] += result[ NUM_AGENT_STEPS_TRAINED] # For those policies that have been updated in this iteration # (not all policies may have undergone an updated as we are # requesting updates asynchronously): # - Gather train infos. # - Update weights to those remote rollout workers that contain # the respective policy. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: train_infos = {} policy_weights = {} for pol_actor, policy_result in train_results.items(): if policy_result: pid = self.distributed_learners.get_policy_id(pol_actor) train_infos[pid] = policy_result policy_weights[pid] = pol_actor.get_weights.remote() policy_weights_ref = ray.put(policy_weights) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "win_rates": self.win_rates, } for worker in self.workers.remote_workers(): worker.set_weights.remote(policy_weights_ref, global_vars) return train_infos
def sample_from_replay_buffer_place_on_learner_queue_non_blocking( self, num_samples_collected: Dict[ActorHandle, int]) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: num_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have been collected since the last time that training was triggered. """ def wait_on_replay_actors(timeout: float) -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, ray_wait_timeout_s=timeout, ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected if self.curr_num_samples_collected >= self.config["train_batch_size"]: wait_on_replay_actors(None) training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( self.curr_num_samples_collected / self.config["train_batch_size"]) * training_intensity num_requests_to_launch = max(1, round(num_requests_to_launch)) self.curr_num_samples_collected = 0 for _ in range(num_requests_to_launch): rand_actor = random.choice(self.replay_actors) replay_samples_ready: Dict[ ActorHandle, T] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, actors=[rand_actor], ray_wait_timeout_s=0.1, max_remote_requests_in_flight_per_actor= num_requests_to_launch, remote_fn=lambda actor: actor.replay(), ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch)) wait_on_replay_actors(0.1) # add the sample batches to the learner queue while self.replay_sample_batches: try: item = self.replay_sample_batches[0] # the replay buffer returns none if it has not been filled to # the minimum threshold yet. if item: self.learner_thread.inqueue.put( self.replay_sample_batches[0], timeout=0.001) self.replay_sample_batches.pop(0) except queue.Full: break