def test_destructure(self): def d(key): key1, key2 = key return key1 self.check(d, ['2'], '', {}, [(2, )], ['int_'], jtu.rand_int(self.rng(), 0, 10))
def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size): rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=-2, high=3) args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)] if np.issubdtype(dtype, np.integer): if np.isposinf(identity): identity = np.iinfo(dtype).max elif np.isneginf(identity): identity = np.iinfo(dtype).min jnp_fun = lambda data, segment_ids: reducer( data, segment_ids, num_segments=num_segments, bucket_size=bucket_size) def np_fun(data, segment_ids): size = num_segments if num_segments is not None else (segment_ids.max() + 1) out = np.full((size,) + shape[1:], identity, dtype) for i, val in zip(segment_ids, data): if 0 <= i < size: out[i] = op(out[i], val).astype(dtype) return out self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) if num_segments is not None: self._CompileAndCheck(jnp_fun, args_maker)
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) def args_maker(): m = rng(shape, dtype) n = abs(m) theta = jnp.linspace(-4.0, 5.0, num_z) phi = jnp.linspace(-2.0, 1.0, num_z) return m, n, theta, phi with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_special_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(onp.ones(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(np.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = onp.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = onp.dot(A, B) self.assertAllClose(ans, expected, check_dtypes=False) # this is a crude check that we only call a single dot def pv_like(x): aval = ShapedArray(onp.shape(x), onp.result_type(x)) return pe.PartialVal((aval, unit)) def make_jaxpr(fun, example_args): jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args)) return jaxpr jaxpr = make_jaxpr(matmat, (A, B)) self.assertEqual(len(jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = np.dot(W, inputs) + b inputs = np.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return np.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = onp.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_vec = R(3) target_vec = R(4) datum = (input_vec, target_vec) input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size, ) + W.shape) self.assertEqual(db.shape, (batch_size, ) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y)) jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x, ), (v, )) std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) R = onp.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: np.tanh(np.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(onp.ones(2)), 2 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * onp.ones(2)), 4 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2, ), (4, )) R = onp.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = onp.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = onp.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2, ), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: np.maximum(x, 0.0) R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = onp.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans, check_dtypes=True) def testNpMaximumPerExampleGrad(self): R = onp.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2) ans = vmap(partial(grad(fun), W))(x) W_t = np.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * np.dot( np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex) expected_ans = np.transpose(expected_ans) self.assertAllClose(ans[i], expected_ans, check_dtypes=False) def testDotGeneral(self): R = onp.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(2, 1))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack( [fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(3, None))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(None, 2))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = np.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == () assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, ) assert vecvec(np.zeros((4, 2, 3)), np.zeros((3, ))).shape == (4, 2) def testDot2(self): R = onp.random.RandomState(0).randn xs = R(10, 3) ys = R(10, 3) ans = vmap(np.dot)(xs, ys) expected = onp.einsum('ni,ni->n', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot3(self): R = onp.random.RandomState(0).randn xs = R(5, 8, 10) ys = R(10, 1) ans = vmap(np.dot, in_axes=(1, None))(xs, ys) expected = onp.einsum('inj,jk->nik', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testPad(self): R = onp.random.RandomState(0).randn fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype( onp.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = onp.concatenate( [x, onp.swapaxes(y, 0, 1), onp.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = onp.concatenate( [x, onp.broadcast_to(y, (10, 2, 3)), onp.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return np.array([x for x in xs]) xs = np.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(np.any)(np.array([[True, False], [False, False]])) expected = np.array([True, False]) self.assertAllClose(ans, expected, check_dtypes=True) @jtu.skip_on_devices("tpu") def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return np.sum(np.power(np.maximum(x, 0.0), 2)) + t x = onp.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax x = onp.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[onp.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(3) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicUpdateSlice(self): x = onp.random.randn(10, 3) y = onp.random.randn(10) ans = vmap( lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(0, 0, None))(x, y, 1) expected = x.copy() expected[:, 1] = y self.assertAllClose(ans, expected, check_dtypes=False) x = onp.random.randn(3) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap( lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(None, 0, 0))(x, y, idx) expected = onp.broadcast_to(x, (10, 3)).copy() expected[onp.arange(10), idx] = y self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(onp.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = onp.stack([ random.normal(random.PRNGKey(seed), (3, 2)) for seed in onp.arange(10) ]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(onp.unique(ans)) == 10 * 3 * 2 def testSortKeyVal(self): k = onp.arange(12)[::-1].reshape(3, 4) v = onp.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :], check_dtypes=True) self.assertAllClose(sv, v[::-1, :], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, onp.broadcast_to(k[0, ::-1], (3, 4)), check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, onp.broadcast_to(v[0, ::-1], (3, 4)), check_dtypes=True) def testConvGeneralDilated(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testConvGeneralDilatedBatchNotMajor(self): W = np.array(onp.random.randn(3, 3, 1, 4), dtype=onp.float32) x = np.array(onp.random.randn(3, 5, 7, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) return y per_example = vmap(partial(f, W))(x) per_example = np.reshape(np.transpose(per_example, (1, 2, 0, 3, 4)), (5, 5, 21, 4)) per_example_direct = f( W, np.reshape(np.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testMaxPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, -np.inf, lax.max, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSumPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSelect(self): pred = onp.array([True, False]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = onp.array([0, 3]) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = onp.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected, check_dtypes=True) pred = True on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = onp.array([0, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = onp.array([3, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([2], onp.float32) on_false = onp.array([[3, 4]], onp.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = onp.array([[3, 2]], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) def testLaxLinalgCholesky(self): a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2))) ans = vmap(lax_linalg.cholesky)(a) expected = onp.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False) b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2))) b_trans = onp.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = onp.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False) def testLaxLinalgTriangularSolve(self): a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32) a += onp.eye(4, dtype=np.float32)[:, None, :] b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = onp.stack([ lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = onp.stack([ lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = onp.stack([ lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = onp.stack([ fun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = onp.stack([ gfun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = onp.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert onp.all(ans == expected) def testNumpyIndexing2(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = np.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert onp.all(ans == expected) def testTranspose(self): x = onp.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + onp.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x) expected = onp.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x) expected = onp.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = onp.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = onp.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = onp.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = np.linalg.cholesky(scaled_mat) return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x))) scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = onp.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue387(self): # https://github.com/google/jax/issues/387 R = onp.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, np.newaxis, :] - R[np.newaxis, :, :] zero = np.zeros_like(dR) dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR)) return np.sum(dR**2, axis=2) @jit def f(R): dr = dist_sq(R) return np.sum(R**2) H = hessian(f)(R) # don't crash on UnshapedArray def testIssue489(self): def f(key): def body_fn(uk): key = uk[1] u = random.uniform(key, (), dtype=np.float64) key, _ = random.split(key) return u, key u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) return u print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): # Ensure there is no crash when a vectorized input contains empty tuples. result = vmap(lambda x, _: x + 1)(onp.array([0, 1]), ()) self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False) # Ensure there is no crash when a vectorized output contains empty tuples. result, empty_tuple = vmap(lambda x: (x + 1, ()))(onp.array([0, 1])) self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False) self.assertEqual((), empty_tuple) def testIndexAddBatchedIndexesOnly(self): f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) result = vmap(f, (None, 0, None))(onp.zeros((10, )), onp.arange(10, ), 1.) self.assertAllClose(result, onp.eye(10), check_dtypes=False)
class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) def testOp(self, onp_op, lnp_op, rng, shapes, dtypes): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in filter( _dtypes_are_compatible_for_bitwise_ops, CombosWithReplacement(rec.dtypes, rec.nargs))) for rec in JAX_BITWISE_OP_RECORDS)) def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes): if not FLAGS.jax_enable_x64 and any( onp.iinfo(dtype).bits == 64 for dtype in dtypes): self.skipTest("x64 types are disabled by jax_enable_x64") args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_RECORDS for shape in rec.shapes for dtype in rec.dtypes for out_dtype in [None] + rec.dtypes for axis in set(range(-len(shape), len(shape))) | set([None]) for keepdims in [False, True])) def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_NO_DTYPE_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in set(range(-len(shape), len(shape))) | set([None]) for keepdims in [False, True])) def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}".format( jtu.format_shape_dtype_string(shape, dtype), axis), "shape": shape, "dtype": dtype, "axis": axis} for shape in all_shapes for dtype in all_dtypes for axis in set(range(-len(shape), len(shape))) | set([None]))) def testCountNonzero(self, shape, dtype, axis): rng = jtu.rand_some_zero() onp_fun = lambda x: onp.count_nonzero(x, axis) lnp_fun = lambda x: lnp.count_nonzero(x, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis} for rec in JAX_ARGMINMAX_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)))) def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis): def onp_fun(array_to_reduce): return onp_op(array_to_reduce, axis) def lnp_fun(array_to_reduce): return lnp_op(array_to_reduce, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("matrix-scalar", (3, 3), ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), (5,)), ("vector-matrix", (6,), (6, 4)), ("matrix-matrix", (3, 4), (4, 5)), ("tensor-vector", (4, 3, 2), (2,)), ("vector-tensor", (2,), (3, 2, 4)), ("tensor-matrix", (4, 3, 2), (2, 5)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-tensor", (2, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2))) def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("vector-vector", (3,), (3,)), ("matrix-vector", (3, 3), (3,)), ("vector-matrix", (3,), (3, 3)), ("matrix-matrix", (3, 3), (3, 3)), ("vector-tensor", (3,), (5, 3, 2)), ("tensor-vector", (5, 3, 2), (2,)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2))) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), axes), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "axes": axes, "rng": rng} for rng in [jtu.rand_default()] for lhs_shape, rhs_shape, axes in [ [(2, 3, 4), (3, 4, 5, 6), 2], [(2, 3, 4), (5, 4, 3, 6), [1, 2]], [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], ] for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2))) def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) onp_fun = lambda a, b: onp.tensordot(a, b, axes) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": jtu.rand_default()} # TODO(phawkins): support integer dtypes too. for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2) for lhs_shape, rhs_shape in [ (l, r) for l, r in CombosWithReplacement(all_shapes, 2) if len(jtu._dims_of_shape(l)) == 0 or len(jtu._dims_of_shape(r)) == 0 or l[-1] == r[-1]])) def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs) lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_amin={}_amax={}".format( jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, "rng": jtu.rand_default()} for shape in all_shapes for dtype in number_dtypes for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng): onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_decimals={}".format( jtu.format_shape_dtype_string(shape, dtype), decimals), "shape": shape, "dtype": dtype, "decimals": decimals, "rng": jtu.rand_default()} for shape in all_shapes for dtype in number_dtypes for decimals in [0, 1, -2])) def testRoundStaticDecimals(self, shape, dtype, decimals, rng): if onp.issubdtype(dtype, onp.integer) and decimals < 0: self.skipTest("Integer rounding with decimals < 0 not implemented") onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_rpadwidth={}_rconstantvalues={}".format( jtu.format_shape_dtype_string(shape, dtype), pad_width_rank, constant_values_rank), "shape": shape, "dtype": dtype, "pad_width_rank": pad_width_rank, "constant_values_rank": constant_values_rank, "rng": jtu.rand_default(), "irng": jtu.rand_int(3)} for shape in all_shapes for dtype in all_dtypes for pad_width_rank in range(3) for constant_values_rank in range(3))) def testPad(self, shape, dtype, pad_width_rank, constant_values_rank, rng, irng): pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32) def onp_fun(x, constant_vals): if pad_width.size == 0: return x return onp.pad(x, pad_width, mode='constant', constant_values=constant_vals) def lnp_fun(x, constant_vals): return lnp.pad(x, pad_width, mode='constant', constant_values=constant_vals) def args_maker(): constant_vals = rng([len(shape), 2][2 - constant_values_rank:], dtype) return rng(shape, dtype), constant_vals self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join(onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default()} for num_arrs in [3] for dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape)))) def testConcatenate(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] onp_fun = lambda *args: onp.concatenate(args, axis=axis) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join(onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default()} for dtypes in CombosWithReplacement(default_dtypes, 2) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape)))) def testAppend(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] onp_fun = lambda arr, values: onp.append(arr, values, axis=axis) lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, repeats), "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, "rng": jtu.rand_default()} for repeats in [0, 1, 2] for dtype in default_dtypes for shape in all_shapes for axis in [None] + list(range(-len(shape), len(shape))))) def testRepeat(self, axis, shape, dtype, repeats, rng): onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( onp.dtype(dtype).name, m, n, k), "m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for n in [0, 4] for m in [None, 0, 1, 3, 4] for k in list(range(-4, 4)))) def testTri(self, m, n, k, dtype, rng): onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) args_maker = lambda: [] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_k={}".format( op, jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "op": op, "k": k, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 1] for op in ["tril", "triu"] for k in list(range(-3, 3)))) def testTriLU(self, dtype, shape, op, k, rng): onp_fun = lambda arg: getattr(onp, op)(arg, k=k) lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] for k in list(range(-4, 4)))) def testDiag(self, shape, dtype, k, rng): onp_fun = lambda arg: onp.diag(arg, k) lnp_fun = lambda arg: lnp.diag(arg, k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 2] for axis1 in range(-len(shape), len(shape)) for axis2 in [a for a in range(-len(shape), len(shape)) if a % len(shape) != axis1 % len(shape)] for offset in list(range(-4, 4)))) def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng): onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), "dtype": dtype, "n": n} for dtype in default_dtypes for n in list(range(4)))) def testIdentity(self, n, dtype): onp_fun = lambda: onp.identity(n, dtype) lnp_fun = lambda: lnp.identity(n, dtype) args_maker = lambda: [] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), out_dtype, offset, axis1, axis2), "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()} for dtype in default_dtypes for out_dtype in [None] + number_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 2] for (axis1, axis2) in itertools.combinations(range(len(shape)), 2) for offset in list(range(-4, 4)))) def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng): onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype) lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), "shape": shape, "dtypes": dtypes, "rng": rng} for dtypes in [ [onp.float32], [onp.float32, onp.float32], [onp.float32, onp.int32, onp.float32], [onp.float32, onp.int64, onp.float32], [onp.float32, onp.int32, onp.float64], ] for shape in [(), (2,), (3, 4), (1, 100)] for rng in [jtu.rand_default()])) def testStack(self, shape, dtypes, rng): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outdtype={}".format( jtu.format_shape_dtype_string(shape, fill_value_dtype), onp.dtype(out_dtype).name), "shape": shape, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in array_shapes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes)) def testFull(self, shape, fill_value_dtype, out_dtype, rng): onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( jtu.format_shape_dtype_string(shape, in_dtype), onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name), "shape": shape, "in_dtype": in_dtype, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in array_shapes for in_dtype in default_dtypes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes)) def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng): onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_{}sections".format( jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype, "rng": jtu.rand_default()} for shape, axis, num_sections in [ ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] for dtype in default_dtypes)) def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng): onp_fun = lambda x: onp.split(x, num_sections, axis=axis) lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype)), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for arg_shape, out_shape in [ (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), ((), (1, 1, 1)), ((7, 0), (0, 42, 101)), ((3, 4), 12), ((3, 4), (12,)), ((3, 4), -1), ((2, 1, 4), (-1,)), ((2, 2, 4), (2, 8)) ])) def testReshape(self, arg_shape, out_shape, dtype, rng): onp_fun = lambda x: onp.reshape(x, out_shape) lnp_fun = lambda x: lnp.reshape(x, out_shape) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_expanddim={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), dim), "arg_shape": arg_shape, "dtype": dtype, "dim": dim, "rng": jtu.rand_default()} for arg_shape in [(), (3,), (3, 4)] for dtype in default_dtypes for dim in range(-len(arg_shape)+1, len(arg_shape)))) def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng): onp_fun = lambda x: onp.expand_dims(x, dim) lnp_fun = lambda x: lnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axes=({},{})".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, "rng": jtu.rand_default()} for arg_shape, ax1, ax2 in [ ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] for dtype in default_dtypes)) def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng): onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax), "arg_shape": arg_shape, "dtype": dtype, "ax": ax, "rng": jtu.rand_default()} for arg_shape, ax in [ ((3, 1), None), ((3, 1), 1), ((1, 3, 1), (0, 2)), ((1, 4, 1), (0,))] for dtype in default_dtypes)) def testSqueeze(self, arg_shape, dtype, ax, rng): onp_fun = lambda x: onp.squeeze(x, ax) lnp_fun = lambda x: lnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_arg{}".format(i), "arg": arg} for i, arg in enumerate([ [1, 2, 3], [1., 2., 3.], [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]], [[3, onp.array(2), 1], onp.arange(3.)], ]))) def testArray(self, arg): args_maker = lambda: [arg] self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True) def testArrayAsarrayMethod(self): class arraylike(object): def __asarray__(self, dtype=None): return 3. a = arraylike() ans = lnp.array(a) assert ans == 3. def testAllClose(self): rng = onp.random.RandomState(0) x = rng.randn(2, 2) y = rng.randn(2) def same(list1, list2): allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) elements_close = list(map(allclose, list1, list2)) return lnp.all(lnp.array(elements_close)) csame = api.jit(same) a1 = same((x, y), (x, y)) a2 = csame((x, y), (x, y)) a3 = csame((x, y), (x, 2 * y)) self.assertTrue(a1) self.assertTrue(a2) self.assertFalse(a3) @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure def testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 self.skipTest("test needs jax3 update") def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) def testZeroStridesConstantHandler(self): raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) def fun(x): return x * const fun = api.jit(fun) out_val = fun(3.) self.assertAllClose(out_val, 3. * const, check_dtypes=False) def testIsInstanceNdarrayDuringTracing(self): arr = onp.ones(3) @api.jit def f(x): self.assertIsInstance(x, lnp.ndarray) return lnp.sum(x) f(arr) def testNonArrayErrorMessage(self): x = [1., 2.] y = onp.array([3., 4.]) def g(x, y): return lnp.add(x, y) def f(x, y): return lnp.dot(x, y) self.assertRaises(TypeError, lambda: g(x, y)) self.assertRaises(TypeError, lambda: f(x, y)) self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) def testAbstractionErrorMessage(self): @api.jit def f(x, n): for _ in range(n): x = x * x return x self.assertRaises(TypeError, lambda: f(3., 3)) @api.jit def g(x): if x > 0.: return x * 2 else: return x + 2 self.assertRaises(TypeError, lambda: g(3.)) def testTracingPrimitiveWithNoTranslationErrorMessage(self): # TODO(mattjj): update this for jax3 self.skipTest("test needs jax3 update") foo = lnp._not_implemented(lambda x: x) # No error if there's no tracing. foo(onp.arange(3)) cfoo = api.jit(foo) self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}".format( jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rng, "shape": shape, "dtype": dtype, "axis": axis} for shape in [(3,), (2, 3)] for dtype in default_dtypes for axis in range(len(shape)) for rng in [jtu.rand_default()])) def testFlip(self, shape, dtype, axis, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.flip(x, axis) onp_op = lambda x: onp.flip(x, axis) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_k={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), k, axes), "rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes} for shape, axes in [ [(2, 3), (0, 1)], [(2, 3), (1, 0)], [(4, 3, 2), (0, 2)], [(4, 3, 2), (2, 1)], ] for k in range(-3, 4) for dtype in default_dtypes for rng in [jtu.rand_default()])) def testRot90(self, shape, dtype, k, axes, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.rot90(x, k, axes) onp_op = lambda x: onp.rot90(x, k, axes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) # TODO(mattjj): test infix operator overrides def testRavel(self): rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True) def testAstype(self): rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] op = lambda x: x.astype(lnp.int32) self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) self._CompileAndCheck(op, args_maker, check_dtypes=True) # TODO(mattjj): test other ndarray-like method overrides def testOnpMean(self): # from https://github.com/google/jax/issues/125 x = lax.add(lnp.eye(3), 0.) ans = onp.mean(x) self.assertAllClose(ans, onp.array([1./3, 1./3, 1./3]), check_dtypes=False) # TODO(mattjj): more exhaustive arange tests def testArangeOnFloats(self): # from https://github.com/google/jax/issues/145 expected = onp.arange(0.0, 1.0, 0.1) ans = lnp.arange(0.0, 1.0, 0.1) self.assertAllClose(expected, ans, check_dtypes=True) def testSortManually(self): # manual tests for sort are nice because we don't have to worry about ties. # lax.sort is tested combinatorially. ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4])) expected = onp.array([4, 8, 15, 16, 23, 42]) self.assertAllClose(expected, ans, check_dtypes=True) a = onp.array([[1, 4], [3, 1]]) ans = lnp.sort(a, axis=None) expected = onp.array([[1, 1, 3, 4]]) self.assertAllClose(expected, ans, check_dtypes=True) a = onp.array([[1, 4], [3, 1]]) ans = lnp.sort(a) # last axis expected = onp.array([[1, 4], [1, 3]]) self.assertAllClose(expected, ans, check_dtypes=True) a = onp.array([[1, 4], [3, 1]]) ans = lnp.sort(a, axis=0) expected = onp.array([[1, 1], [3, 4]]) self.assertAllClose(expected, ans, check_dtypes=True) def testArgsortManually(self): x = onp.array([16, 15, 23, 42, 8, 4]) ans = lnp.argsort(x) expected = onp.argsort(x) self.assertAllClose(expected, ans, check_dtypes=False) x = onp.array([[16, 15, 23], [42, 8, 4]]) ans = lnp.argsort(x, axis=0) expected = onp.argsort(x, axis=0) self.assertAllClose(expected, ans, check_dtypes=False) x = onp.array([[16, 15, 23], [42, 8, 4]]) ans = lnp.argsort(x, axis=1) expected = onp.argsort(x, axis=1) self.assertAllClose(expected, ans, check_dtypes=False) x = onp.array([[16, 15, 23], [42, 8, 4]]) ans = lnp.argsort(x, axis=None) expected = onp.argsort(x, axis=None) self.assertAllClose(expected, ans, check_dtypes=False) x = onp.array([[16, 15, 23], [42, 8, 4]]) ans = lnp.argsort(x) expected = onp.argsort(x) self.assertAllClose(expected, ans, check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}".format( jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rng, "shape": shape, "dtype": dtype, "axis": axis} for shape in [(3,), (3, 4), (3, 4, 5)] for axis in itertools.chain(range(len(shape)), [-1], [None]) for dtype in default_dtypes for rng in [jtu.rand_default()])) def testTakeAlongAxis(self, shape, dtype, axis, rng): def args_maker(): x = rng(shape, dtype) i = onp.argsort(x, axis=axis) return x, i lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis) if hasattr(onp, "take_along_axis"): onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis) self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)