Пример #1
0
 def testIndexDtypeError(self):
     # https://github.com/google/jax/issues/2795
     jnp.array(1)  # get rid of startup warning
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("error")
         jnp.zeros(5).at[::2].set(1)
         self.assertLen(w, 0)
Пример #2
0
 def testFloatIndexingError(self):
   error_regex = "only integers, slices.*are valid indices"
   # Verify onp behavior
   with self.assertRaisesRegex(IndexError, error_regex):
     _ = onp.zeros((2, 2))[(0, 0.)]
   # Test jnp
   with self.assertRaisesRegex(IndexError, error_regex):
     jnp.zeros(2)[0.]
   with self.assertRaisesRegex(IndexError, error_regex):
     jnp.zeros((2, 2))[(0, 0.)]
   # Test with jit
   with self.assertRaisesRegex(IndexError, error_regex):
     npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
Пример #3
0
    def _testEvalOnShapes(self, transformer, allow_static_outputs):

        # A class that's not convertable to tensor
        class Thing:
            def __init__(self, value):
                self.value = value

        def f(a, b, reverse=False):
            res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
            res = (res, 10)
            if allow_static_outputs:
                res = res + (Thing(20), )
            if reverse:
                res = tuple(reversed(res))
            return res

        f_prime = transformer(f,
                              static_argnums=(2, ),
                              allow_static_outputs=allow_static_outputs)
        shape = [10]
        dtype = np.float16
        a = tf_np.zeros(shape=shape, dtype=dtype)
        b = tf_np.zeros(shape=shape, dtype=dtype)
        expected, *_ = f(a, b)
        got = f_prime(a, b)

        def check(got):
            self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray))
            self.assertAllEqual(expected.shape, got[0].shape)
            self.assertAllEqual(expected.dtype, got[0].dtype)
            if allow_static_outputs:
                self.assertIsInstance(got[1], int)
                self.assertEqual(10, got[1])
                self.assertIsInstance(got[2], Thing)
                self.assertEqual(20, got[2].value)
            else:
                self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray))
                self.assertAllEqual((), got[1].shape)

        check(got)
        # Call again since the code path is different on second call
        got = f_prime(a, b)
        check(got)
        # Retrace and check again
        got = f_prime(a, b, True)
        check(tuple(reversed(got)))
        got = f_prime(a, b, True)
        check(tuple(reversed(got)))
Пример #4
0
    def _testEvalOnShapes(self, transformer):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_prime = transformer(f)
        shape = [10]
        dtype = np.float16
        a = tf_np.zeros(shape=shape, dtype=dtype)
        b = tf_np.zeros(shape=shape, dtype=dtype)
        expected = f(a, b)
        got = f_prime(a, b)
        self.assertAllEqual(expected.shape, got.shape)
        self.assertAllEqual(expected.dtype, got.dtype)
        # Call again since the code path is different on second call
        got = f_prime(a, b)
        self.assertAllEqual(expected.shape, got.shape)
        self.assertAllEqual(expected.dtype, got.dtype)
Пример #5
0
 def testFloatIndexingError(self):
     BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros(2)[0.]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros((2, 2))[(0, 0.)]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros((2, 2))[(0, 0.)]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         ops.index_add(jnp.zeros(2), 0., 1.)
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         ops.index_update(jnp.zeros(2), 0., 1.)
Пример #6
0
    def test_vmap_in_axes_list(self):
        # https://github.com/google/jax/issues/2367
        dictionary = {'a': 5., 'b': tf_np.ones(2)}
        x = tf_np.zeros(3)
        y = tf_np.arange(3.)

        def f(dct, x, y):
            return dct['a'] + dct['b'] + x + y

        out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y)
        out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y)
        self.assertAllClose(out1, out2)
Пример #7
0
    def testMap(self):
        shape = [2, 3]
        dtype = tf_np.int32
        xs1 = tf_np.zeros(shape, dtype)
        xs2 = tf_np.ones(shape, dtype)
        ys_expected = [xs2 + 10, xs1 + 20]

        def f(x):
            self.assertIsInstance(x, tuple)
            for a in x:
                self.assertEqual(a.shape, shape[1:])
            x1, x2 = x
            return [x2 + 10, x1 + 20]

        ys = extensions.tf_map(f, (xs1, xs2))
        self.assertIsInstance(ys, list)
        self.assertAllClose(ys, ys_expected)
Пример #8
0
 def f_prime(*args):
   res = extensions.eval_on_shapes(f, **kwargs)(*args)
   return tf.nest.map_structure(
       lambda x: tf_np.zeros(x.shape, x.dtype), res)
Пример #9
0
 def zeros(x):
   return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x)
Пример #10
0
 def f_prime(a, b):
     shape_dtype = extensions.eval_on_shapes(f)(a, b)
     return tf_np.zeros(shape=shape_dtype.shape,
                        dtype=shape_dtype.dtype)