Exemple #1
0
    def test_gradients_with_custom_vjp(self, with_function=True):
        """Check gradients, for a function with custom VJP."""
        @jax.custom_vjp
        def f(x):
            return x * x

        # f_fwd: a -> (b, residual)
        def f_fwd(x):
            return f(x), 3. * x

        # f_bwd: (residual, CT b) -> [CT a]
        def f_bwd(residual, ct_b):
            return residual * ct_b,

        f.defvjp(f_fwd, f_bwd)

        self.assertAllClose(4. * 4., f(4.))
        self.assertAllClose(3. * 4., jax.grad(f)(4.))

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        self.assertAllClose(4. * 4., f_tf(4.))
        x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
        with tf.GradientTape() as tape:
            tape.watch(x)
            y = f_tf(x)

        self.assertAllClose(4. * 4., y)
        self.assertAllClose(3. * 4., tape.gradient(y, x))
 def check_avals(*, args: Sequence[jax2tf.jax2tf.TfVal],
                 polymorphic_shapes: Sequence[Optional[Union[str, PS]]],
                 expected_avals: Sequence[core.ShapedArray]):
   arg_dtypes = tuple(map(lambda a: jax2tf.jax2tf._to_jax_dtype(jax2tf.dtype_of_val(a)), args))
   avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env(
       args, arg_dtypes, polymorphic_shapes)  # The function under test
   self.assertEqual(expected_avals, avals)
Exemple #3
0
    def test_gradients_with_custom_jvp(self, with_function=True):
        """Check gradients, for a function with custom JVP."""
        @jax.custom_jvp
        def f(x):
            return x * x

        @f.defjvp
        def f_jvp(primals, tangents):
            # 3 * x * x_t
            x, = primals
            x_dot, = tangents
            primal_out = f(x)
            tangent_out = 3. * x * x_dot
            return primal_out, tangent_out

        self.assertAllClose(4. * 4., f(4.))
        self.assertAllClose(3. * 4., jax.grad(f)(4.))

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        self.assertAllClose(4. * 4., f_tf(4.))
        x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
        with tf.GradientTape() as tape:
            tape.watch(x)
            y = f_tf(x)

        self.assertAllClose(4. * 4., y)
        self.assertAllClose(3. * 4., tape.gradient(y, x))
Exemple #4
0
  def test_gradients(self, with_function=True):
    def f(x, y):
      return x * x, x * y
    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    default_float_type = jax2tf.dtype_of_val(4.)
    x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
    y = tf.Variable(5., dtype=default_float_type)
    with tf.GradientTape(persistent=True) as tape:
      u, v = f_tf(x, y)

    self.assertAllClose(2. * 4., tape.gradient(u, x))
    self.assertAllClose(0., tape.gradient(u, y))
    self.assertAllClose(5., tape.gradient(v, x))
    self.assertAllClose(4., tape.gradient(v, y))
Exemple #5
0
    def test_gradients_pytree(self, with_function=True):
        def f(xy: Tuple[float, float]) -> Dict[str, float]:
            x, y = xy
            return dict(one=x * x, two=x * y)

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        default_float_dtype = jax2tf.dtype_of_val(4.)
        x = tf.Variable(4., dtype=default_float_dtype)
        y = tf.Variable(5., dtype=default_float_dtype)
        with tf.GradientTape(persistent=True) as tape:
            uv = f_tf((x, y))

        self.assertAllClose(2. * 4., tape.gradient(uv["one"], x))
        self.assertAllClose(0., tape.gradient(uv["one"], y))
        self.assertAllClose(5., tape.gradient(uv["two"], x))
        self.assertAllClose(4., tape.gradient(uv["two"], y))
Exemple #6
0
  def test_gradients_with_ordered_dict_input(self, with_function=True):
    def f(inputs):
      out = 0.0
      for v in inputs.values():
        out += jnp.sum(v)
      return out

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    default_float_type = jax2tf.dtype_of_val(4.)
    inputs = OrderedDict()
    x = tf.Variable([4.], dtype=default_float_type)
    y = tf.Variable([4., 5.], dtype=default_float_type)
    inputs = OrderedDict()
    inputs['r'] = x
    inputs['d'] = y
    with tf.GradientTape(persistent=True) as tape:
      u = f_tf(inputs)

    self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
    self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
Exemple #7
0
 def test_variable_input(self):
     f_jax = lambda x: jnp.sin(jnp.cos(x))
     f_tf = jax2tf.convert(f_jax)
     v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7))
     self.assertIsInstance(f_tf(v), tf.Tensor)
     self.assertAllClose(f_jax(0.7), f_tf(v))
Exemple #8
0
 def _make_one_array_signature(tf_arg):
     return tf.TensorSpec(np.shape(tf_arg), jax2tf.dtype_of_val(tf_arg))
Exemple #9
0
 def _get_signature(tf_arg):
     return tf.TensorSpec(np.shape(tf_arg), jax2tf.dtype_of_val(tf_arg))