def random(self, size=None): return JaxArray( jr.uniform(self.split_key(), shape=_size2shape(size), minval=0., maxval=1.))
def fftfreq(n, d=1.0): return JaxArray(jax.numpy.fft.fftfreq(n=n, d=d))
def ifft(a, n=None, axis=-1, norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm))
def standard_normal(size=None): return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size)))
def uniform(low=0.0, high=1.0, size=None): return JaxArray( jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high))
def poisson(lam=1.0, size=None): return JaxArray( jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size)))
def standard_exponential(size=None): return JaxArray( jr.exponential(DEFAULT.split_key(), shape=_size2shape(size)))
def bernoulli(self, p, size=None): return JaxArray( jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size)))
def rand(*dn): return JaxArray( jr.uniform(DEFAULT.split_key(), shape=dn, minval=0., maxval=1.))
def pareto(self, a, size=None): return JaxArray( jr.pareto(self.split_key(), b=a, shape=_size2shape(size)))
def truncated_normal(self, lower, upper, size, scale=1.): rands = jr.truncated_normal(self.split_key(), lower=lower, upper=upper, shape=_size2shape(size)) return JaxArray(rands * scale)
def laplace(self, loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.laplace(self.split_key(), shape=_size2shape(size)))
def shuffle(self, x, axis=0): x = x.value if isinstance(x, JaxArray) else x return JaxArray(jr.shuffle(self.split_key(), x, axis=axis))
def permutation(self, x): x = x.value if isinstance(x, JaxArray) else x return JaxArray(jr.permutation(self.split_key(), x))
def normal(loc=0.0, scale=1.0, size=None): return JaxArray( jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc)
def randn(*dn): return JaxArray(jr.normal(DEFAULT.split_key(), shape=dn))
def pareto(a, size=None): return JaxArray( jr.pareto(DEFAULT.split_key(), b=a, shape=_size2shape(size)))
def random_sample(size=None): return JaxArray( jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=0., maxval=1.))
def standard_cauchy(size=None): return JaxArray(jr.cauchy(DEFAULT.split_key(), shape=_size2shape(size)))
def shuffle(x): x = x.value if isinstance(x, JaxArray) else x return JaxArray(jr.permutation(DEFAULT.split_key(), x))
def standard_gamma(shape, size=None): return JaxArray( jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)))
def beta(a, b, size=None): a = a.value if isinstance(a, JaxArray) else a b = b.value if isinstance(b, JaxArray) else b return JaxArray( jr.beta(DEFAULT.split_key(), a=a, b=b, shape=_size2shape(size)))
def standard_t(df, size=None): return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size)))
def exponential(scale=1.0, size=None): assert scale == 1. return JaxArray( jr.exponential(DEFAULT.split_key(), shape=_size2shape(size)))
def fft2(a, s=None, axes=(-2, -1), norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm))
def gamma(shape, scale=1.0, size=None): assert scale == 1. return JaxArray( jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)))
def fftshift(x, axes=None): if isinstance(x, JaxArray): x = x.value return JaxArray(jax.numpy.fft.fftshift(x=x, axes=axes))
def logistic(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size)))
def ifftn(a, s=None, axes=None, norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm))
def randn(self, *dn): return JaxArray(jr.normal(self.split_key(), shape=dn))