def tanhshrink(a: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.tanhshrink """ if not isinstance(a, TensorLike): raise RuntimeError( "Expected a tensor input for an elementwise unary operation!") return refs.sub(a, refs.tanh(a))
def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.mish """ if inplace: raise NotImplementedError return refs.mul(a, refs.tanh(refs.nn.functional.softplus(a)))