コード例 #1
0
def _make_elementwise_binary_reference(
    prim: Callable,
    *,
    type_promotion_kind,
    aten_op=infer_aten_op,
    has_out=True,
) -> Callable:
    @elementwise_type_promotion_wrapper(
        type_promoting_args=("a", "b"), type_promotion_kind=type_promotion_kind
    )
    def _ref(
        a: Union[Tensor, NumberType],
        b: Union[Tensor, NumberType],
    ) -> Tensor:
        a, b = _maybe_broadcast(a, b)
        return prim(a, b)

    if has_out:
        _ref = out_wrapper(_ref)

    if aten_op is infer_aten_op:
        aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
    if aten_op is not None:
        register_decomposition(aten_op)(_ref)

    return _ref
コード例 #2
0
def _make_elementwise_unary_reference(
    prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op
) -> Callable:
    @out_wrapper
    @elementwise_type_promotion_wrapper(
        type_promoting_args=("a",), type_promotion_kind=type_promotion_kind
    )
    def _ref(a: Tensor) -> Tensor:
        return prim(a)

    if aten_op is infer_aten_op:
        aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
    if aten_op is not None:
        register_decomposition(aten_op)(_ref)

    return _ref