def testScanLinearize(self, jit_scan, jit_f): d = np.zeros(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = np.sum(np.sin(a)) + np.sum(np.sin(c)) + np.sum(np.sin(d)) c = np.sin(c * b) assert b.shape == () return c, b if jit_f: f = api.jit(f) if jit_scan: scan = api.jit(lax.scan, (0, )) else: scan = lax.scan as_ = np.ones((5, 3)) c = np.ones(4) ans = api.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_) expected = api.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_) self.assertAllClose(ans, expected, check_dtypes=False)
def testScanLinearize(self, jit_scan, jit_f): rng = onp.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = np.cos( np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d))) c = np.sin(c * b) assert b.shape == () return c, b if jit_f: f = api.jit(f) if jit_scan: scan = api.jit(lax.scan, (0, )) else: scan = lax.scan as_ = rng.randn(5, 3) c = rng.randn(4) ans = api.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_) expected = api.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_) self.assertAllClose(ans, expected, check_dtypes=False)
def test_issue_871(self): T = np.array([[1., 2.], [3., 4.], [5., 6.]]) x = np.array([1, 2, 3]) y, f_jvp = api.linearize(np.sum, x) jtu.check_raises(lambda: f_jvp(T), ValueError, ("linearized function called on tangent values " "inconsistent with the original primal values.")) y, f_jvp = api.linearize(api.jit(np.sum), x) jtu.check_raises(lambda: f_jvp(T), ValueError, ("linearized function called on tangent values " "inconsistent with the original primal values."))
def test_remat_scan(self): to_scan = lambda c, x: (np.sin(c), None) def f_noremat(x): y, _ = lax.scan(to_scan, x, onp.arange(3.)) return y def f_yesremat(x): y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.)) return y ans = f_yesremat(4.) expected = f_noremat(4.) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.grad(f_yesremat)(4.) expected = api.grad(f_noremat)(4.) self.assertAllClose(ans, expected, check_dtypes=False) jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_remat_basic(self): @api.remat def g(x): return lax.sin(lax.sin(x)), 3. def f(x): x, _ = g(x) return x ans = f(2.) expected = onp.sin(onp.sin(2.)) self.assertAllClose(ans, expected, check_dtypes=False) ans, f_lin = api.linearize(f, 2.) expected = onp.sin(onp.sin(2.)) self.assertAllClose(ans, expected, check_dtypes=False) ans = f_lin(3.) expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3. self.assertAllClose(ans, expected, check_dtypes=False) sin_calls = [] cos_calls = [] sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) lax.cos_p.def_impl(cos_impl) self.assertEqual(len(sin_calls), 1) self.assertEqual(len(cos_calls), 2)
def test_remat_freevars(self): def f1(x): y = 2 * np.sin(x) z = np.cos(x) * np.sin(y) return z def f2(x): y = 2 * np.sin(x) z = api.remat(lambda x: np.cos(x) * np.sin(y))(x) return z ans, f_lin = api.linearize(f2, 2.) expected, f_lin_expected = api.linearize(f1, 2.) self.assertAllClose(ans, expected, check_dtypes=False) ans = f_lin(3.) expected = f_lin_expected(3.) self.assertAllClose(ans, expected, check_dtypes=False)
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve): params = primals[:num_consts] solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)) params_dot = tangents[:num_consts] # F(m, u) = 0 # system of equations in u, parameterized by m # # solution is u*(m) defined in a neighborhood # F(m, u*(m)) = 0 # satisfied in a neighborhood # # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange # # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp f = core.jaxpr_as_fun(jaxpr) f_fixed_params = lambda *solution: f(*(params + solution)) f_fixed_solution = lambda *params: f(*(params + solution)) _, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot) _, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution) solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)] return solution, solution_dot
def testScanRnn(self): r = npr.RandomState(0) n_in = 4 n_hid = 2 n_out = 1 length = 3 W_trans = r.randn(n_hid, n_hid + n_in) W_out = r.randn(n_out, n_hid + n_in) params = W_trans, W_out inputs = r.randn(length, n_in) targets = r.randn(length, n_out) def step(params, state, input): W_trans, W_out = params stacked = np.concatenate([state, input]) output = np.tanh(np.dot(W_out, stacked)) next_state = np.tanh(np.dot(W_trans, stacked)) return next_state, output def rnn(params, inputs): init_state = np.zeros(n_hid) _, outputs = lax.scan(partial(step, params), init_state, inputs) return outputs def loss(params, inputs, targets): predictions = rnn(params, inputs) return np.sum((predictions - targets)**2) # evaluation doesn't crash loss(params, inputs, targets) # jvp evaluation doesn't crash api.jvp(lambda params: loss(params, inputs, targets), (params, ), (params, )) # jvp numerical check passes jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"]) # linearize works _, expected = api.jvp(loss, (params, inputs, targets), (params, inputs, targets)) _, linfun = api.linearize(loss, params, inputs, targets) ans = linfun(params, inputs, targets) self.assertAllClose(ans, expected, check_dtypes=False) # gradient evaluation doesn't crash api.grad(loss)(params, inputs, targets) # gradient check passes jtu.check_grads(loss, (params, inputs, targets), order=2) # we can vmap to batch things batch_size = 7 batched_inputs = r.randn(batch_size, length, n_in) batched_targets = r.randn(batch_size, length, n_out) batched_loss = api.vmap(lambda x, y: loss(params, x, y)) losses = batched_loss(batched_inputs, batched_targets) expected = onp.stack( list( map(lambda x, y: loss(params, x, y), batched_inputs, batched_targets))) self.assertAllClose(losses, expected, check_dtypes=False)
def jvp_unlinearized(f, primals, tangents): out, jvp = linearize(f, *primals) return out, jvp(tangents)
def splitjvp(x): _, jvp = linearize(f, x) return jvp(np.ones_like(x))