Beispiel #1
0
    def testScanLinearize(self, jit_scan, jit_f):
        d = np.zeros(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = np.sum(np.sin(a)) + np.sum(np.sin(c)) + np.sum(np.sin(d))
            c = np.sin(c * b)
            assert b.shape == ()
            return c, b

        if jit_f:
            f = api.jit(f)
        if jit_scan:
            scan = api.jit(lax.scan, (0, ))
        else:
            scan = lax.scan

        as_ = np.ones((5, 3))
        c = np.ones(4)

        ans = api.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
        expected = api.linearize(lambda c, as_: scan_reference(f, c, as_), c,
                                 as_)[1](c, as_)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #2
0
    def testScanLinearize(self, jit_scan, jit_f):
        rng = onp.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = np.cos(
                np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d)))
            c = np.sin(c * b)
            assert b.shape == ()
            return c, b

        if jit_f:
            f = api.jit(f)
        if jit_scan:
            scan = api.jit(lax.scan, (0, ))
        else:
            scan = lax.scan

        as_ = rng.randn(5, 3)
        c = rng.randn(4)

        ans = api.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
        expected = api.linearize(lambda c, as_: scan_reference(f, c, as_), c,
                                 as_)[1](c, as_)
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #3
0
    def test_issue_871(self):
        T = np.array([[1., 2.], [3., 4.], [5., 6.]])
        x = np.array([1, 2, 3])

        y, f_jvp = api.linearize(np.sum, x)
        jtu.check_raises(lambda: f_jvp(T), ValueError,
                         ("linearized function called on tangent values "
                          "inconsistent with the original primal values."))

        y, f_jvp = api.linearize(api.jit(np.sum), x)
        jtu.check_raises(lambda: f_jvp(T), ValueError,
                         ("linearized function called on tangent values "
                          "inconsistent with the original primal values."))
Beispiel #4
0
  def test_remat_scan(self):
    to_scan = lambda c, x: (np.sin(c), None)

    def f_noremat(x):
      y, _ = lax.scan(to_scan, x, onp.arange(3.))
      return y

    def f_yesremat(x):
      y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
      return y

    ans = f_yesremat(4.)
    expected = f_noremat(4.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.grad(f_yesremat)(4.)
    expected = api.grad(f_noremat)(4.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
    scan_eqn, = jaxpr.jaxpr.eqns
    self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))

    jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
    scan_eqn, = jaxpr.jaxpr.eqns
    self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
Beispiel #5
0
  def test_remat_basic(self):
    @api.remat
    def g(x):
      return lax.sin(lax.sin(x)), 3.

    def f(x):
      x, _ = g(x)
      return x

    ans = f(2.)
    expected = onp.sin(onp.sin(2.))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans, f_lin = api.linearize(f, 2.)
    expected = onp.sin(onp.sin(2.))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = f_lin(3.)
    expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
    self.assertAllClose(ans, expected, check_dtypes=False)

    sin_calls = []
    cos_calls = []
    sin_impl = lax.sin_p.impl
    cos_impl = lax.cos_p.impl
    try:
      lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
      lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
      f_lin(3.)
    finally:
      lax.sin_p.def_impl(sin_impl)
      lax.cos_p.def_impl(cos_impl)
    self.assertEqual(len(sin_calls), 1)
    self.assertEqual(len(cos_calls), 2)
Beispiel #6
0
  def test_remat_freevars(self):
    def f1(x):
      y = 2 * np.sin(x)
      z = np.cos(x) * np.sin(y)
      return z

    def f2(x):
      y = 2 * np.sin(x)
      z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
      return z

    ans, f_lin = api.linearize(f2, 2.)
    expected, f_lin_expected = api.linearize(f1, 2.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = f_lin(3.)
    expected = f_lin_expected(3.)
    self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #7
0
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
  params = primals[:num_consts]
  solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr,
                               solve=solve, tangent_solve=tangent_solve))
  params_dot = tangents[:num_consts]

  # F(m, u) = 0      # system of equations in u, parameterized by m
  #                  # solution is u*(m) defined in a neighborhood
  # F(m, u*(m)) = 0  # satisfied in a neighborhood
  #
  # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0       # implied by line above
  # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m))  # rearrange
  #
  # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]]  # jvp

  f = core.jaxpr_as_fun(jaxpr)
  f_fixed_params = lambda *solution: f(*(params + solution))
  f_fixed_solution = lambda *params: f(*(params + solution))

  _, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot)
  _, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution)
  solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)]

  return solution, solution_dot
Beispiel #8
0
    def testScanRnn(self):
        r = npr.RandomState(0)

        n_in = 4
        n_hid = 2
        n_out = 1
        length = 3

        W_trans = r.randn(n_hid, n_hid + n_in)
        W_out = r.randn(n_out, n_hid + n_in)
        params = W_trans, W_out

        inputs = r.randn(length, n_in)
        targets = r.randn(length, n_out)

        def step(params, state, input):
            W_trans, W_out = params
            stacked = np.concatenate([state, input])
            output = np.tanh(np.dot(W_out, stacked))
            next_state = np.tanh(np.dot(W_trans, stacked))
            return next_state, output

        def rnn(params, inputs):
            init_state = np.zeros(n_hid)
            _, outputs = lax.scan(partial(step, params), init_state, inputs)
            return outputs

        def loss(params, inputs, targets):
            predictions = rnn(params, inputs)
            return np.sum((predictions - targets)**2)

        # evaluation doesn't crash
        loss(params, inputs, targets)

        # jvp evaluation doesn't crash
        api.jvp(lambda params: loss(params, inputs, targets), (params, ),
                (params, ))

        # jvp numerical check passes
        jtu.check_grads(loss, (params, inputs, targets),
                        order=2,
                        modes=["fwd"])

        # linearize works
        _, expected = api.jvp(loss, (params, inputs, targets),
                              (params, inputs, targets))
        _, linfun = api.linearize(loss, params, inputs, targets)
        ans = linfun(params, inputs, targets)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # gradient evaluation doesn't crash
        api.grad(loss)(params, inputs, targets)

        # gradient check passes
        jtu.check_grads(loss, (params, inputs, targets), order=2)

        # we can vmap to batch things
        batch_size = 7
        batched_inputs = r.randn(batch_size, length, n_in)
        batched_targets = r.randn(batch_size, length, n_out)
        batched_loss = api.vmap(lambda x, y: loss(params, x, y))
        losses = batched_loss(batched_inputs, batched_targets)
        expected = onp.stack(
            list(
                map(lambda x, y: loss(params, x, y), batched_inputs,
                    batched_targets)))
        self.assertAllClose(losses, expected, check_dtypes=False)
Beispiel #9
0
def jvp_unlinearized(f, primals, tangents):
    out, jvp = linearize(f, *primals)
    return out, jvp(tangents)
Beispiel #10
0
 def splitjvp(x):
   _, jvp = linearize(f, x)
   return jvp(np.ones_like(x))