def _make_elementwise_binary_reference( prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op, has_out=True, ) -> Callable: @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), type_promotion_kind=type_promotion_kind ) def _ref( a: Union[Tensor, NumberType], b: Union[Tensor, NumberType], ) -> Tensor: a, b = _maybe_broadcast(a, b) return prim(a, b) if has_out: _ref = out_wrapper(_ref) if aten_op is infer_aten_op: aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) if aten_op is not None: register_decomposition(aten_op)(_ref) return _ref
def _make_elementwise_unary_reference( prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op ) -> Callable: @out_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=type_promotion_kind ) def _ref(a: Tensor) -> Tensor: return prim(a) if aten_op is infer_aten_op: aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) if aten_op is not None: register_decomposition(aten_op)(_ref) return _ref