Example #1
0
  def testForiLoopTupleState(self):
    def sum_first_n(arr, num):
      def body_fun(i, state):
        arr, total = state
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return (arr, lax.add(total, arr_i))

      init_val = (arr, 0.)
      _, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
                               init_val)
      return total

    cfun = api.jit(sum_first_n)
    x = npr.RandomState(0).randn(10)

    for num in [0, 5, 10, 15]:
      self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
                          check_dtypes=False)
      self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
      self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
Example #2
0
  def test_custom_root_vector_with_solve_closure(self):

    def vector_solve(f, y):
      return np.linalg.solve(api.jacobian(f)(y), y)

    def linear_solve(a, b):
      f = lambda y: high_precision_dot(a, y) - b
      x0 = np.zeros_like(b)
      solution = np.linalg.solve(a, b)
      oracle = lambda func, x0: solution
      return lax.custom_root(f, x0, oracle, vector_solve)

    rng = onp.random.RandomState(0)
    a = rng.randn(2, 2)
    b = rng.randn(2)
    jtu.check_grads(linear_solve, (a, b), order=2)

    actual = api.jit(linear_solve)(a, b)
    expected = np.linalg.solve(a, b)
    self.assertAllClose(expected, actual, check_dtypes=True)
Example #3
0
 def test_jit_interleaving(self):
   # Several jit's without data dependencies; they may interfere
   count = 0  # Count tap invocations
   nr_arrays = 5
   def tap_func(arg, **kwargs):
     nonlocal count
     assert len(arg) == nr_arrays
     count += 1
   # This is the function that we'll run multiple times
   def func(x, count):
     for i in range(count):
       x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1]
     return x
   with hcb.outfeed_receiver(receiver_name=self._testMethodName):
     x = jnp.array(1, dtype=np.int32)
     res = 0
     for i in range(10):
       # No dependencies between the jit invocations
       res += api.jit(lambda x: func(x, 10))(x)
   self.assertEqual(100, count)
Example #4
0
    def testPoisson(self, lam, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, lam)
        compiled_samples = crand(key, lam)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
            # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
            # based on the central limit theorem).
            self.assertAllClose(samples.mean(),
                                lam,
                                rtol=0.01,
                                check_dtypes=False)
            self.assertAllClose(samples.var(),
                                lam,
                                rtol=0.03,
                                check_dtypes=False)
Example #5
0
  def testWeibullSample(self, concentration, scale):
    num_samples = 10**5
    rng = random.PRNGKey(0)

    rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
    crand = api.jit(rand)

    loc = scipy.stats.weibull_min.mean(c=concentration, scale=scale)
    std = scipy.stats.weibull_min.std(c=concentration, scale=scale)

    uncompiled_samples = rand(rng)
    compiled_samples = crand(rng)

    for samples in [uncompiled_samples, compiled_samples]:
      # Check first and second moments.
      self.assertEqual((num_samples,), samples.shape)
      self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
      self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min(
          c=concentration, scale=scale).cdf)
Example #6
0
  def test_grad_of_jit_compilation_caching(self):
    if not hasattr(self, "assertLogs"):
      raise unittest.SkipTest("test requires assertLogs (python 3)")

    lax.add(1, 2)  # make sure some initial warnings are already printed

    sin = api.jit(np.sin)

    prev_level = logging.get_verbosity()
    try:
      logging.set_verbosity('DEBUG')
      with self.assertLogs(level=logging.DEBUG) as l:
        ans1 = api.grad(sin)(2.)
        ans2 = api.grad(sin)(3.)
    finally:
      logging.set_verbosity(prev_level)
    self.assertLen(l.output, 2)

    self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
    self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
Example #7
0
    def test_jit_unknown_tap(self):
        # Simulate an unknown tap function
        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with self.assertRaises(hcb.TapFunctionException):
            with hcb.outfeed_receiver(receiver_name=self._testMethodName):
                res = api.jit(func)(0)
        # Even though the receiver thread raised, the main thread should still
        # return 3.
        self.assertEqual(3, res)
        # We should have received all others
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()
Example #8
0
    def test_loop_1(self):
        """One loop with one state var, with transforms."""
        def f_op(inc):
            with loops.Scope() as s:
                s.out = 10.
                for _ in s.range(5):
                    s.out += inc
                return s.out

        def f_expected(inc):
            return 10 + 5 * inc

        self.assertAllClose(f_expected(2.), f_op(2.))
        self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
        self.assertAllClose(5., api.grad(f_op)(2.))
        self.assertAllClose(5., api.grad(f_op)(2.))
        inc_batch = np.arange(5, dtype=jnp.float_)
        self.assertAllClose(
            jnp.array([f_expected(inc) for inc in inc_batch],
                      dtype=jnp.float_),
            api.vmap(f_op)(inc_batch))
Example #9
0
  def test_while(self):
    def f_op(init):
      with loops.Scope() as s:
        s.out = init
        for _ in s.while_range(lambda: s.out < 5.):
          s.out += 2.
        s.out += 1.
        return s.out
    def f_expected(init):
      out = init
      while out < 5.:
        out += 2.
      out += 1.
      return out

    self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
    self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
    self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
    init_batch = np.array([1., 2., 3.])
    self.assertAllClose(np.array([f_expected(init) for init in init_batch]),
                        api.vmap(f_op)(init_batch), check_dtypes=True)
Example #10
0
  def test_root_scalar(self):

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2 - x ** 3
      return lax.root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    # TODO(shoyer): reenable when batching works
    # inputs = np.array([4.0, 5.0])
    # results = api.vmap(sqrt_cubed)(inputs)
    # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
Example #11
0
  def testCategorical(self, p, axis, dtype, sample_shape):
    key = random.PRNGKey(0)
    p = onp.array(p, dtype=dtype)
    logits = onp.log(p) - 42 # test unnormalized
    shape = sample_shape + tuple(onp.delete(logits.shape, axis))
    rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, p)
    compiled_samples = crand(key, p)

    for samples in [uncompiled_samples, compiled_samples]:
      if axis < 0:
       axis += len(logits.shape)

      assert samples.shape == shape

      if len(p.shape[:-1]) > 0:
        for cat_index, p_ in enumerate(p):
          self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x])
      else:
        self._CheckChiSquared(samples, pmf=lambda x: p[x])
Example #12
0
    def testWhileWithTuple(self):
        limit = 10

        def loop_cond(state):
            pos, _ = state
            return lax.lt(pos, limit)

        def loop_body(state):
            pos, count = state
            return (lax.add(pos, 1), lax.add(count, 1))

        def loop(init):
            result = lax.while_loop(loop_cond, loop_body, (init, 0))
            _, count = result
            return count

        cloop = api.jit(loop)

        self.assertEqual(loop(2), limit - 2)
        self.assertEqual(cloop(2), limit - 2)
        self.assertEqual(cloop(2), limit - 2)
        self.assertEqual(cloop(3), limit - 3)
Example #13
0
    def testRadamacher(self):
        rng = random.PRNGKey(0)
        num_samples = 10**5

        rand = lambda x: random.rademacher(x, (num_samples, ))
        crand = api.jit(rand)

        uncompiled_samples = rand(rng)
        compiled_samples = crand(rng)

        for samples in [uncompiled_samples, compiled_samples]:
            unique_values, counts = np.unique(samples, return_counts=True)
            assert len(unique_values) == 2
            assert len(counts) == 2

            self.assertAllClose(counts[0] / num_samples,
                                0.5,
                                rtol=1e-02,
                                atol=1e-02)
            self.assertAllClose(counts[1] / num_samples,
                                0.5,
                                rtol=1e-02,
                                atol=1e-02)
Example #14
0
    def testMultivariateNormal(self, mean, cov, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, mean, cov: random.multivariate_normal(
            key, mean, cov, (1000, ), dtype)
        crand = api.jit(rand)
        if hasattr(cov, "shape") and cov.ndim > 2 or hasattr(
                mean, "shape") and mean.ndim > 1:
            self.assertRaises(ValueError, lambda: rand(key, mean, cov))
            self.assertRaises(ValueError, lambda: crand(key, mean, cov))
            return

        uncompiled_samples = rand(key, mean, cov)
        compiled_samples = crand(key, mean, cov)
        if hasattr(cov, "shape") and cov.ndim == 2:
            inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov),
                                                   lower=True)[0]
            rescale = lambda x: onp.tensordot(x, inv_scale, axes=(-1, 1))
        else:
            rescale = lambda x: x / np.sqrt(cov)
        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(
                rescale(samples - mean).reshape(-1),
                scipy.stats.norm().cdf)
Example #15
0
    def testPoisson(self, lam, dtype):
        if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
            raise SkipTest(
                "random.poisson() not supported on TPU for 16-bit types.")
        key = random.PRNGKey(0)
        rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, lam)
        compiled_samples = crand(key, lam)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
            # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
            # based on the central limit theorem).
            self.assertAllClose(samples.mean(),
                                lam,
                                rtol=0.01,
                                check_dtypes=False)
            self.assertAllClose(samples.var(),
                                lam,
                                rtol=0.03,
                                check_dtypes=False)
Example #16
0
  def testLoopWithConjunctionCondition(self):
    def sum_first_n(arr, num):  # pylint: disable=missing-docstring
      def cond_fun(state):
        arr, num, i, _ = state
        return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0]))

      def body_fun(state):
        arr, num, i, total = state
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return (arr, num, lax.add(i, 1), lax.add(total, arr_i))

      init_val = (arr, num, 0, 0.)
      _, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
      return total

    cfun = api.jit(sum_first_n)
    x = npr.RandomState(0).randn(10)

    for num in [0, 5, 10, 15]:
      self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
                          check_dtypes=False)
      self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
      self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
Example #17
0
  def DISABLED_testOnesBroadcastingConstantHandler(self):
    # TODO(mattjj): update this test for jax3

    def fun(x):
      ones = lnp.ones((3, 4))
      assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

      # To check that the constant handler generates a Broadcast for stride-zero
      # arrays, we monkey-patch the client instance.
      # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
      # we can check the HLO more directly.
      c = x._node.c
      Broadcast = c.Broadcast  # pylint: disable=invalid-name
      was_called = []
      c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
      out = x + ones  # the ndarray constant handler should call Broadcast here
      assert was_called, "Broadcast was not called."

      return out

    fun = api.jit(fun)
    out_val = fun(lnp.ones(4))
    self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)
Example #18
0
  def test_custom_linear_solve(self, symmetric):

    def explicit_jacobian_solve(matvec, b):
      return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))

    def matrix_free_solve(matvec, b):
      return lax.custom_linear_solve(
          matvec, b, explicit_jacobian_solve, explicit_jacobian_solve,
          symmetric=symmetric)

    def linear_solve(a, b):
      return matrix_free_solve(partial(high_precision_dot, a), b)

    rng = onp.random.RandomState(0)
    a = rng.randn(3, 3)
    if symmetric:
      a = a + a.T
    b = rng.randn(3)
    jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)

    expected = np.linalg.solve(a, b)
    actual = api.jit(linear_solve)(a, b)
    self.assertAllClose(expected, actual, check_dtypes=True)
Example #19
0
  def testMultivariateNormal(self, dim, dtype, method):
    r = np.random.RandomState(dim)
    mean = r.randn(dim)
    cov_factor = r.randn(dim, dim)
    cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)

    key = random.PRNGKey(0)
    rand = partial(random.multivariate_normal, mean=mean, cov=cov,
                   shape=(10000,), method=method)
    crand = api.jit(rand)

    uncompiled_samples = np.asarray(rand(key), np.float64)
    compiled_samples = np.asarray(crand(key), np.float64)

    inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
    for samples in [uncompiled_samples, compiled_samples]:
      centered = samples - mean
      whitened = np.einsum('nj,ij->ni', centered, inv_scale)

      # This is a quick-and-dirty multivariate normality check that tests that a
      # uniform mixture of the marginals along the covariance matrix's
      # eigenvectors follow a standard normal distribution.
      self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
Example #20
0
  def testDoublesidedMaxwellSample(self, loc, scale):
    num_samples = 10**5
    rng = random.PRNGKey(0)

    rand = lambda key: random.double_sided_maxwell(
        rng, loc, scale, (num_samples,))
    crand = api.jit(rand)

    mean = loc
    std = np.sqrt(3.) * scale

    uncompiled_samples = rand(rng)
    compiled_samples = crand(rng)

    # Compute the double sided maxwell CDF through the one sided maxwell cdf.
    # This is done as follows:
    # P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) =
    # P (radamacher_sample * one_sided_sample <= (x - loc) / scale) =
    # 1/2 P(one_sided_sample <= (x - loc) / scale)
    #    + 1/2 P( - one_sided_sample <= (x - loc) / scale) =
    #  1/2 P(one_sided_sample <= (x - loc) / scale)
    #    + 1/2 P(one_sided_sample >= - (x - loc) / scale) =
    # 1/2 CDF_one_maxwell((x - loc) / scale))
    #   + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
    def double_sided_maxwell_cdf(x, loc, scale):
      pos = scipy.stats.maxwell().cdf((x - loc)/ scale)
      neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale))
      return (pos + neg) / 2

    for samples in [uncompiled_samples, compiled_samples]:
      # Check first and second moments.
      self.assertEqual((num_samples,), samples.shape)
      self.assertAllClose(np.mean(samples), mean, atol=0., rtol=0.1)
      self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)

      self._CheckKolmogorovSmirnovCDF(
          samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
Example #21
0
  def testCategorical(self, p, axis, dtype, sample_shape):
    key = random.PRNGKey(0)
    p = np.array(p, dtype=dtype)
    logits = np.log(p) - 42 # test unnormalized
    out_shape = tuple(np.delete(logits.shape, axis))
    shape = sample_shape + out_shape
    rand = partial(random.categorical, shape=shape, axis=axis)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, logits)
    compiled_samples = crand(key, logits)

    if axis < 0:
      axis += len(logits.shape)

    for samples in [uncompiled_samples, compiled_samples]:
      assert samples.shape == shape
      samples = jnp.reshape(samples, (10000,) + out_shape)
      if len(p.shape[:-1]) > 0:
        ps = np.transpose(p, (1, 0)) if axis == 0 else p
        for cat_samples, cat_p in zip(samples.transpose(), ps):
          self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
      else:
        self._CheckChiSquared(samples, pmf=lambda x: p[x])
Example #22
0
  def test_custom_linear_solve_cholesky(self):

    def positive_definive_solve(a, b):
      factors = jsp.linalg.cho_factor(a)
      def solve(matvec, x):
        return jsp.linalg.cho_solve(factors, x)
      return lax.custom_linear_solve(
          partial(np.dot, a), b, solve, symmetric=True)

    rng = onp.random.RandomState(0)
    a = rng.randn(2, 2)
    b = rng.randn(2)

    expected = np.linalg.solve(np.dot(a, a.T), b)
    actual = positive_definive_solve(np.dot(a, a.T), b)
    self.assertAllClose(expected, actual, check_dtypes=True)

    actual = api.jit(positive_definive_solve)(np.dot(a, a.T), b)
    self.assertAllClose(expected, actual, check_dtypes=True)

    # numerical gradients are only well defined if ``a`` is guaranteed to be
    # positive definite.
    jtu.check_grads(lambda x, y: positive_definive_solve(np.dot(x, x.T), y),
                    (a, b), order=2)
Example #23
0
  def test_custom_linear_solve_lu(self):

    def linear_solve(a, b):
      a_factors = jsp.linalg.lu_factor(a)
      at_factors = jsp.linalg.lu_factor(a.T)
      def solve(matvec, x):
        return jsp.linalg.lu_solve(a_factors, x)
      def transpose_solve(vecmat, x):
        return jsp.linalg.lu_solve(at_factors, x)
      return lax.custom_linear_solve(
          partial(np.dot, a), b, solve, transpose_solve)

    rng = onp.random.RandomState(0)
    a = rng.randn(3, 3)
    b = rng.randn(3)

    expected = np.linalg.solve(a, b)
    actual = linear_solve(a, b)
    self.assertAllClose(expected, actual, check_dtypes=True)

    jtu.check_grads(linear_solve, (a, b), order=2)

    # regression test for https://github.com/google/jax/issues/1536
    jtu.check_grads(api.jit(linear_solve), (a, b), order=2)
Example #24
0
        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))

            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(x_train, None, params)
                g_td = kernel_fn(x_test, x_train, params)
                predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

                fx_initial_train = f(params, x_train)
                fx_initial_test = f(params, x_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count
Example #25
0
        def f_pmapped(x, *args, **kwargs):
            args_np, args_np_idxs = [], []
            args_other = {}

            # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
            # https://github.com/google/jax/issues/912
            # Filter out `np.ndarray`s from other arguments.
            for i, arg in enumerate(args):
                if _is_np_ndarray(arg):
                    args_np.append(arg)
                    args_np_idxs.append(i)
                else:
                    args_other[i] = arg

            # Check cache before jitting.
            _key = key + tuple(args_other.items()) + tuple(kwargs.items())
            if _key in cache:
                _f = cache[_key]
            else:
                # Define a `np.ndarray`-only function as a closure over other arguments.
                def _f(_x, *_args_np):
                    # Merge args.
                    _args_np = {
                        i: _arg_np
                        for i, _arg_np in zip(args_np_idxs, _args_np)
                    }
                    _args = _merge_dicts(_args_np, args_other)
                    _args = tuple(v for k, v in sorted(_args.items()))
                    return f(_x, *_args, **kwargs)

                _f = jit(_f) if device_count == 0 else pmap(_f)
                cache[_key] = _f

            # Broadcast `np.ndarray` arguments and apply the new function to them.
            args_np = tree_map(broadcast, args_np)
            return _f(x, *args_np)
Example #26
0
 def loop_body(state):
     effect[0] = True
     pos, count = state
     f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
     return api.jit(f)(pos, inc)
Example #27
0
 def testPermutationErrors(self):
     key = random.PRNGKey(0)
     with self.assertRaises(TypeError):
         random.permutation(key, 10.)
     with self.assertRaises(core.ConcretizationTypeError):
         api.jit(random.permutation)(key, 10)
Example #28
0
 def test_jit_error_no_consumer(self):
     # Check for errors if starting jit without a consumer active
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.jit(lambda x: hcb.id_print(x))(0)
Example #29
0
 def test_jit_several_together(self):
     arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
             arg, jnp.ones(100, dtype=jnp.int32))
Example #30
0
 def test_jit_large(self):
     arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(hcb.id_print)(arg)