def test_tier_0_additions_done_in_tier_0(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) diag3 = linalg.LinearOperatorDiag([1.]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], ] # Should not raise since all were added in tier 0, and tier 1 (with the # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], [linear_operator_addition._AddAndReturnTriL()], ] # tril cannot be added in tier 0, and the intermediate tier 1 with the # BadAdder will catch it and raise. with self.assertRaisesRegex(AssertionError, "BadAdder.can_add called"): add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], [linear_operator_addition._AddAndReturnTriL()], ] # tril cannot be added in tier 0, and the intermediate tier 1 with the # BadAdder will catch it and raise. with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], [_BadAdder()], ] # Should not raise since all were added by tier 1, and the # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], ] # Tier 0 could convert to TriL, and this converted everything to TriL, # including the Diags. # Tier 1 was never used. # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
def test_cannot_add_everything_so_return_more_than_one_operator(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([2.]) tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], ] # Tier 0 (the only tier) can only convert to Diag, so it combines the two # diags, but the TriL is unchanged. # Result should contain two operators, one Diag, one TriL. op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) self.assertEqual(2, len(op_sum)) found_diag = False found_tril = False with self.test_session(): for op in op_sum: if isinstance(op, linalg.LinearOperatorDiag): found_diag = True self.assertAllClose([[3.]], op.to_dense().eval()) if isinstance(op, linalg.LinearOperatorLowerTriangular): found_tril = True self.assertAllClose([[5.]], op.to_dense().eval()) self.assertTrue(found_diag and found_tril)
def test_cannot_add_everything_so_return_more_than_one_operator(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([2.]) tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], ] # Tier 0 (the only tier) can only convert to Diag, so it combines the two # diags, but the TriL is unchanged. # Result should contain two operators, one Diag, one TriL. op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) self.assertEqual(2, len(op_sum)) found_diag = False found_tril = False with self.cached_session(): for op in op_sum: if isinstance(op, linalg.LinearOperatorDiag): found_diag = True self.assertAllClose([[3.]], op.to_dense()) if isinstance(op, linalg.LinearOperatorLowerTriangular): found_tril = True self.assertAllClose([[5.]], op.to_dense()) self.assertTrue(found_diag and found_tril)
def setUp(self): self._adder = linear_operator_addition._AddAndReturnDiag()