def test_gradients_with_custom_vjp(self, with_function=True): """Check gradients, for a function with custom VJP.""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(4.)) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x))
def check_avals(*, args: Sequence[jax2tf.jax2tf.TfVal], polymorphic_shapes: Sequence[Optional[Union[str, PS]]], expected_avals: Sequence[core.ShapedArray]): arg_dtypes = tuple(map(lambda a: jax2tf.jax2tf._to_jax_dtype(jax2tf.dtype_of_val(a)), args)) avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env( args, arg_dtypes, polymorphic_shapes) # The function under test self.assertEqual(expected_avals, avals)
def test_gradients_with_custom_jvp(self, with_function=True): """Check gradients, for a function with custom JVP.""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(4.)) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x))
def test_gradients(self, with_function=True): def f(x, y): return x * x, x * y f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_type = jax2tf.dtype_of_val(4.) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) y = tf.Variable(5., dtype=default_float_type) with tf.GradientTape(persistent=True) as tape: u, v = f_tf(x, y) self.assertAllClose(2. * 4., tape.gradient(u, x)) self.assertAllClose(0., tape.gradient(u, y)) self.assertAllClose(5., tape.gradient(v, x)) self.assertAllClose(4., tape.gradient(v, y))
def test_gradients_pytree(self, with_function=True): def f(xy: Tuple[float, float]) -> Dict[str, float]: x, y = xy return dict(one=x * x, two=x * y) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_dtype = jax2tf.dtype_of_val(4.) x = tf.Variable(4., dtype=default_float_dtype) y = tf.Variable(5., dtype=default_float_dtype) with tf.GradientTape(persistent=True) as tape: uv = f_tf((x, y)) self.assertAllClose(2. * 4., tape.gradient(uv["one"], x)) self.assertAllClose(0., tape.gradient(uv["one"], y)) self.assertAllClose(5., tape.gradient(uv["two"], x)) self.assertAllClose(4., tape.gradient(uv["two"], y))
def test_gradients_with_ordered_dict_input(self, with_function=True): def f(inputs): out = 0.0 for v in inputs.values(): out += jnp.sum(v) return out f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_type = jax2tf.dtype_of_val(4.) inputs = OrderedDict() x = tf.Variable([4.], dtype=default_float_type) y = tf.Variable([4., 5.], dtype=default_float_type) inputs = OrderedDict() inputs['r'] = x inputs['d'] = y with tf.GradientTape(persistent=True) as tape: u = f_tf(inputs) self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy()) self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
def test_variable_input(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax) v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7)) self.assertIsInstance(f_tf(v), tf.Tensor) self.assertAllClose(f_jax(0.7), f_tf(v))
def _make_one_array_signature(tf_arg): return tf.TensorSpec(np.shape(tf_arg), jax2tf.dtype_of_val(tf_arg))
def _get_signature(tf_arg): return tf.TensorSpec(np.shape(tf_arg), jax2tf.dtype_of_val(tf_arg))