Example #1
0
class SparseCooMatrix:
    """Sparse matrix in COO form."""

    values: hints.Array
    """Non-zero matrix values. Shape should be `(*, N)`."""
    coords: SparseCooCoordinates
    """Row and column indices of non-zero entries. Shapes should be `(*, N)`."""
    shape: Tuple[int, int] = jdc.static_field()
    """Shape of matrix."""

    # Shape checks break under vmap
    # def __post_init__(self):
    #     print(self)
    #     assert self.coords.rows.shape == self.coords.cols.shape == self.values.shape
    #     assert len(self.shape) == 2

    def __matmul__(self, other: hints.Array):
        """Compute `Ax`, where `x` is a 1D vector."""
        assert other.shape == (
            self.shape[1], ), "Inner product only supported for 1D vectors!"
        return (jnp.zeros(self.shape[0],
                          dtype=other.dtype).at[self.coords.rows].add(
                              self.values * other[self.coords.cols]))

    def as_dense(self) -> jnp.ndarray:
        """Convert to a dense JAX array."""
        # TODO: untested
        return (jnp.zeros(self.shape).at[self.coords.rows,
                                         self.coords.cols].set(self.values))

    @staticmethod
    def from_scipy_coo_matrix(
            matrix: scipy.sparse.coo_matrix) -> "SparseCooMatrix":
        """Build from a sparse scipy matrix."""
        return SparseCooMatrix(
            values=matrix.data,
            coords=SparseCooCoordinates(
                rows=matrix.row,
                cols=matrix.col,
            ),
            shape=matrix.shape,
        )

    def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix:
        """Convert to a sparse scipy matrix."""
        return scipy.sparse.coo_matrix(
            (self.values, (self.coords.rows, self.coords.cols)),
            shape=self.shape)

    @property
    def T(self):
        """Return transpose of our sparse matrix."""
        return SparseCooMatrix(
            values=self.values,
            coords=SparseCooCoordinates(
                rows=self.coords.cols,
                cols=self.coords.rows,
            ),
            shape=self.shape[::-1],
        )
Example #2
0
class _FactorBase:
    # For why we have two classes:
    # https://github.com/python/mypy/issues/5374#issuecomment-650656381

    variables: Tuple[VariableBase, ...] = jdc.static_field()
    """Variables connected to this factor. 1-to-1, in-order correspondence with
    `VariableValueTuple`."""

    noise_model: noises.NoiseModelBase
    """Noise model."""
Example #3
0
class _NonlinearSolverBase:
    # For why we have two classes:
    # https://github.com/python/mypy/issues/5374#issuecomment-650656381
    """Nonlinear solver interface."""

    verbose: Boolean = jdc.static_field(default=True)
    """Set to `True` to enable printing."""

    linear_solver: sparse.LinearSubproblemSolverBase = jdc.field(
        default_factory=lambda: sparse.CholmodSolver())
    """Solver to use for linear subproblems."""
class FixedIterationGaussNewtonSolver(NonlinearSolverBase[NonlinearSolverState]
                                      ):
    """Alternative version of Gauss-Newton solver, which ignores convergence checks."""

    unroll: bool = jdc.static_field(default=True)

    # To unroll the optimizer loop, we must have a concrete (static) iteration count
    iterations: int = jdc.static_field(default=10)

    @overrides
    def _initialize_state(
        self,
        graph: "StackedFactorGraph",
        initial_assignments: VariableAssignments,
    ) -> NonlinearSolverState:
        # Initialize
        cost, residual_vector = graph.compute_cost(initial_assignments)
        return NonlinearSolverState(
            iterations=0,
            assignments=initial_assignments,
            cost=cost,
            residual_vector=residual_vector,
            done=False,
        )

    @overrides
    def _step(
        self,
        graph: "StackedFactorGraph",
        state_prev: NonlinearSolverState,
    ) -> NonlinearSolverState:
        """Linearize, solve linear subproblem, and update on manifold."""

        self._hcb_print(
            lambda i, cost: f"Iteration #{i}: cost={str(cost)}",
            i=state_prev.iterations,
            cost=state_prev.cost,
        )

        # Linearize graph
        A: sparse.SparseCooMatrix = graph.compute_whitened_residual_jacobian(
            assignments=state_prev.assignments,
            residual_vector=state_prev.residual_vector,
        )
        ATb = -(A.T @ state_prev.residual_vector)

        # Solve linear subproblem
        local_delta_assignments = VariableAssignments(
            storage=self.linear_solver.solve_subproblem(
                A=A,
                ATb=ATb,
                lambd=0.0,
                iteration=state_prev.iterations,
            ),
            storage_layout=graph.local_storage_layout,
        )

        # On-manifold retraction
        assignments = state_prev.assignments.manifold_retract(
            local_delta_assignments=local_delta_assignments, )

        # Check for convergence
        cost, residual_vector = graph.compute_cost(assignments)
        done = state_prev.iterations >= (self.iterations - 1)

        return NonlinearSolverState(
            iterations=state_prev.iterations + 1,
            assignments=assignments,
            cost=cost,
            residual_vector=residual_vector,
            done=done,
        )

    @jax.jit
    @overrides
    def solve(
        self,
        graph: "StackedFactorGraph",
        initial_assignments: VariableAssignments,
    ) -> VariableAssignments:
        """Run MAP inference on a factor graph."""

        # Initialize
        assignments = initial_assignments
        cost, residual_vector = graph.compute_cost(assignments)
        state = self._initialize_state(graph, initial_assignments)

        # Optimization
        if self.unroll:
            for i in range(self.iterations):
                state = self._step(graph, state)
        else:
            state = jax.lax.while_loop(
                cond_fun=lambda state: jnp.logical_not(state.done),
                body_fun=functools.partial(self._step, graph),
                init_val=state,
            )

        self._hcb_print(
            lambda i, cost:
            f"Terminated @ iteration #{i}: cost={str(cost).ljust(15)}",
            i=state.iterations,
            cost=state.cost,
        )

        return state.assignments
Example #5
0
class VariableAssignments:
    """Storage class that maps variables to values."""

    storage: hints.Array
    """Values of variables stacked and flattened. Can either be local or global
    parameterizations, depending on the value of `.storage_layout.local_flag`."""

    storage_layout: StorageLayout = jdc.static_field()
    """Metadata for how variables are stored."""

    @staticmethod
    def make_from_defaults(variables: Iterable[VariableBase]) -> "VariableAssignments":
        """Create an assignment object from the default values corresponding to each variable."""

        # Figure out how variables are stored
        storage_layout = StorageLayout.make(variables, local=False)

        # Stack variable values in order
        storage = jnp.concatenate(
            [
                jnp.tile(
                    jax.jit(variable_type.flatten)(variable_type.get_default_value()),
                    reps=(storage_layout.count_from_variable_type[variable_type],),
                )
                for variable_type in storage_layout.get_variable_types()
            ],
            axis=0,
        )
        assert storage.shape == (storage_layout.dim,)

        return VariableAssignments(storage=storage, storage_layout=storage_layout)

    @staticmethod
    def make_from_dict(
        assignments: Dict[VariableBase, hints.VariableValue],
    ) -> "VariableAssignments":
        """Create an assignment object from a full set of assignments."""

        return VariableAssignments.make_from_partial_dict(
            assignments.keys(), assignments
        )

    @staticmethod
    def make_from_partial_dict(
        variables: Iterable[VariableBase],
        assignments: Dict[VariableBase, hints.VariableValue],
    ) -> "VariableAssignments":
        """Create an assignment object from a variables and assignments. Missing
        assignments are assigned the default variable values."""

        # Figure out how variables are stored
        storage_layout = StorageLayout.make(variables, local=False)

        # Stack variable values in order
        storage = jnp.concatenate(
            [
                jax.jit(variable.flatten)(
                    assignments[variable]
                    if assignments is not None and variable in assignments
                    else variable.get_default_value()
                )
                for variable in storage_layout.get_variables()
            ],
            axis=0,
        )
        assert storage.shape == (storage_layout.dim,)

        return VariableAssignments(storage=storage, storage_layout=storage_layout)

    @functools.partial(jax.jit, static_argnums=1)
    def update_storage_layout(
        self, storage_layout: StorageLayout
    ) -> "VariableAssignments":
        """Returns a new assignments object representing the same variable->value
        mapping, but with an updated storage layout.

        The primary motivation of this method is that the storage layout of an
        assignments object can sometimes be shuffled with respect to the layout
        expected by a graph (StackedFactorGraph)."""

        # No-op if storage layouts already match.
        if self.storage_layout == storage_layout:
            return self

        assert self.storage_layout.dim == storage_layout.dim
        assert self.storage_layout.local_flag == storage_layout.local_flag
        assert set(self.storage_layout.get_variables()) == set(
            storage_layout.get_variables()
        )
        dim = storage_layout.dim
        variables = storage_layout.get_variables()
        local_flag = storage_layout.local_flag

        shuffle_indices = jnp.zeros(dim, dtype=jnp.int32)
        for variable in variables:
            source_index = self.storage_layout.index_from_variable[variable]
            target_index = storage_layout.index_from_variable[variable]
            variable_dim = (
                variable.get_local_parameter_dim()
                if local_flag
                else variable.get_parameter_dim()
            )
            shuffle_indices = shuffle_indices.at[
                target_index : target_index + variable_dim
            ].set(jnp.arange(source_index, source_index + variable_dim))

        new_storage = self.storage[shuffle_indices]
        assert new_storage.shape == self.storage.shape
        return VariableAssignments(storage=new_storage, storage_layout=storage_layout)

    def as_dict(self) -> Dict[VariableBase, hints.VariableValue]:
        """Grab assignments as a variable -> value dictionary."""
        return {v: self.get_value(v) for v in self.get_variables()}

    def __repr__(self):
        value_from_variable = {
            variable: self.get_value(variable) for variable in self.get_variables()
        }
        k: VariableBase

        contents: str = "\n".join(
            [
                f"    {i}.{k.__class__.__name__}: {v}"
                for i, (k, v) in enumerate(value_from_variable.items())
            ]
        )
        return f"VariableAssignments(\n{contents}\n)"

    def get_variables(self) -> Collection[VariableBase]:
        """Helper for iterating over variables."""
        return self.storage_layout.get_variables()

    def get_value(self, variable: VariableBase[VariableValueType]) -> VariableValueType:
        """Get value corresponding to specific variable."""
        index = self.storage_layout.index_from_variable[variable]
        return type(variable).unflatten(
            self.storage[index : index + variable.get_parameter_dim()]
        )

    def get_stacked_value(
        self, variable_type: Type[VariableBase[VariableValueType]]
    ) -> VariableValueType:
        """Get values of all variables corresponding to a specific type."""
        index = self.storage_layout.index_from_variable_type[variable_type]
        count = self.storage_layout.count_from_variable_type[variable_type]
        return jax.vmap(variable_type.unflatten)(
            self.storage[
                index : index + variable_type.get_parameter_dim() * count
            ].reshape((count, variable_type.get_parameter_dim()))
        )

    def set_value(
        self, variable: VariableBase[VariableValueType], value: VariableValueType
    ) -> "VariableAssignments":
        """Update a value corresponding to specific variable."""

        index = self.storage_layout.index_from_variable[variable]
        with jdc.copy_and_mutate(self) as output:
            output.storage = (
                jnp.asarray(output.storage)  # In case storage vector is an onp array
                .at[index : index + type(variable).get_parameter_dim()]
                .set(type(variable).flatten(value))
            )
        return output

    @jax.jit
    def manifold_retract(
        self, local_delta_assignments: "VariableAssignments"
    ) -> "VariableAssignments":
        """Update variables on manifold."""

        # Check that inputs make sense
        assert not self.storage_layout.local_flag
        assert local_delta_assignments.storage_layout.local_flag

        # On-manifold retractions, one variable type at a time!
        new_storage = jnp.zeros_like(self.storage)
        variable_type: Type[VariableBase]
        for variable_type in self.storage_layout.index_from_variable_type.keys():

            # Get locations
            count = self.storage_layout.count_from_variable_type[variable_type]
            storage_index = self.storage_layout.index_from_variable_type[variable_type]
            local_storage_index = (
                local_delta_assignments.storage_layout.index_from_variable_type[
                    variable_type
                ]
            )
            dim = variable_type.get_parameter_dim()
            local_dim = variable_type.get_local_parameter_dim()

            # Get batched variables
            batched_values_flat = self.storage[
                storage_index : storage_index + dim * count
            ].reshape((count, dim))
            batched_deltas = local_delta_assignments.storage[
                local_storage_index : local_storage_index + local_dim * count
            ].reshape((count, local_dim))

            # Batched variable update
            new_storage = new_storage.at[
                storage_index : storage_index + dim * count
            ].set(
                jax.vmap(variable_type.flatten)(
                    jax.vmap(variable_type.manifold_retract)(
                        jax.vmap(variable_type.unflatten)(batched_values_flat),
                        batched_deltas,
                    )
                ).flatten()
            )

        return jdc.replace(self, storage=new_storage)
Example #6
0
class StackedFactorGraph:
    """Dataclass for vectorized factor graph computations.

    Improves runtime by stacking factors based on their group key.
    """

    factor_stacks: List[FactorStack]
    jacobian_coords: sparse.SparseCooCoordinates
    storage_layout: StorageLayout = jdc.static_field()
    local_storage_layout: StorageLayout = jdc.static_field()
    residual_dim: int = jdc.static_field()

    # Shape checks break under vmap
    # def __post_init__(self):
    #     """Check that inputs make sense!"""
    #     for stacked_factor in self.factor_stacks:
    #         N = stacked_factor.num_factors
    #         for value_indices, variable in zip(
    #             stacked_factor.value_indices,
    #             stacked_factor.factor.variables,
    #         ):
    #             assert value_indices.shape == (N, variable.get_parameter_dim())

    def get_variables(self) -> Collection[VariableBase]:
        return self.local_storage_layout.get_variables()

    @staticmethod
    def make(
        factors: Iterable[FactorBase],
        use_onp: bool = True,
    ) -> "StackedFactorGraph":
        """Create a factor graph from a set of factors."""

        # Start by grouping our factors and grabbing a list of (ordered!) variables
        factors_from_group: DefaultDict[GroupKey, List[FactorBase]] = defaultdict(list)
        variables_ordered_set: Dict[VariableBase, None] = {}
        for factor in factors:
            # Each factor is ultimately just a pytree node; in order for a set of
            # factors to be batchable, they must share the same:
            group_key: GroupKey = (
                # (1) Treedef. Note that variables can be different as long as their
                # types are the same.
                jax.tree_structure(factor.anonymize_variables()),
                # (2) Leaf shapes: contained array shapes must match
                tuple(
                    leaf.shape if hasattr(leaf, "shape") else ()
                    for leaf in jax.tree_leaves(factor)
                ),
            )

            # Record factor and variables
            factors_from_group[group_key].append(factor)
            for v in factor.variables:
                variables_ordered_set[v] = None
        variables = list(variables_ordered_set.keys())

        # Fields we want to populate
        stacked_factors: List[FactorStack] = []
        jacobian_coords: List[sparse.SparseCooCoordinates] = []

        # Create storage layout: this describes which parts of our storage object is
        # allocated to each variable
        storage_layout = StorageLayout.make(variables, local=False)
        local_storage_layout = StorageLayout.make(variables, local=True)

        # Prepare each factor group
        residual_offset = 0
        for group_key, group in factors_from_group.items():
            # Make factor stack
            stacked_factors.append(
                FactorStack.make(
                    group,
                    storage_layout,
                    use_onp=use_onp,
                )
            )

            # Compute Jacobian coordinates
            #
            # These should be N pairs of (row, col) indices, where rows correspond to
            # residual indices and columns correspond to local parameter indices
            jacobian_coords.extend(
                FactorStack.compute_jacobian_coords(
                    factors=group,
                    local_storage_layout=local_storage_layout,
                    row_offset=residual_offset,
                )
            )
            residual_offset += stacked_factors[-1].get_residual_dim()

        jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_map(
            lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords
        )

        return StackedFactorGraph(
            factor_stacks=stacked_factors,
            jacobian_coords=jacobian_coords_concat,
            storage_layout=storage_layout,
            local_storage_layout=local_storage_layout,
            residual_dim=residual_offset,
        )

    @jax.jit
    def compute_whitened_residual_vector(
        self, assignments: VariableAssignments
    ) -> jnp.ndarray:
        """Computes flattened+whitened residual vector associated with our factor graph.

        Args:
            assignments (VariableAssignments): Variable assignments.

        Returns:
            jnp.ndarray: Residual vector.
        """

        # Resolve storage layout mismatches. Factor stack computations will raise an
        # assertion error if the storage layout is incorrect.
        assignments = assignments.update_storage_layout(self.storage_layout)

        # Flatten and concatenate residuals from all groups.
        stacked_factor: FactorStack
        residual_vector = jnp.concatenate(
            [
                jax.vmap(
                    type(stacked_factor.factor.noise_model).whiten_residual_vector
                )(
                    stacked_factor.factor.noise_model,
                    stacked_factor.compute_residual_vector(assignments),
                ).flatten()
                for stacked_factor in self.factor_stacks
            ],
            axis=0,
        )
        assert residual_vector.shape == (self.residual_dim,)
        return residual_vector

    @jax.jit
    def compute_cost(
        self, assignments: VariableAssignments
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Compute the sum of squared residuals associated with a factor graph. Also
        returns intermediate (whitened) residual vector.

        Args:
            assignments (VariableAssignments): Variable assignments.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Scalar cost, residual vector.
        """
        residual_vector = self.compute_whitened_residual_vector(assignments)
        cost = jnp.sum(residual_vector**2)
        return cost, residual_vector

    @jax.jit
    def compute_joint_nll(
        self,
        assignments: VariableAssignments,
        include_constants: bool = False,
    ) -> jnp.ndarray:
        """Compute the joint negative log-likelihood density associated with a set of
        variable assignments. Assumes Gaussian noise models.

        Args:
            assignments (VariableAssignments): Variable assignments.
            include_constants (bool): Whether or not to include constant terms.

        Returns:
            jnp.ndarray: Scalar cost negative log-likelihood.
        """

        if include_constants:
            raise NotImplementedError()

        # Add Mahalanobis distance terms to NLL
        joint_nll = jnp.sum(self.compute_whitened_residual_vector(assignments) ** 2)

        # Add log-determinant terms
        for stacked_factor in self.factor_stacks:
            noise_model = stacked_factor.factor.noise_model
            if isinstance(noise_model, noises.Gaussian):
                cov_determinants = -2.0 * jnp.log(
                    jnp.abs(
                        jnp.linalg.det(
                            cast(noises.Gaussian, noise_model).sqrt_precision_matrix
                        )
                    )
                )
            elif isinstance(noise_model, noises.DiagonalGaussian):
                cov_determinants = -2.0 * jnp.log(
                    jnp.abs(
                        jnp.prod(
                            cast(
                                noises.DiagonalGaussian, noise_model
                            ).sqrt_precision_diagonal,
                            axis=-1,
                        )
                    )
                )
            else:
                assert False, f"Joint NLL not supported  for {type(noise_model)}"
            assert cov_determinants.shape == (stacked_factor.num_factors,)

            joint_nll = joint_nll + jnp.sum(cov_determinants)

        return joint_nll

    @jax.jit
    def compute_whitened_residual_jacobian(
        self,
        assignments: VariableAssignments,
        residual_vector: hints.Array,
    ) -> sparse.SparseCooMatrix:
        """Compute the Jacobian of a graph's residual vector with respect to the stacked
        local delta vectors. Shape should be `(residual_dim, local_delta_storage_dim)`."""

        # Resolve storage layout mismatches. Factor stack computations will raise an
        # assertion error if the storage layout is incorrect.
        assignments = assignments.update_storage_layout(self.storage_layout)

        # Linearize factors by group.
        A_values_list: List[jnp.ndarray] = []
        residual_start = 0
        for stacked_factor in self.factor_stacks:
            residual_end = residual_start + stacked_factor.get_residual_dim()
            stacked_residual_vector = residual_vector[
                residual_start:residual_end
            ].reshape(
                (
                    stacked_factor.num_factors,
                    stacked_factor.factor.get_residual_dim(),
                )
            )

            # Compute all Jacobians and whiten.
            for jacobian in stacked_factor.compute_residual_jacobian(assignments):
                A_values_list.append(
                    jax.vmap(type(stacked_factor.factor.noise_model).whiten_jacobian)(
                        stacked_factor.factor.noise_model,
                        jacobian,
                        residual_vector=stacked_residual_vector,
                    )
                )
            residual_start = residual_end
        assert residual_end == self.residual_dim

        # Build Jacobian.
        A = sparse.SparseCooMatrix(
            values=jnp.concatenate([A.flatten() for A in A_values_list]),
            coords=self.jacobian_coords,
            shape=(self.residual_dim, self.local_storage_layout.dim),
        )
        return A

    def solve(
        self,
        initial_assignments: VariableAssignments,
        solver: NonlinearSolverBase = GaussNewtonSolver(),
    ) -> VariableAssignments:
        """Solve MAP inference problem."""
        # Note that the solver will handle storage layout mismatches.
        return solver.solve(graph=self, initial_assignments=initial_assignments)
Example #7
0
class FactorStack(Generic[FactorType]):
    """A set of factors, with their parameters stacked."""

    num_factors: int = jdc.static_field()
    factor: FactorType

    value_indices: Tuple[hints.Array, ...]
    """The storage indices corresponding to the flattened value of each factor input."""

    storage_layout: StorageLayout = jdc.static_field()
    """The layout used to compute the value indices."""

    def __post_init__(self):
        # There should be one set of indices for each variable type.
        assert len(self.value_indices) == len(self.factor.variables)

    # Shape checks break under vmap
    #     # Check that shapes make sense.
    #     for variable, indices in zip(self.factor.variables, self.value_indices):
    #         residual_dim = self.factor.noise_model.get_residual_dim()
    #         assert indices.shape == (
    #             self.num_factors,
    #             variable.get_parameter_dim(),
    #         )
    #         assert residual_dim == self.factor.get_residual_dim()

    @staticmethod
    def make(
        factors: Sequence[FactorType],
        storage_layout: StorageLayout,
        use_onp: bool,
    ) -> "FactorStack[FactorType]":
        """Make a stacked factor."""

        # For one-off computations, onp has much less overhead than jnp.
        jnp = onp if use_onp else globals()["jnp"]

        # Stack factors in our group.
        # This requires that the treedefs of each factor match, which won't be
        # the case when factors are connected to different variables!
        stacked_factor: FactorType = jax.tree_map(
            lambda *arrays: jnp.stack(arrays, axis=0),
            *map(FactorBase.anonymize_variables, factors),  # type: ignore
            # > https://github.com/python/mypy/issues/1317
        )

        # Get indices for each variable of each factor.
        value_indices_list: Tuple[List[onp.ndarray], ...] = tuple(
            [] for _ in range(len(stacked_factor.variables))
        )
        for factor in factors:
            for i, variable in enumerate(factor.variables):
                assert isinstance(
                    variable, type(factors[0].variables[i])
                ), "Variable types of stacked factors must match"
                storage_pos = storage_layout.index_from_variable[variable]
                value_indices_list[i].append(
                    onp.arange(storage_pos, storage_pos + variable.get_parameter_dim())
                )

        # Stack: end result should be Tuple[array of shape (N, parameter_dim), ...].
        value_indices_stacked: Tuple[onp.ndarray, ...] = tuple(
            onp.array(indices) for indices in value_indices_list
        )

        # Record values.
        return FactorStack(
            num_factors=len(factors),
            factor=stacked_factor,
            value_indices=value_indices_stacked,
            storage_layout=storage_layout,
        )

    @staticmethod
    def compute_jacobian_coords(
        factors: Sequence[FactorType],
        local_storage_layout: StorageLayout,
        row_offset: int,
    ) -> List[sparse.SparseCooCoordinates]:
        """Computes Jacobian coordinates for a factor stack. One array of indices per
        variable."""

        variable_types: List[Type[VariableBase]] = [
            type(v) for v in factors[0].variables
        ]

        # Get indices for each variable.
        local_value_indices_list: Tuple[List[onp.ndarray], ...] = tuple(
            [] for _ in range(len(variable_types))
        )
        for factor in factors:
            for i, variable in enumerate(factor.variables):
                # Record local parameterization indices.
                storage_pos = local_storage_layout.index_from_variable[variable]
                local_value_indices_list[i].append(
                    onp.arange(
                        storage_pos,
                        storage_pos + variable.get_local_parameter_dim(),
                    )
                )

        # Stack: end result should be Tuple[array of shape (N, parameter_dim), ...].
        local_value_indices_stacked: Tuple[onp.ndarray, ...] = tuple(
            onp.array(indices) for indices in local_value_indices_list
        )

        # Get residual indices.
        num_factors = len(factors)
        residual_dim = factors[0].get_residual_dim()
        residual_indices = onp.arange(num_factors * residual_dim).reshape(
            (num_factors, residual_dim)
        )

        # Get Jacobian coordinates.
        jacobian_coords: List[sparse.SparseCooCoordinates] = []
        for variable_index, variable_type in enumerate(variable_types):
            variable_dim = variable_type.get_local_parameter_dim()

            coords = onp.stack(
                (
                    # Row indices.
                    onp.broadcast_to(
                        residual_indices[:, :, None],
                        (num_factors, residual_dim, variable_dim),
                    )
                    + row_offset,
                    # Column indices.
                    onp.broadcast_to(
                        local_value_indices_stacked[variable_index][:, None, :],
                        (num_factors, residual_dim, variable_dim),
                    ),
                ),
                axis=-1,
            ).reshape((num_factors * residual_dim * variable_dim, 2))

            jacobian_coords.append(
                sparse.SparseCooCoordinates(
                    rows=coords[:, 0],
                    cols=coords[:, 1],
                )
            )

        return jacobian_coords

    def get_residual_dim(self) -> int:
        return self.factor.get_residual_dim() * self.num_factors

    def compute_residual_vector(self, assignments: VariableAssignments) -> jnp.ndarray:
        """Compute stacked residual vectors.

        Shape of output should be `(N, stacked_factor.factor.get_residual_dim())`.
        """

        assert assignments.storage_layout == self.storage_layout

        # Stack inputs to our factors.
        values_stacked = tuple(
            jax.vmap(type(variable).unflatten)(assignments.storage[indices])
            for variable, indices in zip(self.factor.variables, self.value_indices)
        )

        # Vectorized residual computation.
        # The type of `values_stacked` should match `FactorVariableValues`.
        residual_vector = jax.vmap(type(self.factor).compute_residual_vector)(
            self.factor,
            self.factor.build_variable_value_tuple(values_stacked),
        )
        return residual_vector

    def compute_residual_jacobian(
        self,
        assignments: VariableAssignments,
    ) -> Tuple[jnp.ndarray, ...]:
        """Compute stacked Jacobian matrices, one for each variable.

        Shape of each Jacobian array should be `(N, local parameter dim, residual dim)`.
        """

        assert assignments.storage_layout == self.storage_layout

        # Stack inputs to our factors.
        values_stacked = tuple(
            jax.vmap(variable.unflatten)(assignments.storage[indices])
            for indices, variable in zip(self.value_indices, self.factor.variables)
        )

        # Compute Jacobians wrt local parameterizations.
        # The type of `values_stacked` should match `FactorVariableValues`.
        jacobians = jax.vmap(type(self.factor).compute_residual_jacobians)(
            self.factor,
            self.factor.build_variable_value_tuple(values_stacked),
        )
        return jacobians