コード例 #1
0
ファイル: call_tf_test.py プロジェクト: zizai/jax
    def test_grad_pytree(self, with_jit=False):
        def fun_tf(x: Dict, y: Tuple) -> Tuple:
            return (x["first"] * x["second"] + 3. * y[0] + 4. * y[1])

        x = dict(first=np.float32(3.), second=np.float32(4.))
        y = (np.float32(5.), np.float32(6.))
        grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(fun_tf)))(x, y)
        self.assertAllClose(dict(first=np.float32(4.), second=np.float32(3.)),
                            grad_x)
コード例 #2
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_with_value_capture(self, with_jit=True):
    outer_val = np.array(3., dtype=np.float32)

    def fun_tf(x):
      return x * outer_val + 1.

    x = np.float32(2.)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
コード例 #3
0
ファイル: call_tf_test.py プロジェクト: zizai/jax
 def test_eval_devicearray_no_copy(self):
     if jtu.device_under_test() != "cpu":
         # TODO(necula): add tests for GPU and TPU
         raise unittest.SkipTest("no_copy test works only on CPU")
     # For DeviceArray zero-copy works even if not aligned
     x = jnp.ones((3, 3), dtype=np.float32)
     res = jax2tf.call_tf(lambda x: x)(x)
     self.assertAllClose(x, res)
     self.assertTrue(np.shares_memory(x, res))
コード例 #4
0
    def test_bool(self, with_jit=False):
        def fun_tf(x, y):
            return tf.math.logical_and(x, y)

        x = np.array([True, False, True, False], dtype=np.bool_)
        y = np.array([True, True, False, False], dtype=np.bool_)
        res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x, y)
        self.assertAllClose(
            np.array([True, False, False, False], dtype=np.bool_), res)
コード例 #5
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_with_tensor_capture_x64(self, with_jit=True):
    outer_tensor = tf.constant(3., dtype=np.float64)

    def fun_tf(x):
      return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1.

    x = np.float32(2.)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
コード例 #6
0
    def test_with_var_read(self, with_jit=True):
        outer_var = tf.Variable(3., dtype=np.float32)

        def fun_tf(x):
            return x * outer_var + 1.

        x = np.float32(2.)
        res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
        self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
コード例 #7
0
    def test_with_tensor_capture(self, with_jit=False):
        outer_tensor = tf.constant(3., dtype=np.float32)

        def fun_tf(x):
            return x * outer_tensor + 1.

        x = np.float32(2.)
        res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
        self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
コード例 #8
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_saved_model_simple(self):
    x = np.array([0.7, 0.8], dtype=np.float32)
    def f_jax(x):
      return jnp.sin(x)

    f_tf = jax2tf.convert(f_jax)
    restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
    restored_jax = jax2tf.call_tf(restored_tf)
    self.assertAllClose(f_jax(x), restored_jax(x))
コード例 #9
0
ファイル: call_tf_test.py プロジェクト: zizai/jax
    def test_eval_pytree(self, with_jit=True):
        def fun_tf(x: Dict, y: Tuple) -> Tuple:
            return (x["first"] * x["second"], y[0] + y[1])

        x = dict(first=np.float32(3.), second=np.float32(4.))
        y = (np.float64(5.), np.float64(6.))
        fun_jax = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))
        res = fun_jax(x, y)
        self.assertAllClose((np.float32(12.), np.float64(11.)), res)
コード例 #10
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_function_compile_time_constant_inputs(self):
    # Call a function for which shape inference does not give an output
    # shape.
    x = np.array([1, 2, 3], dtype=np.int32)
    def fun_tf(x):  # x:i32[3]
      # Indexing with a dynamic slice makes the TF shape inference return
      # a partially known shape.
      end_idx = x[1]
      res = x[0:end_idx]
      return res

    # Call in eager mode. Should work!
    res1 = jax2tf.call_tf(fun_tf)(x)
    self.assertAllClose(x[0:x[1]], res1)

    # Now under jit, should fail because the function is not compileable
    with self.assertRaisesRegex(ValueError,
                                "Compiled TensorFlow function has unexpected parameter types"):
      fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
      fun_jax(x)
コード例 #11
0
    def test_eval_non_compileable_dynamic_shape(self):
        # Check that in op-by-op we call a function in eager mode.
        def f_tf_non_compileable(x):
            return tf.cond(x[0], lambda: x[1:], lambda: x)

        f_jax = jax2tf.call_tf(f_tf_non_compileable)
        x = np.array([True, False], dtype=np.bool_)
        self.assertAllClose(f_tf_non_compileable(x), f_jax(x))

        with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
            jax.jit(f_jax)(x)
コード例 #12
0
ファイル: call_tf_test.py プロジェクト: zhaowilliam/jax
  def test_with_var_read(self, with_jit=True):
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    outer_var = tf.Variable(3., dtype=np.float32)

    def fun_tf(x):
      return x * outer_var + 1.

    x = np.float32(2.)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
コード例 #13
0
ファイル: call_tf_test.py プロジェクト: vishalbelsare/jax
    def test_round_trip_saved_model_no_gradients(self):
        # Save without gradients
        f_jax = jnp.sum

        x = np.array([0.7, 0.8], dtype=np.float32)
        f_tf = tf_test_util.SaveAndLoadFunction(
            jax2tf.convert(f_jax, with_gradient=True),
            [tf.TensorSpec(x.shape, dtype=x.dtype)],
            save_gradients=False)
        f_rt = jax2tf.call_tf(f_tf)

        self.assertAllClose(f_jax(x), f_rt(x))
コード例 #14
0
ファイル: call_tf_test.py プロジェクト: zstreeter/jax
    def test_with_var_write_error(self, with_jit=True):
        if with_jit:
            raise unittest.SkipTest("variable writes not yet working")
        outer_var = tf.Variable(3., dtype=np.float32)

        def fun_tf(x):
            outer_var.assign(tf.constant(4.))
            return x * outer_var + 1.

        x = np.float32(2.)
        res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
        self.assertAllClose(x * 4. + 1, res, check_dtypes=False)
コード例 #15
0
    def test_with_multiple_capture(self, with_jit=True):
        v2 = tf.Variable(2., dtype=np.float32)
        v3 = tf.Variable(3., dtype=np.float32)
        t4 = tf.constant(4., dtype=np.float32)
        t5 = tf.constant(5., dtype=np.float32)

        def fun_tf(x):
            return (x * v2 + t4) * v3 + t5

        x = np.float32(2.)
        res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
        self.assertAllClose((x * 2. + 4.) * 3. + 5., res, check_dtypes=False)
コード例 #16
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_with_var_read_x64(self, with_jit=True):
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    outer_var_array = np.array([3., 4.], dtype=np.float64)
    outer_var = tf.Variable(outer_var_array)

    def fun_tf(x):
      return x * tf.cast(outer_var, x.dtype) + 1.

    x = np.array([2., 5.,], dtype=np.float32)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
コード例 #17
0
    def test_shape_polymorphism_error(self):
        x = np.array([.7, .8], dtype=np.float32)

        def fun_tf(x):
            return tf.math.sin(x)

        fun_jax = jax2tf.call_tf(fun_tf)

        fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])
        with self.assertRaisesRegex(
                ValueError,
                "call_tf cannot be applied to shape-polymorphic arguments"):
            fun_tf_rt(x)
コード例 #18
0
ファイル: call_tf_test.py プロジェクト: zhaowilliam/jax
  def test_with_multiple_capture(self, with_jit=True):
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    v2 = tf.Variable(2., dtype=np.float32)
    v3 = tf.Variable(3., dtype=np.float32)
    t4 = tf.constant(4., dtype=np.float32)
    t5 = tf.constant(5., dtype=np.float32)

    def fun_tf(x):
      return (x * v2 + t4) * v3 + t5

    x = np.float32(2.)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose((x * 2. + 4.) * 3. + 5., res, check_dtypes=False)
コード例 #19
0
ファイル: call_tf_test.py プロジェクト: zstreeter/jax
    def test_grad_custom(self, with_jit=False):
        @tf.custom_gradient
        def func_square_tf(x):
            # Like x ** 2, but with custom grad 3. * x
            def grad(dy, variables=None):
                # dy, = dys
                return 3. * x * dy,

            return x * x, grad

        x = np.float32(4.)
        grad_x = _maybe_jit(with_jit,
                            jax.grad(jax2tf.call_tf(func_square_tf)))(x)
        self.assertAllClose(np.float32(3.) * x, grad_x)
コード例 #20
0
    def test_saved_model_variables(self):
        param = np.array([1., 2.], dtype=np.float32)
        x = np.array([0.7, 0.8], dtype=np.float32)

        def f_jax(param, x):
            return jnp.sin(x) + jnp.cos(param)

        param_v = tf.Variable(param)
        f_tf = jax2tf.convert(f_jax)
        _, restored_model = tf_test_util.SaveAndLoadFunction(
            lambda x: f_tf(param_v, x), input_args=[x], variables=[param_v])
        restored_jax = jax2tf.call_tf(restored_model.f)
        self.assertAllClose(f_jax(param, x), restored_jax(x))
        self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
コード例 #21
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_without_gradient_saved_model(self):
    # Explicitly with_gradient=False
    f_jax = jnp.sum

    x = np.array([0.7, 0.8], dtype=np.float32)
    f_tf, _ = tf_test_util.SaveAndLoadFunction(
        jax2tf.convert(f_jax, with_gradient=False),
        input_args=[x])
    f_rt = jax2tf.call_tf(f_tf)

    self.assertAllClose(f_jax(x), f_rt(x))
    with self.assertRaisesRegex(Exception,
                                "Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
      jax.grad(f_rt)(x)
コード例 #22
0
    def test_function_dynamic_shape(self):
        # Call a function for which shape inference does not give an output
        # shape.
        x = np.array([-1, 0, 1], dtype=np.int32)

        def fun_tf(x):  # x:i32[3]
            # The shape depends on the value of x
            return tf.cond(x[0] >= 0, lambda: x, lambda: x[1:])

        # Call in eager mode. Should work!
        res1 = jax2tf.call_tf(fun_tf)(x)
        expected = x[1:]
        self.assertAllClose(expected, res1, check_dtypes=False)

        # Now under jit, should fail because the function is not compileable
        with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
            fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
            fun_jax(x)

        # TODO(necula): this should work in op-by-op mode, but it fails because
        # jax2tf.convert does abstract evaluation.
        with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
            fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
            fun_tf_rt(x)
コード例 #23
0
    def test_eval_non_compileable_strings(self):
        # Check that in op-by-op we call a function in eager mode.
        def f_tf_non_compileable(x):
            return tf.strings.length(tf.strings.format("Hello {}!", [x]))

        f_jax = jax2tf.call_tf(f_tf_non_compileable)
        x = np.float32(0.7)
        self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x))
        with self.assertRaisesRegex(ValueError,
                                    CallTfTest.call_tf_non_compileable):
            jax.jit(f_jax)(x)

        with self.assertRaisesRegex(ValueError,
                                    CallTfTest.call_tf_non_compileable):
            lax.cond(True, lambda x: f_jax(x), lambda x: f_jax(x), x)
コード例 #24
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_with_var_different_shape(self):
    # See https://github.com/google/jax/issues/6050
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    v = tf.Variable((4., 2.), dtype=tf.float32)

    def tf_func(x):
      return v + x
    x = np.float32(123.)
    tf_out = tf_func(x)

    jax_func = jax.jit(jax2tf.call_tf(tf_func))
    jax_out = jax_func(x)

    self.assertAllClose(tf_out, jax_out, check_dtypes=False)
コード例 #25
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_grad_int_argument(self):
    # Similar to https://github.com/google/jax/issues/6975
    # state is a pytree that contains an integer and a boolean.
    # The function returns an integer and a boolean.
    def f(param, state, x):
      return param * x, state

    param = np.array([0.7, 0.9], dtype=np.float32)
    state = dict(array=np.float32(1.), counter=7, truth=True)
    x = np.float32(3.)

    # tf.function is important, without it the bug does not appear
    f_call_tf = jax2tf.call_tf(f)
    g_call_tf = jax.grad(lambda *args: jnp.sum(f_call_tf(*args)[0]))(param, state, x)
    g = jax.grad(lambda *args: jnp.sum(f(*args)[0]))(param, state, x)
    self.assertAllClose(g_call_tf, g)
コード例 #26
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_saved_model_no_gradients(self):
    # Save without gradients
    f_jax = jnp.sum

    x = np.array([0.7, 0.8], dtype=np.float32)
    f_tf, _ = tf_test_util.SaveAndLoadFunction(
        jax2tf.convert(f_jax, with_gradient=True), input_args=[x],
        save_gradients=False)
    f_rt = jax2tf.call_tf(f_tf)

    self.assertAllClose(f_jax(x), f_rt(x))
    # TODO: clean this up b/191117111: it should fail with a clear error
    # The following results in a confusing error:
    # TypeError: tf.Graph captured an external symbolic tensor.
    with self.assertRaises(TypeError):
      _ = jax.grad(f_rt)(x)
コード例 #27
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
 def test_repro_193754660(self):
   # Try to reproduce b/193754660. I can't.
   # We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))).
   # The get_compiler_ir will indeed fail for f_tf. Then we try to use
   # shape inference for f_tf.
   # I thought to use a f_tf that uses an op without shape inference, e.g.,
   # tfxla.gather. If we wash it through a saved_model I expect that shape
   # inference would not work on it. Instead, shape inference works!!!
   x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)
   def f_jax(x):
     return x[1]
   f_tf = jax2tf.convert(f_jax)
   f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
   f_jax2 = jax2tf.call_tf(f_tf_rt)
   f_tf2 = jax2tf.convert(f_jax2)
   res = tf.function(f_tf2, autograph=False)(x)
   self.assertAllClose(res.numpy(), f_jax(x))
コード例 #28
0
ファイル: call_tf_test.py プロジェクト: matthewfeickert/jax
  def test_custom_grad(self):
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), np.float32(3.) * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)

    f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
    x = np.float32(0.7)
    self.assertAllClose(f(x), f_rt(x))
    self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
コード例 #29
0
    def test_eval_non_compileable_dynamic_shape(self):
        # Check that in op-by-op we call a function in eager mode.
        def f_tf_non_compileable(x):
            return tf.where(x)

        f_jax = jax2tf.call_tf(f_tf_non_compileable)
        x = np.array([True, False], dtype=np.bool_)
        self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x))

        if jtu.device_under_test() == "tpu":
            # TODO: This works on TPU!!!
            self.assertAllClose(
                f_tf_non_compileable(x).numpy(),
                jax.jit(f_jax)(x))
        else:
            with self.assertRaisesRegex(ValueError,
                                        CallTfTest.call_tf_non_compileable):
                jax.jit(f_jax)(x)
コード例 #30
0
    def test_several_round_trips(self,
                                 f2_function=False,
                                 f2_saved_model=False,
                                 f4_function=False,
                                 f4_saved_model=False):
        x = np.array(.7, dtype=np.float32)

        # f(n)(x) = 2. * x^n
        def f(n):
            def fn(x):
                acc = np.array(2., dtype=x.dtype)
                for i in range(n):
                    acc *= x
                return acc

            return fn

        f2_tf = lambda x: x * jax2tf.convert(f(1))(x)
        if f2_function:
            f2_tf = tf.function(f2_tf, autograph=False)
        if f2_saved_model:
            f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x])

        self.assertAllClose(f(2)(x), f2_tf(x).numpy())
        _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x])
        self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy())

        f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x)
        self.assertAllClose(f(3)(x), f3_jax(x))
        self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x))
        self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x))

        f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x)
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())

        if f4_function:
            f4_tf = tf.function(f4_tf, autograph=False)
        if f4_saved_model:
            f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x])
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())