def get_function(function: str, gamma: float) -> Callable: if function == "tanh": return lambda x: ep.tanh(gamma * x) elif function == "identity": return lambda x: x elif function == "constant": return lambda x: (abs(x) > 0).astype(int) else: raise ValueError("Function given for DCT is incorrect.")
def test_tanh(t: Tensor) -> Tensor: return ep.tanh(t)