Exemplo n.º 1
0
def _reduction(
    a: Tensor,
    prim: Callable,
    *,
    has_identity: bool = True,
    accepts_dim_tuple: bool = True,  # to handle min/argmin that accept single dim only
    dims: Optional[DimsType] = None,
    keepdims: bool = False,
    dtype: Optional[torch.dtype] = None,  # should be specified for ops that support it
    out: Optional[Tensor] = None,
    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
):  # it is usually SAME, but I want
    # ref writers to actually think about what to put here
    assert isinstance(a, TensorLike)
    if out is not None:
        assert isinstance(out, TensorLike)
        if dtype is not None:
            # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
            if dtype != out.dtype:
                raise RuntimeError(
                    "dtype argument and out dtype must match in reduction"
                )
    if not accepts_dim_tuple:
        assert dims is None or isinstance(dims, int)
    if isinstance(dims, int):
        dims = (dims,)  # type: ignore[assignment]
    dims = utils.reduction_dims(a.shape, dims)
    if not has_identity:
        valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims)
        if not valid_shape:
            raise RuntimeError(
                "reducing over zero-size dimension for reduction operation without identity"
            )
    # even though some reductions, like amin or amax, don't strictly require type promotion,
    # all the math ops (including comparisons) are still defined only for a computation type,
    # so promotion will still happen. We are doing it explicitly here
    inp_dtype = dtype if dtype is not None else a.dtype
    computation_dtype = utils._get_computation_dtype(inp_dtype)
    a_converted = prims.convert_element_type(a, computation_dtype)
    result = prim(a_converted, dims)

    if keepdims:
        output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
        broadcast_dims = [i for i in range(a.ndim) if i not in dims]
        result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
    if out is not None:
        if dtype is None:
            if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
                if out.dtype != a.dtype:
                    raise RuntimeError("Expected the dtype for input and out to match")
            elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL:
                if out.dtype != torch.bool:
                    raise RuntimeError("Expected the dtype for input and out to match")
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
        result_dtype = dtype if dtype else a.dtype
        result = prims.convert_element_type(result, result_dtype)
    return result
Exemplo n.º 2
0
def isclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> TensorLikeType:
    if a.dtype != b.dtype:
        msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format(
            a.dtype, b.dtype)
        raise ValueError(a, b)
    if rtol < 0:
        msg = "rtol must be greater than or equal to zero, but got {0}!".format(
            rtol)
    if atol < 0:
        msg = "atol must be greater than or equal to zero, but got {0}!".format(
            atol)

    close = eq(a, b)
    if equal_nan and (utils.is_float_dtype(a.dtype)
                      or utils.is_complex_dtype(a.dtype)):
        close = logical_or(close, logical_and(isnan(a), isnan(b)))

    # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
    # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
    if atol == 0 and rtol == 0:
        return close

    # Note [closeness error computation]
    # atol and rtol are provided as doubles, so the computation
    # rtol * other will produce a float or complex tensor.
    # When the difference (self - other) is compared to it then the
    # tensor representing the difference will also be cast to float or complex.
    # However, since (self - other) in uint8 is very likely to produce a
    # negative value, this moves the cast forward so the difference is
    # always computed in a float or complex type.
    # If the values of the integer tensors cannot be exactly represented
    # by the default scalar type then this may cause an incorrect result.
    if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(
            a.dtype):
        a = prims.convert_element_type(a, torch.get_default_dtype())
        b = prims.convert_element_type(b, torch.get_default_dtype())

    allowed_error = add(atol, abs(mul(b, rtol)))
    actual_error = abs(sub(a, b))

    # Computes finite closeness
    result = logical_or(
        close,
        logical_and(isfinite(actual_error), le(actual_error, allowed_error)))

    return result
Exemplo n.º 3
0
 def _convert(x):
     if isinstance(x, TensorLike):
         return prims.convert_element_type(x, dtype)
     elif isinstance(x, Number):
         typ = utils.dtype_to_type(dtype)
         return typ(x)
     return x
Exemplo n.º 4
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = reduction_dtypes(
        x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
    )

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        # Avoid computing a sqrt in abs and then squaring (more stable)
        # This could potentially be done for complex dtypes as
        # x = torch.real(torch.conj(x) * x))
        # and it should be more stable, but it's not clear whether it'll be faster on, say
        # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now
        if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
Exemplo n.º 5
0
def _maybe_promote_tensor_fft(t: TensorLikeType,
                              require_complex: bool = False) -> TensorLikeType:
    """Helper to promote a tensor to a dtype supported by the FFT primitives"""
    cur_type = t.dtype
    new_type = _promote_type_fft(cur_type, require_complex)
    if cur_type == new_type:
        return t
    return prims.convert_element_type(t, new_type)
Exemplo n.º 6
0
def _maybe_convert_to_dtype(
        a: Union[TensorLikeType, NumberType, Sequence],
        dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]:
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(
            type(a)))
Exemplo n.º 7
0
def _maybe_convert_to_dtype(
        a: Union[TensorLikeType, NumberType, Sequence],
        dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]:
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            # NOTE: this is incorrect on the CPU
            # See https://github.com/pytorch/pytorch/issues/77553
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(
            type(a)))
Exemplo n.º 8
0
def _maybe_convert_to_dtype(
    a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype
) -> Union[TensorLikeType, NumberType, Sequence, None]:
    import torch._prims as prims
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            # NOTE: this is incorrect on the CPU
            # See https://github.com/pytorch/pytorch/issues/77553
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
    # Passthrough None because some functions wrapped with type promotion
    # wrapper might have optional args
    if a is None:
        return None

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(type(a))
    )
Exemplo n.º 9
0
def matrix_norm(
    A: TensorLikeType,
    ord: Union[float, str] = "fro",
    dim: DimsType = (-2, -1),
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    # shape
    check_is_matrix(A, "linalg.matrix_norm")
    # dim
    dim = utils.canonicalize_dims(A.ndim, dim)
    if isinstance(dim, int):
        dim = (dim, )  # type: ignore[assignment]
    check(
        len(dim) == 2,
        lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
    check(
        dim[0] != dim[1],
        lambda:
        "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
    )
    # dtype arg
    check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")

    if isinstance(ord, str):
        # ord
        check(
            ord in ("fro", "nuc"),
            lambda: "linalg.matrix_norm: Order {ord} not supported.",
        )
        # dtype
        check_fp_or_complex(A.dtype,
                            "linalg.matrix_norm",
                            allow_low_precision_dtypes=ord != "nuc")

        if ord == "fro":
            return vector_norm(A, 2, dim, keepdim, dtype=dtype)
        else:  # ord == "nuc"
            if dtype is not None:
                A = prims.convert_element_type(A, dtype)
            perm = backshift_permutation(dim[0], dim[1], A.ndim)
            result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
            if keepdim:
                inv_perm = inverse_permutation(perm)
                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
            return result
    else:
        # ord
        abs_ord = abs(ord)
        check(
            abs_ord in (2, 1, float("inf")),
            lambda: "linalg.matrix_norm: Order {ord} not supported.",
        )
        # dtype
        check_fp_or_complex(A.dtype,
                            "linalg.matrix_norm",
                            allow_low_precision_dtypes=ord != 2)

        max_min = partial(torch.amax if ord > 0.0 else torch.amin,
                          keepdim=keepdim)

        if abs_ord == 2.0:
            if dtype is not None:
                A = prims.convert_element_type(A, dtype)
            perm = backshift_permutation(dim[0], dim[1], A.ndim)
            result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
            if keepdim:
                inv_perm = inverse_permutation(perm)
                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
            return result
        else:  # 1, -1, inf, -inf
            dim0, dim1 = dim
            if abs_ord == float("inf"):
                dim0, dim1 = dim1, dim0
            if not keepdim and (dim0 < dim1):
                dim1 -= 1
            return max_min(
                vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype),
                dim1)