Example #1
0
    def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
        per_device_found_inf = _MultiDeviceReplicator(found_inf)

        for group in optimizer.param_groups:
            for param in group["params"]:
                if param.grad is not None:
                    if (not allow_fp16) and param.grad.dtype == torch.float16:
                        raise ValueError(
                            "Attempting to unscale FP16 gradients.")
                    else:
                        with torch.no_grad():
                            if param.grad.is_sparse:
                                # is_coalesced() == False means the sparse grad has values with duplicate indices.
                                # coalesce() deduplicates indices and adds all values that have the same index.
                                # For scaled fp16 values, there's a good chance coalescing will cause overflow,
                                # so we should check the coalesced _values().
                                if param.grad.dtype is torch.float16:
                                    param.grad = param.grad.coalesce()
                                to_unscale = param.grad._values()
                            else:
                                to_unscale = param.grad

                            torch._amp_non_finite_check_and_unscale_(
                                to_unscale,
                                per_device_found_inf.get(param.grad.device),
                                per_device_inv_scale.get(param.grad.device))

        return per_device_found_inf._per_device_tensors
Example #2
0
    def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
        per_device_found_inf = _MultiDeviceReplicator(found_inf)

        for group in optimizer.param_groups:
            for param in group["params"]:
                if param.grad is not None:
                    if (not allow_fp16) and param.grad.dtype == torch.float16:
                        raise ValueError("Attempting to unscale FP16 gradients.")
                    else:
                        torch._amp_non_finite_check_and_unscale_(param.grad,
                                                                 per_device_found_inf.get(param.grad.device),
                                                                 per_device_inv_scale.get(param.grad.device))

        return per_device_found_inf._per_device_tensors