Beispiel #1
0
def test_Tr():
    A, B = symbols('A B', commutative=False)
    t = Tr(A * B)
    assert str(t) == 'Tr(A*B)'
Beispiel #2
0
def test_permute():
    A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False)
    t = Tr(A * B * C * D * E * F * G)

    assert t.permute(0).args[0].args == (A, B, C, D, E, F, G)
    assert t.permute(2).args[0].args == (F, G, A, B, C, D, E)
    assert t.permute(4).args[0].args == (D, E, F, G, A, B, C)
    assert t.permute(6).args[0].args == (B, C, D, E, F, G, A)
    assert t.permute(8).args[0].args == t.permute(1).args[0].args

    assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A)
    assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C)
    assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E)
    assert t.permute(-8).args[0].args == t.permute(-1).args[0].args

    t = Tr((A + B) * (B * B) * C * D)
    assert t.permute(2).args[0].args == (C, D, (A + B), (B**2))

    t1 = Tr(A * B)
    t2 = t1.permute(1)
    assert id(t1) != id(t2) and t1 == t2
Beispiel #3
0
def test_permute():
    A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False)
    t = Tr(A*B*C*D*E*F*G)

    assert t.permute(0).args[0].args == (A, B, C, D, E, F, G)
    assert t.permute(2).args[0].args == (F, G, A, B, C, D, E)
    assert t.permute(4).args[0].args == (D, E, F, G, A, B, C)
    assert t.permute(6).args[0].args == (B, C, D, E, F, G, A)
    assert t.permute(8).args[0].args == t.permute(1).args[0].args

    assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A)
    assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C)
    assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E)
    assert t.permute(-8).args[0].args == t.permute(-1).args[0].args

    t = Tr((A + B)*(B*B)*C*D)
    assert t.permute(2).args[0].args == (C, D, (A + B), (B**2))

    t1 = Tr(A*B)
    t2 = t1.permute(1)
    assert id(t1) != id(t2) and t1 == t2
Beispiel #4
0
def test_trace_new():
    a, b, c, d, Y = symbols('a b c d Y')
    A, B, C, D = symbols('A B C D', commutative=False)

    assert Tr(a + b) == a + b
    assert Tr(A + B) == Tr(A) + Tr(B)

    # check trace args not implicitly permuted
    assert Tr(C * D * A * B).args[0].args == (C, D, A, B)

    # check for mul and adds
    assert Tr((a * b) + (c * d)) == (a * b) + (c * d)
    # Tr(scalar*A) = scalar*Tr(A)
    assert Tr(a * A) == a * Tr(A)
    assert Tr(a * A * B * b) == a * b * Tr(A * B)

    # since A is symbol and not commutative
    assert isinstance(Tr(A), Tr)

    # POW
    assert Tr(pow(a, b)) == a**b
    assert isinstance(Tr(pow(A, a)), Tr)

    # Matrix
    M = Matrix([[1, 1], [2, 2]])
    assert Tr(M) == 3

    # test indices in different forms
    # no index
    t = Tr(A)
    assert t.args[1] == Tuple()

    # single index
    t = Tr(A, 0)
    assert t.args[1] == Tuple(0)

    # index in a list
    t = Tr(A, [0])
    assert t.args[1] == Tuple(0)

    t = Tr(A, [0, 1, 2])
    assert t.args[1] == Tuple(0, 1, 2)

    # index is tuple
    t = Tr(A, (0))
    assert t.args[1] == Tuple(0)

    t = Tr(A, (1, 2))
    assert t.args[1] == Tuple(1, 2)

    # trace indices test
    t = Tr((A + B), [2])
    assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2)

    t = Tr(a * A, [2, 3])
    assert t.args[1].args[1] == Tuple(2, 3)

    # class with trace method defined
    # to simulate numpy objects
    class Foo:
        def trace(self):
            return 1

    assert Tr(Foo()) == 1

    # argument test
    # check for value error, when either/both arguments are not provided
    pytest.raises(ValueError, lambda: Tr())
    pytest.raises(ValueError, lambda: Tr(A, 1, 2))

    # non-Expr objects
    assert isinstance(Tr(None), Tr)