def add( a: Union[Tensor, Number], b: Union[Tensor, Number], *, alpha: Optional[Number] = None, out: Optional[Tensor] = None ): """ Reference implementation of torch.add """ # Type checks assert isinstance(a, (TensorLike, Number)) assert isinstance(b, (TensorLike, Number)) assert out is None or isinstance(out, TensorLike) assert alpha is None or isinstance(alpha, Number) # Special-cases Number x Number case if isinstance(a, Number) and isinstance(b, Number): a, b = utils.wrap_scalars(a, b) computation_dtype, result_dtype = _elementwise_dtypes( a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH ) a, b = _convert_dtype(a, b, dtype=computation_dtype) a, b = broadcast(a, b) if alpha is not None: alpha_promotion_type = utils.dtype_to_type(computation_dtype) assert utils.is_lesser_type(type(alpha), alpha_promotion_type) or ( computation_dtype is torch.bool and type(alpha) is int ) b = prims.mul(b, alpha_promotion_type(alpha)) result = prims.add(a, b) (result,) = _convert_dtype(result, dtype=result_dtype) if out is not None: out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] return result
def add( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], *, alpha: Optional[NumberType] = None, ): """ Reference implementation of torch.add """ a, b = _maybe_broadcast(a, b) if alpha is not None: dtype = a.dtype if isinstance( a, TensorLike) else b.dtype # type: ignore[union-attr] python_type = utils.dtype_to_type(dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) b = prims.mul(b, alpha) return prims.add(a, b)
def _wrapper(a, b, broadcast_dimensions): a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) return prims.add(a_bc, b)