示例#1
0
def reorder(data: jnp.ndarray,
            row: jnp.ndarray,
            col: jnp.ndarray,
            ncols: Optional[int] = None):
    assert_coo(data, row, col)
    perm = reorder_perm(row, col, ncols)
    return data.take(perm), row.take(perm), col.take(perm)
示例#2
0
def symmetrize(
    data: jnp.ndarray,
    row: jnp.ndarray,
    col: jnp.ndarray,
    ncols: Optional[int] = None,
):
    """
    Get data of `(A + A.T) / 2` assuming `A` has symmetric sparsity.

    Args:
        data: values of `A`
        row: row indices of `A`
        col: col indices of `A`
        ncols: number of columns of `A`

    Returns:
        `sym_data`, same shape and dtype as `data`.
    """
    perm = reorder_perm(row=col, col=row, ncols=ncols)
    return (data + data.take(perm)) / 2