def test_commutator_expand_evaluate(): """Test expansion and evaluation of commutators""" hs = LocalSpace("0") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) C = OperatorSymbol('C', hs=hs) D = OperatorSymbol('D', hs=hs) E = OperatorSymbol('E', hs=hs) expr = Commutator(A, B * C * D * E) res = (B * C * D * Commutator(A, E) + B * C * Commutator(A, D) * E + B * Commutator(A, C) * D * E + Commutator(A, B) * C * D * E) assert expand_commutators_leibniz(expr) == res assert expr.doit([Commutator]) == (A * B * C * D * E - B * C * D * E * A) assert res.doit([Commutator ]).expand() == (A * B * C * D * E - B * C * D * E * A) assert expand_commutators_leibniz(expr, expand_expr=False) == ( B * (C * (D * Commutator(A, E) + Commutator(A, D) * E) + Commutator(A, C) * D * E) + Commutator(A, B) * C * D * E) expr = Commutator(A * B * C, D) assert expand_commutators_leibniz(expr) == (A * B * Commutator(C, D) + A * Commutator(B, D) * C + Commutator(A, D) * B * C) expr = Commutator(A * B, C * D) assert expand_commutators_leibniz(expr) == (A * Commutator(B, C) * D + C * A * Commutator(B, D) + C * Commutator(A, D) * B + Commutator(A, C) * B * D)
def test_exception_teardown(): """Test that teardown works when breaking out due to an exception""" class TemporaryRulesException(Exception): pass h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) hs_repr = "LocalSpace('h1')" rule_name = 'extra' rule = (pattern_head(6, a), lambda: b) simplifications = OperatorPlus.simplifications try: with temporary_rules(ScalarTimesOperator, OperatorPlus): ScalarTimesOperator.add_rule(rule_name, rule[0], rule[1]) OperatorPlus.simplifications.remove(scalars_to_op) raise TemporaryRulesException except TemporaryRulesException: assert rule not in ScalarTimesOperator._rules.values() assert scalars_to_op in OperatorPlus.simplifications finally: # Even if this failed we don't want to make a mess for other tests try: ScalarTimesOperator.del_rules(rule_name) except KeyError: pass OperatorPlus.simplifications = simplifications
def test_series_expand(): """Test series expension of commutator""" hs = LocalSpace("0") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) a3, a2, a1, a0, b3, b2, b1, b0, t, t0 = symbols( 'a_3, a_2, a_1, a_0, b_3, b_2, b_1, b_0, t, t_0') A_form = (a3 * t**3 + a2 * t**2 + a1 * t + a0) * A B_form = (b3 * t**3 + b2 * t**2 + b1 * t + b0) * B comm = Commutator.create(A_form, B_form) terms = comm.series_expand(t, 0, 2) assert terms == ( a0 * b0 * Commutator(A, B), (a0 * b1 + a1 * b0) * Commutator(A, B), (a0 * b2 + a1 * b1 + a2 * b0) * Commutator(A, B), ) A_form = (a1 * t + a0) * A B_form = (b1 * t + b0) * B comm = Commutator.create(A_form, B_form) terms = comm.series_expand(t, t0, 1) assert terms == ( ((a0 * b0 + a0 * b1 * t0 + a1 * b0 * t0 + a1 * b1 * t0**2) * Commutator(A, B)), (a0 * b1 + a1 * b0 + 2 * a1 * b1 * t0) * Commutator(A, B), ) comm = Commutator.create(A, B) terms = comm.series_expand(t, t0, 1) assert terms == (Commutator(A, B), ZeroOperator)
def test_finditer(): h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'}) c_local = Create(hs=h1_custom) expr = 2 * (a * b * c - b * c * a + a * b) pat = wc('sym', head=OperatorSymbol) for m in pat.finditer(expr): assert 'sym' in m matches = list(pat.finditer(expr)) assert len(matches) == 8 op_symbols = [m['sym'] for m in matches] assert set(op_symbols) == {a, b, c} op = wc(head=Operator) three_factors = pattern(OperatorTimes, op, op, op).findall(expr) assert three_factors == [a * b * c, b * c * a] assert len(list(pattern(LocalOperator).finditer(expr))) == 0 assert ( len( list( pattern(LocalOperator).finditer(expr.substitute({c: c_local})) ) ) == 2 )
def test_exception_teardown(): """Test that teardown works when breaking out due to an exception""" class InstanceCachingException(Exception): pass h1 = LocalSpace("caching") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) expr1 = a + b instance_caching = Expression.instance_caching try: with no_instance_caching(): expr2 = a + c raise InstanceCachingException except InstanceCachingException: expr3 = b + c assert expr1 in OperatorPlus._instances.values() assert expr2 not in OperatorPlus._instances.values() assert expr3 in OperatorPlus._instances.values() finally: # Even if this failed we don't want to make a mess for other tests Expression.instance_caching = instance_caching instances = OperatorPlus._instances try: with temporary_instance_cache(OperatorPlus): expr2 = a + c raise InstanceCachingException except InstanceCachingException: assert expr1 in OperatorPlus._instances.values() assert expr2 not in OperatorPlus._instances.values() finally: # Even if this failed we don't want to make a mess for other tests OperatorPlus._instances = instances
def disjunct_commutative_test_data(): A1 = OperatorSymbol("A", hs=1) B1 = OperatorSymbol("B", hs=1) C1 = OperatorSymbol("C", hs=1) A2 = OperatorSymbol("A", hs=2) B2 = OperatorSymbol("B", hs=2) A3 = OperatorSymbol("A", hs=3) B4 = OperatorSymbol("B", hs=4) tr_A1 = tr(A1, over_space=1) tr_A2 = tr(A2, over_space=2) A1_m = OperatorSymbol("A", hs=LocalSpace(1, order_index=2)) B1_m = OperatorSymbol("B", hs=LocalSpace(1, order_index=2)) B2_m = OperatorSymbol("B", hs=LocalSpace(2, order_index=1)) ket_0 = BasisKet(0, hs=1) ket_1 = BasisKet(1, hs=1) ketbra = KetBra(ket_0, ket_1) braket = BraKet(ket_1, ket_1) # fmt: off return [ ([B2, B1, A1], [B1, A1, B2]), ([B2_m, B1_m, A1_m], [B2_m, B1_m, A1_m]), ([B1_m, A1_m, B2_m], [B2_m, B1_m, A1_m]), ([B1, A2, C1, tr_A2], [tr_A2, B1, C1, A2]), ([A1, B1 + B2], [A1, B1 + B2]), ([B1 + B2, A1], [B1 + B2, A1]), ([A3 + B4, A1 + A2], [A1 + A2, A3 + B4]), ([A1 + A2, A3 + B4], [A1 + A2, A3 + B4]), ([B4 + A3, A2 + A1], [A1 + A2, A3 + B4]), ([tr_A2, tr_A1], [tr_A1, tr_A2]), ([A2, ketbra, A1], [ketbra, A1, A2]), ([A2, braket, A1], [braket, A1, A2]), ]
def test_commutator_hs(): """Test that commutator is in the correct Hilbert space""" hs1 = LocalSpace("1") hs2 = LocalSpace("2") A = OperatorSymbol('A', hs=hs1) B = OperatorSymbol('B', hs=hs2) C = OperatorSymbol('C', hs=hs2) assert Commutator.create(B, C).space == hs2 assert Commutator.create(B, A + C).space == hs1 * hs2
def test_commutator_oder(): """Test anti-commutativity of commutators""" hs = LocalSpace("0") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) assert Commutator.create(B, A) == -Commutator(A, B) a = Destroy(hs=hs) a_dag = Create(hs=hs) assert Commutator.create(a, a_dag) == -Commutator.create(a_dag, a)
def test_pull_out_scalars(): """Test that scalars are properly pulled out of commutators""" hs = LocalSpace("sys") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) alpha, beta = symbols('alpha, beta') assert Commutator.create(alpha * A, B) == alpha * Commutator(A, B) assert Commutator.create(A, beta * B) == beta * Commutator(A, B) assert Commutator.create(alpha * A, beta * B) == alpha * beta * Commutator(A, B)
def test_disjunct_hs(): """Test that commutator of objects in disjunt Hilbert spaces is zero""" hs1 = LocalSpace("1") hs2 = LocalSpace("2") alpha, beta = symbols('alpha, beta') A = OperatorSymbol('A', hs=hs1) B = OperatorSymbol('B', hs=hs2) assert Commutator.create(A, B) == ZeroOperator assert Commutator.create(alpha, beta) == ZeroOperator assert Commutator.create(alpha, B) == ZeroOperator assert Commutator.create(A, beta) == ZeroOperator
def test_diff(): """Test differentiation of commutators""" hs = LocalSpace("0") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) alpha, t = symbols('alpha, t') assert Commutator(alpha * t**2 * A, t * B).diff(t) == (3 * alpha * t**2 * Commutator(A, B)) assert Commutator.create(alpha * t**2 * A, t * B).diff(t) == (3 * alpha * t**2 * Commutator(A, B)) assert Commutator(A, B).diff(t) == ZeroOperator
def testSPreSPostRules(self): h1 = LocalSpace("h1") h2 = LocalSpace("h2") d = OperatorSymbol("d", hs=h1) e = OperatorSymbol("e", hs=h1) dpre = SPre(d) epre = SPre(e) dpost = SPost(d) epost = SPost(e) assert dpre * epre == SPre(d * e) assert dpost * epost == SPost(e * d) assert dpost * epre == SPre(e) * SPost(d)
def testCombination(self): h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) A = SuperOperatorSymbol("A", hs=h1) B = SuperOperatorSymbol("B", hs=h1) assert A * (B * a) == (A * B) * a
def testEqual2(self): h1 = LocalSpace("h1") A = SuperOperatorSymbol("A", hs=h1) a = OperatorSymbol("a", hs=h1) OTO = SuperOperatorTimesOperator(A, a) assert A * a == OTO
def test_symbol(): expN = OperatorSymbol("expN", hs=1) hs1 = LocalSpace("sym1", dimension=10) hs2 = LocalSpace("sym2", dimension=5) N = Create(hs=hs1) * Destroy(hs=hs1) M = Create(hs=hs2) * Destroy(hs=hs2) converter1 = {expN: convert_to_qutip(N).expm()} expNq = convert_to_qutip(expN, mapping=converter1) assert (np.linalg.norm(expNq.data.toarray() - (convert_to_qutip(N).expm().data.toarray())) < 1e-8) expNMq = convert_to_qutip(expN * M, mapping=converter1) assert (np.linalg.norm(expNMq.data.toarray() - (qutip.tensor( convert_to_qutip(N).expm(), convert_to_qutip(M)).data.toarray())) < 1e-8) converter2 = {expN: lambda expr: convert_to_qutip(N).expm()} expNq = convert_to_qutip(expN, mapping=converter2) assert (np.linalg.norm(expNq.data.toarray() - (convert_to_qutip(N).expm().data.toarray())) < 1e-8) expNMq = convert_to_qutip(expN * M, mapping=converter1) assert (np.linalg.norm(expNMq.data.toarray() - (qutip.tensor( convert_to_qutip(N).expm(), convert_to_qutip(M)).data.toarray())) < 1e-8)
def test_extra_rules(): """Test creation of expr with extra rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) hs_repr = "LocalSpace('h1')" rule = (pattern_head(6, a), lambda: b) with temporary_rules(ScalarTimesOperator): ScalarTimesOperator.add_rule('extra', rule[0], rule[1]) assert ('extra', rule) in ScalarTimesOperator._rules.items() expr = 2 * a * 3 + 3 * (2 * a * 3) assert expr == 4 * b assert rule not in ScalarTimesOperator._rules.values() assert (srepr(2 * a * 3 + 3 * (2 * a * 3)) == "ScalarTimesOperator(ScalarValue(24), " "OperatorSymbol('a', hs=" + hs_repr + "))")
def test_commutator_expansion(): """Test expansion of sums in commutator""" hs = LocalSpace("0") A = OperatorSymbol('A', hs=hs) B = OperatorSymbol('B', hs=hs) C = OperatorSymbol('C', hs=hs) D = OperatorSymbol('D', hs=hs) alpha = symbols('alpha') assert Commutator(A + B, C).expand() == Commutator(A, C) + Commutator(B, C) assert Commutator(A, B + C).expand() == Commutator(A, B) + Commutator(A, C) assert Commutator(A + B, C + D).expand() == (Commutator(A, C) + Commutator(A, D) + Commutator(B, C) + Commutator(B, D)) assert Commutator(A + B, C + D + alpha).expand() == (Commutator(A, C) + Commutator(A, D) + Commutator(B, C) + Commutator(B, D))
def test_proto_expr_as_sequence(): """Test sequence interface of proto-expressions""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) proto_expr = ProtoExpr.from_expr(a) assert len(proto_expr) == 2 assert proto_expr[0] == 'a' assert proto_expr[1] == h1
def test_context_instance_caching(): """Test that we can temporarily suppress instance caching""" h1 = LocalSpace("caching") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) expr1 = a + b assert expr1 in OperatorPlus._instances.values() with no_instance_caching(): assert expr1 in OperatorPlus._instances.values() expr2 = a + c assert expr2 not in OperatorPlus._instances.values() with temporary_instance_cache(OperatorPlus): assert len(OperatorPlus._instances) == 0 expr2 = a + c assert expr2 in OperatorPlus._instances.values() assert expr1 in OperatorPlus._instances.values() assert expr2 not in OperatorPlus._instances.values()
def test_simplify(): """Test simplification of expr according to manual rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) d = OperatorSymbol("d", hs=h1) expr = 2 * (a * b * c - b * c * a) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) C_ = wc('C', head=Operator) def b_times_c_equal_d(B, C): if B.label == 'b' and C.label == 'c': return d else: raise CannotSimplify with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = expr.rebuild() commutator_rule = ( pattern( OperatorPlus, pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: OperatorSymbol("Commut%s%s" % (A.label.upper(), B.label.upper()), hs=A.space), ) assert commutator_rule[0].match(new_expr.term) with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = _apply_rules(expr, [commutator_rule]) assert (srepr(new_expr) == "ScalarTimesOperator(ScalarValue(2), OperatorSymbol('CommutAD', " "hs=LocalSpace('h1')))")
def testZeroOne(self): h1 = LocalSpace("h1") h2 = LocalSpace("h2") a = OperatorSymbol("a", hs=h1) B = SuperOperatorSymbol("B", hs=h2) z = ZeroSuperOperator one = IdentitySuperOperator assert one * a == a assert z * a == ZeroOperator
def test_findall(): h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'}) c_local = Create(hs=h1_custom) expr = 2 * (a * b * c - b * c * a + a * b) op_symbols = pattern(OperatorSymbol).findall(expr) assert len(op_symbols) == 8 assert set(op_symbols) == {a, b, c} op = wc(head=Operator) three_factors = pattern(OperatorTimes, op, op, op).findall(expr) assert three_factors == [a * b * c, b * c * a] assert len(pattern(LocalOperator).findall(expr)) == 0 assert ( len(pattern(LocalOperator).findall(expr.substitute({c: c_local}))) == 2 )
def test_custom_repr(): A = OperatorSymbol('A', hs=1) assert repr(A) in ['Â⁽¹⁾', 'A^(1)'] init_printing(repr_format='srepr', reset=True) assert repr(A) == "OperatorSymbol('A', hs=LocalSpace('1'))" init_printing(reset=True) assert repr(A) in ['Â⁽¹⁾', 'A^(1)'] with configure_printing(repr_format='srepr'): assert repr(A) == "OperatorSymbol('A', hs=LocalSpace('1'))" assert repr(A) in ['Â⁽¹⁾', 'A^(1)']
def test_sympy_setting(): """Test that we can pass settings to the sympy sub-printer""" x = symbols('a') A = OperatorSymbol("A", hs=1) expr = atan(x) * A assert latex(expr) == r'\operatorname{atan}{\left(a \right)} \hat{A}^{(1)}' assert ( latex(expr, inv_trig_style='full') == r'\arctan{\left(a \right)} \hat{A}^{(1)}' )
def test_sympy_tex_cached(): """Test that we can use the cache to change how sub-expressions of sympy are printed in tex""" a = symbols('a') A = OperatorSymbol("A", hs=1) expr = (a ** 2 / 2) * A assert latex(expr) == r'\frac{a^{2}}{2} \hat{A}^{(1)}' cache = {a: r'\alpha'} assert latex(expr, cache=cache) == r'\frac{\alpha^{2}}{2} \hat{A}^{(1)}'
def test_extra_binary_rules(): """Test creation of expr with extra binary rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) rule = ( pattern_head( pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: c, ) with temporary_rules(OperatorPlus): OperatorPlus.add_rule('extra', rule[0], rule[1]) assert ('extra', rule) in OperatorPlus._binary_rules.items() expr = 2 * (a * b - b * a + IdentityOperator) assert expr == 2 * (c + IdentityOperator) assert rule not in OperatorPlus._binary_rules.values()
def test_substitute_sub_expr(H_JC): """Test that we can replace non-atomic sub-expressions""" hil_a = LocalSpace('A') hil_b = LocalSpace('B') omega_a, omega_b, g = symbols('omega_a, omega_b, g') a = Destroy(hs=hil_a) a_dag = a.dag() b = Destroy(hs=hil_b) b_dag = b.dag() n_op_a = OperatorSymbol('n', hs=hil_a) n_op_b = OperatorSymbol('n', hs=hil_b) x_op = OperatorSymbol('x', hs=H_JC.space) mapping = { a_dag * a: n_op_a, b_dag * b: n_op_b, (a_dag * b + b_dag * a): x_op + x_op.dag(), } H2_expected = (omega_a * n_op_a + omega_b * n_op_b + 2 * g * (x_op + x_op.dag())) H2 = H_JC.substitute(mapping) assert H2 == H2_expected H2 = substitute(H_JC, mapping) assert H2 == H2_expected
def test_exception_teardown(): """Test that teardown works when breaking out due to an exception""" class ConfigurePrintingException(Exception): pass init_printing(show_hs_label=True, repr_format='ascii') try: with configure_printing(show_hs_label=False, repr_format='srepr'): raise ConfigurePrintingException except ConfigurePrintingException: A = OperatorSymbol('A', hs=1) assert repr(A) == 'A^(1)' finally: # Even if this failed we don't want to make a mess for other tests init_printing(reset=True)
def test_no_rules(): """Test creation of expr when rule application for one or more operation is suppressed""" A, B = (OperatorSymbol(s, hs=0) for s in ('A', 'B')) expr = lambda: Commutator.create(2 * A, 2 * (3 * B)) myrepr = lambda e: srepr(e, cache={A: 'A', B: 'B'}) assert (myrepr( expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))') with temporary_rules(ScalarTimesOperator, clear=True): assert (myrepr(expr()) == 'ScalarTimesOperator(ScalarValue(4), ' 'ScalarTimesOperator(ScalarValue(3), Commutator(A, B)))') with temporary_rules(Commutator, clear=True): assert (myrepr( expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), ' 'ScalarTimesOperator(ScalarValue(6), B))') with temporary_rules(Commutator, ScalarTimesOperator, clear=True): assert (myrepr( expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), ' 'ScalarTimesOperator(ScalarValue(2), ' 'ScalarTimesOperator(ScalarValue(3), B)))') assert (myrepr( expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))')
pattern_head('O', FullSpace, a=1, b=2, conditions=[true_cond]), Pattern( args=['O', FullSpace], kwargs={'a': 1, 'b': 2}, conditions=[true_cond], ), ), ] for pat1, pat2 in patterns: print(repr(pat1)) assert pat1 == pat2 # test expressions two_t = 2 * Symbol('t') two_O = 2 * OperatorSymbol('O', hs=FullSpace) proto_two_O = ProtoExpr([2, OperatorSymbol('O', hs=FullSpace)], {}) proto_kwargs = ProtoExpr([1, 2], {'a': '3', 'b': 4}) proto_kw_only = ProtoExpr([], {'a': 1, 'b': 2}) proto_ints2 = ProtoExpr([1, 2], {}) proto_ints3 = ProtoExpr([1, 2, 3], {}) proto_ints4 = ProtoExpr([1, 2, 3, 4], {}) proto_ints5 = ProtoExpr([1, 2, 3, 4, 5], {}) # test patterns and wildcards wc_a_int_2 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 2]) wc_a_int_3 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 3]) wc_a_int = wc('a', head=int) wc_label_str = wc('label', head=str) wc_hs = wc('space', head=HilbertSpace) pattern_two_O = pattern(