def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], [linear_operator_addition._AddAndReturnTriL()], ] # tril cannot be added in tier 0, and the intermediate tier 1 with the # BadAdder will catch it and raise. with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], [linear_operator_addition._AddAndReturnTriL()], ] # tril cannot be added in tier 0, and the intermediate tier 1 with the # BadAdder will catch it and raise. with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], [_BadAdder()], ] # Should not raise since all were added by tier 1, and the # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL))
def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], [_BadAdder()], ] # Should not raise since all were added by tier 1, and the # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL))
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], ] # Tier 0 could convert to TriL, and this converted everything to TriL, # including the Diags. # Tier 1 was never used. # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL))
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) tril = linalg.LinearOperatorTriL([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], ] # Tier 0 could convert to TriL, and this converted everything to TriL, # including the Diags. # Tier 1 was never used. # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL))
def setUp(self): self._adder = linear_operator_addition._AddAndReturnTriL()
def setUp(self): self._adder = linear_operator_addition._AddAndReturnTriL()