class SparseCooMatrix: """Sparse matrix in COO form.""" values: hints.Array """Non-zero matrix values. Shape should be `(*, N)`.""" coords: SparseCooCoordinates """Row and column indices of non-zero entries. Shapes should be `(*, N)`.""" shape: Tuple[int, int] = jdc.static_field() """Shape of matrix.""" # Shape checks break under vmap # def __post_init__(self): # print(self) # assert self.coords.rows.shape == self.coords.cols.shape == self.values.shape # assert len(self.shape) == 2 def __matmul__(self, other: hints.Array): """Compute `Ax`, where `x` is a 1D vector.""" assert other.shape == ( self.shape[1], ), "Inner product only supported for 1D vectors!" return (jnp.zeros(self.shape[0], dtype=other.dtype).at[self.coords.rows].add( self.values * other[self.coords.cols])) def as_dense(self) -> jnp.ndarray: """Convert to a dense JAX array.""" # TODO: untested return (jnp.zeros(self.shape).at[self.coords.rows, self.coords.cols].set(self.values)) @staticmethod def from_scipy_coo_matrix( matrix: scipy.sparse.coo_matrix) -> "SparseCooMatrix": """Build from a sparse scipy matrix.""" return SparseCooMatrix( values=matrix.data, coords=SparseCooCoordinates( rows=matrix.row, cols=matrix.col, ), shape=matrix.shape, ) def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix: """Convert to a sparse scipy matrix.""" return scipy.sparse.coo_matrix( (self.values, (self.coords.rows, self.coords.cols)), shape=self.shape) @property def T(self): """Return transpose of our sparse matrix.""" return SparseCooMatrix( values=self.values, coords=SparseCooCoordinates( rows=self.coords.cols, cols=self.coords.rows, ), shape=self.shape[::-1], )
class _FactorBase: # For why we have two classes: # https://github.com/python/mypy/issues/5374#issuecomment-650656381 variables: Tuple[VariableBase, ...] = jdc.static_field() """Variables connected to this factor. 1-to-1, in-order correspondence with `VariableValueTuple`.""" noise_model: noises.NoiseModelBase """Noise model."""
class _NonlinearSolverBase: # For why we have two classes: # https://github.com/python/mypy/issues/5374#issuecomment-650656381 """Nonlinear solver interface.""" verbose: Boolean = jdc.static_field(default=True) """Set to `True` to enable printing.""" linear_solver: sparse.LinearSubproblemSolverBase = jdc.field( default_factory=lambda: sparse.CholmodSolver()) """Solver to use for linear subproblems."""
class FixedIterationGaussNewtonSolver(NonlinearSolverBase[NonlinearSolverState] ): """Alternative version of Gauss-Newton solver, which ignores convergence checks.""" unroll: bool = jdc.static_field(default=True) # To unroll the optimizer loop, we must have a concrete (static) iteration count iterations: int = jdc.static_field(default=10) @overrides def _initialize_state( self, graph: "StackedFactorGraph", initial_assignments: VariableAssignments, ) -> NonlinearSolverState: # Initialize cost, residual_vector = graph.compute_cost(initial_assignments) return NonlinearSolverState( iterations=0, assignments=initial_assignments, cost=cost, residual_vector=residual_vector, done=False, ) @overrides def _step( self, graph: "StackedFactorGraph", state_prev: NonlinearSolverState, ) -> NonlinearSolverState: """Linearize, solve linear subproblem, and update on manifold.""" self._hcb_print( lambda i, cost: f"Iteration #{i}: cost={str(cost)}", i=state_prev.iterations, cost=state_prev.cost, ) # Linearize graph A: sparse.SparseCooMatrix = graph.compute_whitened_residual_jacobian( assignments=state_prev.assignments, residual_vector=state_prev.residual_vector, ) ATb = -(A.T @ state_prev.residual_vector) # Solve linear subproblem local_delta_assignments = VariableAssignments( storage=self.linear_solver.solve_subproblem( A=A, ATb=ATb, lambd=0.0, iteration=state_prev.iterations, ), storage_layout=graph.local_storage_layout, ) # On-manifold retraction assignments = state_prev.assignments.manifold_retract( local_delta_assignments=local_delta_assignments, ) # Check for convergence cost, residual_vector = graph.compute_cost(assignments) done = state_prev.iterations >= (self.iterations - 1) return NonlinearSolverState( iterations=state_prev.iterations + 1, assignments=assignments, cost=cost, residual_vector=residual_vector, done=done, ) @jax.jit @overrides def solve( self, graph: "StackedFactorGraph", initial_assignments: VariableAssignments, ) -> VariableAssignments: """Run MAP inference on a factor graph.""" # Initialize assignments = initial_assignments cost, residual_vector = graph.compute_cost(assignments) state = self._initialize_state(graph, initial_assignments) # Optimization if self.unroll: for i in range(self.iterations): state = self._step(graph, state) else: state = jax.lax.while_loop( cond_fun=lambda state: jnp.logical_not(state.done), body_fun=functools.partial(self._step, graph), init_val=state, ) self._hcb_print( lambda i, cost: f"Terminated @ iteration #{i}: cost={str(cost).ljust(15)}", i=state.iterations, cost=state.cost, ) return state.assignments
class VariableAssignments: """Storage class that maps variables to values.""" storage: hints.Array """Values of variables stacked and flattened. Can either be local or global parameterizations, depending on the value of `.storage_layout.local_flag`.""" storage_layout: StorageLayout = jdc.static_field() """Metadata for how variables are stored.""" @staticmethod def make_from_defaults(variables: Iterable[VariableBase]) -> "VariableAssignments": """Create an assignment object from the default values corresponding to each variable.""" # Figure out how variables are stored storage_layout = StorageLayout.make(variables, local=False) # Stack variable values in order storage = jnp.concatenate( [ jnp.tile( jax.jit(variable_type.flatten)(variable_type.get_default_value()), reps=(storage_layout.count_from_variable_type[variable_type],), ) for variable_type in storage_layout.get_variable_types() ], axis=0, ) assert storage.shape == (storage_layout.dim,) return VariableAssignments(storage=storage, storage_layout=storage_layout) @staticmethod def make_from_dict( assignments: Dict[VariableBase, hints.VariableValue], ) -> "VariableAssignments": """Create an assignment object from a full set of assignments.""" return VariableAssignments.make_from_partial_dict( assignments.keys(), assignments ) @staticmethod def make_from_partial_dict( variables: Iterable[VariableBase], assignments: Dict[VariableBase, hints.VariableValue], ) -> "VariableAssignments": """Create an assignment object from a variables and assignments. Missing assignments are assigned the default variable values.""" # Figure out how variables are stored storage_layout = StorageLayout.make(variables, local=False) # Stack variable values in order storage = jnp.concatenate( [ jax.jit(variable.flatten)( assignments[variable] if assignments is not None and variable in assignments else variable.get_default_value() ) for variable in storage_layout.get_variables() ], axis=0, ) assert storage.shape == (storage_layout.dim,) return VariableAssignments(storage=storage, storage_layout=storage_layout) @functools.partial(jax.jit, static_argnums=1) def update_storage_layout( self, storage_layout: StorageLayout ) -> "VariableAssignments": """Returns a new assignments object representing the same variable->value mapping, but with an updated storage layout. The primary motivation of this method is that the storage layout of an assignments object can sometimes be shuffled with respect to the layout expected by a graph (StackedFactorGraph).""" # No-op if storage layouts already match. if self.storage_layout == storage_layout: return self assert self.storage_layout.dim == storage_layout.dim assert self.storage_layout.local_flag == storage_layout.local_flag assert set(self.storage_layout.get_variables()) == set( storage_layout.get_variables() ) dim = storage_layout.dim variables = storage_layout.get_variables() local_flag = storage_layout.local_flag shuffle_indices = jnp.zeros(dim, dtype=jnp.int32) for variable in variables: source_index = self.storage_layout.index_from_variable[variable] target_index = storage_layout.index_from_variable[variable] variable_dim = ( variable.get_local_parameter_dim() if local_flag else variable.get_parameter_dim() ) shuffle_indices = shuffle_indices.at[ target_index : target_index + variable_dim ].set(jnp.arange(source_index, source_index + variable_dim)) new_storage = self.storage[shuffle_indices] assert new_storage.shape == self.storage.shape return VariableAssignments(storage=new_storage, storage_layout=storage_layout) def as_dict(self) -> Dict[VariableBase, hints.VariableValue]: """Grab assignments as a variable -> value dictionary.""" return {v: self.get_value(v) for v in self.get_variables()} def __repr__(self): value_from_variable = { variable: self.get_value(variable) for variable in self.get_variables() } k: VariableBase contents: str = "\n".join( [ f" {i}.{k.__class__.__name__}: {v}" for i, (k, v) in enumerate(value_from_variable.items()) ] ) return f"VariableAssignments(\n{contents}\n)" def get_variables(self) -> Collection[VariableBase]: """Helper for iterating over variables.""" return self.storage_layout.get_variables() def get_value(self, variable: VariableBase[VariableValueType]) -> VariableValueType: """Get value corresponding to specific variable.""" index = self.storage_layout.index_from_variable[variable] return type(variable).unflatten( self.storage[index : index + variable.get_parameter_dim()] ) def get_stacked_value( self, variable_type: Type[VariableBase[VariableValueType]] ) -> VariableValueType: """Get values of all variables corresponding to a specific type.""" index = self.storage_layout.index_from_variable_type[variable_type] count = self.storage_layout.count_from_variable_type[variable_type] return jax.vmap(variable_type.unflatten)( self.storage[ index : index + variable_type.get_parameter_dim() * count ].reshape((count, variable_type.get_parameter_dim())) ) def set_value( self, variable: VariableBase[VariableValueType], value: VariableValueType ) -> "VariableAssignments": """Update a value corresponding to specific variable.""" index = self.storage_layout.index_from_variable[variable] with jdc.copy_and_mutate(self) as output: output.storage = ( jnp.asarray(output.storage) # In case storage vector is an onp array .at[index : index + type(variable).get_parameter_dim()] .set(type(variable).flatten(value)) ) return output @jax.jit def manifold_retract( self, local_delta_assignments: "VariableAssignments" ) -> "VariableAssignments": """Update variables on manifold.""" # Check that inputs make sense assert not self.storage_layout.local_flag assert local_delta_assignments.storage_layout.local_flag # On-manifold retractions, one variable type at a time! new_storage = jnp.zeros_like(self.storage) variable_type: Type[VariableBase] for variable_type in self.storage_layout.index_from_variable_type.keys(): # Get locations count = self.storage_layout.count_from_variable_type[variable_type] storage_index = self.storage_layout.index_from_variable_type[variable_type] local_storage_index = ( local_delta_assignments.storage_layout.index_from_variable_type[ variable_type ] ) dim = variable_type.get_parameter_dim() local_dim = variable_type.get_local_parameter_dim() # Get batched variables batched_values_flat = self.storage[ storage_index : storage_index + dim * count ].reshape((count, dim)) batched_deltas = local_delta_assignments.storage[ local_storage_index : local_storage_index + local_dim * count ].reshape((count, local_dim)) # Batched variable update new_storage = new_storage.at[ storage_index : storage_index + dim * count ].set( jax.vmap(variable_type.flatten)( jax.vmap(variable_type.manifold_retract)( jax.vmap(variable_type.unflatten)(batched_values_flat), batched_deltas, ) ).flatten() ) return jdc.replace(self, storage=new_storage)
class StackedFactorGraph: """Dataclass for vectorized factor graph computations. Improves runtime by stacking factors based on their group key. """ factor_stacks: List[FactorStack] jacobian_coords: sparse.SparseCooCoordinates storage_layout: StorageLayout = jdc.static_field() local_storage_layout: StorageLayout = jdc.static_field() residual_dim: int = jdc.static_field() # Shape checks break under vmap # def __post_init__(self): # """Check that inputs make sense!""" # for stacked_factor in self.factor_stacks: # N = stacked_factor.num_factors # for value_indices, variable in zip( # stacked_factor.value_indices, # stacked_factor.factor.variables, # ): # assert value_indices.shape == (N, variable.get_parameter_dim()) def get_variables(self) -> Collection[VariableBase]: return self.local_storage_layout.get_variables() @staticmethod def make( factors: Iterable[FactorBase], use_onp: bool = True, ) -> "StackedFactorGraph": """Create a factor graph from a set of factors.""" # Start by grouping our factors and grabbing a list of (ordered!) variables factors_from_group: DefaultDict[GroupKey, List[FactorBase]] = defaultdict(list) variables_ordered_set: Dict[VariableBase, None] = {} for factor in factors: # Each factor is ultimately just a pytree node; in order for a set of # factors to be batchable, they must share the same: group_key: GroupKey = ( # (1) Treedef. Note that variables can be different as long as their # types are the same. jax.tree_structure(factor.anonymize_variables()), # (2) Leaf shapes: contained array shapes must match tuple( leaf.shape if hasattr(leaf, "shape") else () for leaf in jax.tree_leaves(factor) ), ) # Record factor and variables factors_from_group[group_key].append(factor) for v in factor.variables: variables_ordered_set[v] = None variables = list(variables_ordered_set.keys()) # Fields we want to populate stacked_factors: List[FactorStack] = [] jacobian_coords: List[sparse.SparseCooCoordinates] = [] # Create storage layout: this describes which parts of our storage object is # allocated to each variable storage_layout = StorageLayout.make(variables, local=False) local_storage_layout = StorageLayout.make(variables, local=True) # Prepare each factor group residual_offset = 0 for group_key, group in factors_from_group.items(): # Make factor stack stacked_factors.append( FactorStack.make( group, storage_layout, use_onp=use_onp, ) ) # Compute Jacobian coordinates # # These should be N pairs of (row, col) indices, where rows correspond to # residual indices and columns correspond to local parameter indices jacobian_coords.extend( FactorStack.compute_jacobian_coords( factors=group, local_storage_layout=local_storage_layout, row_offset=residual_offset, ) ) residual_offset += stacked_factors[-1].get_residual_dim() jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_map( lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords ) return StackedFactorGraph( factor_stacks=stacked_factors, jacobian_coords=jacobian_coords_concat, storage_layout=storage_layout, local_storage_layout=local_storage_layout, residual_dim=residual_offset, ) @jax.jit def compute_whitened_residual_vector( self, assignments: VariableAssignments ) -> jnp.ndarray: """Computes flattened+whitened residual vector associated with our factor graph. Args: assignments (VariableAssignments): Variable assignments. Returns: jnp.ndarray: Residual vector. """ # Resolve storage layout mismatches. Factor stack computations will raise an # assertion error if the storage layout is incorrect. assignments = assignments.update_storage_layout(self.storage_layout) # Flatten and concatenate residuals from all groups. stacked_factor: FactorStack residual_vector = jnp.concatenate( [ jax.vmap( type(stacked_factor.factor.noise_model).whiten_residual_vector )( stacked_factor.factor.noise_model, stacked_factor.compute_residual_vector(assignments), ).flatten() for stacked_factor in self.factor_stacks ], axis=0, ) assert residual_vector.shape == (self.residual_dim,) return residual_vector @jax.jit def compute_cost( self, assignments: VariableAssignments ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute the sum of squared residuals associated with a factor graph. Also returns intermediate (whitened) residual vector. Args: assignments (VariableAssignments): Variable assignments. Returns: Tuple[jnp.ndarray, jnp.ndarray]: Scalar cost, residual vector. """ residual_vector = self.compute_whitened_residual_vector(assignments) cost = jnp.sum(residual_vector**2) return cost, residual_vector @jax.jit def compute_joint_nll( self, assignments: VariableAssignments, include_constants: bool = False, ) -> jnp.ndarray: """Compute the joint negative log-likelihood density associated with a set of variable assignments. Assumes Gaussian noise models. Args: assignments (VariableAssignments): Variable assignments. include_constants (bool): Whether or not to include constant terms. Returns: jnp.ndarray: Scalar cost negative log-likelihood. """ if include_constants: raise NotImplementedError() # Add Mahalanobis distance terms to NLL joint_nll = jnp.sum(self.compute_whitened_residual_vector(assignments) ** 2) # Add log-determinant terms for stacked_factor in self.factor_stacks: noise_model = stacked_factor.factor.noise_model if isinstance(noise_model, noises.Gaussian): cov_determinants = -2.0 * jnp.log( jnp.abs( jnp.linalg.det( cast(noises.Gaussian, noise_model).sqrt_precision_matrix ) ) ) elif isinstance(noise_model, noises.DiagonalGaussian): cov_determinants = -2.0 * jnp.log( jnp.abs( jnp.prod( cast( noises.DiagonalGaussian, noise_model ).sqrt_precision_diagonal, axis=-1, ) ) ) else: assert False, f"Joint NLL not supported for {type(noise_model)}" assert cov_determinants.shape == (stacked_factor.num_factors,) joint_nll = joint_nll + jnp.sum(cov_determinants) return joint_nll @jax.jit def compute_whitened_residual_jacobian( self, assignments: VariableAssignments, residual_vector: hints.Array, ) -> sparse.SparseCooMatrix: """Compute the Jacobian of a graph's residual vector with respect to the stacked local delta vectors. Shape should be `(residual_dim, local_delta_storage_dim)`.""" # Resolve storage layout mismatches. Factor stack computations will raise an # assertion error if the storage layout is incorrect. assignments = assignments.update_storage_layout(self.storage_layout) # Linearize factors by group. A_values_list: List[jnp.ndarray] = [] residual_start = 0 for stacked_factor in self.factor_stacks: residual_end = residual_start + stacked_factor.get_residual_dim() stacked_residual_vector = residual_vector[ residual_start:residual_end ].reshape( ( stacked_factor.num_factors, stacked_factor.factor.get_residual_dim(), ) ) # Compute all Jacobians and whiten. for jacobian in stacked_factor.compute_residual_jacobian(assignments): A_values_list.append( jax.vmap(type(stacked_factor.factor.noise_model).whiten_jacobian)( stacked_factor.factor.noise_model, jacobian, residual_vector=stacked_residual_vector, ) ) residual_start = residual_end assert residual_end == self.residual_dim # Build Jacobian. A = sparse.SparseCooMatrix( values=jnp.concatenate([A.flatten() for A in A_values_list]), coords=self.jacobian_coords, shape=(self.residual_dim, self.local_storage_layout.dim), ) return A def solve( self, initial_assignments: VariableAssignments, solver: NonlinearSolverBase = GaussNewtonSolver(), ) -> VariableAssignments: """Solve MAP inference problem.""" # Note that the solver will handle storage layout mismatches. return solver.solve(graph=self, initial_assignments=initial_assignments)
class FactorStack(Generic[FactorType]): """A set of factors, with their parameters stacked.""" num_factors: int = jdc.static_field() factor: FactorType value_indices: Tuple[hints.Array, ...] """The storage indices corresponding to the flattened value of each factor input.""" storage_layout: StorageLayout = jdc.static_field() """The layout used to compute the value indices.""" def __post_init__(self): # There should be one set of indices for each variable type. assert len(self.value_indices) == len(self.factor.variables) # Shape checks break under vmap # # Check that shapes make sense. # for variable, indices in zip(self.factor.variables, self.value_indices): # residual_dim = self.factor.noise_model.get_residual_dim() # assert indices.shape == ( # self.num_factors, # variable.get_parameter_dim(), # ) # assert residual_dim == self.factor.get_residual_dim() @staticmethod def make( factors: Sequence[FactorType], storage_layout: StorageLayout, use_onp: bool, ) -> "FactorStack[FactorType]": """Make a stacked factor.""" # For one-off computations, onp has much less overhead than jnp. jnp = onp if use_onp else globals()["jnp"] # Stack factors in our group. # This requires that the treedefs of each factor match, which won't be # the case when factors are connected to different variables! stacked_factor: FactorType = jax.tree_map( lambda *arrays: jnp.stack(arrays, axis=0), *map(FactorBase.anonymize_variables, factors), # type: ignore # > https://github.com/python/mypy/issues/1317 ) # Get indices for each variable of each factor. value_indices_list: Tuple[List[onp.ndarray], ...] = tuple( [] for _ in range(len(stacked_factor.variables)) ) for factor in factors: for i, variable in enumerate(factor.variables): assert isinstance( variable, type(factors[0].variables[i]) ), "Variable types of stacked factors must match" storage_pos = storage_layout.index_from_variable[variable] value_indices_list[i].append( onp.arange(storage_pos, storage_pos + variable.get_parameter_dim()) ) # Stack: end result should be Tuple[array of shape (N, parameter_dim), ...]. value_indices_stacked: Tuple[onp.ndarray, ...] = tuple( onp.array(indices) for indices in value_indices_list ) # Record values. return FactorStack( num_factors=len(factors), factor=stacked_factor, value_indices=value_indices_stacked, storage_layout=storage_layout, ) @staticmethod def compute_jacobian_coords( factors: Sequence[FactorType], local_storage_layout: StorageLayout, row_offset: int, ) -> List[sparse.SparseCooCoordinates]: """Computes Jacobian coordinates for a factor stack. One array of indices per variable.""" variable_types: List[Type[VariableBase]] = [ type(v) for v in factors[0].variables ] # Get indices for each variable. local_value_indices_list: Tuple[List[onp.ndarray], ...] = tuple( [] for _ in range(len(variable_types)) ) for factor in factors: for i, variable in enumerate(factor.variables): # Record local parameterization indices. storage_pos = local_storage_layout.index_from_variable[variable] local_value_indices_list[i].append( onp.arange( storage_pos, storage_pos + variable.get_local_parameter_dim(), ) ) # Stack: end result should be Tuple[array of shape (N, parameter_dim), ...]. local_value_indices_stacked: Tuple[onp.ndarray, ...] = tuple( onp.array(indices) for indices in local_value_indices_list ) # Get residual indices. num_factors = len(factors) residual_dim = factors[0].get_residual_dim() residual_indices = onp.arange(num_factors * residual_dim).reshape( (num_factors, residual_dim) ) # Get Jacobian coordinates. jacobian_coords: List[sparse.SparseCooCoordinates] = [] for variable_index, variable_type in enumerate(variable_types): variable_dim = variable_type.get_local_parameter_dim() coords = onp.stack( ( # Row indices. onp.broadcast_to( residual_indices[:, :, None], (num_factors, residual_dim, variable_dim), ) + row_offset, # Column indices. onp.broadcast_to( local_value_indices_stacked[variable_index][:, None, :], (num_factors, residual_dim, variable_dim), ), ), axis=-1, ).reshape((num_factors * residual_dim * variable_dim, 2)) jacobian_coords.append( sparse.SparseCooCoordinates( rows=coords[:, 0], cols=coords[:, 1], ) ) return jacobian_coords def get_residual_dim(self) -> int: return self.factor.get_residual_dim() * self.num_factors def compute_residual_vector(self, assignments: VariableAssignments) -> jnp.ndarray: """Compute stacked residual vectors. Shape of output should be `(N, stacked_factor.factor.get_residual_dim())`. """ assert assignments.storage_layout == self.storage_layout # Stack inputs to our factors. values_stacked = tuple( jax.vmap(type(variable).unflatten)(assignments.storage[indices]) for variable, indices in zip(self.factor.variables, self.value_indices) ) # Vectorized residual computation. # The type of `values_stacked` should match `FactorVariableValues`. residual_vector = jax.vmap(type(self.factor).compute_residual_vector)( self.factor, self.factor.build_variable_value_tuple(values_stacked), ) return residual_vector def compute_residual_jacobian( self, assignments: VariableAssignments, ) -> Tuple[jnp.ndarray, ...]: """Compute stacked Jacobian matrices, one for each variable. Shape of each Jacobian array should be `(N, local parameter dim, residual dim)`. """ assert assignments.storage_layout == self.storage_layout # Stack inputs to our factors. values_stacked = tuple( jax.vmap(variable.unflatten)(assignments.storage[indices]) for indices, variable in zip(self.value_indices, self.factor.variables) ) # Compute Jacobians wrt local parameterizations. # The type of `values_stacked` should match `FactorVariableValues`. jacobians = jax.vmap(type(self.factor).compute_residual_jacobians)( self.factor, self.factor.build_variable_value_tuple(values_stacked), ) return jacobians