def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the # final results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same # structure no matter the setup (multi-GPU, multi-agent, # minibatch SGD, tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch( replay) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results( results, pid) td_error = results["td_error"] # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing # (may get manipulated if they are part of a tensor). replay.policy_batches[pid].set_get_interceptor(None) prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) self.learner_info = learner_info_builder.finalize() self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
def step(self) -> Optional[_NextValueNotReady]: with self.queue_timer: try: batch, _ = self.minibatch_buffer.get() except queue.Empty: return _NextValueNotReady() with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch(batch) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results(results, pid) self.learner_info = learner_info_builder.finalize() learner_stats = { pid: info[LEARNER_STATS_KEY] for pid, info in self.learner_info.items() } self.weights_updated = True self.num_steps += 1 self.outqueue.put((batch.count, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize())
def process_trained_results(self) -> ResultDict: # Get learner outputs/stats from output queue. final_learner_info = {} learner_infos = [] num_env_steps_trained = 0 num_agent_steps_trained = 0 for _ in range(self._learner_thread.outqueue.qsize()): if self._learner_thread.is_alive(): ( env_steps, agent_steps, learner_results, ) = self._learner_thread.outqueue.get(timeout=0.001) num_env_steps_trained += env_steps num_agent_steps_trained += agent_steps if learner_results: learner_infos.append(learner_results) else: raise RuntimeError("The learner thread died in while training") if not learner_infos: final_learner_info = copy.deepcopy(self._learner_thread.learner_info) else: builder = LearnerInfoBuilder() for info in learner_infos: builder.add_learn_on_batch_results_multi_agent(info) final_learner_info = builder.finalize() # Update the steps trained counters. self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained return final_learner_info
def do_minibatch_sgd( samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields, ): """Execute minibatch SGD. Args: samples (SampleBatch): Batch of samples to optimize. policies (dict): Dictionary of policies to optimize. local_worker (RolloutWorker): Master rollout worker instance. num_sgd_iter (int): Number of epochs of optimization to take. sgd_minibatch_size (int): Size of minibatches to use for optimization. standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: averaged info fetches over the last SGD epoch taken. """ # Handle everything as if multi-agent. samples = samples.as_multi_agent() # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id, policy in policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in standardize_fields: batch[field] = standardized(batch[field]) # Check to make sure that the sgd_minibatch_size is not smaller # than max_seq_len otherwise this will cause indexing errors while # performing sgd when using a RNN or Attention model if (policy.is_recurrent() and policy.config["model"]["max_seq_len"] > sgd_minibatch_size): raise ValueError("`sgd_minibatch_size` ({}) cannot be smaller than" "`max_seq_len` ({}).".format( sgd_minibatch_size, policy.config["model"]["max_seq_len"])) for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): results = (local_worker.learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] learner_info_builder.add_learn_on_batch_results( results, policy_id) learner_info = learner_info_builder.finalize() return learner_info
def step(self) -> None: assert self.loader_thread.is_alive() with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder( num_devices=len(self.policy.devices)) default_policy_results = self.policy.learn_on_loaded_batch( offset=0, buffer_index=buffer_idx) learner_info_builder.add_learn_on_batch_results( default_policy_results) self.learner_info = learner_info_builder.finalize() learner_stats = { DEFAULT_POLICY_ID: self.learner_info[DEFAULT_POLICY_ID][ LEARNER_STATS_KEY] } self.weights_updated = True if released: self.idle_tower_stacks.put(buffer_idx) self.outqueue.put( (self.policy.get_num_samples_loaded_into_buffer(buffer_idx), learner_stats)) self.learner_queue_size.push(self.inqueue.qsize())
def training_iteration(self) -> ResultDict: # Shortcut. first_worker = self.workers.remote_workers()[0] # Run sampling and update steps on each worker in asynchronous fashion. sample_and_update_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.0, max_remote_requests_in_flight_per_actor=1, # 2 remote_fn=self._sample_and_train_torch_distributed, ) # For all results collected: # - Update our counters and timers. # - Update the worker's global_vars. # - Build info dict using a LearnerInfoBuilder object. learner_info_builder = LearnerInfoBuilder(num_devices=1) steps_this_iter = 0 for worker, results in sample_and_update_results.items(): for result in results: steps_this_iter += result["env_steps"] self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"]) self._timers[SAMPLE_TIMER].push(result["sample_time"]) # Add partial learner info to builder object. learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"]) # Broadcast the local set of global vars. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]} for worker in self.workers.remote_workers(): worker.set_global_vars.remote(global_vars) self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter # Sync down the weights from 1st remote worker (only if we have received # some results from it). # As with the sync up, this is not really needed unless the user is # reading the local weights. if ( self.config["keep_local_weights_in_sync"] and first_worker in sample_and_update_results ): self.workers.local_worker().set_weights( ray.get(first_worker.get_weights.remote()) ) # Return merged laarner into results. new_learner_info = learner_info_builder.finalize() if new_learner_info: self._curr_learner_info = new_learner_info return self._curr_learner_info
def training_step(self) -> ResultDict: # Shortcut. first_worker = self.workers.remote_workers()[0] self._ddppo_worker_manager.call_on_all_available( self._sample_and_train_torch_distributed ) sample_and_update_results = self._ddppo_worker_manager.get_ready() # For all results collected: # - Update our counters and timers. # - Update the worker's global_vars. # - Build info dict using a LearnerInfoBuilder object. learner_info_builder = LearnerInfoBuilder(num_devices=1) for worker, results in sample_and_update_results.items(): for result in results: self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"]) self._timers[SAMPLE_TIMER].push(result["sample_time"]) # Add partial learner info to builder object. learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"]) # Broadcast the local set of global vars. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]} for worker in self.workers.remote_workers(): worker.set_global_vars.remote(global_vars) # Sync down the weights from 1st remote worker (only if we have received # some results from it). # As with the sync up, this is not really needed unless the user is # reading the local weights. if ( self.config["keep_local_weights_in_sync"] and first_worker in sample_and_update_results ): self.workers.local_worker().set_weights( ray.get(first_worker.get_weights.remote()) ) # Return merged laarner into results. new_learner_info = learner_info_builder.finalize() if new_learner_info: self._curr_learner_info = new_learner_info return self._curr_learner_info
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields): """Execute minibatch SGD. Args: samples (SampleBatch): Batch of samples to optimize. policies (dict): Dictionary of policies to optimize. local_worker (RolloutWorker): Master rollout worker instance. num_sgd_iter (int): Number of epochs of optimization to take. sgd_minibatch_size (int): Size of minibatches to use for optimization. standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: averaged info fetches over the last SGD epoch taken. """ if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in standardize_fields: batch[field] = standardized(batch[field]) for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): results = (local_worker.learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] learner_info_builder.add_learn_on_batch_results( results, policy_id) learner_info = learner_info_builder.finalize() return learner_info
def step(self) -> None: assert self.loader_thread.is_alive() with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() get_num_samples_loaded_into_buffer = 0 with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=len(self.devices)) for pid in self.policy_map.keys(): # Not a policy-to-train. if not self.local_worker.is_policy_to_train(pid): continue policy = self.policy_map[pid] default_policy_results = policy.learn_on_loaded_batch( offset=0, buffer_index=buffer_idx ) learner_info_builder.add_learn_on_batch_results(default_policy_results) self.weights_updated = True get_num_samples_loaded_into_buffer += ( policy.get_num_samples_loaded_into_buffer(buffer_idx) ) self.learner_info = learner_info_builder.finalize() if released: self.idle_tower_stacks.put(buffer_idx) # Put tuple: env-steps, agent-steps, and learner info into the queue. self.outqueue.put( ( get_num_samples_loaded_into_buffer, get_num_samples_loaded_into_buffer, self.learner_info, ) ) self.learner_queue_size.push(self.inqueue.qsize())
def process_trained_results(self) -> ResultDict: # Get learner outputs/stats from output queue. learner_infos = [] num_env_steps_trained = 0 num_agent_steps_trained = 0 # Loop through output queue and update our counts. for _ in range(self._learner_thread.outqueue.qsize()): if self._learner_thread.is_alive(): ( env_steps, agent_steps, learner_results, ) = self._learner_thread.outqueue.get(timeout=0.001) num_env_steps_trained += env_steps num_agent_steps_trained += agent_steps if learner_results: learner_infos.append(learner_results) else: raise RuntimeError("The learner thread died while training") # Nothing new happened since last time, use the same learner stats. if not learner_infos: final_learner_info = copy.deepcopy( self._learner_thread.learner_info) # Accumulate learner stats using the `LearnerInfoBuilder` utility. else: builder = LearnerInfoBuilder() for info in learner_infos: builder.add_learn_on_batch_results_multi_agent(info) final_learner_info = builder.finalize() # Update the steps trained counters. self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained return final_learner_info
def multi_gpu_train_one_step(trainer, train_batch) -> Dict: """Multi-GPU version of train_one_step. Uses the policies' `load_batch_into_buffer` and `learn_on_loaded_batch` methods to be more efficient wrt CPU/GPU data transfers. For example, when doing multiple passes through a train batch (e.g. for PPO) using `config.num_sgd_iter`, the actual train batch is only split once and loaded once into the GPU(s). Examples: >>> from ray.rllib.execution.rollout_ops import synchronous_parallel_sample >>> trainer = [...] # doctest: +SKIP >>> train_batch = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> # This trains the policy on one batch. >>> results = multi_gpu_train_one_step(trainer, train_batch)) # doctest: +SKIP {"default_policy": ...} Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as the LOAD_BATCH_TIMER and LEARN_ON_BATCH_TIMER timers of the `trainer` object. """ config = trainer.config workers = trainer.workers local_worker = workers.local_worker() num_sgd_iter = config.get("num_sgd_iter", 1) sgd_minibatch_size = config.get("sgd_minibatch_size", config["train_batch_size"]) # Determine the number of devices (GPUs or 1 CPU) we use. num_devices = int(math.ceil(config["num_gpus"] or 1)) # Make sure total batch size is dividable by the number of devices. # Batch size per tower. per_device_batch_size = sgd_minibatch_size // num_devices # Total batch size. batch_size = per_device_batch_size * num_devices assert batch_size % num_devices == 0 assert batch_size >= num_devices, "Batch size too small!" # Handle everything as if multi-agent. train_batch = train_batch.as_multi_agent() # Load data into GPUs. load_timer = trainer._timers[LOAD_BATCH_TIMER] with load_timer: num_loaded_samples = {} for policy_id, batch in train_batch.policy_batches.items(): # Not a policy-to-train. if not local_worker.is_policy_to_train(policy_id, train_batch): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER] with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=num_devices) for policy_id, samples_per_device in num_loaded_samples.items(): policy = local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * per_device_batch_size, buffer_index=0) learner_info_builder.add_learn_on_batch_results( results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(train_batch.count) learn_timer.push_units_processed(train_batch.count) trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() return learner_info
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. samples = samples.as_multi_agent() metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if not self.local_worker.is_policy_to_train( policy_id, samples): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder( num_devices=len(self.devices)) for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0, ) learner_info_builder.add_learn_on_batch_results( results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, learner_info
def training_iteration(self) -> ResultDict: # Shortcut. local_worker = self.workers.local_worker() # Define the function executed in parallel by all RolloutWorkers to collect # samples + compute and return gradients (and other information). def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: """Call sample() and compute_gradients() remotely on workers.""" samples = worker.sample() grads, infos = worker.compute_gradients(samples) return { "grads": grads, "infos": infos, "agent_steps": samples.agent_steps(), "env_steps": samples.env_steps(), } # Perform rollouts and gradient calculations asynchronously. with self._timers[GRAD_WAIT_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. async_results: Dict[ ActorHandle, Dict] = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.0, max_remote_requests_in_flight_per_actor=1, remote_fn=sample_and_compute_grads, ) # Loop through all fetched worker-computed gradients (if any) # and apply them - one by one - to the local worker's model. # After each apply step (one step per worker that returned some gradients), # update that particular worker's weights. global_vars = None learner_info_builder = LearnerInfoBuilder(num_devices=1) for worker, result in async_results.items(): # Apply gradients to local worker. with self._timers[APPLY_GRADS_TIMER]: local_worker.apply_gradients(result["grads"]) self._timers[APPLY_GRADS_TIMER].push_units_processed( result["agent_steps"]) # Update all step counters. self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] # Create current global vars. global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Synch updated weights back to the particular worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: weights = local_worker.get_weights( local_worker.get_policies_to_train()) worker.set_weights.remote(weights, global_vars) learner_info_builder.add_learn_on_batch_results_multi_agent( result["infos"]) # Update global vars of the local worker. if global_vars: local_worker.set_global_vars(global_vars) return learner_info_builder.finalize()
def multi_gpu_train_one_step(trainer, train_batch) -> Dict: config = trainer.config workers = trainer.workers local_worker = workers.local_worker() num_sgd_iter = config.get("sgd_num_iter", 1) sgd_minibatch_size = config.get("sgd_minibatch_size", config["train_batch_size"]) # Determine the number of devices (GPUs or 1 CPU) we use. num_devices = int(math.ceil(config["num_gpus"] or 1)) # Make sure total batch size is dividable by the number of devices. # Batch size per tower. per_device_batch_size = sgd_minibatch_size // num_devices # Total batch size. batch_size = per_device_batch_size * num_devices assert batch_size % num_devices == 0 assert batch_size >= num_devices, "Batch size too small!" # Handle everything as if multi-agent. train_batch = train_batch.as_multi_agent() # Load data into GPUs. load_timer = trainer._timers[LOAD_BATCH_TIMER] with load_timer: num_loaded_samples = {} for policy_id, batch in train_batch.policy_batches.items(): # Not a policy-to-train. if not local_worker.is_policy_to_train(policy_id, train_batch): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = local_worker.policy_map[ policy_id ].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER] with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=num_devices) for policy_id, samples_per_device in num_loaded_samples.items(): policy = local_worker.policy_map[policy_id] num_batches = max(1, int(samples_per_device) // int(per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * per_device_batch_size, buffer_index=0 ) learner_info_builder.add_learn_on_batch_results(results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(train_batch.count) learn_timer.push_units_processed(train_batch.count) trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update weights - after learning on the local worker - on all remote # workers. if workers.remote_workers(): with trainer._timers[WORKER_UPDATE_TIMER]: weights = ray.put( local_worker.get_weights( local_worker.get_policies_to_train(train_batch) ) ) for e in workers.remote_workers(): e.set_weights.remote(weights) return learner_info