Beispiel #1
0
  def testAdvancedIndexingManually(self):
    x = np.random.RandomState(0).randn(3, 4, 5)
    index_array = np.array([0, 2, -1, 0])

    op = lambda x, index_array: x[..., index_array, :]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2)

    op = lambda x, index_array: x[..., index_array, :, index_array, None]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2)

    op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2)
Beispiel #2
0
    def test_jit_on_nondefault_backend(self):
        cpus = api.devices("cpu")
        self.assertNotEmpty(cpus)

        # Since we are not on CPU, some other backend will be the default
        default_dev = api.devices()[0]
        self.assertNotEqual(default_dev.platform, "cpu")

        data_on_cpu = api.device_put(1, device=cpus[0])
        self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])

        def my_sin(x):
            return jnp.sin(x)

        # jit without any device spec follows the data
        result1 = api.jit(my_sin)(2)
        self.assertEqual(result1.device_buffer.device(), default_dev)
        result2 = api.jit(my_sin)(data_on_cpu)
        self.assertEqual(result2.device_buffer.device(), cpus[0])

        # jit with `device` spec places the data on the specified device
        result3 = api.jit(my_sin, device=cpus[0])(2)
        self.assertEqual(result3.device_buffer.device(), cpus[0])

        # jit with `backend` spec places the data on the specified backend
        result4 = api.jit(my_sin, backend="cpu")(2)
        self.assertEqual(result4.device_buffer.device(), cpus[0])
Beispiel #3
0
    def test_closed_over_values_device_placement(self):
        # see https://github.com/google/jax/issues/1431
        def f():
            return jnp.add(3., 4.)

        self.assertNotEqual(
            api.jit(f)().device_buffer.device(),
            api.devices('cpu')[0])
        self.assertEqual(
            api.jit(f, backend='cpu')().device_buffer.device(),
            api.devices('cpu')[0])
Beispiel #4
0
 def testFloatIndexingError(self):
   BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     jnp.zeros(2)[0.]
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     jnp.zeros((2, 2))[(0, 0.)]
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     jnp.zeros((2, 2))[(0, 0.)]
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     api.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     ops.index_add(jnp.zeros(2), 0., 1.)
   with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
     ops.index_update(jnp.zeros(2), 0., 1.)
Beispiel #5
0
  def testIndexingEmptyDimension(self):
    # Issue 2671: XLA error when indexing into dimension of size 0
    x = jnp.ones((2, 0))
    # The following work, even on axis 1 of size 0
    _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]

    with self.assertRaisesRegex(IndexError,
                                "index .* is out of bounds for axis .* with size 0"):
      _ = np.ones((2, 0))[0, 0]  # The numpy error
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      _ = x[0, 0]  # JAX indexing
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      api.jit(lambda i: x[0, i])(0)  # JAX indexing under jit
Beispiel #6
0
  def testCategorical(self, p, axis, dtype, sample_shape):
    key = random.PRNGKey(0)
    p = np.array(p, dtype=dtype)
    logits = np.log(p) - 42 # test unnormalized
    out_shape = tuple(np.delete(logits.shape, axis))
    shape = sample_shape + out_shape
    rand = partial(random.categorical, shape=shape, axis=axis)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, logits)
    compiled_samples = crand(key, logits)

    if axis < 0:
      axis += len(logits.shape)

    for samples in [uncompiled_samples, compiled_samples]:
      assert samples.shape == shape
      samples = jnp.reshape(samples, (10000,) + out_shape)
      if len(p.shape[:-1]) > 0:
        ps = np.transpose(p, (1, 0)) if axis == 0 else p
        for cat_samples, cat_p in zip(samples.transpose(), ps):
          pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0)
          self._CheckChiSquared(cat_samples, pmf=pmf)
      else:
        pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
        self._CheckChiSquared(samples, pmf=pmf)
Beispiel #7
0
    def testMultivariateNormal(self, dim, dtype, method):
        r = np.random.RandomState(dim)
        mean = r.randn(dim)
        cov_factor = r.randn(dim, dim)
        cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)

        key = random.PRNGKey(0)
        rand = partial(random.multivariate_normal,
                       mean=mean,
                       cov=cov,
                       shape=(10000, ),
                       method=method)
        crand = api.jit(rand)

        uncompiled_samples = np.asarray(rand(key), np.float64)
        compiled_samples = np.asarray(crand(key), np.float64)

        inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov),
                                               lower=True)[0]
        for samples in [uncompiled_samples, compiled_samples]:
            centered = samples - mean
            whitened = np.einsum('nj,ij->ni', centered, inv_scale)

            # This is a quick-and-dirty multivariate normality check that tests that a
            # uniform mixture of the marginals along the covariance matrix's
            # eigenvectors follow a standard normal distribution.
            self._CheckKolmogorovSmirnovCDF(whitened.ravel(),
                                            scipy.stats.norm().cdf)
Beispiel #8
0
  def testIssue187(self):
    x = jnp.ones((5, 5))
    x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

    x = np.arange(25).reshape((5, 5))
    ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
    expected = x[[0, 2, 4], [0, 2, 4]]
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #9
0
  def testGamma(self, a, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, a)
    compiled_samples = crand(key, a)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
Beispiel #10
0
  def testLogistic(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.logistic(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
Beispiel #11
0
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
  def fn(x1, x2):
    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
    return _wraps(numpy_fn, lax_description=doc)(fn)
  else:
    return _wraps(numpy_fn)(fn)
Beispiel #12
0
  def testPareto(self, b, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, b)
    compiled_samples = crand(key, b)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
Beispiel #13
0
  def testT(self, df, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, df: random.t(key, df, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, df)
    compiled_samples = crand(key, df)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
Beispiel #14
0
  def testRngUniform(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.uniform(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
Beispiel #15
0
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
  if promote_to_inexact:
    fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
  else:
    fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
    return _wraps(numpy_fn, lax_description=doc)(fn)
  else:
    return _wraps(numpy_fn)(fn)
Beispiel #16
0
  def test_convert_element_type(self):
    # Regression test for part of https://github.com/google/jax/issues/5982
    with enable_x64():
      x = jnp.int64(1)
    self.assertEqual(x.dtype, jnp.int64)

    y = x.astype(jnp.int32)
    self.assertEqual(y.dtype, jnp.int32)

    z = api.jit(lambda x: x.astype(jnp.int32))(x)
    self.assertEqual(z.dtype, jnp.int32)
Beispiel #17
0
  def testBernoulli(self, p, dtype):
    key = random.PRNGKey(0)
    p = np.array(p, dtype=dtype)
    rand = lambda key, p: random.bernoulli(key, p, (10000,))
    crand = api.jit(rand)

    uncompiled_samples = rand(key, p)
    compiled_samples = crand(key, p)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
Beispiel #18
0
 def test_prng_seeds_and_keys(self, seed, type, jit, key):
   if (jit and type is int and not config.x64_enabled and
       (seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
     self.skipTest("Expected failure: integer out of range for jit.")
   seed = type(seed)
   if jit:
     actual = api.jit(random.PRNGKey)(seed)
   else:
     actual = random.PRNGKey(seed)
   expected = jnp.array(key, dtype=jnp.uint32)
   self.assertArraysEqual(actual, expected)
Beispiel #19
0
  def testNormalComplex(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.normal(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
      self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
      self.assertEqual(dtype, samples.dtype)
Beispiel #20
0
  def testUnpacking(self):

    def foo(x):
      a, b, c = x
      return a + b + c

    cfoo = api.jit(foo)

    a1 = foo(np.arange(3))
    a2 = cfoo(np.arange(3))

    self.assertAllClose(a1, a2)
Beispiel #21
0
  def testBeta(self, a, b, dtype):
    if not config.x64_enabled:
      raise SkipTest("skip test except on X64")
    key = random.PRNGKey(0)
    rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, a, b)
    compiled_samples = crand(key, a, b)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
  def test_impl(self, ad="simple"):
    call_tf = CALL_TF_IMPLEMENTATIONS[ad]

    def f_jax(x):
      return jnp.sin(x)

    def f_outside(x):
      return call_tf(tf.math.sin, x,
                     result_shape=x)

    res = f_outside(3.)
    self.assertAllClose(f_jax(3.), res)
    self.assertAllClose(f_jax(3.), api.jit(f_outside)(3.))
Beispiel #23
0
  def testDirichlet(self, alpha, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, alpha)
    compiled_samples = crand(key, alpha)

    for samples in [uncompiled_samples, compiled_samples]:
      self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
      alpha_sum = sum(alpha)
      for i, a in enumerate(alpha):
        self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
Beispiel #24
0
  def testPermutationInteger(self):
    key = random.PRNGKey(0)
    x = 100
    rand = lambda key: random.permutation(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2)
    self.assertEqual(perm1.dtype, perm2.dtype)
    self.assertFalse(np.all(perm1 == np.arange(100)))  # seems unlikely!
    self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
Beispiel #25
0
  def testPoisson(self, lam, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, lam)
    compiled_samples = crand(key, lam)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
      # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
      # based on the central limit theorem).
      self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
      self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
Beispiel #26
0
  def testRngRandint(self, dtype):
    lo = 5
    hi = 10

    key = random.PRNGKey(0)
    rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self.assertTrue(np.all(lo <= samples))
      self.assertTrue(np.all(samples < hi))
Beispiel #27
0
  def testShuffle(self, dtype):
    key = random.PRNGKey(0)
    x = np.arange(100).astype(dtype)
    rand = lambda key: random.shuffle(key, x)
    crand = api.jit(rand)

    with self.assertWarns(FutureWarning):
      perm1 = rand(key)
    with self.assertWarns(FutureWarning):
      perm2 = crand(key)

    self.assertAllClose(perm1, perm2)
    self.assertFalse(np.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
Beispiel #28
0
  def testTruncatedNormal(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    min_val = np.min(uncompiled_samples)
    max_val = np.max(uncompiled_samples)
    self.assertTrue(min_val > -0.3)
    self.assertTrue(max_val < 0.3)
    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
Beispiel #29
0
  def testPermutationArray(self, dtype, shape):
    key = random.PRNGKey(0)
    x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
    rand = lambda key: random.permutation(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2)
    if x.shape[0] > 1:
      self.assertFalse(np.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
    self.assertArraysAllClose(
      x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
Beispiel #30
0
  def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
                       rtol=None, atol=None, check_cache_misses=True):
    """Helper method for running JAX compilation and allclose assertions."""
    args = args_maker()

    def wrapped_fun(*args):
      self.assertTrue(python_should_be_executing)
      return fun(*args)

    python_should_be_executing = True
    python_ans = fun(*args)

    python_shapes = tree_map(lambda x: np.shape(x), python_ans)
    np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
    self.assertEqual(python_shapes, np_shapes)

    cache_misses = dispatch.xla_primitive_callable.cache_info().misses
    python_ans = fun(*args)
    if check_cache_misses:
      self.assertEqual(
          cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
          "Compilation detected during second call of {} in op-by-op "
          "mode.".format(fun))

    cfun = api.jit(wrapped_fun)
    python_should_be_executing = True
    monitored_ans = cfun(*args)

    python_should_be_executing = False
    compiled_ans = cfun(*args)

    self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
                        atol=atol, rtol=rtol)
    self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
                        atol=atol, rtol=rtol)

    args = args_maker()

    python_should_be_executing = True
    python_ans = fun(*args)

    python_should_be_executing = False
    compiled_ans = cfun(*args)

    self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
                        atol=atol, rtol=rtol)