def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: if not isinstance(postprocessed_batch, SampleBatch) or \ not postprocessed_batch.zero_padded: pad_batch_to_sequences_of_same_size( postprocessed_batch, max_seq_len=self.max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, view_requirements=self.view_requirements, ) # Mark the batch as "is_training" so the Model can use this # information. postprocessed_batch.is_training = True # Single device case: Use batch as-is (no slicing). if len(self.devices) == 1: batches = [self._lazy_tensor_dict(postprocessed_batch)] # Multi-GPU case: Slice inputs into n (roughly) equal batches. else: len_ = len(postprocessed_batch) batches = [] start = 0 for i, device in enumerate(self.devices): shard_len = len_ // (len(self.devices) - i) batch = self._lazy_tensor_dict(postprocessed_batch.slice( start, start + shard_len), device=device) batches.append(batch) len_ -= shard_len start += shard_len # Copy weights of main model to all towers. state_dict = self.model.state_dict() for tower in self.model_gpu_towers: tower.load_state_dict(state_dict) # Do the (maybe parallelized) gradient calculation step. tower_outputs = self._multi_gpu_parallel_grad_calc(batches) # Multi device (GPU) case. if len(self.devices) > 1: # Mean-reduce over GPU-towers. all_grads = [] for i in range(len(tower_outputs[0][0])): if tower_outputs[0][0][i] is not None: all_grads.append( torch.mean(torch.stack( [t[0][i].to(self.device) for t in tower_outputs]), dim=0)) else: all_grads.append(None) # Set main model's grads to mean-reduced values. for i, p in enumerate(self.model.parameters()): p.grad = all_grads[i] # Reduce stats over towers as well. from ray.rllib.execution.train_ops import all_tower_reduce grad_info = tree.map_structure_with_path( lambda p, *t: all_tower_reduce(p, *t), *[t[1] for t in tower_outputs]) # Single device case. else: all_grads, grad_info = tower_outputs[0] grad_info["allreduce_latency"] /= len(self._optimizers) grad_info.update(self.extra_grad_info(postprocessed_batch)) fetches = self.extra_compute_grad_fetches() return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: # For multi-GPU, split the batch into n slices (n=#GPUs). if len(self.devices) == 1: batches = [postprocessed_batch] else: from ray.rllib.utils.sgd import minibatches batches = list( minibatches(postprocessed_batch, len(postprocessed_batch) // len(self.devices), shuffle=False)) if not isinstance(postprocessed_batch, SampleBatch) or \ not postprocessed_batch.zero_padded: for b in batches: pad_batch_to_sequences_of_same_size( b, max_seq_len=self.max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, view_requirements=self.view_requirements, ) for b, d in zip(batches, self.devices): b.is_training = True self._lazy_tensor_dict(b, device=d) # Multi-GPU case: Slice inputs into n (roughly) equal batches. if len(self.devices) > 1: # Copy weights of main model to all towers. state_dict = self.model.state_dict() for tower in self.model_gpu_towers: tower.load_state_dict(state_dict) # Do the (maybe parallelized) gradient calculation step. tower_outputs = self._multi_gpu_parallel_grad_calc(batches) # Multi device (GPU) case. if len(self.devices) > 1: # Mean-reduce over GPU-towers. all_grads = [] for i in range(len(tower_outputs[0][0])): if tower_outputs[0][0][i] is not None: all_grads.append( torch.mean(torch.stack( [t[0][i].to(self.device) for t in tower_outputs]), dim=0)) else: all_grads.append(None) # Set main model's grads to mean-reduced values. for i, p in enumerate(self.model.parameters()): p.grad = all_grads[i] # Reduce stats over towers as well. from ray.rllib.execution.train_ops import all_tower_reduce grad_info = tree.map_structure_with_path( lambda p, *t: all_tower_reduce(p, *t), *[t[1] for t in tower_outputs]) # Single device case. else: all_grads, grad_info = tower_outputs[0] grad_info["allreduce_latency"] /= len(self._optimizers) grad_info.update(self.extra_grad_info(postprocessed_batch)) fetches = self.extra_compute_grad_fetches() return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})