示例#1
0
def test_basic_handling_range_with_variable_bounds(spark_ctx):
    """Test the treatment of ranges with variable bounds.

    Here we use a simple example that slightly resembles the angular momentum
    handling in quantum physics.  Here we concentrate on basic operations of
    dummy resetting and mapping of scalar functions.
    """

    dr = Drudge(spark_ctx)

    j1, j2 = symbols('j1 j2')
    m1, m2 = symbols('m1, m2')
    j_max = symbols('j_max')
    j = Range('j', 0, j_max)
    m = Range('m')
    dr.set_dumms(j, [j1, j2])
    dr.set_dumms(m, [m1, m2])

    v = Vec('v')
    x = IndexedBase('x')
    tensor = dr.sum((j2, j), (m2, m[0, j2]), x[j2, m2] * v[j2, m2])

    reset = tensor.reset_dumms()
    assert reset.n_terms == 1
    term = reset.local_terms[0]
    assert len(term.sums) == 2
    if term.sums[0][1].label == 'j':
        j_sum, m_sum = term.sums
    else:
        m_sum, j_sum = term.sums
    assert j_sum[0] == j1
    assert j_sum[1].args == j.args
    assert m_sum[0] == m1
    assert m_sum[1].label == 'm'
    assert m_sum[1].lower == 0
    assert m_sum[1].upper == j1  # Important!
    assert term.amp == x[j1, m1]
    assert term.vecs == (v[j1, m1], )

    # Test that functions can be mapped to the bounds.
    repled = reset.map2scalars(lambda x: x.xreplace({j_max: 10}),
                               skip_ranges=False)
    assert repled.n_terms == 1
    term = repled.local_terms[0]
    checked = False
    for _, i in term.sums:
        if i.label == 'j':
            assert i.lower == 0
            assert i.upper == 10
            checked = True
        continue
    assert checked
示例#2
0
def test_tensor_has_basic_operations(free_alg):
    """Test some of the basic operations on tensors.

    Tested in this module:

        1. Addition.
        2. Merge.
        3. Free variable.
        4. Dummy reset.
        5. Equality comparison.
        6. Expansion
        7. Mapping to scalars.
        8. Base presence testing.
    """

    dr = free_alg
    p = dr.names
    i, j, k, l, m = p.R_dumms[:5]
    x = IndexedBase('x')
    r = p.R
    v = p.v
    tensor = (dr.sum((l, r), x[i, l] * v[l]) + dr.sum((m, r), x[j, m] * v[m]))

    # Without dummy resetting, they cannot be merged.
    assert tensor.n_terms == 2
    assert tensor.merge().n_terms == 2

    # Free variables are important for dummy resetting.
    free_vars = tensor.free_vars
    assert free_vars == {x.label, i, j}

    # Reset dummy.
    reset = tensor.reset_dumms()
    expected = (dr.sum((k, r), x[i, k] * v[k]) + dr.sum(
        (k, r), x[j, k] * v[k]))
    assert reset == expected
    assert reset.local_terms == expected.local_terms

    # Merge the terms.
    merged = reset.merge()
    assert merged.n_terms == 1
    term = merged.local_terms[0]
    assert term == Term(((k, r), ), x[i, k] + x[j, k], (v[k], ))

    # Slightly separate test for expansion.
    c, d = symbols('c d')
    tensor = dr.sum((i, r), x[i] * (c + d) * v[i])
    assert tensor.n_terms == 1
    expanded = tensor.expand()
    assert expanded.n_terms == 2
    # Make sure shallow expansion does not delve into the tree.
    shallowly_expanded = tensor.shallow_expand()
    assert shallowly_expanded.n_terms == 1

    # Make sure shallow expansion does the job on the top-level.
    y = IndexedBase('y')
    tensor = dr.sum((i, r), (x[i] * (c + d) + y[i]) * v[i])
    assert tensor.n_terms == 1
    expanded = tensor.expand()
    assert expanded.n_terms == 3
    shallowly_expanded = tensor.shallow_expand()
    assert shallowly_expanded.n_terms == 2

    # Here we also test concrete summation facility.
    expected = dr.sum((i, r), (j, [c, d]), x[i] * j * v[i])
    assert expected == dr.sum((i, r),
                              x[i] * c * v[i] + x[i] * d * v[i]).expand()

    # Test mapping to scalars.
    tensor = dr.sum((i, r), x[i] * v[i, j])
    y = IndexedBase('y')
    substs = {x: y, j: c}
    res = tensor.map2scalars(lambda x: x.xreplace(substs))
    assert res == dr.sum((i, r), y[i] * v[i, c])
    res = tensor.map2scalars(lambda x: x.xreplace(substs), skip_vecs=True)
    assert res == dr.sum((i, r), y[i] * v[i, j])
    assert res == tensor.map2amps(lambda x: x.xreplace(substs))

    # Test base presence.
    tensor = dr.einst(x[i] * v[i])
    assert tensor.has_base(x)
    assert tensor.has_base(v)
    assert not tensor.has_base(IndexedBase('y'))
    assert not tensor.has_base(Vec('w'))

    # Test Einstein summation over multiple ranges.
    a1, a2 = p.a1, p.a2
    summand = x[a1, a2] * v[a1, a2]
    res = dr.einst(summand).simplify()
    assert res.n_terms == 4
    ranges = (p.R, p.S)
    assert res == dr.sum((a1, ranges), (a2, ranges), summand).simplify()