Пример #1
0
 def testJVPOfGradOfIndexing(self):
     # Should return a value, even though we didn't pass a symbolic zero as the
     # index tangent.
     x = jnp.ones((3, 4), jnp.float32)
     i = jnp.ones((3, ), jnp.int32)
     f = lambda x, i: jnp.sum(x[i])
     primals, tangents = api.jvp(api.grad(f), (x, i),
                                 (x, onp.zeros_like(i)))
     expected = onp.broadcast_to(
         onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
     self.assertAllClose(expected, primals, check_dtypes=True)
     self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
Пример #2
0
 def test_vmap_in_axes_tree_prefix_error(self):
   # https://github.com/google/jax/issues/795
   self.assertRaisesRegex(
       ValueError,
       'vmap in_axes specification must be a tree prefix of the corresponding '
       r'value, got specification \(0, 0\) for value tree ',
       lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3)))
Пример #3
0
  def testIssue187(self):
    x = jnp.ones((5, 5))
    x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

    x = onp.arange(25).reshape((5, 5))
    ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
    expected = x[[0, 2, 4], [0, 2, 4]]
    self.assertAllClose(ans, expected, check_dtypes=False)
Пример #4
0
    def test_vmap_in_axes_list(self):
        # https://github.com/google/jax/issues/2367
        dictionary = {'a': 5., 'b': tf_np.ones(2)}
        x = tf_np.zeros(3)
        y = tf_np.arange(3.)

        def f(dct, x, y):
            return dct['a'] + dct['b'] + x + y

        out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y)
        out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y)
        self.assertAllClose(out1, out2)
Пример #5
0
  def testIndexingEmptyDimension(self):
    # Issue 2671: XLA error when indexing into dimension of size 0
    x = jnp.ones((2, 0))
    # The following work, even on axis 1 of size 0
    _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]

    with self.assertRaisesRegex(IndexError,
                                "index .* is out of bounds for axis .* with size 0"):
      _ = onp.ones((2, 0))[0, 0]  # The numpy error
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      _ = x[0, 0]  # JAX indexing
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      npe.jit(lambda i: x[0, i])(0)  # JAX indexing under jit
Пример #6
0
    def testMap(self):
        shape = [2, 3]
        dtype = tf_np.int32
        xs1 = tf_np.zeros(shape, dtype)
        xs2 = tf_np.ones(shape, dtype)
        ys_expected = [xs2 + 10, xs1 + 20]

        def f(x):
            self.assertIsInstance(x, tuple)
            for a in x:
                self.assertEqual(a.shape, shape[1:])
            x1, x2 = x
            return [x2 + 10, x1 + 20]

        ys = extensions.tf_map(f, (xs1, xs2))
        self.assertIsInstance(ys, list)
        self.assertAllClose(ys, ys_expected)
Пример #7
0
 def testIndexOutOfBounds(self):  # https://github.com/google/jax/issues/2245
   array = jnp.ones(5)
   self.assertAllClose(array, array[:10], check_dtypes=True)