def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type) raise ValueError(msg) scaled_input = a * beta rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] else: scaled_input = a rhs = torch.log1p(torch.exp(scaled_input)) return torch.where(scaled_input > threshold, a, rhs)
def celu( a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.celu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!".format( type(alpha), python_type ) ) raise ValueError(msg) rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] else: rhs = torch.expm1(a) return torch.where(a > 0, a, rhs)
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: if not isinstance(a, Number): msg = "Found unknown type {0} when trying to convert scalars!".format( type(a)) raise ValueError(msg) if not utils.is_weakly_lesser_type(type(a), typ): msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format( a, type(a), typ) raise ValueError(msg) return typ(a)
def leaky_relu(a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.leaky_relu """ if inplace: raise NotImplementedError python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(negative_slope), python_type): msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" raise ValueError(msg) return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
def sub( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], *, alpha: Optional[NumberType] = None, ): """ Reference implementation of torch.add """ a, b = _maybe_broadcast(a, b) if alpha is not None: dtype = a.dtype if isinstance( a, TensorLike) else b.dtype # type: ignore[union-attr] python_type = utils.dtype_to_type(dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) b = prims.mul(b, alpha) return prims.sub(a, b)
def elu(a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) rhs = refs.mul(alpha, refs.expm1(a)) else: rhs = refs.expm1(a) return refs.where(refs.gt(a, 0), a, rhs)