def test_Tr(): A, B = symbols('A B', commutative=False) t = Tr(A * B) assert str(t) == 'Tr(A*B)'
def test_permute(): A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False) t = Tr(A * B * C * D * E * F * G) assert t.permute(0).args[0].args == (A, B, C, D, E, F, G) assert t.permute(2).args[0].args == (F, G, A, B, C, D, E) assert t.permute(4).args[0].args == (D, E, F, G, A, B, C) assert t.permute(6).args[0].args == (B, C, D, E, F, G, A) assert t.permute(8).args[0].args == t.permute(1).args[0].args assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A) assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C) assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E) assert t.permute(-8).args[0].args == t.permute(-1).args[0].args t = Tr((A + B) * (B * B) * C * D) assert t.permute(2).args[0].args == (C, D, (A + B), (B**2)) t1 = Tr(A * B) t2 = t1.permute(1) assert id(t1) != id(t2) and t1 == t2
def test_permute(): A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False) t = Tr(A*B*C*D*E*F*G) assert t.permute(0).args[0].args == (A, B, C, D, E, F, G) assert t.permute(2).args[0].args == (F, G, A, B, C, D, E) assert t.permute(4).args[0].args == (D, E, F, G, A, B, C) assert t.permute(6).args[0].args == (B, C, D, E, F, G, A) assert t.permute(8).args[0].args == t.permute(1).args[0].args assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A) assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C) assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E) assert t.permute(-8).args[0].args == t.permute(-1).args[0].args t = Tr((A + B)*(B*B)*C*D) assert t.permute(2).args[0].args == (C, D, (A + B), (B**2)) t1 = Tr(A*B) t2 = t1.permute(1) assert id(t1) != id(t2) and t1 == t2
def test_trace_new(): a, b, c, d, Y = symbols('a b c d Y') A, B, C, D = symbols('A B C D', commutative=False) assert Tr(a + b) == a + b assert Tr(A + B) == Tr(A) + Tr(B) # check trace args not implicitly permuted assert Tr(C * D * A * B).args[0].args == (C, D, A, B) # check for mul and adds assert Tr((a * b) + (c * d)) == (a * b) + (c * d) # Tr(scalar*A) = scalar*Tr(A) assert Tr(a * A) == a * Tr(A) assert Tr(a * A * B * b) == a * b * Tr(A * B) # since A is symbol and not commutative assert isinstance(Tr(A), Tr) # POW assert Tr(pow(a, b)) == a**b assert isinstance(Tr(pow(A, a)), Tr) # Matrix M = Matrix([[1, 1], [2, 2]]) assert Tr(M) == 3 # test indices in different forms # no index t = Tr(A) assert t.args[1] == Tuple() # single index t = Tr(A, 0) assert t.args[1] == Tuple(0) # index in a list t = Tr(A, [0]) assert t.args[1] == Tuple(0) t = Tr(A, [0, 1, 2]) assert t.args[1] == Tuple(0, 1, 2) # index is tuple t = Tr(A, (0)) assert t.args[1] == Tuple(0) t = Tr(A, (1, 2)) assert t.args[1] == Tuple(1, 2) # trace indices test t = Tr((A + B), [2]) assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2) t = Tr(a * A, [2, 3]) assert t.args[1].args[1] == Tuple(2, 3) # class with trace method defined # to simulate numpy objects class Foo: def trace(self): return 1 assert Tr(Foo()) == 1 # argument test # check for value error, when either/both arguments are not provided pytest.raises(ValueError, lambda: Tr()) pytest.raises(ValueError, lambda: Tr(A, 1, 2)) # non-Expr objects assert isinstance(Tr(None), Tr)