示例#1
0
def sub(
    a: Union[Tensor, Number],
    b: Union[Tensor, Number],
    *,
    alpha: Optional[Number] = None,
    out: Optional[Tensor] = None
):
    """
    Reference implementation of torch.sub
    """

    # Type checks
    assert isinstance(a, (TensorLike, Number))
    assert isinstance(b, (TensorLike, Number))
    assert out is None or isinstance(out, TensorLike)
    assert alpha is None or isinstance(alpha, Number)

    # Special-cases Number x Number case
    if isinstance(a, Number) and isinstance(b, Number):
        a, b = utils.wrap_scalars(a, b)

    computation_dtype, result_dtype = _elementwise_dtypes(
        a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
    )
    a, b = _convert_dtype(a, b, dtype=computation_dtype)

    a, b = broadcast(a, b)

    if alpha is not None:
        alpha_promotion_type = utils.dtype_to_type(computation_dtype)
        assert utils.is_lesser_type(type(alpha), alpha_promotion_type) or (
            computation_dtype is torch.bool and type(alpha) is int
        )
        b = prims.mul(b, alpha_promotion_type(alpha))

    result = prims.sub(a, b)

    (result,) = _convert_dtype(result, dtype=result_dtype)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    return result
示例#2
0
def sub(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    *,
    alpha: Optional[NumberType] = None,
):
    """
    Reference implementation of torch.add
    """
    a, b = _maybe_broadcast(a, b)

    if alpha is not None:
        dtype = a.dtype if isinstance(
            a, TensorLike) else b.dtype  # type: ignore[union-attr]
        python_type = utils.dtype_to_type(dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = (
                "alpha argument of type {0} cannot be safely cast to type {1}!"
                .format(type(alpha), python_type))
            raise ValueError(msg)
        b = prims.mul(b, alpha)

    return prims.sub(a, b)
示例#3
0
 def test_mul_complex(self):
     prims.mul(torch.randn(2), 1 + 1j)