def reorder(data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, ncols: Optional[int] = None): assert_coo(data, row, col) perm = reorder_perm(row, col, ncols) return data.take(perm), row.take(perm), col.take(perm)
def symmetrize( data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, ncols: Optional[int] = None, ): """ Get data of `(A + A.T) / 2` assuming `A` has symmetric sparsity. Args: data: values of `A` row: row indices of `A` col: col indices of `A` ncols: number of columns of `A` Returns: `sym_data`, same shape and dtype as `data`. """ perm = reorder_perm(row=col, col=row, ncols=ncols) return (data + data.take(perm)) / 2