def multiply(a: LowRank, b: LowRank): assert_compatible(a, b) if structured(a.left, a.right, b.left, b.right): warn_upmodule( f"Multiplying {a} and {b}: converting factors to dense.", category=ToDenseWarning, ) al, am, ar = B.dense(a.left), B.dense(a.middle), B.dense(a.right) bl, bm, br = B.dense(b.left), B.dense(b.middle), B.dense(b.right) # Pick apart the matrices. al, ar = B.unstack(al, axis=1), B.unstack(ar, axis=1) bl, br = B.unstack(bl, axis=1), B.unstack(br, axis=1) am = [B.unstack(x, axis=0) for x in B.unstack(am, axis=0)] bm = [B.unstack(x, axis=0) for x in B.unstack(bm, axis=0)] # Construct the factors. left = B.stack(*[B.multiply(ali, blk) for ali in al for blk in bl], axis=1) right = B.stack(*[B.multiply(arj, brl) for arj in ar for brl in br], axis=1) middle = B.stack( *[ B.stack(*[amij * bmkl for amij in ami for bmkl in bmk], axis=0) for ami in am for bmk in bm ], axis=0, ) return LowRank(left, right, middle)
def multiply(a: Kronecker, b: Kronecker): left_compatible = B.shape(a.left) == B.shape(b.left) right_compatible = B.shape(a.right) == B.shape(b.right) assert ( left_compatible and right_compatible ), f"Kronecker products {a} and {b} must be compatible, but they are not." assert_compatible(a.left, b.left) assert_compatible(a.right, b.right) return Kronecker(B.multiply(a.left, b.left), B.multiply(a.right, b.right))
def sum(a: LowRank, axis=None): if axis is None: return B.sum( B.sum(B.matmul(a.left, a.middle), axis=0) * B.sum(a.right, axis=0)) elif axis == 0: return B.sum( B.multiply( B.expand_dims(B.sum(a.left, axis=0), axis=0), B.matmul(a.right, a.middle, tr_b=True), ), axis=1, ) elif axis == 1: return B.sum( B.multiply( B.matmul(a.left, a.middle), B.expand_dims(B.sum(a.right, axis=0), axis=0), ), axis=1, ) else: _raise(axis)
def diag(a: LowRank): if structured(a.left, a.right): warn_upmodule( f"Getting the diagonal of {a}: converting the factors to dense.", category=ToDenseWarning, ) diag_len = _diag_len(a) left_mul = B.matmul(a.left, a.middle) return B.sum( B.multiply( B.dense(left_mul)[:diag_len, :], B.dense(a.right)[:diag_len, :]), axis=1, )
def matmul_diag(a, b, tr_a=False, tr_b=False): """Compute the diagonal of the matrix product of `a` and `b`. Args: a (matrix): First matrix. b (matrix): Second matrix. tr_a (bool, optional): Transpose first matrix. Defaults to `False`. tr_b (bool, optional): Transpose second matrix. Defaults to `False`. Returns: vector: Diagonal of matrix product of `a` and `b`. """ a = _tr(a, not tr_a) b = _tr(b, tr_b) return B.sum(B.multiply(a, b), axis=0)
def multiply(a: AbstractMatrix, b: AbstractMatrix): if structured(a, b): warn_upmodule(f"Multiplying {a} and {b}: converting to dense.", category=ToDenseWarning) return Dense(B.multiply(B.dense(a), B.dense(b)))
def multiply(a: Woodbury, b: AbstractMatrix): # Expand out Woodbury matrices. return B.add(B.multiply(a.diag, b), B.multiply(a.lr, b))
def multiply(a: Constant, b: LowRank): assert_compatible(a, b) return LowRank(b.left, b.right, B.multiply(a.const, b.middle))
def __mul__(self, other): return B.multiply(self, other)
def matmul(a: Diagonal, b: Diagonal, tr_a=False, tr_b=False): _assert_composable(a, b, tr_a=tr_a, tr_b=tr_b) return Diagonal(B.multiply(a.diag, b.diag))
def multiply(a: UpperTriangular, b: UpperTriangular): return UpperTriangular(B.multiply(a.mat, b.mat))
def multiply(a: LowerTriangular, b: LowerTriangular): return LowerTriangular(B.multiply(a.mat, b.mat))
def align(a, b): """Align two matrices according to identical columns. Args: a (matrix): First matrix to align. b (matrix): Second matrix to align. Returns: tuple[matrix]: A four tuple. The first two elements are permutations to *align* `a` and `b`. The second two elements are permutations to *join* `a` and `b`. *Important:* The permutations assume that the last column is a column of zeros. """ if B.control_flow.use_cache: a_perm = B.control_flow.get_outcome("align:a_perm") b_perm = B.control_flow.get_outcome("align:b_perm") a_join_perm = B.control_flow.get_outcome("align:a_join_perm") b_join_perm = B.control_flow.get_outcome("align:b_join_perm") return a_perm, b_perm, a_join_perm, b_join_perm def equal(index_a, index_b): dist = B.mean(B.subtract(a[..., :, index_a], b[..., :, index_b])**2) return dist < 1e-10 # We need the norms later on. a_norms = B.sum(B.multiply(a, a), axis=0) b_norms = B.sum(B.multiply(b, b), axis=0) # Perform sorting to enable linear-time algorithm. These need to be regular Python # lists. a_sorted_inds = list(B.argsort(a_norms)) b_sorted_inds = list(B.argsort(b_norms)) a_perm = [] b_perm = [] a_join_perm = [] b_join_perm = [] while a_sorted_inds and b_sorted_inds: # Match at the first index. if equal(a_sorted_inds[0], b_sorted_inds[0]): a_ind = a_sorted_inds.pop(0) b_ind = b_sorted_inds.pop(0) a_perm.append(a_ind) b_perm.append(b_ind) a_join_perm.append(a_ind) b_join_perm.append(-1) # No match. Figure out which should be discarded. elif a_norms[a_sorted_inds[0]] < b_norms[b_sorted_inds[0]]: a_ind = a_sorted_inds.pop(0) a_perm.append(a_ind) b_perm.append(-1) a_join_perm.append(a_ind) b_join_perm.append(-1) else: b_ind = b_sorted_inds.pop(0) a_perm.append(-1) b_perm.append(b_ind) a_join_perm.append(-1) b_join_perm.append(b_ind) # Either `a_sorted_inds` or `b_sorted_inds` can have indices left. if a_sorted_inds: a_perm.extend(a_sorted_inds) b_perm.extend([-1] * len(a_sorted_inds)) a_join_perm.extend(a_sorted_inds) b_join_perm.extend([-1] * len(a_sorted_inds)) if b_sorted_inds: a_perm.extend([-1] * len(b_sorted_inds)) b_perm.extend(b_sorted_inds) a_join_perm.extend([-1] * len(b_sorted_inds)) b_join_perm.extend(b_sorted_inds) B.control_flow.set_outcome("align:a_perm", a_perm) B.control_flow.set_outcome("align:b_perm", b_perm) B.control_flow.set_outcome("align:a_join_perm", a_join_perm) B.control_flow.set_outcome("align:b_join_perm", b_join_perm) return a_perm, b_perm, a_join_perm, b_join_perm
def kron(a: Constant, b: Constant): return Constant(B.multiply(a.const, b.const), *_product_shape(a, b))
def __rmul__(self, other): return B.multiply(other, self)
def multiply(a: Dense, b: Dense): return Dense(B.multiply(a.mat, b.mat))
def multiply(a: Diagonal, b: Diagonal): return Diagonal(B.multiply(a.diag, b.diag))
def multiply(a: UpperTriangular, b: LowerTriangular): return Diagonal(B.multiply(B.diag(a), B.diag(b)))
def test_multiply_const_broadcasting(): assert B.shape(B.multiply(Constant(1, 3, 4), Constant(1, 1, 4))) == (3, 4) assert B.shape(B.multiply(Constant(1, 3, 4), Constant(1, 3, 1))) == (3, 4) with pytest.raises(AssertionError): B.multiply(Constant(1, 3, 4), Constant(1, 4, 4)) B.multiply(Constant(1, 3, 4), Constant(1, 3, 3))
def multiply(a: UpperTriangular, b: Constant): return UpperTriangular(B.multiply(a.mat, b.const))