def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type) raise ValueError(msg) scaled_input = refs.mul(a, beta) rhs = refs.true_divide(refs.log1p(refs.exp(scaled_input)), beta) else: scaled_input = a rhs = refs.log1p(refs.exp(scaled_input)) return refs.where(refs.gt(scaled_input, threshold), a, rhs)
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.selu """ if inplace: raise NotImplementedError alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 rhs = refs.mul(alpha, refs.expm1(a)) return refs.mul(scale, refs.where(refs.gt(a, 0), a, rhs))
def elu(a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) rhs = refs.mul(alpha, refs.expm1(a)) else: rhs = refs.expm1(a) return refs.where(refs.gt(a, 0), a, rhs)