def test_temporary_simplifications(): """Test that we can locally modify the simplifications class attrib""" assert scalars_to_op in OperatorPlus.simplifications with temporary_rules(OperatorPlus): OperatorPlus.simplifications.remove(scalars_to_op) assert scalars_to_op not in OperatorPlus.simplifications assert scalars_to_op in OperatorPlus.simplifications
def test_exception_teardown(): """Test that teardown works when breaking out due to an exception""" class TemporaryRulesException(Exception): pass h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) hs_repr = "LocalSpace('h1')" rule_name = 'extra' rule = (pattern_head(6, a), lambda: b) simplifications = OperatorPlus.simplifications try: with temporary_rules(ScalarTimesOperator, OperatorPlus): ScalarTimesOperator.add_rule(rule_name, rule[0], rule[1]) OperatorPlus.simplifications.remove(scalars_to_op) raise TemporaryRulesException except TemporaryRulesException: assert rule not in ScalarTimesOperator._rules.values() assert scalars_to_op in OperatorPlus.simplifications finally: # Even if this failed we don't want to make a mess for other tests try: ScalarTimesOperator.del_rules(rule_name) except KeyError: pass OperatorPlus.simplifications = simplifications
def test_simplify(): """Test simplification of expr according to manual rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) d = OperatorSymbol("d", hs=h1) expr = 2 * (a * b * c - b * c * a) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) C_ = wc('C', head=Operator) def b_times_c_equal_d(B, C): if B.label == 'b' and C.label == 'c': return d else: raise CannotSimplify with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = expr.rebuild() commutator_rule = ( pattern( OperatorPlus, pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: OperatorSymbol("Commut%s%s" % (A.label.upper(), B.label.upper()), hs=A.space), ) assert commutator_rule[0].match(new_expr.term) with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = _apply_rules(expr, [commutator_rule]) assert (srepr(new_expr) == "ScalarTimesOperator(ScalarValue(2), OperatorSymbol('CommutAD', " "hs=LocalSpace('h1')))")
def rewrite_with_operator_pm_cc(expr): """Try to rewrite expr using :class:`.OperatorPlusMinusCC`. Example:: >>> A = OperatorSymbol('A', hs=1) >>> sum = A + A.dag() >>> sum2 = rewrite_with_operator_pm_cc(sum) >>> print(ascii(sum2)) A^(1) + c.c. """ # TODO: move this to the toolbox from qalgebra.toolbox.core import temporary_rules def _combine_operator_p_cc(A, B): if B.adjoint() == A: return OperatorPlusMinusCC(A, sign=+1) else: raise CannotSimplify def _combine_operator_m_cc(A, B): if B.adjoint() == A: return OperatorPlusMinusCC(A, sign=-1) else: raise CannotSimplify def _scal_combine_operator_pm_cc(c, A, d, B): if B.adjoint() == A: if c == d: return c * OperatorPlusMinusCC(A, sign=+1) elif c == -d: return c * OperatorPlusMinusCC(A, sign=-1) raise CannotSimplify A = wc("A", head=Operator) B = wc("B", head=Operator) c = wc("c", head=Scalar) d = wc("d", head=Scalar) with temporary_rules(OperatorPlus, clear=True): OperatorPlus.add_rule('PM1', pattern_head(A, B), _combine_operator_p_cc) OperatorPlus.add_rule( 'PM2', pattern_head(pattern(ScalarTimesOperator, -1, B), A), _combine_operator_m_cc, ) OperatorPlus.add_rule( 'PM3', pattern_head( pattern(ScalarTimesOperator, c, A), pattern(ScalarTimesOperator, d, B), ), _scal_combine_operator_pm_cc, ) return expr.rebuild()
def test_no_rules(): """Test creation of expr when rule application for one or more operation is suppressed""" A, B = (OperatorSymbol(s, hs=0) for s in ('A', 'B')) expr = lambda: Commutator.create(2 * A, 2 * (3 * B)) myrepr = lambda e: srepr(e, cache={A: 'A', B: 'B'}) assert (myrepr( expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))') with temporary_rules(ScalarTimesOperator, clear=True): assert (myrepr(expr()) == 'ScalarTimesOperator(ScalarValue(4), ' 'ScalarTimesOperator(ScalarValue(3), Commutator(A, B)))') with temporary_rules(Commutator, clear=True): assert (myrepr( expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), ' 'ScalarTimesOperator(ScalarValue(6), B))') with temporary_rules(Commutator, ScalarTimesOperator, clear=True): assert (myrepr( expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), ' 'ScalarTimesOperator(ScalarValue(2), ' 'ScalarTimesOperator(ScalarValue(3), B)))') assert (myrepr( expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))')
def test_extra_rules(): """Test creation of expr with extra rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) hs_repr = "LocalSpace('h1')" rule = (pattern_head(6, a), lambda: b) with temporary_rules(ScalarTimesOperator): ScalarTimesOperator.add_rule('extra', rule[0], rule[1]) assert ('extra', rule) in ScalarTimesOperator._rules.items() expr = 2 * a * 3 + 3 * (2 * a * 3) assert expr == 4 * b assert rule not in ScalarTimesOperator._rules.values() assert (srepr(2 * a * 3 + 3 * (2 * a * 3)) == "ScalarTimesOperator(ScalarValue(24), " "OperatorSymbol('a', hs=" + hs_repr + "))")
def test_extra_binary_rules(): """Test creation of expr with extra binary rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) rule = ( pattern_head( pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: c, ) with temporary_rules(OperatorPlus): OperatorPlus.add_rule('extra', rule[0], rule[1]) assert ('extra', rule) in OperatorPlus._binary_rules.items() expr = 2 * (a * b - b * a + IdentityOperator) assert expr == 2 * (c + IdentityOperator) assert rule not in OperatorPlus._binary_rules.values()