Пример #1
0
  def testRewriteThroughCustomVJPInScan(self):

    @jax.custom_gradient
    def foo(x):
      return x * 2, lambda g: g + x

    def f(x):
      out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
      return out

    x = 2.
    self.assertAllClose(f(x), 4.)
    self.assertAllClose(grad(f)(x), 3.)

    rewrites = {
        lax.mul_p: lambda x, y: x / y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 1.)
    self.assertAllClose(grad(g)(x), 3.)

    rewrites = {
        lax.add_p: lambda x, y: x * y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 4.)
    self.assertAllClose(grad(g)(x), 2.)
Пример #2
0
  def testRewriteThroughCustomVJP(self):

    @jax.custom_gradient
    def f(x):
      return x * 2, lambda g: g + x

    x = 2.
    self.assertAllClose(f(x), 4.)
    self.assertAllClose(grad(f)(x), 3.)

    rewrites = {
        lax.mul_p: lambda x, y: x / y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 1.)
    self.assertAllClose(grad(g)(x), 3.)

    rewrites = {
        lax.add_p: lambda x, y: x - y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 4.)
    self.assertAllClose(grad(g)(x), -1.)
Пример #3
0
    def testRewriteThroughCustomJVP(self):
        @jax.custom_jvp
        def f(x):
            return x + 2

        @f.defjvp
        def f_jvp(primals, tangents):
            x, = primals
            d, = tangents
            return f(x), x * d

        x = 2.
        self.assertAllClose(f(x), 4.)
        f_primal, jvp = jax.jvp(f, (x, ), (1., ))
        self.assertAllClose(f_primal, 4.)
        self.assertAllClose(jvp, 2.)
        self.assertAllClose(grad(f)(x), 2.)

        rewrites = {lax.add_p: lambda x, y: x - y}
        g = rewrite(f, rewrites)

        self.assertAllClose(g(x), 0.)
        g_primal, jvp = jax.jvp(g, (x, ), (1., ))
        self.assertAllClose(g_primal, 0.)
        self.assertAllClose(jvp, 2.)
        self.assertAllClose(grad(g)(x), 2.)
Пример #4
0
    def testRewriteWithCustomGradients(self):
        def f(x):
            return jax.nn.relu(x)

        x = jnp.array([2.0, 4.0])
        self.assertAllClose(f(x), jnp.array([2.0, 4.0]))

        self.assertAllClose(rewrite(f, {})(x), jnp.array([2.0, 4.0]))
Пример #5
0
  def testRewrite(self):
    def f(x):
      return x * 2

    x = jnp.array([2.0, 4.0])
    self.assertAllClose(f(x), jnp.array([4.0, 8.0]))

    self.assertAllClose(
        rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
        jnp.array([4.0, 6.0]))
Пример #6
0
  def testRewriteThroughCustomJVPInScan(self):

    @jax.custom_jvp
    def foo(x):
      return x + 2

    @foo.defjvp
    def foo_jvp(primals, tangents):
      x, = primals
      d, = tangents
      return f(x), x * d
    def f(x):
      out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
      return out

    x = 2.
    self.assertAllClose(f(x), 4.)
    f_primal, jvp = jax.jvp(f, (x,), (1.,))
    self.assertAllClose(f_primal, 4.)
    self.assertAllClose(jvp, 2.)
    self.assertAllClose(grad(f)(x), 2.)

    rewrites = {
        lax.add_p: lambda x, y: x - y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 0.)
    g_primal, jvp = jax.jvp(g, (x,), (1.,))
    self.assertAllClose(g_primal, 0.)
    self.assertAllClose(jvp, 2.)
    self.assertAllClose(grad(g)(x), 2.)

    rewrites = {
        lax.mul_p: lambda x, y: x + y
    }
    g = rewrite(f, rewrites)

    self.assertAllClose(g(x), 4.)
    g_primal, jvp = jax.jvp(g, (x,), (1.,))
    self.assertAllClose(g_primal, 4.)
    self.assertAllClose(jvp, 3.)
    self.assertAllClose(grad(g)(x), 1.)
Пример #7
0
    def testRewriteThroughWhile(self):
        def f(x):
            def cond(x):
                return x < 5

            def body(x):
                return x + 1

            return lax.while_loop(cond, body, x)

        x = 0
        self.assertAllClose(f(x), 5)

        rewrites = {
            lax.add_p: lambda x, y: x + y + 100,
        }
        self.assertAllClose(rewrite(f, rewrites)(x), 101)

        rewrites = {lax.lt_p: lambda x, y: x < y + 5}
        self.assertAllClose(rewrite(f, rewrites)(x), 10)
Пример #8
0
    def testRewriteThroughForLoop(self):
        def f(x):
            def body(i, x):
                return x * i

            return lax.fori_loop(1, 5, body, x)

        x = 1
        self.assertAllClose(f(x), 24)

        rewrites = {lax.mul_p: lambda x, y: x + y}
        self.assertAllClose(rewrite(f, rewrites)(x), 11)
Пример #9
0
    def testRewriteJIT(self):
        def f(x):
            @jit
            def g(x):
                return x * 2

            return g(x)

        x = jnp.array([2.0, 4.0])
        self.assertAllClose(f(x), jnp.array([4.0, 8.0]), True)

        self.assertAllClose(
            rewrite(f, {
                lax.mul_p: lambda x, y: x + y
            })(x), jnp.array([4.0, 6.0]), True)
Пример #10
0
  def testRewriteThroughScan(self):
    def f(xs):
      def body(carry, x):
        carry = carry * 2.
        return carry, x - 2.
      return lax.scan(body, 1., xs)

    xs = jnp.arange(4.)
    carry, ys = f(xs)
    self.assertAllClose(carry, 16.)
    self.assertAllClose(ys, jnp.arange(4.) - 2.)

    rewrites = {
        lax.mul_p: lambda x, y: x + y,
        lax.sub_p: lambda x, y: x / y
    }
    carry, ys = rewrite(f, rewrites)(xs)
    self.assertAllClose(carry, 1. + 8.)
    self.assertAllClose(ys, jnp.arange(4.) / 2.)
def rewrite_high_precision(fn):
    return rewrite(fn, HIGH_PRECISION_RULES)