def test_tier_0_additions_done_in_tier_0(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   diag3 = linalg.LinearOperatorDiag([1.])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
       [_BadAdder()],
   ]
   # Should not raise since all were added in tier 0, and tier 1 (with the
   # _BadAdder) was never reached.
   op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
Exemple #2
0
 def test_intermediate_tier_is_not_skipped(self):
     diag1 = linalg.LinearOperatorDiag([1.])
     diag2 = linalg.LinearOperatorDiag([1.])
     tril = linalg.LinearOperatorLowerTriangular([[1.]])
     addition_tiers = [
         [linear_operator_addition._AddAndReturnDiag()],
         [_BadAdder()],
         [linear_operator_addition._AddAndReturnTriL()],
     ]
     # tril cannot be added in tier 0, and the intermediate tier 1 with the
     # BadAdder will catch it and raise.
     with self.assertRaisesRegex(AssertionError, "BadAdder.can_add called"):
         add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
 def test_intermediate_tier_is_not_skipped(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   tril = linalg.LinearOperatorLowerTriangular([[1.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
       [_BadAdder()],
       [linear_operator_addition._AddAndReturnTriL()],
   ]
   # tril cannot be added in tier 0, and the intermediate tier 1 with the
   # BadAdder will catch it and raise.
   with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
     add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
 def test_tier_0_additions_done_in_tier_0(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   diag3 = linalg.LinearOperatorDiag([1.])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
       [_BadAdder()],
   ]
   # Should not raise since all were added in tier 0, and tier 1 (with the
   # _BadAdder) was never reached.
   op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
 def test_tier_1_additions_done_by_tier_1(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   tril = linalg.LinearOperatorLowerTriangular([[1.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
       [linear_operator_addition._AddAndReturnTriL()],
       [_BadAdder()],
   ]
   # Should not raise since all were added by tier 1, and the
   # _BadAdder) was never reached.
   op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
 def test_tier_1_additions_done_by_tier_1(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   tril = linalg.LinearOperatorLowerTriangular([[1.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
       [linear_operator_addition._AddAndReturnTriL()],
       [_BadAdder()],
   ]
   # Should not raise since all were added by tier 1, and the
   # _BadAdder) was never reached.
   op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
 def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   tril = linalg.LinearOperatorLowerTriangular([[1.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnTriL()],
       [linear_operator_addition._AddAndReturnDiag()],
       [_BadAdder()],
   ]
   # Tier 0 could convert to TriL, and this converted everything to TriL,
   # including the Diags.
   # Tier 1 was never used.
   # Tier 2 was never used (therefore, _BadAdder didn't raise).
   op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
 def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([1.])
   tril = linalg.LinearOperatorLowerTriangular([[1.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnTriL()],
       [linear_operator_addition._AddAndReturnDiag()],
       [_BadAdder()],
   ]
   # Tier 0 could convert to TriL, and this converted everything to TriL,
   # including the Diags.
   # Tier 1 was never used.
   # Tier 2 was never used (therefore, _BadAdder didn't raise).
   op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
   self.assertEqual(1, len(op_sum))
   self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
 def test_cannot_add_everything_so_return_more_than_one_operator(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([2.])
   tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
   ]
   # Tier 0 (the only tier) can only convert to Diag, so it combines the two
   # diags, but the TriL is unchanged.
   # Result should contain two operators, one Diag, one TriL.
   op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
   self.assertEqual(2, len(op_sum))
   found_diag = False
   found_tril = False
   with self.test_session():
     for op in op_sum:
       if isinstance(op, linalg.LinearOperatorDiag):
         found_diag = True
         self.assertAllClose([[3.]], op.to_dense().eval())
       if isinstance(op, linalg.LinearOperatorLowerTriangular):
         found_tril = True
         self.assertAllClose([[5.]], op.to_dense().eval())
     self.assertTrue(found_diag and found_tril)
 def test_cannot_add_everything_so_return_more_than_one_operator(self):
   diag1 = linalg.LinearOperatorDiag([1.])
   diag2 = linalg.LinearOperatorDiag([2.])
   tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
   addition_tiers = [
       [linear_operator_addition._AddAndReturnDiag()],
   ]
   # Tier 0 (the only tier) can only convert to Diag, so it combines the two
   # diags, but the TriL is unchanged.
   # Result should contain two operators, one Diag, one TriL.
   op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
   self.assertEqual(2, len(op_sum))
   found_diag = False
   found_tril = False
   with self.cached_session():
     for op in op_sum:
       if isinstance(op, linalg.LinearOperatorDiag):
         found_diag = True
         self.assertAllClose([[3.]], op.to_dense())
       if isinstance(op, linalg.LinearOperatorLowerTriangular):
         found_tril = True
         self.assertAllClose([[5.]], op.to_dense())
     self.assertTrue(found_diag and found_tril)
Exemple #11
0
 def setUp(self):
     self._adder = linear_operator_addition._AddAndReturnDiag()
 def setUp(self):
   self._adder = linear_operator_addition._AddAndReturnDiag()