def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return jnp.array(list(xs)) xs = jnp.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash
def jacob_constr_blocks(q, x_obs_seq, partition=0): """Return non-zero blocks of constraint function Jacobian. Input state q can be decomposed into q = [u, v₀, v₁, v₂] where global latent state (parameters) are determined by u, initial subsequence by v₀, middle subsequences by v₁ and final subsequence by v₂. Constraint function can then be decomposed as c(q) = [c₀(u, v₀), c₁(u, v₁), c₂(u, v₂)] Constraint Jacobian ∂c(q) has block structure ∂c(q) = [[∂₀c₀(u, v₀), ∂₁c₀(u, v₀), 0, , 0 ] [∂₀c₁(u, v₁), 0 , ∂₁c₁(u, v₁), 0 ] [∂₀c₂(u, v₀), 0 , 0 , ∂₁c₂(u, v₂)]] """ def g_y_bar(u, v, w_0, b): z = generate_z(u) if b == 0: w_0, v = split(v, (dim_x, )) v_seq = np.reshape(v, (-1, dim_v)) return generate_y_bar(z, w_0, v_seq, b) u, v_0, v_seq_flat = split(q, ( dim_z, dim_x, )) v_seq = np.reshape(v_seq_flat, (-1, dim_v)) (v_subseqs, w_inits, y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition) v_bars = (np.concatenate([v_0, v_subseqs[0].flatten()]), np.reshape(v_subseqs[1], (v_subseqs[1].shape[0], -1)), v_subseqs[2].flatten()) jac_g_y_bar = api.jacrev(g_y_bar, (0, 1)) jacob_funcs = (jac_g_y_bar, api.vmap(jac_g_y_bar, (None, 0, 0, None)), jac_g_y_bar) return tuple( zip(*[ jacob_funcs[b](u, v_bars[b], w_inits[b], b) for b in range(3) ]))
def test_remat_vmap(self): @api.remat def g(x): return lax.sin(lax.sin(x)) x = onp.arange(3.) ans = api.vmap(g)(x) expected = onp.sin(onp.sin(x)) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.jacfwd(g)(x) expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.jacrev(g)(x) expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) self.assertAllClose(ans, expected, check_dtypes=False)