示例#1
0
    def test_while_batched(self, with_function=True):
        """A while with a single carry"""
        with jax_to_tf.enable_jit():

            def product(x, y):
                # Equivalent to "x * y" implemented as:
                #      res = 0.
                #      for(i=0; i < y; i++)
                #         res += x
                return lax.while_loop(
                    lambda idx_carry: idx_carry[0] < y, lambda idx_carry:
                    (idx_carry[0] + 1, idx_carry[1] + x), (0, 0.))

            # We use vmap to compute result[i, j] = i * j
            xs = np.arange(4, dtype=np.int32)
            ys = np.arange(5, dtype=np.int32)

            def product_xs_y(xs, y):
                return jax.vmap(product, in_axes=(0, None))(xs, y)

            def product_xs_ys(xs, ys):
                return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)

            f_jax = product_xs_ys
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            res_jax = f_jax(xs, ys)
            res_tf = f_tf(xs, ys)
            for r_tf, r_jax in zip(res_tf, res_jax):
                np.testing.assert_allclose(r_tf, r_jax)
示例#2
0
 def test_gather(self):
     values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
     indices = np.array([0, 1], dtype=np.int32)
     for axis in (0, 1):
         f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
         f_tf = tf.function(jax_to_tf.convert(f_jax))
         np.testing.assert_allclose(f_jax(values, indices),
                                    f_tf(values, indices))
示例#3
0
 def test_unary_elementwise(self, f_jax=lax.abs):
     x = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
                  dtype=np.float32)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     r_jax = f_jax(x)
     r_tf = f_tf(x)
     self.assertAllClose(r_jax[np.isfinite(r_jax)],
                         r_tf[np.isfinite(r_tf)],
                         atol=1e-4)
示例#4
0
 def test_bitwise_not(self):
     x = np.array([-1, 3, -2, 0, 0, 2, 1, 3], dtype=np.int32)
     f_jax = jax.jit(lax.bitwise_not)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     r_jax = f_jax(x)
     r_tf = f_tf(x)
     self.assertAllClose(r_jax[np.isfinite(r_jax)],
                         r_tf[np.isfinite(r_tf)],
                         atol=1e-4)
示例#5
0
 def test_boolean_gather(self):
     values = np.array([[True, True], [False, True], [False, False]],
                       dtype=np.bool)
     indices = np.array([0, 1], dtype=np.int32)
     for axis in [0, 1]:
         f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
         f_tf = tf.function(jax_to_tf.convert(f_jax))
         np.testing.assert_allclose(f_jax(values, indices),
                                    f_tf(values, indices))
示例#6
0
 def test_binary_logical_elementwise(self, f_jax):
     a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
     b = np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=np.uint32)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     r_jax = f_jax(a, b)
     r_tf = f_tf(a, b)
     self.assertAllClose(r_jax[np.isfinite(r_jax)],
                         r_tf[np.isfinite(r_tf)],
                         atol=1e-4)
示例#7
0
 def test_concat(self):
     values = [
         np.array([1, 2], dtype=np.float32),
         np.array([1, 2], dtype=np.int32),
         np.array([1, 2], dtype=np.int8)
     ]
     f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values), f_tf(values))
示例#8
0
 def test_gradients_disabled(self):
     f = jax_to_tf.convert(jnp.tan)
     x = tf.ones([])
     with tf.GradientTape() as tape:
         tape.watch(x)
         y = f(x)
     with self.assertRaisesRegex(
             ValueError, 'jax2tf currently does not support gradients'):
         tape.gradient(y, x)
示例#9
0
 def test_squeeze(self):
     shape = (2, 1, 3, 1)
     values = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
     for squeeze_dims in ((1, ), (3, ), (
             1,
             3,
     )):
         f_jax = jax.jit(lambda v: jnp.squeeze(v, axis=squeeze_dims))  # pylint: disable=cell-var-from-loop
         f_tf = tf.function(jax_to_tf.convert(f_jax))
         np.testing.assert_allclose(f_jax(values), f_tf(values))
示例#10
0
文件: stax_test.py 项目: samuela/jax
 def test_res_net(self):
     key = jax.random.PRNGKey(0)
     shape = (224, 224, 3, 1)
     init_fn, apply_fn = resnet50.ResNet50(1000)
     _, params = init_fn(key, shape)
     infer = functools.partial(apply_fn, params)
     images = np.array(jax.random.normal(key, shape))
     np.testing.assert_allclose(infer(images),
                                jax_to_tf.convert(infer)(images),
                                rtol=0.5)
示例#11
0
    def test_cond(self, with_function=False):
        with jax_to_tf.enable_jit():

            def f_jax(pred, x):
                return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            np.testing.assert_allclose(f_tf(True, 1.), f_jax(True, 1.))
            np.testing.assert_allclose(f_tf(False, 1.), f_jax(False, 1.))
示例#12
0
 def ConvertAndCompare(self, func_jax: Callable, *args,
                       with_function: bool = False,
                       atol=None,
                       rtol=None) -> Tuple[Any, Any]:
   """Compares jax_func(*args) with convert(jax_func)(*args)."""
   func_tf = jax_to_tf.convert(func_jax)
   if with_function:
     func_tf = tf.function(func_tf)
   res_jax = func_jax(*args)
   res_tf = func_tf(*args)
   self.assertAllClose(res_jax, res_tf, atol=atol, rtol=rtol)
   return (res_jax, res_tf)
示例#13
0
 def test_type_promotion(self, f_jax):
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     # We only test a few types here, as tensorflow does not support many
     # types like uint* or bool in binary ops.
     types = [np.int32, np.int64, np.float32]
     for x_dtype in types:
         for y_dtype in types:
             x = np.array([1, 2], dtype=x_dtype)
             y = np.array([3, 4], dtype=y_dtype)
             r_jax = f_jax(x, y)
             r_tf = f_tf(x, y)
             np.testing.assert_allclose(r_jax, r_tf)
示例#14
0
 def test_trinary_elementwise(self, f_jax):
     a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6],
                  dtype=np.float32)
     b = np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6],
                  dtype=np.float32)
     c = np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
                  dtype=np.float32)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     r_jax = f_jax(a, b, c)
     r_tf = f_tf(a, b, c)
     self.assertAllClose(r_jax[np.isfinite(r_jax)],
                         r_tf[np.isfinite(r_tf)],
                         atol=1e-4)
示例#15
0
 def testSavedModel(self):
     f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
     model = tf.Module()
     model.f = tf.function(jax_to_tf.convert(f_jax),
                           input_signature=[tf.TensorSpec([], tf.float32)])
     x = np.array(0.7)
     np.testing.assert_allclose(model.f(x), f_jax(x))
     # Roundtrip through saved model on disk.
     model_dir = os.path.join(absltest.get_default_test_tmpdir(),
                              str(id(model)))
     tf.saved_model.save(model, model_dir)
     restored_model = tf.saved_model.load(model_dir)
     np.testing.assert_allclose(restored_model.f(x), f_jax(x))
示例#16
0
 def test_prngsplit(self):
     f_jax = jax.jit(lambda key: jax.random.split(key, 2))
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     for rng_key in [
             jax.random.PRNGKey(42),
             np.array([0, 0], dtype=np.uint32),
             np.array([0xFFFFFFFF, 0], dtype=np.uint32),
             np.array([0, 0xFFFFFFFF], dtype=np.uint32),
             np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
     ]:
         jax_keys = f_jax(rng_key)
         tf_keys = f_tf(rng_key)
         for jax_key, tf_key in zip(jax_keys, tf_keys):
             np.testing.assert_equal(jax_key, tf_key)
示例#17
0
    def test_while_single_carry(self, with_function=False):
        """A while with a single carry"""
        with jax_to_tf.enable_jit():

            def func(x):
                # Equivalent to:
                #      for(i=x; i < 4; i++);
                return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)

            f_jax = func
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            res_jax = f_jax(0)
            res_tf = f_tf(0)
            np.testing.assert_allclose(res_jax, res_tf)
示例#18
0
    def test_scan(self, with_function=False):
        def f_jax(xs, ys):
            # Equivalent to:
            #    res = 0.
            #    for x, y in zip(xs, ys):
            #      res += x * y
            def body(carry, inputs):
                x, y = inputs
                return carry + x * y, carry

            return lax.scan(body, 0., (xs, ys))

        f_tf = jax_to_tf.convert(f_jax)
        if with_function:
            f_tf = tf.function(f_tf)
        arg = np.arange(10, dtype=np.float32)
        res_jax = f_jax(arg, arg)
        res_tf = f_tf(arg, arg)
        for r_jax, r_tf in zip(res_jax, res_tf):
            np.testing.assert_allclose(r_tf, r_jax)
示例#19
0
 def test_binary_elementwise(self, f_jax=lax.add):
     a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
                  dtype=np.float32)
     b = np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1, 1.4, -1.6],
                  dtype=np.float32)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     r_jax = f_jax(a, b)
     r_tf = f_tf(a, b)
     # Jax outputs 0 and 1 instead of NaN for values outside the domain.
     # Whereas tensorflow does this for other combinations,
     if f_jax in (lax.igamma, lax.igammac):
         # Make returned array writeable.
         r_jax = np.copy(r_jax)
         r_jax[r_jax == 0] = np.nan
         r_jax[r_jax == 1] = np.nan
         r_tf = np.copy(r_tf)
         r_tf[r_tf == 0] = np.nan
         r_tf[r_tf == 1] = np.nan
     self.assertAllClose(r_jax[np.isfinite(r_jax)],
                         r_tf[np.isfinite(r_tf)],
                         atol=1e-4)
示例#20
0
    def test_while(self, with_function=False):
        with jax_to_tf.enable_jit():
            # Some constants to capture in the conditional branches
            cond_const = np.ones(3, dtype=np.float32)
            body_const1 = np.full_like(cond_const, 1.)
            body_const2 = np.full_like(cond_const, 2.)

            def func(x):
                # Equivalent to:
                #      c = [1, 1, 1]
                #      for(i=0; i < 3; i++)
                #        c += [1, 1, 1] + [2, 2, 2]
                #
                # The function is set-up so that it captures constants in the
                # body of the functionals. This covers some cases in the representation
                # of the lax.while primitive.
                def cond(idx_carry):
                    i, c = idx_carry
                    return i < jnp.sum(lax.tie_in(
                        i, cond_const))  # Capture cond_const

                def body(idx_carry):
                    i, c = idx_carry
                    return (i + 1, c + body_const1 + body_const2)

                return lax.while_loop(cond, body, (0, x))

            f_jax = func
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            input = cond_const
            res_jax = f_jax(input)
            res_tf = f_tf(input)
            for r_jax, r_tf in zip(res_jax, res_tf):
                np.testing.assert_allclose(r_jax, r_tf)
示例#21
0
 def test_pad(self):
     values = np.array([1, 2], dtype=np.float32)
     f_jax = jax.jit(lambda x: jax.lax.pad(x, 0.0, [(3, 1, 2)]))
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values), f_tf(values))
示例#22
0
 def test_reduce_ops_with_numerical_input(self, f_jax):
     values = [np.array([1, 2, 3], dtype=np.float32)]
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values), f_tf(values))
示例#23
0
 def test_cumulated_ops(self, f_jax):
     values = np.array([1, 2, 3], dtype=np.float32)
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values), f_tf(values))
示例#24
0
 def test_nested_jit(self):
     f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
     f_tf = jax_to_tf.convert(f_jax)
     np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
示例#25
0
 def test_variable_input(self):
     f_jax = lambda x: jnp.sin(jnp.cos(x))
     f_tf = jax_to_tf.convert(f_jax)
     v = tf.Variable(0.7)
     self.assertIsInstance(f_tf(v), tf.Tensor)
     self.assertAllClose(f_jax(0.7), f_tf(v))
示例#26
0
 def test_gather_rank_change(self):
     params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
     indices = jnp.array([[1, 1, 2], [0, 1, 0]])
     f_jax = jax.jit(lambda i: params[i])
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(indices), f_tf(indices))
示例#27
0
 def test_scatter_static(self, op):
     values = np.ones((5, 6), dtype=np.float32)
     update = np.float32(6.)
     f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values, update), f_tf(values, update))
示例#28
0
 def test_zeros_like(self):
     v = np.float32(2.)
     f_jax = jax.ad_util.zeros_like_jaxval
     f_tf = jax_to_tf.convert(f_jax)
     self.assertEqual(f_jax(v), f_tf(v))
示例#29
0
 def test_stop_gradient(self):
     f = jax_to_tf.convert(lax.stop_gradient)
     self.assertEqual(f(tf.ones([])), 1.)
示例#30
0
 def test_reduce_ops_with_boolean_input(self, f_jax):
     values = [np.array([True, False, True], dtype=np.bool)]
     f_tf = tf.function(jax_to_tf.convert(f_jax))
     np.testing.assert_allclose(f_jax(values), f_tf(values))