コード例 #1
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_value_and_grad_argnums(self):
    def f(x, y, z, flag=False):
      assert flag
      return 1.0 * x + 2.0 * y + 3.0 * z

    y = f(1.0, 1.0, 1.0, flag=True)
    assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
    assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
    assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
コード例 #2
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_defvjp_all_const(self):
    foo_p = Primitive('foo')
    def foo(x): return foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 9., check_dtypes=False)
    self.assertAllClose(grad_ans, 12., check_dtypes=True)
コード例 #3
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_defvjp_all(self):
    foo_p = Primitive('foo')
    def foo(x): return 2. * foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
    self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
コード例 #4
0
ファイル: api_test.py プロジェクト: yyht/jax
    def test_defvjp_all(self):
        @api.custom_transforms
        def foo(x):
            return np.sin(x)

        api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x, )))
        val_ans, grad_ans = api.value_and_grad(foo)(3.)
        self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
        self.assertAllClose(grad_ans, 3., check_dtypes=False)
コード例 #5
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_defvjp_use_ans(self):
    @api.custom_transforms
    def foo(x, y):
      return np.sin(x * y)

    api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
    val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
    self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
    self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
                        check_dtypes=False)
コード例 #6
0
 def update(i, opt_state, batch, rng, no_batch):
     params = get_params(opt_state)
     loss_, gradient = value_and_grad(elbo)(
         params,
         batch,
         predict_apply_fn,
         rng,
         kl_scale,
         loss_fn,
         noise_std=noise_std)  # scalar, (1, hidden_dim)
     return opt_update(i, gradient, opt_state), loss_
コード例 #7
0
ファイル: api_test.py プロジェクト: yyht/jax
    def test_defvjp(self):
        @api.custom_transforms
        def foo(x, y):
            return np.sin(x * y)

        api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
        val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
        self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
        self.assertAllClose(grad_ans, 0., check_dtypes=False)

        ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
        self.assertAllClose(ans_0, 0., check_dtypes=False)
        self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
コード例 #8
0
 def __init__(self,
              forward,
              backward,
              domain,
              image,
              val_and_grad_forward=None):
     self._forward = forward
     self._backward = backward
     self.domain = domain
     self.image = image
     if val_and_grad_forward is None:
         val_and_grad_forward = api.value_and_grad(forward)
     self._val_and_grad_forward = val_and_grad_forward
コード例 #9
0
  def test_custom_root_scalar(self):

    # TODO(shoyer): Figure out why this fails and re-enable it, if possible. My
    # best guess is that TPUs use less stable numerics for pow().
    if jtu.device_under_test() == "tpu":
      raise SkipTest("Test fails on TPU")

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2. - np.array(x) ** 3.
      return lax.custom_root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False,
                        rtol=1e-7)
    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    # TODO(shoyer): reenable when batching works
    # inputs = np.array([4.0, 5.0])
    # results = api.vmap(sqrt_cubed)(inputs)
    # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False,
                        rtol={onp.float64:1e-7})
コード例 #10
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_defvjp_all_multiple_arguments(self):
    # also tests passing in symbolic zero tangents b/c we differentiate wrt only
    # the first argument in one case

    foo_p = Primitive('foo')
    def foo(x, y): return foo_p.bind(x, y)

    def vjpfun(x, y):
      out = x**2 + y**3
      vjp = lambda g: (g + x + y, g * x * 9.)
      return out, vjp

    ad.defvjp_all(foo_p, vjpfun)
    val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
    self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
    self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)

    ans = api.grad(foo, (0, 1))(3., 4.)
    self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
コード例 #11
0
  def test_root_scalar(self):

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2 - x ** 3
      return lax.root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    # TODO(shoyer): reenable when batching works
    # inputs = np.array([4.0, 5.0])
    # results = api.vmap(sqrt_cubed)(inputs)
    # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
コード例 #12
0
    def test_custom_implicit_solve(self):
        def scalar_solve(f, y):
            return y / f(1.0)

        def _binary_search(func, params, low=0.0, high=100.0, tolerance=1e-6):
            def cond(state):
                low, high = state
                return high - low > tolerance

            def body(state):
                low, high = state
                midpoint = 0.5 * (low + high)
                update_upper = func(midpoint, params) > 0
                low = np.where(update_upper, low, midpoint)
                high = np.where(update_upper, midpoint, high)
                return (low, high)

            solution, _ = lax.while_loop(cond, body, (low, high))
            return solution

        binary_search = api._custom_implicit_solve(_binary_search,
                                                   scalar_solve)
        sqrt_cubed = lambda y, x: y**2 - x**3
        value, grad = api.value_and_grad(binary_search, argnums=1)(sqrt_cubed,
                                                                   5.0)
        self.assertAllClose(value, 5**1.5, check_dtypes=False)
        self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

        def scalar_solve2(f, y):
            y_1d = y[np.newaxis]
            return np.linalg.solve(api.jacobian(f)(y_1d), y_1d).squeeze()

        binary_search = api._custom_implicit_solve(_binary_search,
                                                   scalar_solve2)
        grad = api.grad(binary_search, argnums=1)(sqrt_cubed, 5.0)
        self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)
コード例 #13
0
elif args.optimizer == 'momentum':
    opt_init, opt_apply, get_params = myopt.momentum(
        lr, args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == 'adagrad':
    opt_init, opt_apply, get_params = optimizers.adagrad(lr, args.momentum)
elif args.optimizer == 'adam':
    opt_init, opt_apply, get_params = optimizers.adam(lr)

state = opt_init(params)

if args.loss == 'logistic':
    loss = lambda fx, y: np.mean(-np.sum(logsoftmax(fx) * y, axis=1))
elif args.loss == 'squared':
    loss = lambda fx, y: np.mean(np.sum((fx - y)**2, axis=1))
value_and_grad_loss = jit(
    value_and_grad(lambda params, x, y: loss(f(params, x), y)))
loss_fn = jit(lambda params, x, y: loss(f(params, x), y))
accuracy_sum = jit(
    lambda fx, y: np.sum(np.argmax(fx, axis=1) == np.argmax(y, axis=1)))

# Create tensorboard writer
writer = SummaryWriter(logdir=args.logdir)

# Train the network
global_step, running_count = 0, 0
running_loss, running_loss_g = 0., 0.
if args.save_path is not None:
    save_path = os.path.join(args.save_path, f'{global_step}.npy')
    save_jax_params(params, save_path)
    test_logits = onp.zeros((n_test, args.num_classes), dtype=onp.float32)
    test_loader = tfds.as_numpy(test_data.batch(args.batch_size_test))
コード例 #14
0
    compute_score=True,
    tol=1e-3 * tau,
    alpha_1=hyper_prior[0],
    alpha_2=hyper_prior[1],
    threshold_lambda=1e6,
)
reg.fit(theta_normed, dt)

# %%
loss, mn, prior, metric = SBL(theta_normed, dt, beta_prior=hyper_prior, tol=1e-3 * tau)
# %%
# hyper_prior = (n_samples / 2, n_samples / 2 * 1 / (1e-6))

hyper_prior = (1e-6, 1e-6)
f = lambda x, y: SBL(x, y, beta_prior=hyper_prior, tol=1e-3)[0]
grad_fn = value_and_grad(f)
grad_fn(theta_normed, dt)
# %%

# %%
hyper_prior = (1e-6, 1e-6)
f = lambda x, y: SBL(x, y, beta_prior=hyper_prior, tol=1e-3)[0].sum()
grad_fn = value_and_grad(f)
grad_fn(theta_normed, dt)
# %%

n_samples, n_features = theta.shape
prior_init = jnp.concatenate(
    [jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0
)
# %%
コード例 #15
0
    def __init__(self,
                 obs_interval,
                 num_steps_per_obs,
                 num_obs_per_subseq,
                 y_seq,
                 dim_z,
                 dim_x,
                 dim_v,
                 forward_func,
                 generate_x_0,
                 generate_z,
                 obs_func,
                 metric=None):
        """
        Args:
            obs_interval (float): Interobservation time interval.
            num_steps_per_obs (int): Number of discrete time steps to simulate
                between each observation time.
            num_obs_per_subseq (int): Average number of observations per
                partitioned subsequence. Must be a factor of `len(y_obs_seq)`.
            y_seq (array): Two-dimensional array containing observations at
                equally spaced time intervals, with first axis of array
                corresponding to observation time index (in order of increasing
                time) and second axis corresponding to dimension of each
                (vector-valued) observation.
            dim_z(int): Dimension of parameter vector `z`.
            dim_x (int): Dimension of state vector `x`.
            dim_v (int): Dimension of noise vector `v` consumed by
                `forward_func` to approximate time step.
            forward_func (Callable[[array, array, array, float], array]):
                Function implementing forward step of time-discretisation of
                diffusion such that `forward_func(z, x, v, δ)` for parameter
                vector `z`, current state `x` at time `t`, standard normal
                vector `v` and  small timestep `δ` and is distributed
                approximately according to `X(t + δ) | X(t) = x, Z = z`.
            generate_x_0 (Callable[[array, array], array]): Generator function
                for the initial state such that `generator_x_0(z, v_0)` for
                parameter vector `z` and standard normal vector `v_0` is
                distributed according to prior distribution on `X(0) | Z = z`.
            generate_z (Callable[[array], array]): Generator function
                for parameter vector such that `generator_z(u)` for standard
                normal vector `u` is distributed according to prior distribution
                on parameter vector `Z`.
            obs_func (Callable[[array], array]): Function mapping from state
                vector `x` at an observation time to the corresponding observed
                vector `y = obs_func(x)`.
            metric (Matrix): Metric matrix representation. Should be either an
                `mici.matrices.IdentityMatrix` or
                `mici.matrices.SymmetricBlockDiagonalMatrix` instance, with in
                the latter case the matrix having two blocks on the diagonal,
                the left most of size `dim_z x dim_z`, and the rightmost being
                positive diagonal. Defaults to `mici.matrices.IdentityMatrix`.
        """

        if metric is None or isinstance(metric, IdentityMatrix):
            metric_1 = np.eye(dim_z)
            log_det_sqrt_metric_1 = 0
        elif (isinstance(metric, SymmetricBlockDiagonalMatrix)
              and isinstance(metric.blocks[1], PositiveDiagonalMatrix)):
            metric_1 = metric.blocks[0].array
            log_det_sqrt_metric_1 = metric.blocks[0].log_abs_det_sqrt
            metric_2_diag = metric.blocks[1].diagonal
        else:
            raise NotImplementedError(
                'Only identity and block diagonal metrics with diagonal lower '
                'right block currently supported.')

        num_obs, dim_y = y_seq.shape
        δ = obs_interval / num_steps_per_obs
        dim_q = dim_z + dim_x + num_obs * dim_v * num_steps_per_obs
        if num_obs % num_obs_per_subseq != 0:
            raise NotImplementedError(
                'Only cases where num_obs_per_subseq is a factor of num_obs '
                'supported.')
        num_subseq = num_obs // num_obs_per_subseq
        obs_indices = slice(num_steps_per_obs - 1, None, num_steps_per_obs)
        num_step_per_subseq = num_obs_per_subseq * num_steps_per_obs
        y_subseqs_p0 = np.reshape(y_seq, (num_subseq, num_obs_per_subseq, -1))
        y_subseqs_p1 = split(
            y_seq, (num_obs_per_subseq // 2, num_obs - num_obs_per_subseq))
        y_subseqs_p1[1] = np.reshape(
            y_subseqs_p1[1], (num_subseq - 1, num_obs_per_subseq, dim_y))

        super().__init__(neg_log_dens=standard_normal_neg_log_dens,
                         grad_neg_log_dens=standard_normal_grad_neg_log_dens,
                         metric=metric)

        @api.jit
        def step_func(z, x, v):
            x_n = forward_func(z, x, v, δ)
            return (x_n, x_n)

        @api.jit
        def generate_x_obs_seq(q):
            u, v_0, v_seq_flat = split(q, (dim_z, dim_x))
            z = generate_z(u)
            x_0 = generate_x_0(z, v_0)
            v_seq = np.reshape(v_seq_flat, (-1, dim_v))
            _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0, v_seq)
            return x_seq[obs_indices]

        @api.partial(api.jit, static_argnums=(3, ))
        def partition_into_subseqs(v_seq, v_0, x_obs_seq, partition=0):
            """Partition noise increment and observation sequences.

            Partitition sequences in to either `num_subseq` equally sized
            subsequences (`partition == 0`)  or `num_subseq - 1` equally sized
            subsequences plus initial and final 'half' subsequences.
            """
            if partition == 0:
                v_subseqs = v_seq.reshape(
                    (num_subseq, num_step_per_subseq, dim_v))
                v_subseqs = (v_subseqs[0], v_subseqs[1:-1], v_subseqs[-1])
                x_obs_subseqs = x_obs_seq.reshape(
                    (num_subseq, num_obs_per_subseq, dim_x))
                w_inits = (v_0, x_obs_subseqs[:-2, -1], x_obs_subseqs[-2, -1])
                y_bars = (np.concatenate(
                    (y_subseqs_p0[0, :-1].flatten(), x_obs_subseqs[0, -1])),
                          np.concatenate((y_subseqs_p0[1:-1, :-1].reshape(
                              (num_subseq - 2, -1)), x_obs_subseqs[1:-1, -1]),
                                         -1), y_subseqs_p0[-1].flatten())
            else:
                v_subseqs = split(
                    v_seq, ((num_obs_per_subseq // 2) * num_steps_per_obs,
                            num_step_per_subseq * (num_subseq - 1)))
                v_subseqs[1] = v_subseqs[1].reshape(
                    (num_subseq - 1, num_step_per_subseq, dim_v))
                x_obs_subseqs = split(
                    x_obs_seq,
                    (num_obs_per_subseq // 2, num_obs - num_obs_per_subseq))
                x_obs_subseqs[1] = x_obs_subseqs[1].reshape(
                    (num_subseq - 1, num_obs_per_subseq, dim_x))
                w_inits = (v_0,
                           np.concatenate(
                               (x_obs_subseqs[0][-1:], x_obs_subseqs[1][:-1,
                                                                        -1]),
                               0), x_obs_subseqs[1][-1, -1])
                y_bars = (np.concatenate(
                    (y_subseqs_p1[0][:-1].flatten(), x_obs_subseqs[0][-1])),
                          np.concatenate((
                              y_subseqs_p1[1][:, :-1].reshape(
                                  (num_subseq - 1, -1)),
                              x_obs_subseqs[1][:, -1],
                          ), -1), y_subseqs_p1[2].flatten())
            return v_subseqs, w_inits, y_bars

        def generate_y_bar(z, w_0, v_seq, b):
            x_0 = generate_x_0(z, w_0) if b == 0 else w_0
            _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0, v_seq)
            y_seq = obs_func(x_seq[obs_indices])
            return y_seq.flatten() if b == 2 else np.concatenate(
                (y_seq[:-1].flatten(), x_seq[-1]))

        @api.partial(api.jit, static_argnums=(2, ))
        def constr(q, x_obs_seq, partition=0):
            """Calculate constraint function for current partition."""
            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_seq = v_seq_flat.reshape((-1, dim_v))
            z = generate_z(u)
            (v_subseqs, w_inits,
             y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition)
            gen_funcs = (generate_y_bar,
                         api.vmap(generate_y_bar,
                                  (None, 0, 0, None)), generate_y_bar)
            return np.concatenate([
                (gen_funcs[b](z, w_inits[b], v_subseqs[b], b) -
                 y_bars[b]).flatten() for b in range(3)
            ])

        @api.jit
        def init_objective(q, x_obs_seq, reg_coeff):
            """Optimisation objective to find initial state on manifold."""
            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_subseqs = v_seq_flat.reshape((num_obs, num_steps_per_obs, dim_v))
            z = generate_z(u)
            x_0 = generate_x_0(z, v_0)
            x_inits = np.concatenate((x_0[None], x_obs_seq[:-1]), 0)

            def generate_final_state(z, v_seq, x_0):
                _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0,
                                    v_seq)
                return x_seq[-1]

            c = api.vmap(generate_final_state, in_axes=(None, 0, 0))(
                z, v_subseqs, x_inits) - x_obs_seq
            return 0.5 * np.mean(c**2) + 0.5 * reg_coeff * np.mean(q**2), c

        @api.partial(api.jit, static_argnums=(2, ))
        def jacob_constr_blocks(q, x_obs_seq, partition=0):
            """Return non-zero blocks of constraint function Jacobian.

            Input state q can be decomposed into q = [u, v₀, v₁, v₂]
            where global latent state (parameters) are determined by u,
            initial subsequence by v₀, middle subsequences by v₁ and final
            subsequence by v₂.

            Constraint function can then be decomposed as

                c(q) = [c₀(u, v₀), c₁(u, v₁), c₂(u, v₂)]

            Constraint Jacobian ∂c(q) has block structure

                ∂c(q) = [[∂₀c₀(u, v₀), ∂₁c₀(u, v₀),     0,     ,     0      ]
                         [∂₀c₁(u, v₁),     0      , ∂₁c₁(u, v₁),     0      ]
                         [∂₀c₂(u, v₀),     0      ,     0      , ∂₁c₂(u, v₂)]]

            """
            def g_y_bar(u, v, w_0, b):
                z = generate_z(u)
                if b == 0:
                    w_0, v = split(v, (dim_x, ))
                v_seq = np.reshape(v, (-1, dim_v))
                return generate_y_bar(z, w_0, v_seq, b)

            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_seq = np.reshape(v_seq_flat, (-1, dim_v))
            (v_subseqs, w_inits,
             y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition)
            v_bars = (np.concatenate([v_0, v_subseqs[0].flatten()]),
                      np.reshape(v_subseqs[1], (v_subseqs[1].shape[0], -1)),
                      v_subseqs[2].flatten())
            jac_g_y_bar = api.jacrev(g_y_bar, (0, 1))
            jacob_funcs = (jac_g_y_bar,
                           api.vmap(jac_g_y_bar,
                                    (None, 0, 0, None)), jac_g_y_bar)
            return tuple(
                zip(*[
                    jacob_funcs[b](u, v_bars[b], w_inits[b], b)
                    for b in range(3)
                ]))

        @api.jit
        def chol_gram_blocks(dc_du, dc_dv):
            """Calculate Cholesky factors of decomposition of Gram matrix. """
            if isinstance(metric, IdentityMatrix):
                D = tuple(
                    np.einsum('...ij,...kj', dc_dv[i], dc_dv[i])
                    for i in range(3))
            else:
                m_v = split(
                    metric_2_diag,
                    (dc_dv[0].shape[1], dc_dv[1].shape[0] * dc_dv[1].shape[2]))
                m_v[1] = m_v[1].reshape((dc_dv[1].shape[0], dc_dv[1].shape[2]))
                D = tuple(
                    np.einsum('...ij,...kj', dc_dv[i] /
                              m_v[i][..., None, :], dc_dv[i])
                    for i in range(3))
            chol_D = tuple(nla.cholesky(D[i]) for i in range(3))
            D_inv_dc_du = tuple(
                sla.cho_solve((chol_D[i], True), dc_du[i]) for i in range(3))
            chol_C = nla.cholesky(metric_1 + (
                dc_du[0].T @ D_inv_dc_du[0] +
                np.einsum('ijk,ijl->kl', dc_du[1], D_inv_dc_du[1]) +
                dc_du[2].T @ D_inv_dc_du[2]))
            return chol_C, chol_D

        @api.jit
        def log_det_sqrt_gram_from_chol(chol_C, chol_D):
            """Calculate log-det of Gram matrix from Cholesky factors."""
            return (sum(
                np.log(np.abs(chol_D[i].diagonal(0, -2, -1))).sum()
                for i in range(3)) + np.log(np.abs(chol_C.diagonal())).sum() -
                    log_det_sqrt_metric_1)

        @api.partial(api.jit, static_argnums=(2, ))
        def log_det_sqrt_gram(q, x_obs_seq, partition=0):
            """Calculate log-determinant of constraint Jacobian Gram matrix."""
            dc_du, dc_dv = jacob_constr_blocks(q, x_obs_seq, partition)
            chol_C, chol_D = chol_gram_blocks(dc_du, dc_dv)
            return (log_det_sqrt_gram_from_chol(chol_C, chol_D),
                    ((dc_du, dc_dv), (chol_C, chol_D)))

        @api.jit
        def lmult_by_jacob_constr(dc_du, dc_dv, vct):
            """Left-multiply vector by constraint Jacobian matrix."""
            vct_u, vct_v = split(vct, (dim_z, ))
            j0, j1, j2 = dc_dv[0].shape[1], dc_dv[1].shape[0], dc_dv[2].shape[
                1]
            return (np.vstack((dc_du[0], dc_du[1].reshape(
                (-1, dim_z)), dc_du[2])) @ vct_u + np.concatenate(
                    (dc_dv[0] @ vct_v[:j0],
                     np.einsum('ijk,ik->ij', dc_dv[1],
                               np.reshape(vct_v[j0:-j2], (j1, -1))).flatten(),
                     dc_dv[2] @ vct_v[-j2:])))

        @api.jit
        def rmult_by_jacob_constr(dc_du, dc_dv, vct):
            """Right-multiply vector by constraint Jacobian matrix."""
            vct_parts = split(
                vct,
                (dc_du[0].shape[0], dc_du[1].shape[0] * dc_du[1].shape[1]))
            vct_parts[1] = np.reshape(vct_parts[1], dc_du[1].shape[:2])
            return np.concatenate([
                vct_parts[0] @ dc_du[0] +
                np.einsum('ij,ijk->k', vct_parts[1], dc_du[1]) +
                vct_parts[2] @ dc_du[2], vct_parts[0] @ dc_dv[0],
                np.einsum('ij,ijk->ik', vct_parts[1],
                          dc_dv[1]).flatten(), vct_parts[2] @ dc_dv[2]
            ])

        @api.jit
        def lmult_by_inv_gram(dc_du, dc_dv, chol_C, chol_D, vct):
            """Left-multiply vector by inverse Gram matrix."""
            vct_parts = split(
                vct,
                (dc_du[0].shape[0], dc_du[1].shape[0] * dc_du[1].shape[1]))
            vct_parts[1] = np.reshape(vct_parts[1], dc_du[1].shape[:2])
            D_inv_vct = [
                sla.cho_solve((chol_D[i], True), vct_parts[i])
                for i in range(3)
            ]
            dc_du_T_D_inv_vct = sum(
                np.einsum('...jk,...j->k', dc_du[i], D_inv_vct[i])
                for i in range(3))
            C_inv_dc_du_T_D_inv_vct = sla.cho_solve((chol_C, True),
                                                    dc_du_T_D_inv_vct)
            return np.concatenate([
                sla.cho_solve((chol_D[i], True), vct_parts[i] -
                              dc_du[i] @ C_inv_dc_du_T_D_inv_vct).flatten()
                for i in range(3)
            ])

        @api.jit
        def normal_space_component(vct, dc_du, dc_dv, chol_C, chol_D):
            return rmult_by_jacob_constr(
                dc_du, dc_dv,
                lmult_by_inv_gram(dc_du, dc_dv, chol_C, chol_D,
                                  lmult_by_jacob_constr(dc_du, dc_dv, vct)))

        @api.partial(api.jit, static_argnums=(2, 7, 8, 9, 10))
        def quasi_newton_projection(q, x_obs_seq, partition, dc_du_prev,
                                    dc_dv_prev, chol_C_prev, chol_D_prev,
                                    convergence_tol, position_tol,
                                    divergence_tol, max_iters):

            norm = lambda x: np.max(np.abs(x))

            def body_func(val):
                q, i, _, _ = val
                c = constr(q, x_obs_seq, partition)
                error = norm(c)
                delta_q = rmult_by_jacob_constr(
                    dc_du_prev, dc_dv_prev,
                    lmult_by_inv_gram(dc_du_prev, dc_dv_prev, chol_C_prev,
                                      chol_D_prev, c))
                q -= delta_q
                i += 1
                return q, i, norm(delta_q), error

            def cond_func(val):
                q, i, norm_delta_q, error, = val
                diverged = np.logical_or(error > divergence_tol,
                                         np.isnan(error))
                converged = np.logical_and(error < convergence_tol,
                                           norm_delta_q < position_tol)
                return np.logical_not(
                    np.logical_or((i >= max_iters),
                                  np.logical_or(diverged, converged)))

            return lax.while_loop(cond_func, body_func, (q, 0, np.inf, -1.))

        self._generate_x_obs_seq = generate_x_obs_seq
        self._constr = constr
        self._jacob_constr_blocks = jacob_constr_blocks
        self._chol_gram_blocks = chol_gram_blocks
        self._log_det_sqrt_gram_from_chol = log_det_sqrt_gram_from_chol
        self._grad_log_det_sqrt_gram = api.jit(
            api.value_and_grad(log_det_sqrt_gram, has_aux=True), (2, ))
        self.value_and_grad_init_objective = api.jit(
            api.value_and_grad(init_objective, (0, ), has_aux=True))
        self._normal_space_component = normal_space_component
        self.quasi_newton_projection = quasi_newton_projection