예제 #1
0
    def test_cond_multiple_results(self):
        def f_jax(pred, x):
            return lax.cond(pred, lambda t: (t + 1., 1.), lambda f:
                            (f + 2., 2.), x)

        self.ConvertAndCompare(f_jax, True, jnp.float_(1.))
        self.ConvertAndCompare(f_jax, False, jnp.float_(1.))
예제 #2
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_pytrees(self):
    # Take and return pytrees
    def f_jax(x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]:
      x_a, x_dict = x
      return x_a * 2., {k : v * 3. for k, v in x_dict.items()}

    x = (jnp.float_(.7), {"a": jnp.float_(.8), "b": jnp.float_(.9)})
    self.ConvertAndCompare(f_jax, x)
예제 #3
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_gradients_with_custom_vjp(self, with_function=True):
    """Check gradients, for a function with custom VJP."""
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
    x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))
예제 #4
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_gradients_with_custom_jvp(self, with_function=True):
    """Check gradients, for a function with custom JVP."""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      # 3 * x * x_t
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
    x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))
예제 #5
0
  def test_cond_custom_jvp(self):
    """Conversion of function with custom JVP, inside cond.
    This exercises the custom_jvp_call_jaxpr primitives."""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    def g(x):
      return lax.cond(True, f, lambda y: y, x)

    arg = jnp.float_(0.7)
    self.TransformConvertAndCompare(g, arg, None)
    self.TransformConvertAndCompare(g, arg, "jvp")
    self.TransformConvertAndCompare(g, arg, "vmap")
    self.TransformConvertAndCompare(g, arg, "jvp_vmap")
    self.TransformConvertAndCompare(g, arg, "grad")
    self.TransformConvertAndCompare(g, arg, "grad_vmap")
예제 #6
0
  def test_while_custom_jvp(self):
    """Conversion of function with custom JVP, inside while.
    This exercises the custom_jvp_call_jaxpr primitives."""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    def g(x):
      return lax.while_loop(lambda carry: carry[0] < 10,
                            lambda carry: (carry[0] + 1., f(carry[1])),
                            (0., x))

    arg = jnp.float_(0.7)
    self.TransformConvertAndCompare(g, arg, None)
    self.TransformConvertAndCompare(g, arg, "jvp")
    self.TransformConvertAndCompare(g, arg, "vmap")
    self.TransformConvertAndCompare(g, arg, "jvp_vmap")
예제 #7
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_remat_free_var(self):
    def f(x):
      y = 2 * x

      @jax.remat
      def g():
        return y

      return g()
    arg = jnp.float_(3.)
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "grad")
예제 #8
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_custom_vjp(self):
    """Conversion of function with custom VJP"""
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)
    arg = jnp.float_(0.7)
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "vmap")
    self.TransformConvertAndCompare(f, arg, "grad")
    self.TransformConvertAndCompare(f, arg, "grad_vmap")
예제 #9
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
  def test_custom_jvp(self):
    """Conversion of function with custom JVP"""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    arg = jnp.float_(0.7)
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "jvp")
    self.TransformConvertAndCompare(f, arg, "vmap")
    self.TransformConvertAndCompare(f, arg, "jvp_vmap")
    self.TransformConvertAndCompare(f, arg, "grad")
    self.TransformConvertAndCompare(f, arg, "grad_vmap")
예제 #10
0
  def test_cond_custom_vjp(self):
    """Conversion of function with custom VJP, inside cond.
    This exercises the custom_vjp_call_jaxpr primitives."""
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)

    def g(x):
      return lax.cond(True, f, lambda y: y, x)

    arg = jnp.float_(0.7)
    self.TransformConvertAndCompare(g, arg, None)
    self.TransformConvertAndCompare(g, arg, "vmap")
    self.TransformConvertAndCompare(g, arg, "grad_vmap")
예제 #11
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
 def test_jit(self):
   f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
   self.ConvertAndCompare(f_jax, jnp.float_(0.7))
예제 #12
0
파일: jax2tf_test.py 프로젝트: d3sm0/jax
 def test_basics(self):
   f_jax = lambda x: jnp.sin(jnp.cos(x))
   _, res_tf = self.ConvertAndCompare(f_jax, jnp.float_(0.7))
예제 #13
0
 def test_basics(self):
   f_jax = lambda x: jnp.sin(jnp.cos(x))
   _, res_tf = self.ConvertAndCompare(f_jax, jnp.float_(0.7))
   self.assertIsInstance(res_tf, tf.Tensor)
예제 #14
0
  def test_cond_units(self, with_function=True):
    def g(x):
      return lax.cond(True, lambda x: x, lambda y: y, x)

    self.ConvertAndCompare(g, jnp.float_(0.7))
    self.ConvertAndCompare(jax.grad(g), jnp.float_(0.7))
예제 #15
0
 def test_cond_partial_eval(self):
   def f(x):
     res = lax.cond(True, lambda op: op * x, lambda op: op + x, x)
     return res
   self.ConvertAndCompare(jax.grad(f), jnp.float_(1.))
예제 #16
0
  def test_cond(self):
    def f_jax(pred, x):
      return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

    self.ConvertAndCompare(f_jax, True, jnp.float_(1.))
    self.ConvertAndCompare(f_jax, False, jnp.float_(1.))