def test_add_wb(wb1, wb2): with ConditionalContext( structured(wb1.lr.left) or structured(wb2.lr.left), AssertDenseWarning([ "indexing into <diagonal>", "concatenating <diagonal>, <dense>" ]), ): check_bin_op(B.add, wb1, wb2, asserted_type=Woodbury)
def test_add_lr(lr1, lr2): with ConditionalContext( structured(lr1.left) or structured(lr2.left), AssertDenseWarning([ "indexing into <diagonal>", "concatenating <diagonal>, <dense>" ]), ): check_bin_op(B.add, lr1, lr2, asserted_type=LowRank)
def test_matmul_wb(wb1, wb2): with ConditionalContext( structured(wb1.lr.left) or structured(wb2.lr.left) or wb1.lr.rank == 1 == wb2.lr.rank == 1, AssertDenseWarning([ "indexing into <diagonal>", "concatenating <diagonal>, <dense>" ]), ): _check_matmul(wb1, wb2, asserted_type=Woodbury)
def test_matmul_wb_lr(wb1, lr2): with ConditionalContext( structured(wb1.lr.left) or structured(lr2.left) or wb1.lr.rank == lr2.rank == 1, AssertDenseWarning([ "indexing into <diagonal>", "concatenating <diagonal>, <dense>" ]), ): _check_matmul(wb1, lr2, asserted_type=LowRank) _check_matmul(lr2, wb1, asserted_type=LowRank)
def sample(self, state: B.RandomState, num: int = 1): """Sample. Args: state (random state): Random state. num (int): Number of samples. Returns: tuple[random state, tensor]: Random state and sample. """ state, noise = Normal(self.prec).sample(state, num) sample = B.cholsolve(B.chol(self.prec), B.add(noise, self.lam)) # Remove the matrix type if there is no structure. This eases working with # JITs, which aren't happy with matrix types. if not structured(sample): sample = B.dense(sample) return state, sample
def _conditional_warning(mats, message): mats = [mat.left for mat in mats] + [mat.right for mat in mats] return ConditionalContext(structured(*mats), AssertDenseWarning(message))
def test_diag_wb(wb1): warn = AssertDenseWarning("getting the diagonal of <low-rank>") with ConditionalContext(structured(wb1.lr.left, wb1.lr.right), warn): check_un_op(B.diag, wb1)
def test_matmul_diag_lr(lr1, lr2): with ConditionalContext( structured(lr1.left, lr1.right) or structured(lr2.left, lr2.right), AssertDenseWarning("getting the diagonal of <low-rank>"), ): _check_matmul_diag(lr1, lr2)
def test_structured(): assert structured(Diagonal(B.ones(3))) assert not structured(Dense(B.ones(3, 3))) assert not structured(B.ones(3, 3))