Beispiel #1
0
def multiply(a: LowRank, b: LowRank):
    assert_compatible(a, b)

    if structured(a.left, a.right, b.left, b.right):
        warn_upmodule(
            f"Multiplying {a} and {b}: converting factors to dense.",
            category=ToDenseWarning,
        )
    al, am, ar = B.dense(a.left), B.dense(a.middle), B.dense(a.right)
    bl, bm, br = B.dense(b.left), B.dense(b.middle), B.dense(b.right)

    # Pick apart the matrices.
    al, ar = B.unstack(al, axis=1), B.unstack(ar, axis=1)
    bl, br = B.unstack(bl, axis=1), B.unstack(br, axis=1)
    am = [B.unstack(x, axis=0) for x in B.unstack(am, axis=0)]
    bm = [B.unstack(x, axis=0) for x in B.unstack(bm, axis=0)]

    # Construct the factors.
    left = B.stack(*[B.multiply(ali, blk) for ali in al for blk in bl], axis=1)
    right = B.stack(*[B.multiply(arj, brl) for arj in ar for brl in br],
                    axis=1)
    middle = B.stack(
        *[
            B.stack(*[amij * bmkl for amij in ami for bmkl in bmk], axis=0)
            for ami in am for bmk in bm
        ],
        axis=0,
    )

    return LowRank(left, right, middle)
Beispiel #2
0
def multiply(a: Kronecker, b: Kronecker):
    left_compatible = B.shape(a.left) == B.shape(b.left)
    right_compatible = B.shape(a.right) == B.shape(b.right)
    assert (
        left_compatible and right_compatible
    ), f"Kronecker products {a} and {b} must be compatible, but they are not."
    assert_compatible(a.left, b.left)
    assert_compatible(a.right, b.right)
    return Kronecker(B.multiply(a.left, b.left), B.multiply(a.right, b.right))
Beispiel #3
0
def sum(a: LowRank, axis=None):
    if axis is None:
        return B.sum(
            B.sum(B.matmul(a.left, a.middle), axis=0) * B.sum(a.right, axis=0))
    elif axis == 0:
        return B.sum(
            B.multiply(
                B.expand_dims(B.sum(a.left, axis=0), axis=0),
                B.matmul(a.right, a.middle, tr_b=True),
            ),
            axis=1,
        )
    elif axis == 1:
        return B.sum(
            B.multiply(
                B.matmul(a.left, a.middle),
                B.expand_dims(B.sum(a.right, axis=0), axis=0),
            ),
            axis=1,
        )
    else:
        _raise(axis)
Beispiel #4
0
def diag(a: LowRank):
    if structured(a.left, a.right):
        warn_upmodule(
            f"Getting the diagonal of {a}: converting the factors to dense.",
            category=ToDenseWarning,
        )
    diag_len = _diag_len(a)
    left_mul = B.matmul(a.left, a.middle)
    return B.sum(
        B.multiply(
            B.dense(left_mul)[:diag_len, :],
            B.dense(a.right)[:diag_len, :]),
        axis=1,
    )
Beispiel #5
0
def matmul_diag(a, b, tr_a=False, tr_b=False):
    """Compute the diagonal of the matrix product of `a` and `b`.

    Args:
        a (matrix): First matrix.
        b (matrix): Second matrix.
        tr_a (bool, optional): Transpose first matrix. Defaults to `False`.
        tr_b (bool, optional): Transpose second matrix. Defaults to `False`.

    Returns:
        vector: Diagonal of matrix product of `a` and `b`.
    """
    a = _tr(a, not tr_a)
    b = _tr(b, tr_b)
    return B.sum(B.multiply(a, b), axis=0)
Beispiel #6
0
def multiply(a: AbstractMatrix, b: AbstractMatrix):
    if structured(a, b):
        warn_upmodule(f"Multiplying {a} and {b}: converting to dense.",
                      category=ToDenseWarning)
    return Dense(B.multiply(B.dense(a), B.dense(b)))
Beispiel #7
0
def multiply(a: Woodbury, b: AbstractMatrix):
    # Expand out Woodbury matrices.
    return B.add(B.multiply(a.diag, b), B.multiply(a.lr, b))
Beispiel #8
0
def multiply(a: Constant, b: LowRank):
    assert_compatible(a, b)
    return LowRank(b.left, b.right, B.multiply(a.const, b.middle))
Beispiel #9
0
 def __mul__(self, other):
     return B.multiply(self, other)
Beispiel #10
0
def matmul(a: Diagonal, b: Diagonal, tr_a=False, tr_b=False):
    _assert_composable(a, b, tr_a=tr_a, tr_b=tr_b)
    return Diagonal(B.multiply(a.diag, b.diag))
Beispiel #11
0
def multiply(a: UpperTriangular, b: UpperTriangular):
    return UpperTriangular(B.multiply(a.mat, b.mat))
Beispiel #12
0
def multiply(a: LowerTriangular, b: LowerTriangular):
    return LowerTriangular(B.multiply(a.mat, b.mat))
Beispiel #13
0
def align(a, b):
    """Align two matrices according to identical columns.

    Args:
        a (matrix): First matrix to align.
        b (matrix): Second matrix to align.

    Returns:
        tuple[matrix]: A four tuple. The first two elements are permutations to
            *align* `a` and `b`. The second two elements are permutations to *join*
            `a` and `b`. *Important:* The permutations assume that the last column
            is a column of zeros.
    """
    if B.control_flow.use_cache:
        a_perm = B.control_flow.get_outcome("align:a_perm")
        b_perm = B.control_flow.get_outcome("align:b_perm")
        a_join_perm = B.control_flow.get_outcome("align:a_join_perm")
        b_join_perm = B.control_flow.get_outcome("align:b_join_perm")
        return a_perm, b_perm, a_join_perm, b_join_perm

    def equal(index_a, index_b):
        dist = B.mean(B.subtract(a[..., :, index_a], b[..., :, index_b])**2)
        return dist < 1e-10

    # We need the norms later on.
    a_norms = B.sum(B.multiply(a, a), axis=0)
    b_norms = B.sum(B.multiply(b, b), axis=0)

    # Perform sorting to enable linear-time algorithm. These need to be regular Python
    # lists.
    a_sorted_inds = list(B.argsort(a_norms))
    b_sorted_inds = list(B.argsort(b_norms))

    a_perm = []
    b_perm = []
    a_join_perm = []
    b_join_perm = []

    while a_sorted_inds and b_sorted_inds:
        # Match at the first index.
        if equal(a_sorted_inds[0], b_sorted_inds[0]):
            a_ind = a_sorted_inds.pop(0)
            b_ind = b_sorted_inds.pop(0)
            a_perm.append(a_ind)
            b_perm.append(b_ind)
            a_join_perm.append(a_ind)
            b_join_perm.append(-1)
        # No match. Figure out which should be discarded.
        elif a_norms[a_sorted_inds[0]] < b_norms[b_sorted_inds[0]]:
            a_ind = a_sorted_inds.pop(0)
            a_perm.append(a_ind)
            b_perm.append(-1)
            a_join_perm.append(a_ind)
            b_join_perm.append(-1)
        else:
            b_ind = b_sorted_inds.pop(0)
            a_perm.append(-1)
            b_perm.append(b_ind)
            a_join_perm.append(-1)
            b_join_perm.append(b_ind)

    # Either `a_sorted_inds` or `b_sorted_inds` can have indices left.
    if a_sorted_inds:
        a_perm.extend(a_sorted_inds)
        b_perm.extend([-1] * len(a_sorted_inds))
        a_join_perm.extend(a_sorted_inds)
        b_join_perm.extend([-1] * len(a_sorted_inds))
    if b_sorted_inds:
        a_perm.extend([-1] * len(b_sorted_inds))
        b_perm.extend(b_sorted_inds)
        a_join_perm.extend([-1] * len(b_sorted_inds))
        b_join_perm.extend(b_sorted_inds)

    B.control_flow.set_outcome("align:a_perm", a_perm)
    B.control_flow.set_outcome("align:b_perm", b_perm)
    B.control_flow.set_outcome("align:a_join_perm", a_join_perm)
    B.control_flow.set_outcome("align:b_join_perm", b_join_perm)

    return a_perm, b_perm, a_join_perm, b_join_perm
Beispiel #14
0
def kron(a: Constant, b: Constant):
    return Constant(B.multiply(a.const, b.const), *_product_shape(a, b))
Beispiel #15
0
 def __rmul__(self, other):
     return B.multiply(other, self)
Beispiel #16
0
def multiply(a: Dense, b: Dense):
    return Dense(B.multiply(a.mat, b.mat))
Beispiel #17
0
def multiply(a: Diagonal, b: Diagonal):
    return Diagonal(B.multiply(a.diag, b.diag))
Beispiel #18
0
def multiply(a: UpperTriangular, b: LowerTriangular):
    return Diagonal(B.multiply(B.diag(a), B.diag(b)))
Beispiel #19
0
def test_multiply_const_broadcasting():
    assert B.shape(B.multiply(Constant(1, 3, 4), Constant(1, 1, 4))) == (3, 4)
    assert B.shape(B.multiply(Constant(1, 3, 4), Constant(1, 3, 1))) == (3, 4)
    with pytest.raises(AssertionError):
        B.multiply(Constant(1, 3, 4), Constant(1, 4, 4))
        B.multiply(Constant(1, 3, 4), Constant(1, 3, 3))
Beispiel #20
0
def multiply(a: UpperTriangular, b: Constant):
    return UpperTriangular(B.multiply(a.mat, b.const))