Esempio n. 1
0
def test_IndexedBase_shape():
    i, j, m, n = symbols('i j m n', integer=True)
    a = IndexedBase('a', shape=(m, m))
    b = IndexedBase('a', shape=(m, n))
    assert b.shape == Tuple(m, n)
    assert a[i, j] != b[i, j]
    assert a[i, j] == b[i, j].subs(n, m)
    assert b.func(*b.args) == b
    assert b[i, j].func(*b[i, j].args) == b[i, j]
    raises(IndexException, lambda: b[i])
    raises(IndexException, lambda: b[i, i, j])
Esempio n. 2
0
def cse(expr):
    """ symplify a complicated sympy expression
        into a list of expression using the cse
        sympy function
    """
    ls = list(expr.atoms(Sum))
    if not ls:
        return [expr]
    ls += [expr]
    (ls, _) = sympy_cse(ls)

    (vars_old, stmts) = map(list, zip(*ls))
    vars_new = []
    free_gl = expr.free_symbols
    free_gl.update(expr.atoms(IndexedBase))
    free_gl.update(vars_old)
    stmts.append(expr)

    for i in range(len(stmts) - 1):
        free = stmts[i].free_symbols
        free = free.difference(free_gl)
        free = list(free)
        var = create_variable(stmts[i])
        if len(free) > 0:
            var = IndexedBase(var)[free]
        vars_new.append(var)
    for i in range(len(stmts) - 1):
        stmts[i + 1] = stmts[i + 1].replace(vars_old[i], vars_new[i])
        stmts[-1] = stmts[-1].replace(stmts[i], vars_new[i])

    allocate = []
    for i in range(len(stmts) - 1):
        stmts[i] = Assign(vars_new[i], stmts[i])
        stmts[i] = pyccel_sum(stmts[i])
        if isinstance(vars_new[i], Indexed):
            ind = vars_new[i].indices
            tp = list(stmts[i + 1].atoms(Tuple))
            size = None
            size = [None] * len(ind)
            for (j, k) in enumerate(ind):
                for t in tp:
                    if k == t[0]:
                        size[j] = t[2] - t[1] + 1
                        break
            if not all(size):
                raise ValueError('Unable to find range of index')
            name = str(vars_new[i].base)
            var = Symbol(name)
            stmt = Assign(var, Function('empty')(size[0]))
            allocate.append(stmt)
            stmts[i] = For(ind[0],
                           Function('range')(size[0]), [stmts[i]],
                           strict=False)
    lhs = create_variable(expr)
    stmts[-1] = Assign(lhs, stmts[-1])
    imports = [Import('empty', 'numpy')]
    return imports + allocate + stmts
def test_indexed_is_constant():
    A = IndexedBase("A")
    i, j, k = symbols("i,j,k")
    assert not A[i].is_constant()
    assert A[i].is_constant(j)
    assert not A[1+2*i, k].is_constant()
    assert not A[1+2*i, k].is_constant(i)
    assert A[1+2*i, k].is_constant(j)
    assert not A[1+2*i, k].is_constant(k)
Esempio n. 4
0
def IndexedBases(s):
    """
    declare multiple IndexedBase objects
    :param s: string of names seperated by white space
    returns IndxedBase objects as tuple
    """
    l = s.split()
    bases = [IndexedBase(x) for x in l]
    return tuple(bases)
Esempio n. 5
0
def test_Indexed():
    # Issue #10934
    if not numpy:
        skip("numpy not installed")

    a = IndexedBase('a')
    i, j = symbols('i j')
    b = numpy.array([[1, 2], [3, 4]])
    assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10
Esempio n. 6
0
def test_IndexedBase_sugar():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A1 = Indexed(a, i, j)
    A2 = IndexedBase(a)
    assert A1 == A2[i, j]
    assert A1 == A2[(i, j)]
    assert A1 == A2[[i, j]]
    assert A1 == A2[Tuple(i, j)]
Esempio n. 7
0
 def initFromString(self, s):
     if s.lower() == 'kd':
         # Kronecker delta
         self.symbol = IndexedBase('KD')
         self.dic = None
         self.range = None
         self.dim = None
         self.sym = False
     elif s.lower() == 'eps':
         # Kronecker delta
         self.symbol = IndexedBase('Eps')
         self.dic = None
         self.range = None
         self.dim = None
         self.sym = False
     else:
         loggingCritical("Error : Unkown tensor object '{s}'.")
         return
Esempio n. 8
0
def test_CircularOrthogonalEnsemble():
    CO = COE('U', 3)
    j, k = (Dummy('j', integer=True,
                  positive=True), Dummy('k', integer=True, positive=True))
    t = IndexedBase('t')
    assert joint_eigen_distribution(CO).dummy_eq(
        Lambda((t[1], t[2], t[3]),
               Product(Abs(exp(I * t[j]) - exp(I * t[k])), (j, k + 1, 3),
                       (k, 1, 2)) / (48 * pi**2)))
Esempio n. 9
0
def test_tensor_math_ops(free_alg):
    """Test tensor math operations.

    Mainly here we test addition, multiplication, and division.
    """

    dr = free_alg
    p = dr.names
    r = p.R
    v = p.v
    w = Vec('w')
    x = IndexedBase('x')
    i, j, k = p.R_dumms[:3]
    a = sympify('a')

    v1 = dr.sum((i, r), x[i] * v[i])
    w1 = dr.sum((i, r), x[i] * w[i])
    assert v1.n_terms == 1
    assert w1.n_terms == 1

    v1_neg = -v1
    assert v1_neg == dr.sum((i, r), -x[i] * v[i])

    v1_1 = v1 + 2
    assert v1_1.n_terms == 2
    assert v1_1 == 2 + v1

    w1_1 = w1 + a
    assert w1_1.n_terms == 2
    assert w1_1 == a + w1

    prod = v1_1 * w1_1
    # Test scalar multiplication here as well.
    expected = (2 * a + a * v1 + 2 * w1 + dr.sum(
        (i, r), (j, r), x[i] * x[j] * v[i] * w[j]))
    assert prod.simplify() == expected.simplify()

    # Test the commutator operation.
    comm_v1v1 = v1 | v1
    assert comm_v1v1.simplify() == 0
    # Here the tensor subtraction can also be tested.
    comm_v1w1 = v1 | w1
    expected = (dr.sum((i, r), (j, r), x[i] * x[j] * v[i] * w[j]) - dr.sum(
        (i, r), (j, r), x[j] * x[i] * w[i] * v[j]))
    assert comm_v1w1.simplify() == expected.simplify()

    alpha = symbols('alpha')
    assert alpha not in v1.free_vars
    tensor = v1 / alpha
    assert tensor.n_terms == 1
    terms = tensor.local_terms
    assert len(terms) == 1
    term = terms[0]
    assert term.sums == ((i, r), )
    assert term.amp == x[i] / alpha
    assert term.vecs == (v[i], )
    assert alpha in tensor.free_vars
Esempio n. 10
0
def test_eval_compression(three_ranges):
    """Test compression of optimized evaluations.

    Here we have two targets,

    .. math::

        U X V + U Y V

    and

    .. math::

        U X W + U Y W

    and it has been deliberately made such that the multiplication with U
    should be carried out first.  Then after the factorization of U, we have
    an intermediate U (X + Y), which is a sum of a single product
    intermediate.  This test succeeds when we have two intermediates only,
    without the unnecessary addition of a single product.

    """

    # Basic context setting-up.
    dr = three_ranges
    p = dr.names

    a = p.a  # Small range
    i, j, k = p.i, p.j, p.k  # Big range

    # The indexed bases.
    u = IndexedBase('U')
    v = IndexedBase('V')
    w = IndexedBase('W')
    x = IndexedBase('X')
    y = IndexedBase('Y')

    s = IndexedBase('S')
    t1 = IndexedBase('T1')
    t2 = IndexedBase('T2')

    # The target.
    s_def = dr.define_einst(s[i, j], u[i, k] * x[k, j] + u[i, k] * y[k, j])
    targets = [
        dr.define_einst(t1[i, j], s_def[i, a] * v[a, j]),
        dr.define_einst(t2[i, j], s_def[i, a] * w[a, j])
    ]

    # The actual optimization.
    res = optimize(targets, substs=dr.substs)
    assert len(res) == 4

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)
Esempio n. 11
0
def test_2d_block_1():
    print('============== test_2d_block_1 ================')

    # ... define the weak formulation
    x, y = symbols('x y')

    u = IndexedBase('u')
    v = IndexedBase('v')

    a = Lambda((x, y, v, u),
               Rot(u) * Rot(v) + Div(u) * Div(v) + 0.2 * Dot(u, v))
    # ...

    # ...  create a finite element space
    p1 = 2
    p2 = 2
    ne1 = 8
    ne2 = 8

    print('> Grid   :: [{ne1},{ne2}]'.format(ne1=ne1, ne2=ne2))
    print('> Degree :: [{p1},{p2}]'.format(p1=p1, p2=p2))

    grid_1 = linspace(0., 1., ne1 + 1)
    grid_2 = linspace(0., 1., ne2 + 1)

    V1 = SplineSpace(p1, grid=grid_1)
    V2 = SplineSpace(p2, grid=grid_2)

    W = TensorFemSpace(V1, V2)
    # ...

    # ... vector space
    V = VectorFemSpace(W, W)
    # ...

    # ...
    kernel_py = compile_kernel('kernel_block_1', a, V, backend='python')
    kernel_f90 = compile_kernel('kernel_block_1', a, V, backend='fortran')

    M_py = assemble_matrix(V, kernel_py)
    M_f90 = assemble_matrix(V, kernel_f90)
    # ...

    assert_identical_coo(M_py, M_f90)
Esempio n. 12
0
def test_CircularSymplecticEnsemble():
    CS = CSE('U', 3)
    j, k = (Dummy('j', integer=True,
                  positive=True), Dummy('k', integer=True, positive=True))
    t = IndexedBase('t')
    assert joint_eigen_distribution(CS).dummy_eq(
        Lambda((t[1], t[2], t[3]),
               Product(
                   Abs(exp(I * t[j]) - exp(I * t[k]))**4, (j, k + 1, 3),
                   (k, 1, 2)) / (720 * pi**3)))
Esempio n. 13
0
def main():
    dx, dt, x, y, z, t, c = symbols('dx dt x y z t c')
    U = IndexedBase('U')
    n = 2
    Uxx = Deriv(U, [x, y, z, t], 0, dx, n)[2]
    Utt = Deriv(U, [x, y, z, t], 3, dt, n)[2]
    eq = Eq(Utt, (c**2) * Uxx)
    code = print_myccode(U[x, y, z, t + 1]) + "=" + print_myccode(
        solve(eq, U[x, y, z, t + 1])[0])
    print(code)
Esempio n. 14
0
def test_indexed_by_grid(grid):
    """ Ensure that an Indexed object gets correctly indexed by the Grid indices. """

    idx = Idx(Symbol("i", integer=True))
    base = IndexedBase("test")
    i = base[idx]
    
    assert grid.indexed_by_grid(i) == base[grid.indices]

    return 
Esempio n. 15
0
def test_IndexedBase_sugar():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A1 = Indexed(a, i, j)
    A2 = IndexedBase(a)
    assert A1 == A2[i, j]
    assert A1 == A2[(i, j)]
    assert A1 == A2[[i, j]]
    assert A1 == A2[Tuple(i, j)]
    assert all(a.is_Integer for a in A2[1, 0].args[1:])
Esempio n. 16
0
def test_Indexed_shape_precedence():
    i, j = symbols("i j", integer=True)
    o, p = symbols("o p", integer=True)
    n, m = symbols("n m", integer=True)
    a = IndexedBase("a", shape=(o, p))
    assert a.shape == Tuple(o, p)
    assert Indexed(a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)]
    assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p)
    assert Indexed(a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), Tuple(None, None)]
    assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p)
Esempio n. 17
0
    def __getitem__(self, indices):
        """Index the given symbol.

        In drudge scripts, all symbols are by itself indexed bases.
        """
        base = IndexedBase(self._orig)
        if isinstance(indices, collections.Sequence):
            return DrsIndexed(self._drudge, base, *indices)
        else:
            return DrsIndexed(self._drudge, base, indices)
Esempio n. 18
0
def test_tensor_can_be_simplified_amp(free_alg):
    """Test the amplitude simplification for tensors.

    More than trivial tensor amplitude simplification is tested here.  Currently
    it mostly concentrates on the dispatching to SymPy and delta simplification.
    The master simplification is also tested.
    """

    dr = free_alg
    p = dr.names
    r = p.R
    s = p.S
    v = p.v
    i, j = p.R_dumms[:2]
    alpha = p.alpha

    x = IndexedBase('x')
    y = IndexedBase('y')
    theta = sympify('theta')

    tensor = (dr.sum(
        (i, r),
        sin(theta)**2 * x[i] * v[i]) + dr.sum(
            (i, r), (j, r),
            cos(theta)**2 * x[j] * KroneckerDelta(i, j) * v[i]) + dr.sum(
                (i, r), (alpha, s),
                KroneckerDelta(i, alpha) * y[i] * v[i]))
    assert tensor.n_terms == 3

    first = tensor.simplify_deltas().simplify_amps()
    # Now we should have one term killed.
    assert first.n_terms == 2

    # Merge again should really simplify.
    merged = first.reset_dumms().merge().simplify_amps()
    assert merged.n_terms == 1
    expected = dr.sum((i, r), x[i] * v[i])
    assert merged == expected

    # The master simplification should do it in one turn.
    simpl = tensor.simplify()
    assert simpl == expected
Esempio n. 19
0
def test_numbers_can_substitute_vectors(free_alg, full_balance):
    """Test substituting vectors with numbers."""

    dr = free_alg
    p = dr.names

    x = IndexedBase('x')
    y = IndexedBase('y')
    r = p.R
    i, j, k, l = symbols('i j k l')
    v = p.v
    w = Vec('w')

    orig = dr.sum((i, r), (j, r),
                  x[i, j] * v[i] * w[j] + y[i, j] * v[i] * v[j])

    res = orig.subst(v[k], 0, full_balance=full_balance).simplify()
    assert res == 0
    res = orig.subst(v[i], 1, full_balance=full_balance).simplify()
    assert res == dr.sum((i, r), (j, r), x[j, i] * w[i] + y[i, j])
Esempio n. 20
0
def test_amp_sums_can_be_simplified(free_alg):
    """Test the simplification facility for more complex amplitude sums."""
    dr = free_alg
    v = dr.names.v
    n, i, j = symbols('n i j')
    x = IndexedBase('x')
    r = Range('D', 0, n)

    tensor = dr.sum((i, r), (j, r), i**2 * x[j] * v[j])
    res = tensor.simplify_sums()
    assert res == dr.sum((j, r), (n**3 / 3 - n**2 / 2 + n / 6) * x[j] * v[j])
Esempio n. 21
0
def test_special_substitution_of_identity(free_alg):
    """Test the special substitution of integer one standing for identity.
    """

    dr = free_alg
    p = dr.names

    x = IndexedBase('x')
    t = IndexedBase('y')
    a = IndexedBase('a')
    i, j = p.i, p.j
    v = p.v
    w = Vec('w')

    orig = dr.sum((i, p.R), x[i] * v[i] + a[i])
    ident_def = dr.define(1, dr.einst(t[i] * w[i]))

    res = orig.subst_all([ident_def])
    assert dr.simplify(res - dr.einst(x[i] * v[i]) -
                       dr.sum((i, p.R), (j, p.R), a[i] * t[j] * w[j])) == 0
Esempio n. 22
0
def test_einstein_sum_for_both_particles_and_holes(parthole):
    """Test Einstein convention over both ranges."""
    dr = parthole
    p = dr.names
    x = IndexedBase('x')

    summand = x[p.p, p.q] * p.c_[p.p, p.q]
    res = dr.einst(summand).simplify()
    assert res.n_terms == 4
    ranges = (dr.part_range, dr.hole_range)
    assert res == dr.sum((p.p, ranges), (p.q, ranges), summand).simplify()
Esempio n. 23
0
def test_disconnected_outer_product_factorization(spark_ctx):
    """Test optimization of expressions with disconnected outer products.
    """

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d, e = dumms[:5]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    u = IndexedBase('U')
    x = IndexedBase('X')
    y = IndexedBase('Y')
    z = IndexedBase('Z')
    t = IndexedBase('T')

    # The target.
    target = dr.define_einst(
        t[a, b],
        u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c]
    )
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n ** 2
    assert leading_cost == 4 * n ** 2
Esempio n. 24
0
def test_ccsd_doubles(parthole_drudge):
    """Test discovery of effective T in CCSD doubles equation.

    The purpose of this test is similar to the CCSD energy test.  Just here the
    more complexity about the external indices necessitates using ``ALL``
    strategy for optimization.
    """

    dr = parthole_drudge
    p = dr.names

    a, b, c, d = p.V_dumms[:4]
    i, j = p.O_dumms[:2]
    u = dr.two_body
    t = IndexedBase('t')
    dr.set_dbbar_base(t, 2)

    tensor = dr.define_einst(
        IndexedBase('r')[a, b, i, j],
        t[c, d, i, j] * u[a, b, c, d] + u[a, b, c, d] * t[c, i] * t[d, j]
    )
    targets = [tensor]

    all_eval_seq = optimize(
        targets, substs={p.nv: p.no * 10},
        strategy=Strategy.ALL | Strategy.SUM | Strategy.COMMON
    )

    assert verify_eval_seq(all_eval_seq, targets)
    assert len(all_eval_seq) == 2
    all_cost = get_flop_cost(all_eval_seq)

    best_eval_seq = optimize(
        targets, substs={p.nv: p.no * 10},
        strategy=Strategy.BEST | Strategy.SUM | Strategy.COMMON
    )
    assert verify_eval_seq(best_eval_seq, targets)
    assert len(best_eval_seq) == 2
    best_cost = get_flop_cost(best_eval_seq)

    assert (best_cost - all_cost).xreplace({p.no: 1, p.nv: 10}) > 0
Esempio n. 25
0
def test_matrix_chain(three_ranges):
    """Test a basic matrix chain multiplication problem.

    Here a very simple matrix chain multiplication problem with three
    matrices are used to test the factorization facilities.  In this simple
    test, we will have three matrices :math:`x`, :math:`y`, and :math:`z`,
    which are of shapes :math:`m\\times n`, :math:`n \\times l`, and :math:`l
    \\times m` respectively. In the factorization, we are going to set
    :math:`n = 2 m` and :math:`l = 3 m`.

    """

    dr = three_ranges
    p = dr.names
    m, n, l = p.m, p.n, p.l

    # The indexed bases.
    x = IndexedBase('x', shape=(m, n))
    y = IndexedBase('y', shape=(n, l))
    z = IndexedBase('z', shape=(l, m))

    target_base = IndexedBase('t')
    target = dr.define_einst(target_base[p.a, p.b],
                             x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b])

    # Perform the factorization.
    targets = [target]
    stats = {}
    eval_seq = optimize(targets, substs=dr.substs, stats=stats)
    assert stats['Number of nodes'] < 2**3
    assert len(eval_seq) == 2

    # Check the correctness.
    assert verify_eval_seq(eval_seq, targets)

    # Check the cost.
    cost = get_flop_cost(eval_seq)
    leading_cost = get_flop_cost(eval_seq, leading=True)
    expected_cost = 2 * l * m * n + 2 * m**2 * n
    assert cost == expected_cost
    assert leading_cost == expected_cost
Esempio n. 26
0
def colourful_tensor(simple_drudge):
    """Form a colourful tensor definition capable of large code coverage.
    """

    dr = simple_drudge
    p = dr.names

    x = IndexedBase('x')
    u = IndexedBase('u')
    v = IndexedBase('v')
    dr.set_name(x, u, v)

    r, s = symbols('r s')
    dr.set_name(r, s)

    a, b, c = p.R_dumms[:3]

    tensor = dr.define(x[a, b], (((2 * r) / (3 * s)) * u[b, a] - dr.sum(
        (c, p.R), u[a, c] * v[c, b] * c**2 / 2)))

    return tensor
Esempio n. 27
0
def __substituteGamma(expr,
                      *args,
                      gamma=IndexedBase('\gamma', integer=True, shape=1)):
    """
    Substitute gamma[i] by args[i] in expression.

    :param Expr expr: expression
    :param args: entries of the gamma vector
    :param Expr gamma: optional symbol to use for gamma
    :return: expr with gamma[i] substituted by args[i]
    """
    return expr.subs({gamma[i]: args[i] for i in range(len(args))})
Esempio n. 28
0
def test_numbers_can_substitute_scalars(free_alg, full_balance):
    """Test substituting scalars with numbers."""

    dr = free_alg
    p = dr.names

    x = IndexedBase('x')
    y = IndexedBase('y')
    r = Range('D', 0, 2)
    i, j, k, l = symbols('i j k l')
    dr.set_dumms(r, [i, j, k, l])
    v = p.v

    orig = dr.sum((i, r), x[i]**2 * x[j] * y[k] * v[l])

    res = orig.subst(x[i], 0, full_balance=full_balance).simplify()
    assert res == 0
    res = orig.subst(x[j], 1, full_balance=full_balance).simplify()
    assert res == dr.sum(2 * y[k] * v[l])
    res = orig.subst(x[k], 2, full_balance=full_balance).simplify()
    assert res == dr.sum(16 * y[k] * v[l])
Esempio n. 29
0
def test_einstein_convention(free_alg):
    """Test Einstein summation convention utility.

    In this test, more complex aspects of the Einstein convention facility is
    tested.  Especially for the external indices and definition creation.
    """

    dr = free_alg
    p = dr.names

    o = IndexedBase('o')
    v = IndexedBase('v')
    w = IndexedBase('w')
    i, j, k = p.R_dumms[:3]

    raw_amp_1 = o[i, k] * v[k, j]
    raw_amp_2 = o[i, k] * w[k, j]
    raw_amp = raw_amp_1 + raw_amp_2

    for inp in [raw_amp, dr.sum(raw_amp)]:
        tensor, exts = dr.einst(inp, auto_exts=True)
        assert exts == {i, j}
        terms = tensor.local_terms
        for idx, term in enumerate(terms):
            assert len(term.sums) == 1
            assert term.sums[0] == (k, p.R)
            if idx == 0:
                assert term.amp == raw_amp_1
            elif idx == 1:
                assert term.amp == raw_amp_2
            assert len(term.vecs) == 0
            continue

    # Test the automatic definition formation.
    tensor_def = dr.define_einst('r', raw_amp, auto_exts=True)
    assert len(tensor_def.exts) == 2
    assert tensor_def.exts[0] == (i, p.R)
    assert tensor_def.exts[1] == (j, p.R)
    assert tensor_def.base == IndexedBase('r')
    assert tensor_def.rhs == dr.einst(raw_amp)
Esempio n. 30
0
def test_genmb_simplify_simple_expressions(genmb, par_level, full_simplify,
                                           simple_merge):
    """Test the basic Wick expansion facility on a single Fermion expression."""

    dr = genmb  # type: GenMBDrudge

    c_ = dr.op[AN]
    c_dag = dr.op[CR]
    r = dr.names.L
    a, b, c, d = dr.names.L_dumms[:4]

    t = IndexedBase('t')
    u = IndexedBase('u')

    inp = dr.sum((a, r), (b, r), (c, r), (d, r),
                 t[a, b] * u[c, d] * c_dag[a] * c_[b] * c_dag[c] * c_[d])

    dr.wick_parallel = par_level
    assert dr.wick_parallel == par_level
    dr.full_simplify = full_simplify
    assert dr.full_simplify == full_simplify
    dr.simple_merge = simple_merge
    assert dr.simple_merge == simple_merge

    res = inp.simplify()

    dr.wick_parallel = 0
    assert dr.wick_parallel == 0
    dr.full_simplify = True
    assert dr.full_simplify
    dr.simple_merge = False
    assert not dr.simple_merge

    assert res.n_terms == 2

    expected = dr.einst(t[a, c] * u[b, d] * c_dag[a] * c_dag[b] * c_[d] *
                        c_[c] +
                        t[a, c] * u[c, b] * c_dag[a] * c_[b]).simplify()

    assert res == expected
Esempio n. 31
0
def test_issue_17652():
    """Regression test issue #17652.

    IndexedBase.label should not upcast subclasses of Symbol
    """
    class SubClass(Symbol):
        pass

    x = SubClass('X')
    assert type(x) == SubClass
    base = IndexedBase(x)
    assert type(x) == SubClass
    assert type(base.label) == SubClass
Esempio n. 32
0
 def __new__(typ, name, **kwargs):
     obj = IndexedBase.__new__(typ, name)
     return obj