def trunc(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.trunc(x))
def bitwise_not(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.bitwise_not(x))
def bitwise_xor(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.bitwise_xor(x1, x2))
def fmin(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.fmin(x1, x2))
def clip(a, a_min=None, a_max=None): if isinstance(a, JaxArray): a = a.value if isinstance(a_min, JaxArray): a_min = a_min.value if isinstance(a_max, JaxArray): a_max = a_max.value return JaxArray(jnp.clip(a, a_min, a_max))
def fabs(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.fabs(x))
def heaviside(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.heaviside(x1, x2))
def nansum(a, axis=None, dtype=None, keepdims=None): if isinstance(a, JaxArray): a = a.value r = jnp.nansum(a=a, axis=axis, dtype=dtype, keepdims=keepdims) return r if axis is None else JaxArray(r)
def ediff1d(a, to_end=None, to_begin=None): if isinstance(a, JaxArray): a = a.value if isinstance(to_end, JaxArray): to_end = to_end.value if isinstance(to_begin, JaxArray): to_begin = to_begin.value return JaxArray(jnp.ediff1d(a, to_end=to_end, to_begin=to_begin))
def nancumprod(a, axis=None, dtype=None): if isinstance(a, JaxArray): a = a.value return JaxArray(jnp.nancumprod(a=a, axis=axis, dtype=dtype))
def cumsum(a, axis=None, dtype=None): if isinstance(a, JaxArray): a = a.value return JaxArray(jnp.cumsum(a=a, axis=axis, dtype=dtype))
def median(a, axis=None, keepdims=False): if isinstance(a, JaxArray): a = a.value r = jnp.median(a, axis=axis, keepdims=keepdims) return r if axis is None else JaxArray(r)
def diff(a, n=1, axis: int = -1, prepend=None, append=None): if isinstance(a, JaxArray): a = a.value return JaxArray(jnp.diff(a, n=n, axis=axis, prepend=prepend, append=append))
def sum(a, axis=None, dtype=None, keepdims=None, initial=None, where=None): if isinstance(a, JaxArray): a = a.value r = jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) return r if axis is None else JaxArray(r)
def sqrt(x): if isinstance(x, JaxArray): return JaxArray(jnp.sqrt(x.value)) else: return jnp.sqrt(x)
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if isinstance(a, JaxArray): a = a.value if isinstance(b, JaxArray): b = b.value return JaxArray(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis))
def cbrt(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.cbrt(x))
def isnan(x): if isinstance(x, JaxArray): return JaxArray(jnp.isnan(x.value)) else: return jnp.isnan(x)
def sign(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.sign(x))
def nextafter(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.nextafter(x1, x2))
def maximum(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.maximum(x1, x2))
def copysign(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.copysign(x1, x2))
def interp(x, xp, fp, left=None, right=None, period=None): if isinstance(x, JaxArray): x = x.value if isinstance(xp, JaxArray): xp = xp.value if isinstance(fp, JaxArray): fp = fp.value return JaxArray(jnp.interp(x, xp, fp, left=left, right=right, period=period))
def ldexp(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.ldexp(x1, x2))
def angle(z, deg=False): if isinstance(z, JaxArray): z = z.value a = jnp.angle(z) if deg: a *= 180 / pi return JaxArray(a)
def frexp(x): if isinstance(x, JaxArray): x = x.value mantissa, exponent = jnp.frexp(x) return JaxArray(mantissa), JaxArray(exponent)
def invert(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.invert(x))
def convolve(a, v, mode='full'): if isinstance(a, JaxArray): a = a.value if isinstance(v, JaxArray): v = v.value return JaxArray(jnp.convolve(a, v, mode))
def rand(self, *dn): return JaxArray( jr.uniform(self.split_key(), shape=dn, minval=0., maxval=1.))
def rint(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.rint(x))