def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: self._grad_norm_buf.zero_() self._grad_norm_buf[self.data_parallel_rank] = grad_norm distributed_utils.all_reduce( self._grad_norm_buf, group=self.data_parallel_process_group ) def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( torch.isfinite(tensor).all() or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) if not is_consistent(self._grad_norm_buf): pretty_detail = "\n".join( "rank {:3d} = {:.8f}".format(r, n) for r, n in enumerate(self._grad_norm_buf.tolist()) ) error_detail = "grad_norm across the workers:\n{}\n".format( pretty_detail ) # use FloatingPointError to trigger NanDetector raise FloatingPointError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=legacy_ddp. " "Or are you mixing up different generation of GPUs in training?" + "\n" + "-" * 80 + "\n{}\n".format(error_detail) + "-" * 80 )
def _sync_sample_ratios(self, ratios): # in case the ratios are not precisely the same across processes # also to ensure every procresses update the ratios in the same pace ratios = torch.DoubleTensor(ratios) if torch.distributed.is_initialized(): if torch.cuda.is_available(): distributed_utils.all_reduce( ratios.cuda(), group=distributed_utils.get_data_parallel_group()) else: distributed_utils.all_reduce( ratios, group=distributed_utils.get_data_parallel_group()) ret = ratios.cpu() ret = ret.numpy() return ret
def all_reduce_params(params): buffer = self.buffer nonzero_buffer = False if len(params) > 1: offset = 0 for p in params: sz = p.numel() if p.grad is not None: buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) nonzero_buffer = True else: buffer[offset : offset + sz].zero_() offset += sz else: # we only have a single grad to all-reduce p = params[0] if p.grad is not None: buffer = p.grad.data nonzero_buffer = True elif p.numel() <= self.buffer.numel(): buffer = buffer[: p.numel()] buffer.zero_() else: buffer = torch.zeros_like(p) if nonzero_buffer: buffer.div_(self.world_size) utils.all_reduce(buffer, self.process_group) # copy all-reduced grads back into their original place offset = 0 for p in params: sz = p.numel() if p.grad is not None: p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) else: p.grad = buffer[offset : offset + sz].view_as(p).clone() offset += sz
def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm**2 distributed_utils.all_reduce( total_norm, group=distributed_utils.get_model_parallel_group()) total_norm = total_norm**0.5 return total_norm