def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): per_device_inv_scale = _MultiDeviceReplicator(inv_scale) per_device_found_inf = _MultiDeviceReplicator(found_inf) for group in optimizer.param_groups: for param in group["params"]: if param.grad is not None: if (not allow_fp16) and param.grad.dtype == torch.float16: raise ValueError( "Attempting to unscale FP16 gradients.") else: with torch.no_grad(): if param.grad.is_sparse: # is_coalesced() == False means the sparse grad has values with duplicate indices. # coalesce() deduplicates indices and adds all values that have the same index. # For scaled fp16 values, there's a good chance coalescing will cause overflow, # so we should check the coalesced _values(). if param.grad.dtype is torch.float16: param.grad = param.grad.coalesce() to_unscale = param.grad._values() else: to_unscale = param.grad torch._amp_non_finite_check_and_unscale_( to_unscale, per_device_found_inf.get(param.grad.device), per_device_inv_scale.get(param.grad.device)) return per_device_found_inf._per_device_tensors
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): per_device_inv_scale = _MultiDeviceReplicator(inv_scale) per_device_found_inf = _MultiDeviceReplicator(found_inf) for group in optimizer.param_groups: for param in group["params"]: if param.grad is not None: if (not allow_fp16) and param.grad.dtype == torch.float16: raise ValueError("Attempting to unscale FP16 gradients.") else: torch._amp_non_finite_check_and_unscale_(param.grad, per_device_found_inf.get(param.grad.device), per_device_inv_scale.get(param.grad.device)) return per_device_found_inf._per_device_tensors