示例#1
0
    def test_destructure(self):
        def d(key):
            key1, key2 = key
            return key1

        self.check(d, ['2'], '', {}, [(2, )], ['int_'],
                   jtu.rand_int(self.rng(), 0, 10))
示例#2
0
  def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
    rng = jtu.rand_default(self.rng())
    idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
    args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]

    if np.issubdtype(dtype, np.integer):
      if np.isposinf(identity):
        identity = np.iinfo(dtype).max
      elif np.isneginf(identity):
        identity = np.iinfo(dtype).min

    jnp_fun = lambda data, segment_ids: reducer(
      data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)

    def np_fun(data, segment_ids):
      size = num_segments if num_segments is not None else (segment_ids.max() + 1)
      out = np.full((size,) + shape[1:], identity, dtype)
      for i, val in zip(segment_ids, data):
        if 0 <= i < size:
          out[i] = op(out[i], val).astype(dtype)
      return out

    self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
    if num_segments is not None:
      self._CompileAndCheck(jnp_fun, args_maker)
示例#3
0
 def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
   rng = jtu.rand_default(self.rng())
   rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
   idxs = rng_idx(idxs.shape, idxs.dtype)
   scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
   x = rng(arg_shape, dtype)
   y = rng(update_shape, dtype)
   check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
示例#4
0
  def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
    """Tests against JIT compatibility and Numpy."""
    n_max = l_max
    shape = (num_z,)
    rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

    lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

    def args_maker():
      m = rng(shape, dtype)
      n = abs(m)
      theta = jnp.linspace(-4.0, 5.0, num_z)
      phi = jnp.linspace(-2.0, 1.0, num_z)
      return m, n, theta, phi

    with self.subTest('Test JIT compatibility'):
      self._CompileAndCheck(lsp_special_fn, args_maker)

    with self.subTest('Test against numpy.'):
      self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
示例#5
0
class BatchingTest(jtu.JaxTestCase):
    def testConstantFunction(self):
        ans = vmap(lambda x: 3)(onp.ones(4))
        expected = 3 * onp.ones(4)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNestedBatchingMatMat(self):
        matvec = vmap(np.vdot, in_axes=(0, None))
        matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

        R = onp.random.RandomState(0).randn
        A = R(4, 3)
        B = R(3, 2)

        ans = matmat(A, B)
        expected = onp.dot(A, B)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # this is a crude check that we only call a single dot
        def pv_like(x):
            aval = ShapedArray(onp.shape(x), onp.result_type(x))
            return pe.PartialVal((aval, unit))

        def make_jaxpr(fun, example_args):
            jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args))
            return jaxpr

        jaxpr = make_jaxpr(matmat, (A, B))
        self.assertEqual(len(jaxpr.eqns), 1)

    def testPerExampleGradients(self):
        def predict(params, inputs):
            for W, b in params:
                outputs = np.dot(W, inputs) + b
                inputs = np.tanh(outputs)
            return outputs

        def loss(params, data):
            inputs, targets = data
            predictions = predict(params, inputs)
            return np.sum((predictions - targets)**2)

        batch_size = 5
        layer_sizes = [3, 2, 4]

        R = onp.random.RandomState(0).randn
        params = [(R(m, n), R(m))
                  for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

        input_vec = R(3)
        target_vec = R(4)
        datum = (input_vec, target_vec)

        input_batch = R(5, 3)
        target_batch = R(5, 4)
        batch = (input_batch, target_batch)

        ans = vmap(partial(grad(loss), params))(batch)

        for ans_pair, param_pair in zip(ans, params):
            dW, db = ans_pair
            W, b = param_pair

            self.assertEqual(dW.shape, (batch_size, ) + W.shape)
            self.assertEqual(db.shape, (batch_size, ) + b.shape)

    def testJacobians(self):
        def jacbwd(f, x):
            y, pullback = vjp(f, x)
            std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y))
            jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        def jacfwd(f, x):
            pushfwd = lambda v: jvp(f, (x, ), (v, ))
            std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x))
            y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        R = onp.random.RandomState(0).randn

        A = R(4, 3)
        b = R(4)
        f = lambda x: np.tanh(np.dot(A, x) + b)

        x = R(3)
        self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

    def testBatchOfCompile(self):
        side = []

        @jit
        def f(x):
            side.append(None)
            return x + x

        g = jit(vmap(f))
        self.assertAllClose(g(onp.ones(2)),
                            2 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)
        self.assertAllClose(g(2 * onp.ones(2)),
                            4 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)

    def testSliceLax(self):
        fun = lambda x: lax.slice(x, (2, ), (4, ))
        R = onp.random.RandomState(0).randn
        x = R(5, 10)

        ans = vmap(fun)(x)
        expected_ans = x[:, 2:4]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testSliceNumpy(self):
        fun = lambda x: x[:, 2]
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, 2]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevLax(self):
        fun = lambda x: lax.rev(x, [0])
        R = onp.random.RandomState(0).randn
        x = R(2, 3)

        ans = vmap(fun)(x)
        expected_ans = x[:, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevNumpy(self):
        fun = lambda x: x[:, ::-1]
        R = onp.random.RandomState(0).randn
        x = R(3, 2, 4)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (2, ), 2)(x)
        expected_ans = x[:, ::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpMaximum(self):
        fun = lambda x: np.maximum(x, 0.0)
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = onp.maximum(x, 0.0)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpGtrThan(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(lambda x: x > 1.0)(x)
        expected_ans = x > 1.0
        self.assertAllClose(ans, expected_ans, check_dtypes=True)

    def testNpMaximumPerExampleGrad(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5)
        W = R(5, 5)

        fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2)

        ans = vmap(partial(grad(fun), W))(x)

        W_t = np.transpose(W)
        for i in range(10):
            x_ex = x[i:i + 1]

            expected_ans = 2.0 * np.dot(
                np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex)
            expected_ans = np.transpose(expected_ans)

            self.assertAllClose(ans[i], expected_ans, check_dtypes=False)

    def testDotGeneral(self):
        R = onp.random.RandomState(0).randn

        x = R(10, 3, 4, 5)
        y = R(10, 3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun)(x, y)
        expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 10, 5)
        y = R(3, 10, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(2, 1))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack(
            [fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5, 10)
        y = R(3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(3, None))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack([fun(x[..., i], y) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5)
        y = R(3, 5, 10, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(None, 2))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testDot(self):
        # these tests are based on @shoyer's notebook studying gufuncs

        def vecvec(a, b):
            dot = np.dot
            for ndim in range(1, max(a.ndim, b.ndim)):
                a_ax = 0 if a.ndim > ndim else None
                b_ax = 0 if b.ndim > ndim else None
                dot = vmap(dot, in_axes=(a_ax, b_ax))
            return dot(a, b)

        assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == ()
        assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, )
        assert vecvec(np.zeros((4, 2, 3)), np.zeros((3, ))).shape == (4, 2)

    def testDot2(self):
        R = onp.random.RandomState(0).randn
        xs = R(10, 3)
        ys = R(10, 3)
        ans = vmap(np.dot)(xs, ys)
        expected = onp.einsum('ni,ni->n', xs, ys)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDot3(self):
        R = onp.random.RandomState(0).randn
        xs = R(5, 8, 10)
        ys = R(10, 1)
        ans = vmap(np.dot, in_axes=(1, None))(xs, ys)
        expected = onp.einsum('inj,jk->nik', xs, ys)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testPad(self):
        R = onp.random.RandomState(0).randn

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)])
        x = R(5, 10).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)])
        x = R(5, 10, 3).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testConcatenate(self):
        R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype(
            onp.float32)

        fun = lambda *args: lax.concatenate(args, dimension=0)
        x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
        ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.swapaxes(y, 0, 1),
             onp.broadcast_to(z, (10, 4, 3))], 1)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda *args: lax.concatenate(args, dimension=1)
        x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
        ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.broadcast_to(y, (10, 2, 3)),
             onp.moveaxis(z, 2, 0)], 2)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testJacobianIssue54(self):
        # test modeling the code in https://github.com/google/jax/issues/54

        def func(xs):
            return np.array([x for x in xs])

        xs = np.ones((5, 1))
        jacrev(func)(xs)  # don't crash
        jacfwd(func)(xs)  # don't crash

    def testAny(self):
        # test modeling the code in https://github.com/google/jax/issues/108

        ans = vmap(np.any)(np.array([[True, False], [False, False]]))
        expected = np.array([True, False])
        self.assertAllClose(ans, expected, check_dtypes=True)

    @jtu.skip_on_devices("tpu")
    def testHessian(self):
        # test based on code from sindhwani@google
        def fun(x, t):
            return np.sum(np.power(np.maximum(x, 0.0), 2)) + t

        x = onp.array([-1., -0.5, 0., 0.5, 1.0])

        ans = hessian(lambda x: fun(x, 0.0))(x)
        expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                              [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.],
                              [0., 0., 0., 0., 2.]])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDynamicSlice(self):
        # test dynamic_slice via numpy indexing syntax
        x = onp.arange(30).reshape((10, 3))

        ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
        expected = x[:, 1]
        self.assertAllClose(ans, expected, check_dtypes=False)

        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
        expected = x[onp.arange(10), idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(3)
        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDynamicUpdateSlice(self):
        x = onp.random.randn(10, 3)
        y = onp.random.randn(10)
        ans = vmap(
            lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
            in_axes=(0, 0, None))(x, y, 1)
        expected = x.copy()
        expected[:, 1] = y
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.random.randn(3)
        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(
            lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
            in_axes=(None, 0, 0))(x, y, idx)
        expected = onp.broadcast_to(x, (10, 3)).copy()
        expected[onp.arange(10), idx] = y
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testRandom(self):
        seeds = vmap(random.PRNGKey)(onp.arange(10))
        ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
        expected = onp.stack([
            random.normal(random.PRNGKey(seed), (3, 2))
            for seed in onp.arange(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)
        assert len(onp.unique(ans)) == 10 * 3 * 2

    def testSortKeyVal(self):
        k = onp.arange(12)[::-1].reshape(3, 4)
        v = onp.random.RandomState(0).permutation(12).reshape(3, 4)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
        self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
        self.assertAllClose(sv, v[::-1, :], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0],
                                                                         v)
        self.assertAllClose(sk,
                            onp.broadcast_to(k[0, ::-1], (3, 4)),
                            check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T,
                                                                         v[0])
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv,
                            onp.broadcast_to(v[0, ::-1], (3, 4)),
                            check_dtypes=True)

    def testConvGeneralDilated(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testConvGeneralDilatedBatchNotMajor(self):
        W = np.array(onp.random.randn(3, 3, 1, 4), dtype=onp.float32)
        x = np.array(onp.random.randn(3, 5, 7, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            return y

        per_example = vmap(partial(f, W))(x)
        per_example = np.reshape(np.transpose(per_example, (1, 2, 0, 3, 4)),
                                 (5, 5, 21, 4))
        per_example_direct = f(
            W, np.reshape(np.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1)))
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testMaxPool(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, -np.inf, lax.max, (1, 2, 2, 1),
                                  (1, 1, 1, 1), 'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testSumPool(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1),
                                  'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testSelect(self):
        pred = onp.array([True, False])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select)(pred, on_true, on_false)
        expected = onp.array([0, 3])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
        expected = onp.array([[2, 3], [0, 1]])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = True
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
        expected = onp.array([0, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
        expected = onp.array([3, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([2], onp.float32)
        on_false = onp.array([[3, 4]], onp.float32)
        ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
        expected = onp.array([[3, 2]], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testLaxLinalgCholesky(self):
        a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2)))

        ans = vmap(lax_linalg.cholesky)(a)
        expected = onp.linalg.cholesky(a)
        self.assertAllClose(ans, expected, check_dtypes=False)

        b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2)))
        b_trans = onp.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

        ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
        expected = onp.linalg.cholesky(b)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testLaxLinalgTriangularSolve(self):
        a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32)
        a += onp.eye(4, dtype=np.float32)[:, None, :]
        b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b)
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[...,
                                                                        0])
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, 3), onp.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 3, 5), onp.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (axis, None))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, 3), onp.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 3, 5), onp.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (axis, None))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (None, axis))(operand, idxs)
        expected = onp.stack([
            fun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (None, axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                              dnums, slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * op_axis + (i, )],
                idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                                  dnums, slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * op_axis + (i, )],
                 idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNumpyIndexing1(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
        ind = onp.array([[0, 1], [2, 0]])

        def f(a, ind):
            return a[:, ind]

        expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
        ans = vmap(f, (None, 0))(a, ind)
        assert onp.all(ans == expected)

    def testNumpyIndexing2(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))

        def f(a):
            inds = np.array([0, 2])
            return a[:, inds]

        ans = vmap(f)(a)
        expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])],
                             axis=1)
        assert onp.all(ans == expected)

    def testTranspose(self):
        x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
        ans = vmap(lambda x: x + x.T)(x)
        expected = x + onp.swapaxes(x, -1, -2)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testTransposePermutation(self):
        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
        expected = onp.transpose(x, (0, 2, 1, 3))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
        expected = onp.transpose(x, (0, 2, 3, 1))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
        expected = onp.transpose(x, (2, 1, 3, 0))
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testIssue354(self):
        psd_mat = onp.random.randn(20, 10)
        psd_mat = psd_mat.T.dot(psd_mat)
        vec = onp.random.randn(10)

        def f(scale):
            scaled_mat = scale * psd_mat
            chol = np.linalg.cholesky(scaled_mat)
            return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2)

        vmapped_f = vmap(f)
        vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x)))

        scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
        ans = vmapped_f_grad(scales)  # don't crash!
        expected = onp.stack([grad(f)(scale) for scale in scales])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testIssue387(self):
        # https://github.com/google/jax/issues/387
        R = onp.random.RandomState(0).rand(100, 2)

        def dist_sq(R):
            dR = R[:, np.newaxis, :] - R[np.newaxis, :, :]
            zero = np.zeros_like(dR)
            dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR))
            return np.sum(dR**2, axis=2)

        @jit
        def f(R):
            dr = dist_sq(R)
            return np.sum(R**2)

        H = hessian(f)(R)  # don't crash on UnshapedArray

    def testIssue489(self):
        def f(key):
            def body_fn(uk):
                key = uk[1]
                u = random.uniform(key, (), dtype=np.float64)
                key, _ = random.split(key)
                return u, key

            u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
            return u

        print(vmap(f)(random.split(random.PRNGKey(0), 2)))  # no crash

    def testEmptyTuples(self):
        # Ensure there is no crash when a vectorized input contains empty tuples.
        result = vmap(lambda x, _: x + 1)(onp.array([0, 1]), ())
        self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
        # Ensure there is no crash when a vectorized output contains empty tuples.
        result, empty_tuple = vmap(lambda x: (x + 1, ()))(onp.array([0, 1]))
        self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
        self.assertEqual((), empty_tuple)

    def testIndexAddBatchedIndexesOnly(self):
        f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
        result = vmap(f, (None, 0, None))(onp.zeros((10, )), onp.arange(10, ),
                                          1.)
        self.assertAllClose(result, onp.eye(10), check_dtypes=False)
示例#6
0
class LaxBackedNumpyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Numpy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
                                                      dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
        for shapes in filter(
          _shapes_are_broadcast_compatible,
          CombosWithReplacement(rec.shapes, rec.nargs))
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS)))
  def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
        for shapes in filter(
          _shapes_are_broadcast_compatible,
          CombosWithReplacement(rec.shapes, rec.nargs))
        for dtypes in filter(
          _dtypes_are_compatible_for_bitwise_ops,
          CombosWithReplacement(rec.dtypes, rec.nargs)))
      for rec in JAX_BITWISE_OP_RECORDS))
  def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    if not FLAGS.jax_enable_x64 and any(
        onp.iinfo(dtype).bits == 64 for dtype in dtypes):
      self.skipTest("x64 types are disabled by jax_enable_x64")
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis,
          "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for out_dtype in [None] + rec.dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])
      for keepdims in [False, True]))
  def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_NO_DTYPE_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])
      for keepdims in [False, True]))
  def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "shape": shape, "dtype": dtype, "axis": axis}
      for shape in all_shapes for dtype in all_dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])))
  def testCountNonzero(self, shape, dtype, axis):
    rng = jtu.rand_some_zero()
    onp_fun = lambda x: onp.count_nonzero(x, axis)
    lnp_fun = lambda x: lnp.count_nonzero(x, axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis}
      for rec in JAX_ARGMINMAX_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape))))
  def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):

    def onp_fun(array_to_reduce):
      return onp_op(array_to_reduce, axis)

    def lnp_fun(array_to_reduce):
      return lnp_op(array_to_reduce, axis)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("matrix-scalar", (3, 3), ()),
          ("scalar-matrix", (), (3, 3)),
          ("matrix-vector", (4, 5), (5,)),
          ("vector-matrix", (6,), (6, 4)),
          ("matrix-matrix", (3, 4), (4, 5)),
          ("tensor-vector", (4, 3, 2), (2,)),
          ("vector-tensor", (2,), (3, 2, 4)),
          ("tensor-matrix", (4, 3, 2), (2, 5)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-tensor", (2, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
  def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("vector-vector", (3,), (3,)),
          ("matrix-vector", (3, 3), (3,)),
          ("vector-matrix", (3,), (3, 3)),
          ("matrix-matrix", (3, 3), (3, 3)),
          ("vector-tensor", (3,), (5, 3, 2)),
          ("tensor-vector", (5, 3, 2), (2,)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-matrix", (5, 2, 3), (3, 2)),
          ("tensor-tensor", (5, 3, 4), (5, 4, 1)),
          ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
  def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
                            check_dtypes=True)
    self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
          axes),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "axes": axes, "rng": rng}
      for rng in [jtu.rand_default()]
      for lhs_shape, rhs_shape, axes in [
          [(2, 3, 4), (3, 4, 5, 6), 2],
          [(2, 3, 4), (5, 4, 3, 6), [1, 2]],
          [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
          [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
      ]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
  def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)
    onp_fun = lambda a, b: onp.tensordot(a, b, axes)
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": jtu.rand_default()}
      # TODO(phawkins): support integer dtypes too.
      for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)
      for lhs_shape, rhs_shape in [
        (l, r) for l, r in CombosWithReplacement(all_shapes, 2)
        if len(jtu._dims_of_shape(l)) == 0
        or len(jtu._dims_of_shape(r)) == 0
        or l[-1] == r[-1]]))
  def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs)
    lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs)
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_amin={}_amax={}".format(
          jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
       "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in number_dtypes
      for a_min, a_max in [(-1, None), (None, 1), (-1, 1)]))
  def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
    onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
    lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_decimals={}".format(
          jtu.format_shape_dtype_string(shape, dtype), decimals),
       "shape": shape, "dtype": dtype, "decimals": decimals,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in number_dtypes
      for decimals in [0, 1, -2]))
  def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
    if onp.issubdtype(dtype, onp.integer) and decimals < 0:
      self.skipTest("Integer rounding with decimals < 0 not implemented")
    onp_fun = lambda x: onp.round(x, decimals=decimals)
    lnp_fun = lambda x: lnp.round(x, decimals=decimals)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_rpadwidth={}_rconstantvalues={}".format(
          jtu.format_shape_dtype_string(shape, dtype), pad_width_rank,
          constant_values_rank),
       "shape": shape, "dtype": dtype, "pad_width_rank": pad_width_rank,
       "constant_values_rank": constant_values_rank, "rng": jtu.rand_default(),
       "irng": jtu.rand_int(3)}
      for shape in all_shapes for dtype in all_dtypes
      for pad_width_rank in range(3)
      for constant_values_rank in range(3)))
  def testPad(self, shape, dtype, pad_width_rank, constant_values_rank, rng,
              irng):
    pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32)
    def onp_fun(x, constant_vals):
      if pad_width.size == 0:
        return x
      return onp.pad(x, pad_width, mode='constant',
                     constant_values=constant_vals)
    def lnp_fun(x, constant_vals):
      return lnp.pad(x, pad_width, mode='constant',
                     constant_values=constant_vals)

    def args_maker():
      constant_vals = rng([len(shape), 2][2 - constant_values_rank:], dtype)
      return rng(shape, dtype), constant_vals

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
          axis, ",".join(str(d) for d in base_shape),
          ",".join(onp.dtype(dtype).name for dtype in dtypes)),
       "axis": axis, "base_shape": base_shape, "dtypes": dtypes,
       "rng": jtu.rand_default()}
      for num_arrs in [3]
      for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for axis in range(-len(base_shape)+1, len(base_shape))))
  def testConcatenate(self, axis, base_shape, dtypes, rng):
    wrapped_axis = axis % len(base_shape)
    shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
    onp_fun = lambda *args: onp.concatenate(args, axis=axis)
    lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)

    def args_maker():
      return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
          axis, ",".join(str(d) for d in base_shape),
          ",".join(onp.dtype(dtype).name for dtype in dtypes)),
       "axis": axis, "base_shape": base_shape, "dtypes": dtypes,
       "rng": jtu.rand_default()}
      for dtypes in CombosWithReplacement(default_dtypes, 2)
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for axis in range(-len(base_shape)+1, len(base_shape))))
  def testAppend(self, axis, base_shape, dtypes, rng):
    wrapped_axis = axis % len(base_shape)
    shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
    onp_fun = lambda arr, values: onp.append(arr, values, axis=axis)
    lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis)

    def args_maker():
      return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
       "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
       "rng": jtu.rand_default()}
      for repeats in [0, 1, 2]
      for dtype in default_dtypes
      for shape in all_shapes
      for axis in [None] + list(range(-len(shape), len(shape)))))
  def testRepeat(self, axis, shape, dtype, repeats, rng):
    onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis)
    lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis)

    args_maker = lambda: [rng(shape, dtype)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
          onp.dtype(dtype).name, m, n, k),
       "m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for n in [0, 4]
      for m in [None, 0, 1, 3, 4]
      for k in list(range(-4, 4))))
  def testTri(self, m, n, k, dtype, rng):
    onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
    lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
    args_maker = lambda: []
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_k={}".format(
          op, jtu.format_shape_dtype_string(shape, dtype), k),
       "dtype": dtype, "shape": shape, "op": op, "k": k,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 1]
      for op in ["tril", "triu"]
      for k in list(range(-3, 3))))
  def testTriLU(self, dtype, shape, op, k, rng):
    onp_fun = lambda arg: getattr(onp, op)(arg, k=k)
    lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) in (1, 2)]
      for k in list(range(-4, 4))))
  def testDiag(self, shape, dtype, k, rng):
    onp_fun = lambda arg: onp.diag(arg, k)
    lnp_fun = lambda arg: lnp.diag(arg, k)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
          jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
       "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1,
       "axis2": axis2, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 2]
      for axis1 in range(-len(shape), len(shape))
      for axis2 in [a for a in range(-len(shape), len(shape))
                    if a % len(shape) != axis1 % len(shape)]
      for offset in list(range(-4, 4))))
  def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
    onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
    lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n),
       "dtype": dtype, "n": n}
      for dtype in default_dtypes
      for n in list(range(4))))
  def testIdentity(self, n, dtype):
    onp_fun = lambda: onp.identity(n, dtype)
    lnp_fun = lambda: lnp.identity(n, dtype)
    args_maker = lambda: []
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          out_dtype, offset, axis1, axis2),
       "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset,
       "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for out_dtype in [None] + number_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 2]
      for (axis1, axis2) in itertools.combinations(range(len(shape)), 2)
      for offset in list(range(-4, 4))))
  def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng):
    onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype)
    lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(
          jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
       "shape": shape, "dtypes": dtypes, "rng": rng}
      for dtypes in [
        [onp.float32],
        [onp.float32, onp.float32],
        [onp.float32, onp.int32, onp.float32],
        [onp.float32, onp.int64, onp.float32],
        [onp.float32, onp.int32, onp.float64],
      ]
      for shape in [(), (2,), (3, 4), (1, 100)]
      for rng in [jtu.rand_default()]))
  def testStack(self, shape, dtypes, rng):
    args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
    self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outdtype={}".format(
          jtu.format_shape_dtype_string(shape, fill_value_dtype),
          onp.dtype(out_dtype).name),
       "shape": shape, "fill_value_dtype": fill_value_dtype,
       "out_dtype": out_dtype, "rng": jtu.rand_default()}
      for shape in array_shapes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes))
  def testFull(self, shape, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
    lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
          jtu.format_shape_dtype_string(shape, in_dtype),
          onp.dtype(fill_value_dtype).name,
          onp.dtype(out_dtype).name),
       "shape": shape, "in_dtype": in_dtype,
       "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype,
       "rng": jtu.rand_default()}
      for shape in array_shapes
      for in_dtype in default_dtypes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes))
  def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype)
    lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_axis={}_{}sections".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
       "shape": shape, "num_sections": num_sections, "axis": axis,
       "dtype": dtype, "rng": jtu.rand_default()}
      for shape, axis, num_sections in [
          ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
          ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
      for dtype in default_dtypes))
  def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
    onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
    lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype)),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for arg_shape, out_shape in [
          (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
          ((), (1, 1, 1)),
          ((7, 0), (0, 42, 101)),
          ((3, 4), 12),
          ((3, 4), (12,)),
          ((3, 4), -1),
          ((2, 1, 4), (-1,)),
          ((2, 2, 4), (2, 8))
      ]))
  def testReshape(self, arg_shape, out_shape, dtype, rng):
    onp_fun = lambda x: onp.reshape(x, out_shape)
    lnp_fun = lambda x: lnp.reshape(x, out_shape)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_expanddim={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), dim),
       "arg_shape": arg_shape, "dtype": dtype, "dim": dim,
       "rng": jtu.rand_default()}
      for arg_shape in [(), (3,), (3, 4)]
      for dtype in default_dtypes
      for dim in range(-len(arg_shape)+1, len(arg_shape))))
  def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
    onp_fun = lambda x: onp.expand_dims(x, dim)
    lnp_fun = lambda x: lnp.expand_dims(x, dim)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axes=({},{})".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
       "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
       "rng": jtu.rand_default()}
      for arg_shape, ax1, ax2 in [
          ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
          ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
      for dtype in default_dtypes))
  def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
    onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
    lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax),
       "arg_shape": arg_shape, "dtype": dtype, "ax": ax,
       "rng": jtu.rand_default()}
      for arg_shape, ax in [
          ((3, 1), None),
          ((3, 1), 1),
          ((1, 3, 1), (0, 2)),
          ((1, 4, 1), (0,))]
      for dtype in default_dtypes))
  def testSqueeze(self, arg_shape, dtype, ax, rng):
    onp_fun = lambda x: onp.squeeze(x, ax)
    lnp_fun = lambda x: lnp.squeeze(x, ax)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_arg{}".format(i), "arg": arg}
      for i, arg in enumerate([
          [1, 2, 3], [1., 2., 3.],
          [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
          [[3, onp.array(2), 1], onp.arange(3.)],
      ])))
  def testArray(self, arg):
    args_maker = lambda: [arg]
    self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)

  def testArrayAsarrayMethod(self):
    class arraylike(object):
      def __asarray__(self, dtype=None):
        return 3.
    a = arraylike()
    ans = lnp.array(a)
    assert ans == 3.

  def testAllClose(self):
    rng = onp.random.RandomState(0)
    x = rng.randn(2, 2)
    y = rng.randn(2)

    def same(list1, list2):
      allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
      elements_close = list(map(allclose, list1, list2))
      return lnp.all(lnp.array(elements_close))

    csame = api.jit(same)

    a1 = same((x, y), (x, y))
    a2 = csame((x, y), (x, y))
    a3 = csame((x, y), (x, 2 * y))

    self.assertTrue(a1)
    self.assertTrue(a2)
    self.assertFalse(a3)

  @jtu.skip_on_devices("tpu")  # TODO(mattjj): investigate this failure
  def testOnesBroadcastingConstantHandler(self):
    # TODO(mattjj): update this test for jax3
    self.skipTest("test needs jax3 update")

    def fun(x):
      ones = lnp.ones((3, 4))
      assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

      # To check that the constant handler generates a Broadcast for stride-zero
      # arrays, we monkey-patch the client instance.
      # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
      # we can check the HLO more directly.
      c = x._node.c
      Broadcast = c.Broadcast  # pylint: disable=invalid-name
      was_called = []
      c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
      out = x + ones  # the ndarray constant handler should call Broadcast here
      assert was_called, "Broadcast was not called."

      return out

    fun = api.jit(fun)
    out_val = fun(lnp.ones(4))
    self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)

  def testZeroStridesConstantHandler(self):
    raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
    const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))

    def fun(x):
      return x * const

    fun = api.jit(fun)
    out_val = fun(3.)
    self.assertAllClose(out_val, 3. * const, check_dtypes=False)

  def testIsInstanceNdarrayDuringTracing(self):
    arr = onp.ones(3)

    @api.jit
    def f(x):
      self.assertIsInstance(x, lnp.ndarray)
      return lnp.sum(x)

    f(arr)


  def testNonArrayErrorMessage(self):
    x = [1., 2.]
    y = onp.array([3., 4.])

    def g(x, y):
      return lnp.add(x, y)

    def f(x, y):
      return lnp.dot(x, y)

    self.assertRaises(TypeError, lambda: g(x, y))
    self.assertRaises(TypeError, lambda: f(x, y))
    self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
    self.assertRaises(TypeError, lambda: api.jit(f)(x, y))

  def testAbstractionErrorMessage(self):

    @api.jit
    def f(x, n):
      for _ in range(n):
        x = x * x
      return x

    self.assertRaises(TypeError, lambda: f(3., 3))

    @api.jit
    def g(x):
      if x > 0.:
        return x * 2
      else:
        return x + 2

    self.assertRaises(TypeError, lambda: g(3.))

  def testTracingPrimitiveWithNoTranslationErrorMessage(self):
    # TODO(mattjj): update this for jax3
    self.skipTest("test needs jax3 update")
    foo = lnp._not_implemented(lambda x: x)

    # No error if there's no tracing.
    foo(onp.arange(3))

    cfoo = api.jit(foo)
    self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_axis={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
      for shape in [(3,), (2, 3)]
      for dtype in default_dtypes
      for axis in range(len(shape))
      for rng in [jtu.rand_default()]))
  def testFlip(self, shape, dtype, axis, rng):
    args_maker = self._GetArgsMaker(rng, [shape], [dtype])
    lnp_op = lambda x: lnp.flip(x, axis)
    onp_op = lambda x: onp.flip(x, axis)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_k={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k, axes),
       "rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
      for shape, axes in [
          [(2, 3), (0, 1)],
          [(2, 3), (1, 0)],
          [(4, 3, 2), (0, 2)],
          [(4, 3, 2), (2, 1)],
      ]
      for k in range(-3, 4)
      for dtype in default_dtypes
      for rng in [jtu.rand_default()]))
  def testRot90(self, shape, dtype, k, axes, rng):
    args_maker = self._GetArgsMaker(rng, [shape], [dtype])
    lnp_op = lambda x: lnp.rot90(x, k, axes)
    onp_op = lambda x: onp.rot90(x, k, axes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  # TODO(mattjj): test infix operator overrides

  def testRavel(self):
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)

  def testAstype(self):
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    op = lambda x: x.astype(lnp.int32)
    self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True)
    self._CompileAndCheck(op, args_maker, check_dtypes=True)

  # TODO(mattjj): test other ndarray-like method overrides

  def testOnpMean(self):
    # from https://github.com/google/jax/issues/125
    x = lax.add(lnp.eye(3), 0.)
    ans = onp.mean(x)
    self.assertAllClose(ans, onp.array([1./3, 1./3, 1./3]), check_dtypes=False)

  # TODO(mattjj): more exhaustive arange tests
  def testArangeOnFloats(self):
    # from https://github.com/google/jax/issues/145
    expected = onp.arange(0.0, 1.0, 0.1)
    ans = lnp.arange(0.0, 1.0, 0.1)
    self.assertAllClose(expected, ans, check_dtypes=True)

  def testSortManually(self):
    # manual tests for sort are nice because we don't have to worry about ties.
    # lax.sort is tested combinatorially.
    ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4]))
    expected = onp.array([4, 8, 15, 16, 23, 42])
    self.assertAllClose(expected, ans, check_dtypes=True)

    a = onp.array([[1, 4], [3, 1]])
    ans = lnp.sort(a, axis=None)
    expected = onp.array([[1, 1, 3, 4]])
    self.assertAllClose(expected, ans, check_dtypes=True)

    a = onp.array([[1, 4], [3, 1]])
    ans = lnp.sort(a)  # last axis
    expected = onp.array([[1, 4], [1, 3]])
    self.assertAllClose(expected, ans, check_dtypes=True)

    a = onp.array([[1, 4], [3, 1]])
    ans = lnp.sort(a, axis=0)
    expected = onp.array([[1, 1], [3, 4]])
    self.assertAllClose(expected, ans, check_dtypes=True)

  def testArgsortManually(self):
    x = onp.array([16, 15, 23, 42, 8, 4])
    ans = lnp.argsort(x)
    expected = onp.argsort(x)
    self.assertAllClose(expected, ans, check_dtypes=False)

    x = onp.array([[16, 15, 23], [42, 8, 4]])
    ans = lnp.argsort(x, axis=0)
    expected = onp.argsort(x, axis=0)
    self.assertAllClose(expected, ans, check_dtypes=False)

    x = onp.array([[16, 15, 23], [42, 8, 4]])
    ans = lnp.argsort(x, axis=1)
    expected = onp.argsort(x, axis=1)
    self.assertAllClose(expected, ans, check_dtypes=False)

    x = onp.array([[16, 15, 23], [42, 8, 4]])
    ans = lnp.argsort(x, axis=None)
    expected = onp.argsort(x, axis=None)
    self.assertAllClose(expected, ans, check_dtypes=False)

    x = onp.array([[16, 15, 23], [42, 8, 4]])
    ans = lnp.argsort(x)
    expected = onp.argsort(x)
    self.assertAllClose(expected, ans, check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_axis={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
      for shape in [(3,), (3, 4), (3, 4, 5)]
      for axis in itertools.chain(range(len(shape)), [-1], [None])
      for dtype in default_dtypes
      for rng in [jtu.rand_default()]))
  def testTakeAlongAxis(self, shape, dtype, axis, rng):
    def args_maker():
      x = rng(shape, dtype)
      i = onp.argsort(x, axis=axis)
      return x, i

    lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis)

    if hasattr(onp, "take_along_axis"):
      onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis)
      self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)