Esempio n. 1
0
 def _convert(x):
     if isinstance(x, TensorLike):
         return prims.convert_element_type(x, dtype)
     elif isinstance(x, Number):
         typ = utils.dtype_to_type(dtype)
         return typ(x)
     return x
Esempio n. 2
0
def celu(
    a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False
) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.celu
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if alpha is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!".format(
                    type(alpha), python_type
                )
            )
            raise ValueError(msg)
        rhs = alpha * torch.expm1(torch.true_divide(a, alpha))  # type: ignore[arg-type]
    else:
        rhs = torch.expm1(a)

    return torch.where(a > 0, a, rhs)
Esempio n. 3
0
def softplus(
    a: TensorLikeType,
    beta: Optional[NumberType] = None,
    threshold: NumberType = 20,
    inplace: bool = False,
) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.softplus
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if beta is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(beta), python_type):
            msg = "beta argument of type {0} cannot be safely cast to type {1}!".format(
                type(beta), python_type)
            raise ValueError(msg)
        scaled_input = a * beta
        rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)),
                                beta)  # type: ignore[arg-type]

    else:
        scaled_input = a
        rhs = torch.log1p(torch.exp(scaled_input))

    return torch.where(scaled_input > threshold, a, rhs)
Esempio n. 4
0
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
    assert isinstance(a, TensorLike)
    assert isinstance(b, TensorLike)

    # Validates the cast is safe
    a_typ = utils.dtype_to_type(a.dtype)
    b_typ = utils.dtype_to_type(b.dtype)
    if a_typ is not utils.get_higher_type(a_typ, b_typ):
        raise RuntimeError(str(b.dtype), " can't be cast safely to ",
                           str(a.dtype), "!")

    # Validates the tensors have the same number of elements
    if a.numel() != b.numel():
        msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
            b.numel(), a.numel())
        raise RuntimeError(msg)

    return a
Esempio n. 5
0
def leaky_relu(a: TensorLikeType,
               negative_slope: float = 0.01,
               inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.leaky_relu
    """

    if inplace:
        raise NotImplementedError

    python_type = utils.dtype_to_type(a.dtype)
    if not utils.is_weakly_lesser_type(type(negative_slope), python_type):
        msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!"
        raise ValueError(msg)
    return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
Esempio n. 6
0
def _maybe_convert_to_dtype(
        a: Union[TensorLikeType, NumberType, Sequence],
        dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]:
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(
            type(a)))
Esempio n. 7
0
def _maybe_convert_to_dtype(
        a: Union[TensorLikeType, NumberType, Sequence],
        dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]:
    if isinstance(a, TensorLike):
        if a.dtype != dtype:
            # NOTE: this is incorrect on the CPU
            # See https://github.com/pytorch/pytorch/issues/77553
            return prims.convert_element_type(a, dtype)
        return a
    if isinstance(a, Number):
        return utils.dtype_to_type(dtype)(a)
    if isinstance(a, Sequence):
        return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)

    raise ValueError(
        "Received type {0} that is neither a tensor or a number!".format(
            type(a)))
Esempio n. 8
0
def sub(
    a: Union[Tensor, Number],
    b: Union[Tensor, Number],
    *,
    alpha: Optional[Number] = None,
    out: Optional[Tensor] = None
):
    """
    Reference implementation of torch.sub
    """

    # Type checks
    assert isinstance(a, (TensorLike, Number))
    assert isinstance(b, (TensorLike, Number))
    assert out is None or isinstance(out, TensorLike)
    assert alpha is None or isinstance(alpha, Number)

    # Special-cases Number x Number case
    if isinstance(a, Number) and isinstance(b, Number):
        a, b = utils.wrap_scalars(a, b)

    computation_dtype, result_dtype = _elementwise_dtypes(
        a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
    )
    a, b = _convert_dtype(a, b, dtype=computation_dtype)

    a, b = broadcast(a, b)

    if alpha is not None:
        alpha_promotion_type = utils.dtype_to_type(computation_dtype)
        assert utils.is_lesser_type(type(alpha), alpha_promotion_type) or (
            computation_dtype is torch.bool and type(alpha) is int
        )
        b = prims.mul(b, alpha_promotion_type(alpha))

    result = prims.sub(a, b)

    (result,) = _convert_dtype(result, dtype=result_dtype)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    return result
Esempio n. 9
0
def sub(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    *,
    alpha: Optional[NumberType] = None,
):
    """
    Reference implementation of torch.add
    """
    a, b = _maybe_broadcast(a, b)

    if alpha is not None:
        dtype = a.dtype if isinstance(
            a, TensorLike) else b.dtype  # type: ignore[union-attr]
        python_type = utils.dtype_to_type(dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        b = prims.mul(b, alpha)

    return prims.sub(a, b)
Esempio n. 10
0
def elu(a: TensorLikeType,
        alpha: Optional[NumberType] = None,
        inplace: bool = False) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.elu
    """

    if inplace:
        raise NotImplementedError

    rhs: TensorLikeType
    if alpha is not None:
        python_type = utils.dtype_to_type(a.dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        rhs = refs.mul(alpha, refs.expm1(a))
    else:
        rhs = refs.expm1(a)

    return refs.where(refs.gt(a, 0), a, rhs)
Esempio n. 11
0
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
    return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)
Esempio n. 12
0
def _elementwise_dtypes(*_args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND):
    """
    Computes the computation and result dtypes for elementwise type promotion
    on the given arguments and with the given elementwise type promotion kind.

    Elementwise type promotion first decides which of four ordered types to use:

    bool -> integer -> floating point -> complex

    The selected type is the "lowest" type in the above list such that all number arguments
    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
    type for their dtype.

    Once the type is determined, the particular result dtype is found. The dtypes are
    partially ordered as follows:

    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128

    The result dtype is selected by:
      - if no tensor's dtype has the same corresponding type as the one selected,
          then the result dtype is the dtype corresponding to the selected type
      - if no tensor with one or dimensions' dtype has the same corresponding type as the one
          selected, then the result dtype is the highest dtype among all tensors
      - if the first two cases do not apply, the result dtype is the highest dtype among
          all tensors with one or more dimensions

    The computation dtype is usually the result dtype, except for float16 and bfloat16, where
    the computation dtype is float32, and complex32, where the computation dtype is complex64.
    """

    args = tuple(filter(lambda x: x is not None, _args))

    # Type checking
    for arg in args:
        assert isinstance(arg, (Number, TensorLike))

    # Determines datatypes for each category
    scalar_args = filter(lambda x: isinstance(x, Number), args)
    scalar_type = reduce(
        lambda acc, x: utils.get_higher_type(acc, type(x)), scalar_args, bool  # type: ignore[arg-type, return-value]
    )

    scalar_tensors = filter(lambda t: isinstance(t, TensorLike) and t.ndim == 0, args)
    scalar_tensor_dtype = reduce(
        utils.get_higher_dtype, (t.dtype for t in scalar_tensors), torch.bool
    )
    scalar_tensor_type = utils.dtype_to_type(scalar_tensor_dtype)

    nonscalar_tensors = filter(
        lambda t: isinstance(t, TensorLike) and t.ndim != 0, args
    )
    nonscalar_tensor_dtype = reduce(
        utils.get_higher_dtype, (t.dtype for t in nonscalar_tensors), torch.bool
    )
    nonscalar_tensor_type = utils.dtype_to_type(nonscalar_tensor_dtype)

    typ = reduce(
        utils.get_higher_type, (scalar_type, scalar_tensor_type, nonscalar_tensor_type)
    )

    if nonscalar_tensor_type is typ:
        dtype = nonscalar_tensor_dtype
    elif scalar_tensor_type is typ:
        dtype = scalar_tensor_dtype
    else:
        # scalar type kind -> default torch dtype mapping
        if typ is bool:
            dtype = torch.bool
        elif typ is int:
            dtype = torch.int64
        elif typ is float:
            dtype = torch.get_default_dtype()
        else:
            # typ is complex
            dtype = (
                torch.complex128
                if torch.get_default_dtype() is torch.float64
                else torch.complex64
            )

    if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT and (
        utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)
    ):
        return torch.get_default_dtype(), torch.get_default_dtype()

    if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
        return dtype, torch.bool

    if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH:
        return _get_computation_dtype(dtype), dtype

    # DEFAULT type promotion
    return dtype, dtype
Esempio n. 13
0
def _elementwise_dtypes(
    *_args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND
) -> Tuple[torch.dtype, torch.dtype]:
    """
    Computes the computation and result dtypes for elementwise type promotion
    on the given arguments and with the given elementwise type promotion kind.

    Note that not all inputs to an elementwise operation necessarily participate in type promotion.
    For example, the "alpha" parameter of torch.add does not participate in type promotion,
    although it is cast to the Python type corresponding to the computation dtype that
    the type promotion algorithm determines.

    Default elementwise type promotion, which all other type promotion kinds tweak (see below),
    first decides which of four ordered types to use:

    bool -> integer -> floating point -> complex

    The selected type is the "lowest" type in the above list such that all number arguments
    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
    type for their dtype.

    Once the type is determined, the particular result dtype is found. The dtypes are
    partially ordered as follows:

    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128

    The result dtype is selected by:
      - if no tensor's dtype has the same corresponding type as the one selected,
          then the result dtype is the (default) dtype corresponding to the selected type
          (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
      - if the result type is complex then the dtype is:
        -  the default complex dtype if there are no floating point or complex tensors
        -  if there are floating point or complex tensors with one or more dimensions, then
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
            (for example, double + cfloat -> cdouble)
        -  if there are only floating point or complex tensors with zero dimensions, then
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
      - if the first two cases do not apply, the result dtype is the highest dtype among
          all tensors with one or more dimensions of the output type, and if there are no such
          tensors then it's the highest dtype among all tensors with zero dimensions of the output type
          (for example, long + half -> half, even if the half tensor has zero dimensions)

    The "corresponding complex dtypes" are:
      float16    -> complex32
      bfloat16   -> complex64
      float32    -> complex64
      float64    -> complex128
      complex32  -> complex32
      complex64  -> complex64
      complex128 -> complex128

    The DEFAULT type promotion option computes per above, and uses the result dtype as the computation dtype.

    The OP_MATH, INT_TO_FLOAT, COMPLEX_TO_FLOAT and BOOL_TO_LONG type promotion options tweak the above slightly.
    OP_MATH determines a "computation dtype" from the result dtype, and the mapping is simple:

      float16   -> float32
      bfloat16  -> float32
      complex32 -> complex64

    INT_TO_FLOAT, COMPLEX_TO_FLOAT, and BOOL_TO_LONG compute the computation type in the same way, but INT_TO_FLOAT
    and BOOL_TO_LONG map the result dtype to another dtype first, and COMPLEX_TO_FLOAT maps its result dtype
    after the compuation dtype is determined, as follows:

      INT_TO_FLOAT  maps all boolean and integer result dtypes to the default floating point dtype
      COMPLEX_TO_FLOAT  maps complex result dtypes to their corresponding floating point dtype
      BOOL_TO_LONG maps the boolean result dtype to long

    The "corresponding floating point dtypes" are:
      complex32  -> float16
      complex64  -> float32
      complex128 -> float64

    The ALWAYS_BOOL type promotion option always maps the result dtype to bool.

    Example operators for each type promotion option:
      DEFAULT          : nextafter
      OP_MATH          : add
      INT_TO_FLOAT     : sin
      COMPLEX_TO_FLOAT : abs
      BOOL_TO_LONG     : pow
      ALWAYS_BOOL      : eq

    """

    args = tuple(x for x in _args if x is not None)

    highest_type: type = bool
    for x in args:
        if not isinstance(x, (Number, TensorLike)):
            msg = (
                "Unexpected type {0} when computing elementwise type promotion!"
                .format(str(type(x))))
            raise ValueError(msg)

        if isinstance(x, Number):
            highest_type = utils.get_higher_type(highest_type, type(x))
        else:
            # x is a TensorLike
            highest_type = utils.get_higher_type(highest_type,
                                                 utils.dtype_to_type(x.dtype))

    result_dtype = None

    def _find_highest_dtype_filtered(args,
                                     filter,
                                     *,
                                     float_as_complex=False,
                                     all_tensors_equal=False
                                     ) -> Optional[torch.dtype]:
        zero_dim_tensor_dtype = None
        one_plus_dim_tensor_dtype = None
        for x in args:
            if isinstance(x, TensorLike) and filter(x.dtype):
                _dtype = x.dtype
                if float_as_complex and utils.is_float_dtype(_dtype):
                    _dtype = utils.corresponding_complex_dtype(_dtype)
                if x.ndim == 0 and not all_tensors_equal:
                    zero_dim_tensor_dtype = utils.get_higher_dtype(
                        zero_dim_tensor_dtype, _dtype)
                else:
                    # x.ndim > 0 or all_tensors_equal
                    one_plus_dim_tensor_dtype = utils.get_higher_dtype(
                        one_plus_dim_tensor_dtype, _dtype)

        # Prefers dtype of tensors with one or more dimensions
        if one_plus_dim_tensor_dtype is not None:
            return one_plus_dim_tensor_dtype

        return zero_dim_tensor_dtype

    if highest_type is float:
        result_dtype = _find_highest_dtype_filtered(args, utils.is_float_dtype)
        result_dtype = (torch.get_default_dtype()
                        if result_dtype is None else result_dtype)
    elif highest_type is complex:
        # NOTE: complex x float type promotion is incorrectly implemented in PyTorch today
        # it will treat zero dim and non-zero-dim float and complex tensors equally
        # unless there's a non-zero-dim complex tensor
        # the following captures this oddity
        has_one_plus_dim_complex_tensor = False
        for x in args:
            if (isinstance(x, TensorLike) and x.ndim > 0
                    and utils.is_complex_dtype(x.dtype)):
                has_one_plus_dim_complex_tensor = True
                break

        if has_one_plus_dim_complex_tensor:
            result_dtype = _find_highest_dtype_filtered(
                args,
                lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x),
                float_as_complex=True,
            )
        else:
            # no complex tensors of rank 1+
            # NOTE: bugged case where all tensors are equal
            result_dtype = _find_highest_dtype_filtered(
                args,
                lambda x: utils.is_float_dtype(x) or utils.is_complex_dtype(x),
                float_as_complex=True,
                all_tensors_equal=True,
            )

        if result_dtype is None:
            result_dtype = utils.corresponding_complex_dtype(
                torch.get_default_dtype())
    elif highest_type is int:
        result_dtype = _find_highest_dtype_filtered(args,
                                                    utils.is_integer_dtype)
        result_dtype = torch.long if result_dtype is None else result_dtype
    else:
        # highest_type is bool
        result_dtype = torch.bool

    if type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
        return result_dtype, result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH:
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
        if utils.is_integer_dtype(result_dtype) or utils.is_boolean_dtype(
                result_dtype):
            result_dtype = torch.get_default_dtype()
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
        if utils.is_complex_dtype(result_dtype):
            # Note: computation still occurs in complex
            return _get_computation_dtype(
                result_dtype), utils.corresponding_real_dtype(result_dtype)
        return _get_computation_dtype(result_dtype), result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
        if utils.is_boolean_dtype(result_dtype):
            return torch.long, torch.long
        return result_dtype, result_dtype
    elif type_promotion is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
        return result_dtype, torch.bool
    else:
        raise ValueError("Unknown type promotion kind {0}".format(
            str(type_promotion)))