def float_power( a: Union[Tensor, Number], b: Union[Tensor, Number], out: Optional[Tensor] = None ) -> Tensor: assert isinstance(a, (Tensor, Number)) assert isinstance(b, (Tensor, Number)) assert out is None or isinstance(out, TensorLike) # Special-cases Number x Number case if isinstance(a, Number) and isinstance(b, Number): a, b = utils.wrap_scalars(a, b) # Handles type promotion dtype = utils.get_higher_dtype(a, b) if utils.is_complex_dtype(dtype): dtype = torch.complex128 else: dtype = torch.float64 a, b = _convert_dtype(a, b, dtype=dtype) # Broadcasting a, b = broadcast(a, b) result = prims.pow(a, b) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result
def _ref( a: Union[Tensor, Number], b: Union[Tensor, Number], *, out: Optional[Tensor] = None ) -> Tensor: assert isinstance(a, (TensorLike, Number)) assert isinstance(b, (TensorLike, Number)) assert out is None or isinstance(out, TensorLike) # Special-cases Number x Number case if isinstance(a, Number) and isinstance(b, Number): a, b = utils.wrap_scalars(a, b) # Handles type promotion computation_dtype, result_dtype = _elementwise_dtypes( a, b, type_promotion=type_promotion ) a, b = _convert_dtype(a, b, dtype=computation_dtype) # Broadcasting a, b = broadcast(a, b) result = prim(a, b) if type_promotion is not ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: (result,) = _convert_dtype(result, dtype=result_dtype) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result
def sub( a: Union[Tensor, Number], b: Union[Tensor, Number], *, alpha: Optional[Number] = None, out: Optional[Tensor] = None ): """ Reference implementation of torch.sub """ # Type checks assert isinstance(a, (TensorLike, Number)) assert isinstance(b, (TensorLike, Number)) assert out is None or isinstance(out, TensorLike) assert alpha is None or isinstance(alpha, Number) # Special-cases Number x Number case if isinstance(a, Number) and isinstance(b, Number): a, b = utils.wrap_scalars(a, b) computation_dtype, result_dtype = _elementwise_dtypes( a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH ) a, b = _convert_dtype(a, b, dtype=computation_dtype) a, b = broadcast(a, b) if alpha is not None: alpha_promotion_type = utils.dtype_to_type(computation_dtype) assert utils.is_lesser_type(type(alpha), alpha_promotion_type) or ( computation_dtype is torch.bool and type(alpha) is int ) b = prims.mul(b, alpha_promotion_type(alpha)) result = prims.sub(a, b) (result,) = _convert_dtype(result, dtype=result_dtype) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result