def clip(a, a_min=None, a_max=None): a_min = _dtype_info(_dtype(a)).min if a_min is None else a_min a_max = _dtype_info(_dtype(a)).max if a_max is None else a_max if _dtype(a_min) != _dtype(a): a_min = lax.convert_element_type(a_min, _dtype(a)) if _dtype(a_max) != _dtype(a): a_max = lax.convert_element_type(a_max, _dtype(a)) return lax.clamp(a_min, a, a_max)
def f_jax(mi, x, ma): return lax.clamp(mi, x, ma)