def testMultivariateNormal(self, dim, dtype): jtu.skip_on_mac_xla_bug() r = onp.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = onp.dot(cov_factor, cov_factor.T) + dim * onp.eye(dim) key = random.PRNGKey(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000, )) crand = api.jit(rand) uncompiled_samples = onp.asarray(rand(key), onp.float64) compiled_samples = onp.asarray(crand(key), onp.float64) inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = onp.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
def testMultivariateNormalCovariance(self): # test code based on https://github.com/google/jax/issues/1869 jtu.skip_on_mac_xla_bug() N = 100000 cov = np.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23], [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]]) mean = np.zeros(4) out_onp = onp.random.RandomState(0).multivariate_normal(mean, cov, N) key = random.PRNGKey(0) out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N, )) var_onp = out_onp.var(axis=0) var_jnp = out_jnp.var(axis=0) self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) var_onp = onp.cov(out_onp, rowvar=False) var_jnp = onp.cov(out_jnp, rowvar=False) self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False)
def testCategorical(self, p, axis, dtype, sample_shape): jtu.skip_on_mac_xla_bug() key = random.PRNGKey(0) p = onp.array(p, dtype=dtype) logits = onp.log(p) - 42 # test unnormalized shape = sample_shape + tuple(onp.delete(logits.shape, axis)) rand = lambda key, p: random.categorical( key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: if axis < 0: axis += len(logits.shape) assert samples.shape == shape if len(p.shape[:-1]) > 0: for cat_index, p_ in enumerate(p): self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])
def testTensorsolve(self, m, nq, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if m == 23: jtu.skip_on_mac_xla_bug() # According to numpy docs the shapes are as follows: # Coefficient tensor (a), of shape b.shape + Q. # And prod(Q) == prod(b.shape) # Therefore, n = prod(q) n, q = nq b_shape = (n, m) # To accomplish prod(Q) == prod(b.shape) we append the m extra dim # to Q shape Q = q + (m,) args_maker = lambda: [ rng(b_shape + Q, dtype), # = a rng(b_shape, dtype)] # = b a, b = args_maker() result = np.linalg.tensorsolve(*args_maker()) self.assertEqual(result.shape, Q) self._CheckAgainstNumpy(onp.linalg.tensorsolve, np.linalg.tensorsolve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.tensorsolve, args_maker, check_dtypes=True, rtol={onp.float64: 1e-13})
def testIssue2131(self, n, dtype): jtu.skip_on_mac_xla_bug() args_maker_zeros = lambda: [onp.zeros((n, n), dtype)] osp_fun = lambda a: osp.linalg.expm(a) jsp_fun = lambda a: jsp.linalg.expm(a) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros, check_dtypes=True) self._CompileAndCheck(jsp_fun, args_maker_zeros, check_dtypes=True)
def testPinv(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]: jtu.skip_on_mac_xla_bug() args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.pinv, args_maker, check_dtypes=True)
def testGamma(self, a, dtype): jtu.skip_on_mac_xla_bug() key = random.PRNGKey(0) rand = lambda key, a: random.gamma(key, a, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a) compiled_samples = crand(key, a) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testExponential(self, dtype): jtu.skip_on_mac_xla_bug() key = random.PRNGKey(0) rand = lambda key: random.exponential(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
def testEigvals(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if shape == (50, 50) and dtype == onp.complex64: jtu.skip_on_mac_xla_bug() n = shape[-1] args_maker = lambda: [rng(shape, dtype)] a, = args_maker() w1, _ = np.linalg.eig(a) w2 = np.linalg.eigvals(a) self.assertAllClose(w1, w2, check_dtypes=True)
def testBernoulli(self, p, dtype): jtu.skip_on_mac_xla_bug() key = random.PRNGKey(0) p = onp.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000, )) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
def testLuFactor(self, n, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if n == 200 and dtype == onp.complex64: jtu.skip_on_mac_xla_bug() args_maker = lambda: [rng((n, n), dtype)] x, = args_maker() lu, piv = jsp.linalg.lu_factor(x) l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype) u = onp.triu(lu) for i in range(n): x[[i, piv[i]],] = x[[piv[i], i],] self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3, atol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
def testExpm(self, n, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if n == 50 and dtype in [onp.complex64, onp.float32]: jtu.skip_on_mac_xla_bug() args_maker = lambda: [rng((n, n), dtype)] osp_fun = lambda a: osp.linalg.expm(a) jsp_fun = lambda a: jsp.linalg.expm(a) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) args_maker_triu = lambda: [onp.triu(rng((n, n), dtype))] jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True) self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu, check_dtypes=True) self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True)
def testDirichlet(self, alpha, dtype): jtu.skip_on_mac_xla_bug() key = random.PRNGKey(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), onp.ones(10000, dtype=dtype), check_dtypes=True) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF( samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
def testInv(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if shape == (200, 200) and dtype == onp.float32: jtu.skip_on_mac_xla_bug() if jtu.device_under_test() == "gpu" and shape == (200, 200): raise unittest.SkipTest("Test is flaky on GPU") def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
def testTensorinv(self, shape, dtype, rng_factory): _skip_if_unsupported_type(dtype) if shape[0] > 100: jtu.skip_on_mac_xla_bug() rng = rng_factory() def tensor_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return a args_maker = lambda: [tensor_maker(), int(onp.floor(len(shape) / 2))] self._CheckAgainstNumpy(onp.linalg.tensorinv, np.linalg.tensorinv, args_maker, check_dtypes=False, tol=1e-3) partial_inv = partial(np.linalg.tensorinv, ind=int(onp.floor(len(shape) / 2))) self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)