def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = jnp.ones((3, 4), jnp.float32) i = jnp.ones((3, ), jnp.int32) f = lambda x, i: jnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals, check_dtypes=True) self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
def test_vmap_in_axes_tree_prefix_error(self): # https://github.com/google/jax/issues/795 self.assertRaisesRegex( ValueError, 'vmap in_axes specification must be a tree prefix of the corresponding ' r'value, got specification \(0, 0\) for value tree ', lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3)))
def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False)
def test_vmap_in_axes_list(self): # https://github.com/google/jax/issues/2367 dictionary = {'a': 5., 'b': tf_np.ones(2)} x = tf_np.zeros(3) y = tf_np.arange(3.) def f(dct, x, y): return dct['a'] + dct['b'] + x + y out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y) out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y) self.assertAllClose(out1, out2)
def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex(IndexError, "index .* is out of bounds for axis .* with size 0"): _ = onp.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit
def testMap(self): shape = [2, 3] dtype = tf_np.int32 xs1 = tf_np.zeros(shape, dtype) xs2 = tf_np.ones(shape, dtype) ys_expected = [xs2 + 10, xs1 + 20] def f(x): self.assertIsInstance(x, tuple) for a in x: self.assertEqual(a.shape, shape[1:]) x1, x2 = x return [x2 + 10, x1 + 20] ys = extensions.tf_map(f, (xs1, xs2)) self.assertIsInstance(ys, list) self.assertAllClose(ys, ys_expected)
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 array = jnp.ones(5) self.assertAllClose(array, array[:10], check_dtypes=True)