示例#1
0
  def testJacobianIssue54(self):
    # test modeling the code in https://github.com/google/jax/issues/54

    def func(xs):
      return jnp.array(list(xs))

    xs = jnp.ones((5, 1))
    jacrev(func)(xs)  # don't crash
    jacfwd(func)(xs)  # don't crash
示例#2
0
        def jacob_constr_blocks(q, x_obs_seq, partition=0):
            """Return non-zero blocks of constraint function Jacobian.

            Input state q can be decomposed into q = [u, v₀, v₁, v₂]
            where global latent state (parameters) are determined by u,
            initial subsequence by v₀, middle subsequences by v₁ and final
            subsequence by v₂.

            Constraint function can then be decomposed as

                c(q) = [c₀(u, v₀), c₁(u, v₁), c₂(u, v₂)]

            Constraint Jacobian ∂c(q) has block structure

                ∂c(q) = [[∂₀c₀(u, v₀), ∂₁c₀(u, v₀),     0,     ,     0      ]
                         [∂₀c₁(u, v₁),     0      , ∂₁c₁(u, v₁),     0      ]
                         [∂₀c₂(u, v₀),     0      ,     0      , ∂₁c₂(u, v₂)]]

            """
            def g_y_bar(u, v, w_0, b):
                z = generate_z(u)
                if b == 0:
                    w_0, v = split(v, (dim_x, ))
                v_seq = np.reshape(v, (-1, dim_v))
                return generate_y_bar(z, w_0, v_seq, b)

            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_seq = np.reshape(v_seq_flat, (-1, dim_v))
            (v_subseqs, w_inits,
             y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition)
            v_bars = (np.concatenate([v_0, v_subseqs[0].flatten()]),
                      np.reshape(v_subseqs[1], (v_subseqs[1].shape[0], -1)),
                      v_subseqs[2].flatten())
            jac_g_y_bar = api.jacrev(g_y_bar, (0, 1))
            jacob_funcs = (jac_g_y_bar,
                           api.vmap(jac_g_y_bar,
                                    (None, 0, 0, None)), jac_g_y_bar)
            return tuple(
                zip(*[
                    jacob_funcs[b](u, v_bars[b], w_inits[b], b)
                    for b in range(3)
                ]))
示例#3
0
  def test_remat_vmap(self):
    @api.remat
    def g(x):
      return lax.sin(lax.sin(x))

    x = onp.arange(3.)

    ans = api.vmap(g)(x)
    expected = onp.sin(onp.sin(x))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.jacfwd(g)(x)
    expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.jacrev(g)(x)
    expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
    self.assertAllClose(ans, expected, check_dtypes=False)