Esempio n. 1
0
def trunc(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.trunc(x))
Esempio n. 2
0
def bitwise_not(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.bitwise_not(x))
Esempio n. 3
0
def bitwise_xor(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.bitwise_xor(x1, x2))
Esempio n. 4
0
def fmin(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.fmin(x1, x2))
Esempio n. 5
0
def clip(a, a_min=None, a_max=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(a_min, JaxArray): a_min = a_min.value
  if isinstance(a_max, JaxArray): a_max = a_max.value
  return JaxArray(jnp.clip(a, a_min, a_max))
Esempio n. 6
0
def fabs(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.fabs(x))
Esempio n. 7
0
def heaviside(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.heaviside(x1, x2))
Esempio n. 8
0
def nansum(a, axis=None, dtype=None, keepdims=None):
  if isinstance(a, JaxArray): a = a.value
  r = jnp.nansum(a=a, axis=axis, dtype=dtype, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
Esempio n. 9
0
def ediff1d(a, to_end=None, to_begin=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(to_end, JaxArray): to_end = to_end.value
  if isinstance(to_begin, JaxArray): to_begin = to_begin.value
  return JaxArray(jnp.ediff1d(a, to_end=to_end, to_begin=to_begin))
Esempio n. 10
0
def nancumprod(a, axis=None, dtype=None):
  if isinstance(a, JaxArray): a = a.value
  return JaxArray(jnp.nancumprod(a=a, axis=axis, dtype=dtype))
Esempio n. 11
0
def cumsum(a, axis=None, dtype=None):
  if isinstance(a, JaxArray): a = a.value
  return JaxArray(jnp.cumsum(a=a, axis=axis, dtype=dtype))
Esempio n. 12
0
def median(a, axis=None, keepdims=False):
  if isinstance(a, JaxArray): a = a.value
  r = jnp.median(a, axis=axis, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
Esempio n. 13
0
def diff(a, n=1, axis: int = -1, prepend=None, append=None):
  if isinstance(a, JaxArray): a = a.value
  return JaxArray(jnp.diff(a, n=n, axis=axis, prepend=prepend, append=append))
Esempio n. 14
0
def sum(a, axis=None, dtype=None, keepdims=None, initial=None, where=None):
  if isinstance(a, JaxArray): a = a.value
  r = jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
  return r if axis is None else JaxArray(r)
Esempio n. 15
0
def sqrt(x):
  if isinstance(x, JaxArray):
    return JaxArray(jnp.sqrt(x.value))
  else:
    return jnp.sqrt(x)
Esempio n. 16
0
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(b, JaxArray): b = b.value
  return JaxArray(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis))
Esempio n. 17
0
def cbrt(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.cbrt(x))
Esempio n. 18
0
def isnan(x):
  if isinstance(x, JaxArray):
    return JaxArray(jnp.isnan(x.value))
  else:
    return jnp.isnan(x)
Esempio n. 19
0
def sign(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.sign(x))
Esempio n. 20
0
def nextafter(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.nextafter(x1, x2))
Esempio n. 21
0
def maximum(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.maximum(x1, x2))
Esempio n. 22
0
def copysign(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.copysign(x1, x2))
Esempio n. 23
0
def interp(x, xp, fp, left=None, right=None, period=None):
  if isinstance(x, JaxArray): x = x.value
  if isinstance(xp, JaxArray): xp = xp.value
  if isinstance(fp, JaxArray): fp = fp.value
  return JaxArray(jnp.interp(x, xp, fp, left=left, right=right, period=period))
Esempio n. 24
0
def ldexp(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.ldexp(x1, x2))
Esempio n. 25
0
def angle(z, deg=False):
  if isinstance(z, JaxArray): z = z.value
  a = jnp.angle(z)
  if deg:
    a *= 180 / pi
  return JaxArray(a)
Esempio n. 26
0
def frexp(x):
  if isinstance(x, JaxArray): x = x.value
  mantissa, exponent = jnp.frexp(x)
  return JaxArray(mantissa), JaxArray(exponent)
Esempio n. 27
0
def invert(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.invert(x))
Esempio n. 28
0
def convolve(a, v, mode='full'):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(v, JaxArray): v = v.value
  return JaxArray(jnp.convolve(a, v, mode))
Esempio n. 29
0
 def rand(self, *dn):
     return JaxArray(
         jr.uniform(self.split_key(), shape=dn, minval=0., maxval=1.))
Esempio n. 30
0
def rint(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.rint(x))