Example #1
0
def meta_cdist_forward(x1, x2, p, compute_mode):
    check(
        x1.dim() >= 2,
        lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
    )
    check(
        x2.dim() >= 2,
        lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
    )
    check(
        x1.size(-1) == x2.size(-1),
        lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
    )
    check(
        utils.is_float_dtype(x1.dtype),
        lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
    )
    check(
        utils.is_float_dtype(x2.dtype),
        lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
    )
    check(p >= 0, lambda: "cdist only supports non-negative p values")
    check(
        compute_mode >= 0 and compute_mode <= 2,
        lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}",
    )
    r1 = x1.size(-2)
    r2 = x2.size(-2)
    batch_tensor1 = x1.shape[:-2]
    batch_tensor2 = x2.shape[:-2]
    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
    output_shape.extend([r1, r2])
    return x1.new_empty(output_shape)
Example #2
0
def isclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> TensorLikeType:
    if a.dtype != b.dtype:
        msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format(
            a.dtype, b.dtype)
        raise ValueError(a, b)
    if rtol < 0:
        msg = "rtol must be greater than or equal to zero, but got {0}!".format(
            rtol)
    if atol < 0:
        msg = "atol must be greater than or equal to zero, but got {0}!".format(
            atol)

    close = eq(a, b)
    if equal_nan and (utils.is_float_dtype(a.dtype)
                      or utils.is_complex_dtype(a.dtype)):
        close = logical_or(close, logical_and(isnan(a), isnan(b)))

    # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
    # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
    if atol == 0 and rtol == 0:
        return close

    # Note [closeness error computation]
    # atol and rtol are provided as doubles, so the computation
    # rtol * other will produce a float or complex tensor.
    # When the difference (self - other) is compared to it then the
    # tensor representing the difference will also be cast to float or complex.
    # However, since (self - other) in uint8 is very likely to produce a
    # negative value, this moves the cast forward so the difference is
    # always computed in a float or complex type.
    # If the values of the integer tensors cannot be exactly represented
    # by the default scalar type then this may cause an incorrect result.
    if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(
            a.dtype):
        a = prims.convert_element_type(a, torch.get_default_dtype())
        b = prims.convert_element_type(b, torch.get_default_dtype())

    allowed_error = add(atol, abs(mul(b, rtol)))
    actual_error = abs(sub(a, b))

    # Computes finite closeness
    result = logical_or(
        close,
        logical_and(isfinite(actual_error), le(actual_error, allowed_error)))

    return result
Example #3
0
def norm(self: Tensor,
         p: float = 2,
         dim: List[int] = None,
         keepdim: bool = False):
    if dim is None:
        dim = []

    if p == 0:
        return (self != 0).sum(dim, keepdim=keepdim)
    elif p == float('inf'):
        return self.abs().amax(dim, keepdim=keepdim)
    elif p == -float('inf'):
        return self.abs().amin(dim, keepdim=keepdim)

    def fast_pow(x, ord):
        if ord == 1.0:
            return x
        elif ord == 2.0:
            return x.square()
        elif ord == 0.5:
            return x.sqrt()
        else:
            return x.pow(ord)

    if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)):
        self = self.abs()

    return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p)
Example #4
0
def meta_linalg_qr_helper(input, mode):
    if mode == "reduced":
        compute_q = True
        reduced_mode = True
    elif mode == "complete":
        compute_q = True
        reduced_mode = False
    elif mode == "r":
        compute_q = False
        reduced_mode = True
    else:
        raise RuntimeError(f"qr received unrecognized mode {mode}")
    check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor")
    check(
        utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
        lambda: f"expected float or complex tensor, but got {input.dtype}"
    )
    m = input.size(-2)
    n = input.size(-1)
    mn = min(m, n)
    if compute_q:
        Qt_shape = list(input.size())
        Qt_shape[-2] = mn if reduced_mode else m
        Qt_shape[-1] = m
        Q = input.new_empty(Qt_shape)
        Q.transpose_(-2, -1)
    else:
        Q = input.new_empty(0)
    Rt_shape = list(input.size())
    Rt_shape[-2] = n
    Rt_shape[-1] = mn if reduced_mode or not compute_q else m
    R = input.new_empty(Rt_shape)
    R.transpose_(-2, -1)
    return (Q, R)
Example #5
0
    def _find_highest_dtype_filtered(args,
                                     filter,
                                     *,
                                     float_as_complex=False,
                                     all_tensors_equal=False
                                     ) -> Optional[torch.dtype]:
        zero_dim_tensor_dtype = None
        one_plus_dim_tensor_dtype = None
        for x in args:
            if isinstance(x, TensorLike) and filter(x.dtype):
                _dtype = x.dtype
                if float_as_complex and utils.is_float_dtype(_dtype):
                    _dtype = utils.corresponding_complex_dtype(_dtype)
                if x.ndim == 0 and not all_tensors_equal:
                    zero_dim_tensor_dtype = utils.get_higher_dtype(
                        zero_dim_tensor_dtype, _dtype)
                else:
                    # x.ndim > 0 or all_tensors_equal
                    one_plus_dim_tensor_dtype = utils.get_higher_dtype(
                        one_plus_dim_tensor_dtype, _dtype)

        # Prefers dtype of tensors with one or more dimensions
        if one_plus_dim_tensor_dtype is not None:
            return one_plus_dim_tensor_dtype

        return zero_dim_tensor_dtype
Example #6
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = reduction_dtypes(
        x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
    )

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        # Avoid computing a sqrt in abs and then squaring (more stable)
        # This could potentially be done for complex dtypes as
        # x = torch.real(torch.conj(x) * x))
        # and it should be more stable, but it's not clear whether it'll be faster on, say
        # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now
        if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
Example #7
0
def meta_linalg_cholesky_ex(input, upper=False, check_errors=False):
    check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor")
    check(
        utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
        lambda: f"expected float or complex tensor, but got {input.dtype}"
    )
    check(input.size(-1) == input.size(-2), lambda: f"expected square matrix but got {input.shape}")
    L = input.new_empty(input.size())
    L.transpose_(-2, -1)
    info_sizes = input.size()[:-2]
    info = input.new_empty(info_sizes, dtype=torch.int)
    return L, info
Example #8
0
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
    """
    Checks related to the dtype kwarg in `linalg.*norm` functions
    """
    if dtype is not None:
        check(
            is_float_dtype(dtype) or is_complex_dtype(dtype),
            lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
        )
        check(
            is_complex_dtype(dtype) == is_complex_dtype(x_dtype),
            lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
                fn_name=fn_name,
                d="complex" if is_complex_dtype(x_dtype) else "real",
                dtype=dtype,
            ),
        )
        check(
            get_higher_dtype(dtype, x_dtype) == dtype,
            lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
            "without narrowing to the specified dtype ({dtype})",
        )
Example #9
0
def meta_embedding_bag(
    weight,
    indices,
    offsets,
    scale_grad_by_freq=False,
    mode=0,
    sparse=False,
    per_sample_weights=None,
    include_last_offset=False,
    padding_idx=-1,
):
    check(
        indices.dtype in (torch.long, torch.int),
        lambda: f"expected indices to be long or int, got {indices.dtype}",
    )
    check(
        offsets.dtype in (torch.long, torch.int),
        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
    )
    check(
        utils.is_float_dtype(weight.dtype),
        lambda: f"expected weight to be floating point type, got {weight.dtype}",
    )

    num_bags = offsets.size(0)
    if include_last_offset:
        check(
            num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
        )
        num_bags -= 1

    output = weight.new_empty(num_bags, weight.size(1))
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)

    if per_sample_weights is not None:
        check(
            mode == MODE_SUM,
            lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
        )
        check(
            per_sample_weights.dtype == weight.dtype,
            lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
        )
        check(
            per_sample_weights.ndim == 1,
            lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
        )
        check(
            per_sample_weights.numel() == indices.numel(),
            lambda: (
                f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
                f"to be the same as indices.numel() ({indices.numel()})"
            ),
        )

    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
        return (
            is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
        )

    def is_fast_path_index_select(src, output, padding_idx):
        return (
            (src.dtype == torch.float or src.dtype == torch.half)
            and src.stride(1) == 1
            and output.stride(1) == 1
            and padding_idx < 0
        )

    def is_fast_path(src, scale, output, padding_idx):
        if scale is not None:
            return is_fast_path_index_select_scale(src, scale, output, padding_idx)
        else:
            return is_fast_path_index_select(src, output, padding_idx)

    if offsets.device.type != "cpu":
        offset2bag = indices.new_empty(indices.size(0))
        bag_size = indices.new_empty(offsets.size())
        if mode == MODE_MAX:
            max_indices = indices.new_empty(num_bags, weight.size(1))
        else:
            max_indices = indices.new_empty(0)
    else:
        fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
        if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
            offset2bag = offsets.new_empty(indices.size(0))
        else:
            offset2bag = offsets.new_empty(0)
        bag_size = offsets.new_empty(num_bags)
        max_indices = offsets.new_empty(bag_size.size())
    return output, offset2bag, bag_size, max_indices
Example #10
0
def _elementwise_dtypes(
    *_args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND
) -> Tuple[torch.dtype, torch.dtype]:
    """
    Computes the computation and result dtypes for elementwise type promotion
    on the given arguments and with the given elementwise type promotion kind.

    Note that not all inputs to an elementwise operation necessarily participate in type promotion.
    For example, the "alpha" parameter of torch.add does not participate in type promotion,
    although it is cast to the Python type corresponding to the computation dtype that
    the type promotion algorithm determines.

    Default elementwise type promotion, which all other type promotion kinds tweak (see below),
    first decides which of four ordered types to use:

    bool -> integer -> floating point -> complex

    The selected type is the "lowest" type in the above list such that all number arguments
    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
    type for their dtype.

    Once the type is determined, the particular result dtype is found. The dtypes are
    partially ordered as follows:

    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128

    The result dtype is selected by:
      - if no tensor's dtype has the same corresponding type as the one selected,
          then the result dtype is the (default) dtype corresponding to the selected type
          (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
      - if the result type is complex then the dtype is:
        -  the default complex dtype if there are no floating point or complex tensors
        -  if there are floating point or complex tensors with one or more dimensions, then
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
            (for example, double + cfloat -> cdouble)
        -  if there are only floating point or complex tensors with zero dimensions, then
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
      - if the first two cases do not apply, the result dtype is the highest dtype among
          all tensors with one or more dimensions of the output type, and if there are no such
          tensors then it's the highest dtype among all tensors with zero dimensions of the output type
          (for example, long + half -> half, even if the half tensor has zero dimensions)

    The "corresponding complex dtypes" are:
      float16    -> complex32
      bfloat16   -> complex64
      float32    -> complex64
      float64    -> complex128
      complex32  -> complex32
      complex64  -> complex64
      complex128 -> complex128

    The DEFAULT type promotion option computes per above, and uses the result dtype as the computation dtype.

    The OP_MATH, INT_TO_FLOAT, COMPLEX_TO_FLOAT and BOOL_TO_LONG type promotion options tweak the above slightly.
    OP_MATH determines a "computation dtype" from the result dtype, and the mapping is simple:

      float16   -> float32
      bfloat16  -> float32
      complex32 -> complex64

    INT_TO_FLOAT, COMPLEX_TO_FLOAT, and BOOL_TO_LONG compute the computation type in the same way, but INT_TO_FLOAT
    and BOOL_TO_LONG map the result dtype to another dtype first, and COMPLEX_TO_FLOAT maps its result dtype
    after the compuation dtype is determined, as follows:

      INT_TO_FLOAT  maps all boolean and integer result dtypes to the default floating point dtype
      COMPLEX_TO_FLOAT  maps complex result dtypes to their corresponding floating point dtype
      BOOL_TO_LONG maps the boolean result dtype to long

    The "corresponding floating point dtypes" are:
      complex32  -> float16
      complex64  -> float32
      complex128 -> float64

    The ALWAYS_BOOL type promotion option always maps the result dtype to bool.

    Example operators for each type promotion option:
      DEFAULT          : nextafter
      OP_MATH          : add
      INT_TO_FLOAT     : sin
      COMPLEX_TO_FLOAT : abs
      BOOL_TO_LONG     : pow
      ALWAYS_BOOL      : eq

    """

    args = tuple(x for x in _args if x is not None)

    highest_type: type = bool
    for x in args:
        if not isinstance(x, (Number, TensorLike)):
            msg = (
                "Unexpected type {0} when computing elementwise type promotion!"
                .format(str(type(x))))
            raise ValueError(msg)

        if isinstance(x, Number):
            highest_type = utils.get_higher_type(highest_type, type(x))
        else:
            # x is a TensorLike
            highest_type = utils.get_higher_type(highest_type,
                                                 utils.dtype_to_type(x.dtype))

    result_dtype = None

    def _find_highest_dtype_filtered(args,
                                     filter,
                                     *,
                                     float_as_complex=False,
                                     all_tensors_equal=False
                                     ) -> Optional[torch.dtype]:
        zero_dim_tensor_dtype = None
        one_plus_dim_tensor_dtype = None
        for x in args:
            if isinstance(x, TensorLike) and filter(x.dtype):
                _dtype = x.dtype
                if float_as_complex and utils.is_float_dtype(_dtype):
                    _dtype = utils.corresponding_complex_dtype(_dtype)
                if x.ndim == 0 and not all_tensors_equal:
                    zero_dim_tensor_dtype = utils.get_higher_dtype(
                        zero_dim_tensor_dtype, _dtype)
                else:
                    # x.ndim > 0 or all_tensors_equal
                    one_plus_dim_tensor_dtype = utils.get_higher_dtype(
                        one_plus_dim_tensor_dtype, _dtype)

        # Prefers dtype of tensors with one or more dimensions
        if one_plus_dim_tensor_dtype is not None:
            return one_plus_dim_tensor_dtype

        return zero_dim_tensor_dtype

    if highest_type is float:
        result_dtype = _find_highest_dtype_filtered(args, utils.is_float_dtype)
        result_dtype = (torch.get_default_dtype()
                        if result_dtype is None else result_dtype)
    elif highest_type is complex:
        # NOTE: complex x float type promotion is incorrectly implemented in PyTorch today
        # it will treat zero dim and non-zero-dim float and complex tensors equally
        # unless there's a non-zero-dim complex tensor
        # the following captures this oddity
        has_one_plus_dim_complex_tensor = False
        for x in args:
            if (isinstance(x, TensorLike) and x.ndim > 0
                    and utils.is_complex_dtype(x.dtype)):
                has_one_plus_dim_complex_tensor = True
                break

        if has_one_plus_dim_complex_tensor:
            result_dtype = _find_highest_dtype_filtered(
                args,
                lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x),
                float_as_complex=True,
            )
        else:
            # no complex tensors of rank 1+
            # NOTE: bugged case where all tensors are equal
            result_dtype = _find_highest_dtype_filtered(
                args,
                lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x),
                float_as_complex=True,
                all_tensors_equal=True,
            )

        if result_dtype is None:
            result_dtype = utils.corresponding_complex_dtype(
                torch.get_default_dtype())
    elif highest_type is int:
        result_dtype = _find_highest_dtype_filtered(args,
                                                    utils.is_integer_dtype)
        result_dtype = torch.long if result_dtype is None else result_dtype
    else:
        # highest_type is bool
        result_dtype = torch.bool

    if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
        return result_dtype, result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH:
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
        if utils.is_integer_dtype(result_dtype) or utils.is_boolean_dtype(
                result_dtype):
            result_dtype = torch.get_default_dtype()
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
        if utils.is_complex_dtype(result_dtype):
            # Note: computation still occurs in complex
            return _get_computation_dtype(
                result_dtype), utils.corresponding_real_dtype(result_dtype)
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
        if utils.is_boolean_dtype(result_dtype):
            return torch.long, torch.long
        return result_dtype, result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
        return result_dtype, torch.bool
    else:
        raise ValueError("Unknown type promotion kind {0}".format(
            str(type_promotion)))