Ejemplo n.º 1
0
    def testMultivariateNormal(self, dim, dtype):
        jtu.skip_on_mac_xla_bug()
        r = onp.random.RandomState(dim)
        mean = r.randn(dim)
        cov_factor = r.randn(dim, dim)
        cov = onp.dot(cov_factor, cov_factor.T) + dim * onp.eye(dim)

        key = random.PRNGKey(0)
        rand = partial(random.multivariate_normal,
                       mean=mean,
                       cov=cov,
                       shape=(10000, ))
        crand = api.jit(rand)

        uncompiled_samples = onp.asarray(rand(key), onp.float64)
        compiled_samples = onp.asarray(crand(key), onp.float64)

        inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov),
                                               lower=True)[0]
        for samples in [uncompiled_samples, compiled_samples]:
            centered = samples - mean
            whitened = onp.einsum('nj,ij->ni', centered, inv_scale)

            # This is a quick-and-dirty multivariate normality check that tests that a
            # uniform mixture of the marginals along the covariance matrix's
            # eigenvectors follow a standard normal distribution.
            self._CheckKolmogorovSmirnovCDF(whitened.ravel(),
                                            scipy.stats.norm().cdf)
Ejemplo n.º 2
0
    def testMultivariateNormalCovariance(self):
        # test code based on https://github.com/google/jax/issues/1869
        jtu.skip_on_mac_xla_bug()
        N = 100000
        cov = np.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23],
                        [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]])
        mean = np.zeros(4)

        out_onp = onp.random.RandomState(0).multivariate_normal(mean, cov, N)

        key = random.PRNGKey(0)
        out_jnp = random.multivariate_normal(key,
                                             mean=mean,
                                             cov=cov,
                                             shape=(N, ))

        var_onp = out_onp.var(axis=0)
        var_jnp = out_jnp.var(axis=0)
        self.assertAllClose(var_onp,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

        var_onp = onp.cov(out_onp, rowvar=False)
        var_jnp = onp.cov(out_jnp, rowvar=False)
        self.assertAllClose(var_onp,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)
Ejemplo n.º 3
0
    def testCategorical(self, p, axis, dtype, sample_shape):
        jtu.skip_on_mac_xla_bug()
        key = random.PRNGKey(0)
        p = onp.array(p, dtype=dtype)
        logits = onp.log(p) - 42  # test unnormalized
        shape = sample_shape + tuple(onp.delete(logits.shape, axis))
        rand = lambda key, p: random.categorical(
            key, logits, shape=shape, axis=axis)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        for samples in [uncompiled_samples, compiled_samples]:
            if axis < 0:
                axis += len(logits.shape)

            assert samples.shape == shape

            if len(p.shape[:-1]) > 0:
                for cat_index, p_ in enumerate(p):
                    self._CheckChiSquared(samples[:, cat_index],
                                          pmf=lambda x: p_[x])
            else:
                self._CheckChiSquared(samples, pmf=lambda x: p[x])
Ejemplo n.º 4
0
  def testTensorsolve(self, m, nq, dtype, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if m == 23:
      jtu.skip_on_mac_xla_bug()
    
    # According to numpy docs the shapes are as follows:
    # Coefficient tensor (a), of shape b.shape + Q. 
    # And prod(Q) == prod(b.shape) 
    # Therefore, n = prod(q) 
    n, q = nq
    b_shape = (n, m)
    # To accomplish prod(Q) == prod(b.shape) we append the m extra dim
    # to Q shape
    Q = q + (m,)
    args_maker = lambda: [
        rng(b_shape + Q, dtype), # = a
        rng(b_shape, dtype)]     # = b
    a, b = args_maker()
    result = np.linalg.tensorsolve(*args_maker())
    self.assertEqual(result.shape, Q)

    self._CheckAgainstNumpy(onp.linalg.tensorsolve, 
                            np.linalg.tensorsolve, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.tensorsolve, 
                          args_maker, check_dtypes=True,
                          rtol={onp.float64: 1e-13})
Ejemplo n.º 5
0
 def testIssue2131(self, n, dtype):
   jtu.skip_on_mac_xla_bug()
   args_maker_zeros = lambda: [onp.zeros((n, n), dtype)]
   osp_fun = lambda a: osp.linalg.expm(a)
   jsp_fun = lambda a: jsp.linalg.expm(a)
   self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros,
                           check_dtypes=True)
   self._CompileAndCheck(jsp_fun, args_maker_zeros, check_dtypes=True)
Ejemplo n.º 6
0
  def testPinv(self, shape, dtype, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]:
      jtu.skip_on_mac_xla_bug()
    args_maker = lambda: [rng(shape, dtype)]

    self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.pinv, args_maker, check_dtypes=True)
Ejemplo n.º 7
0
    def testGamma(self, a, dtype):
        jtu.skip_on_mac_xla_bug()
        key = random.PRNGKey(0)
        rand = lambda key, a: random.gamma(key, a, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a)
        compiled_samples = crand(key, a)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
Ejemplo n.º 8
0
    def testExponential(self, dtype):
        jtu.skip_on_mac_xla_bug()
        key = random.PRNGKey(0)
        rand = lambda key: random.exponential(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
Ejemplo n.º 9
0
 def testEigvals(self, shape, dtype, rng_factory):
   rng = rng_factory()
   _skip_if_unsupported_type(dtype)
   if shape == (50, 50) and dtype == onp.complex64:
     jtu.skip_on_mac_xla_bug()
   n = shape[-1]
   args_maker = lambda: [rng(shape, dtype)]
   a, = args_maker()
   w1, _ = np.linalg.eig(a)
   w2 = np.linalg.eigvals(a)
   self.assertAllClose(w1, w2, check_dtypes=True)
Ejemplo n.º 10
0
    def testBernoulli(self, p, dtype):
        jtu.skip_on_mac_xla_bug()
        key = random.PRNGKey(0)
        p = onp.array(p, dtype=dtype)
        rand = lambda key, p: random.bernoulli(key, p, (10000, ))
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
Ejemplo n.º 11
0
  def testLuFactor(self, n, dtype, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if n == 200 and dtype == onp.complex64:
      jtu.skip_on_mac_xla_bug()
    args_maker = lambda: [rng((n, n), dtype)]

    x, = args_maker()
    lu, piv = jsp.linalg.lu_factor(x)
    l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
    u = onp.triu(lu)
    for i in range(n):
      x[[i, piv[i]],] = x[[piv[i], i],]
    self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3,
                        atol=1e-3)
    self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
Ejemplo n.º 12
0
  def testExpm(self, n, dtype, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if n == 50 and dtype in [onp.complex64, onp.float32]:
      jtu.skip_on_mac_xla_bug()
    args_maker = lambda: [rng((n, n), dtype)]

    osp_fun = lambda a: osp.linalg.expm(a)
    jsp_fun = lambda a: jsp.linalg.expm(a)
    self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
                            check_dtypes=True)
    self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    args_maker_triu = lambda: [onp.triu(rng((n, n), dtype))]
    jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
    self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
                            check_dtypes=True)
    self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True)
Ejemplo n.º 13
0
    def testDirichlet(self, alpha, dtype):
        jtu.skip_on_mac_xla_bug()
        key = random.PRNGKey(0)
        rand = lambda key, alpha: random.dirichlet(key, alpha,
                                                   (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, alpha)
        compiled_samples = crand(key, alpha)

        for samples in [uncompiled_samples, compiled_samples]:
            self.assertAllClose(samples.sum(-1),
                                onp.ones(10000, dtype=dtype),
                                check_dtypes=True)
            alpha_sum = sum(alpha)
            for i, a in enumerate(alpha):
                self._CheckKolmogorovSmirnovCDF(
                    samples[..., i],
                    scipy.stats.beta(a, alpha_sum - a).cdf)
Ejemplo n.º 14
0
  def testInv(self, shape, dtype, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if shape == (200, 200) and dtype == onp.float32:
      jtu.skip_on_mac_xla_bug()
    if jtu.device_under_test() == "gpu" and shape == (200, 200):
      raise unittest.SkipTest("Test is flaky on GPU")

    def args_maker():
      invertible = False
      while not invertible:
        a = rng(shape, dtype)
        try:
          onp.linalg.inv(a)
          invertible = True
        except onp.linalg.LinAlgError:
          pass
      return [a]

    self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
Ejemplo n.º 15
0
  def testTensorinv(self, shape, dtype, rng_factory):
    _skip_if_unsupported_type(dtype)
    if shape[0] > 100:
      jtu.skip_on_mac_xla_bug()
    rng = rng_factory()

    def tensor_maker():
      invertible = False
      while not invertible:
        a = rng(shape, dtype)
        try:
          onp.linalg.inv(a)
          invertible = True
        except onp.linalg.LinAlgError:
          pass
      return a

    args_maker = lambda: [tensor_maker(), int(onp.floor(len(shape) / 2))]
    self._CheckAgainstNumpy(onp.linalg.tensorinv, np.linalg.tensorinv, args_maker,
                            check_dtypes=False, tol=1e-3)
    partial_inv = partial(np.linalg.tensorinv, ind=int(onp.floor(len(shape) / 2)))
    self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)