Beispiel #1
0
def float_power(
    a: Union[Tensor, Number], b: Union[Tensor, Number], out: Optional[Tensor] = None
) -> Tensor:

    assert isinstance(a, (Tensor, Number))
    assert isinstance(b, (Tensor, Number))
    assert out is None or isinstance(out, TensorLike)

    # Special-cases Number x Number case
    if isinstance(a, Number) and isinstance(b, Number):
        a, b = utils.wrap_scalars(a, b)

    # Handles type promotion
    dtype = utils.get_higher_dtype(a, b)
    if utils.is_complex_dtype(dtype):
        dtype = torch.complex128
    else:
        dtype = torch.float64

    a, b = _convert_dtype(a, b, dtype=dtype)

    # Broadcasting
    a, b = broadcast(a, b)

    result = prims.pow(a, b)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    return result
Beispiel #2
0
    def _ref(
        a: Union[Tensor, Number],
        b: Union[Tensor, Number],
        *,
        out: Optional[Tensor] = None
    ) -> Tensor:
        assert isinstance(a, (TensorLike, Number))
        assert isinstance(b, (TensorLike, Number))
        assert out is None or isinstance(out, TensorLike)

        # Special-cases Number x Number case
        if isinstance(a, Number) and isinstance(b, Number):
            a, b = utils.wrap_scalars(a, b)

        # Handles type promotion
        computation_dtype, result_dtype = _elementwise_dtypes(
            a, b, type_promotion=type_promotion
        )
        a, b = _convert_dtype(a, b, dtype=computation_dtype)

        # Broadcasting
        a, b = broadcast(a, b)

        result = prim(a, b)

        if type_promotion is not ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
            (result,) = _convert_dtype(result, dtype=result_dtype)

        if out is not None:
            out = _maybe_resize_out(out, result.shape)
            return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

        return result
Beispiel #3
0
def sub(
    a: Union[Tensor, Number],
    b: Union[Tensor, Number],
    *,
    alpha: Optional[Number] = None,
    out: Optional[Tensor] = None
):
    """
    Reference implementation of torch.sub
    """

    # Type checks
    assert isinstance(a, (TensorLike, Number))
    assert isinstance(b, (TensorLike, Number))
    assert out is None or isinstance(out, TensorLike)
    assert alpha is None or isinstance(alpha, Number)

    # Special-cases Number x Number case
    if isinstance(a, Number) and isinstance(b, Number):
        a, b = utils.wrap_scalars(a, b)

    computation_dtype, result_dtype = _elementwise_dtypes(
        a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
    )
    a, b = _convert_dtype(a, b, dtype=computation_dtype)

    a, b = broadcast(a, b)

    if alpha is not None:
        alpha_promotion_type = utils.dtype_to_type(computation_dtype)
        assert utils.is_lesser_type(type(alpha), alpha_promotion_type) or (
            computation_dtype is torch.bool and type(alpha) is int
        )
        b = prims.mul(b, alpha_promotion_type(alpha))

    result = prims.sub(a, b)

    (result,) = _convert_dtype(result, dtype=result_dtype)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    return result