Exemple #1
def test_basic_handling_range_with_variable_bounds(spark_ctx):
    """Test the treatment of ranges with variable bounds.

    Here we use a simple example that slightly resembles the angular momentum
    handling in quantum physics.  Here we concentrate on basic operations of
    dummy resetting and mapping of scalar functions.

    dr = Drudge(spark_ctx)

    j1, j2 = symbols('j1 j2')
    m1, m2 = symbols('m1, m2')
    j_max = symbols('j_max')
    j = Range('j', 0, j_max)
    m = Range('m')
    dr.set_dumms(j, [j1, j2])
    dr.set_dumms(m, [m1, m2])

    v = Vec('v')
    x = IndexedBase('x')
    tensor = dr.sum((j2, j), (m2, m[0, j2]), x[j2, m2] * v[j2, m2])

    reset = tensor.reset_dumms()
    assert reset.n_terms == 1
    term = reset.local_terms[0]
    assert len(term.sums) == 2
    if term.sums[0][1].label == 'j':
        j_sum, m_sum = term.sums
        m_sum, j_sum = term.sums
    assert j_sum[0] == j1
    assert j_sum[1].args == j.args
    assert m_sum[0] == m1
    assert m_sum[1].label == 'm'
    assert m_sum[1].lower == 0
    assert m_sum[1].upper == j1  # Important!
    assert term.amp == x[j1, m1]
    assert term.vecs == (v[j1, m1], )

    # Test that functions can be mapped to the bounds.
    repled = reset.map2scalars(lambda x: x.xreplace({j_max: 10}),
    assert repled.n_terms == 1
    term = repled.local_terms[0]
    checked = False
    for _, i in term.sums:
        if i.label == 'j':
            assert i.lower == 0
            assert i.upper == 10
            checked = True
    assert checked
Exemple #2
def test_simple_scalar_optimization(spark_ctx):
    """Test optimization of a simple scalar.

    There is not much optimization that can be done for simple scalars.  But we
    need to ensure that we get correct result here.

    dr = Drudge(spark_ctx)

    a, b, r = symbols('a b r')
    targets = [dr.define(r, a * b)]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Exemple #3
def test_optimization_handles_nonlinear_factors(spark_ctx):
    """Test optimization of with nonlinear factors.

    Here a factor is the square of an indexed quantity.

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    dumms = symbols('a b c d e f g h')
    dr.set_dumms(r, dumms)
    a, b, c, d = dumms[:4]

    u = symbols('u')
    s = IndexedBase('s')

    targets = [
            dr.sum((a, r), (b, r), (c, r), (d, r),
                   32 * s[a, c]**2 * s[b, d]**2 +
                   32 * s[a, c] * s[a, d] * s[b, c] * s[b, d]))
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Exemple #4
def test_optimization_handles_coeffcients(spark_ctx):
    """Test optimization of scalar intermediates scaled by coefficients.

    This test comes from PoST theory.  It tests the optimization of tensor
    evaluations with scalar intermediates scaled by a factor.

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    a, b = symbols('a b')
    dr.set_dumms(r, [a, b])

    r = IndexedBase('r')
    eps = IndexedBase('epsilon')
    t = IndexedBase('t')

    targets = [
        dr.define(r[a, b],
                  dr.sum(2 * eps[a] * t[a, b]) - 2 * eps[b] * t[a, b])
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Exemple #5
def test_optimization_handles_scalar_intermediates(spark_ctx):
    """Test optimization of scalar intermediates scaling other tensors.

    This is set as a special test primarily since it would entail the same
    collectible giving residues with different ranges.

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    dumms = symbols('a b c d e')
    dr.set_dumms(r, dumms)
    a, b, c = dumms[:3]

    u = IndexedBase('u')
    eps = IndexedBase('epsilon')
    t = IndexedBase('t')
    s = IndexedBase('s')

    targets = [
            u, (a, r), (b, r),
            dr.sum((c, r), 8 * s[a, b] * eps[c] * t[a]) -
            8 * s[a, b] * eps[a] * t[a])
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Exemple #6
def test_drs_symb_call(spark_ctx):
    """Test calling methods by drs symbols."""
    class TestCls:
        def meth(self):
            return 'meth'

        def prop(self):
            return 'prop'

    obj = TestCls()
    meth = DrsSymbol(None, 'meth')
    assert meth(obj) == 'meth'
    prop = DrsSymbol(None, 'prop')
    assert prop(obj) == 'prop'
    invalid = DrsSymbol(None, 'invalid')
    with pytest.raises(NameError):
    with pytest.raises(AttributeError) as exc:
    assert exc.value.args[0].find('prop') > 0

    # Test automatic raising to tensors.
    v = Vec('v')
    tensor_meth = 'local_terms'
    assert not hasattr(v, tensor_meth)  # Or the test just will not work.
    assert DrsSymbol(Drudge(spark_ctx), tensor_meth)(v) == [
        Term(sums=(), amp=Integer(1), vecs=(v, ))
Exemple #7
def test_drs_tensor_def_dispatch(spark_ctx):
    """Tests the dispatch to drudge for tensor definitions."""

    dr = Drudge(spark_ctx)
    names = dr.names

    i_symb = Symbol('i')
    x = IndexedBase('x')
    rhs = x[i_symb]


    a = DrsSymbol(dr, 'a')
    i = DrsSymbol(dr, 'i')
    for lhs in [a, a[i]]:
        expected = dr.define(lhs, rhs)

        def_ = lhs <= rhs
        assert def_ == expected
        assert not hasattr(names, 'a')
        assert not hasattr(names, '_a')

        def_ = lhs.def_as(rhs)
        assert def_ == expected
        assert names.a == expected
        if isinstance(lhs, DrsSymbol):
            assert names._a == Symbol('a')
            assert names._a == IndexedBase('a')
Exemple #8
def test_sums_can_be_expanded(spark_ctx):
    """Test the summation expansion facility.

    Here we have essentially a direct product of two ranges and expand it.  The
    usage here also includes some preliminary steps typical in the usage

    dr = Drudge(spark_ctx)

    comp = Range('P')
    r1, r2 = symbols('r1, r2')
    dr.set_dumms(comp, [r1, r2])

    a = IndexedBase('a')
    v = Vec('v')

    # A simple thing written in terms of composite indices.
    orig = dr.sum((r1, comp), (r2, comp), a[r1] * a[r2] * v[r1] * v[r2])

    # Rewrite the expression in terms of components.  Here, r1 should be
    # construed as a simple Wild.
    rewritten = orig.subst_all([(a[r1], a[x(r1), y(r1)]),
                                (v[r1], v[x(r1), y(r1)])])

    # Expand the summation over r.
    x_dim = Range('X')
    y_dim = Range('Y')
    x1, x2 = symbols('x1 x2')
    dr.set_dumms(x_dim, [x1, x2])
    y1, y2 = symbols('y1 y2')
    dr.set_dumms(y_dim, [y1, y2])

    res = rewritten.expand_sums(
        comp, lambda r: [(Symbol(str(r).replace('r', 'x')), x_dim, x(r)),
                         (Symbol(str(r).replace('r', 'y')), y_dim, y(r))])

    assert (res - dr.sum(
        (x1, x_dim), (y1, y_dim), (x2, x_dim), (y2, y_dim),
        a[x1, y1] * a[x2, y2] * v[x1, y1] * v[x2, y2])).simplify() == 0
def test_drudge_injects_names():
    """Test the name injection method of drudge."""

    # Dummy drudge.
    dr = Drudge(types.SimpleNamespace(defaultParallelism=1))

    string_name = 'string_name'


    assert string_name_ == string_name
    assert one_ == 1
Exemple #10
def simple_drudge(spark_ctx):
    """Form a simple drudge with some basic information.

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    dr.set_dumms(r, dumms)

    return dr
Exemple #11
def test_disconnected_outer_product_factorization(spark_ctx):
    """Test optimization of expressions with disconnected outer products.

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d, e = dumms[:5]
    dr.set_dumms(r, dumms)

    # The indexed bases.
    u = IndexedBase('U')
    x = IndexedBase('X')
    y = IndexedBase('Y')
    z = IndexedBase('Z')
    t = IndexedBase('T')

    # The target.
    target = dr.define_einst(
        t[a, b],
        u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c]
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n ** 2
    assert leading_cost == 4 * n ** 2
Exemple #12
def test_conjugation_optimization(spark_ctx):
    """Test optimization of expressions containing complex conjugate.

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    a, b, c, d = symbols('a b c d')
    dr.set_dumms(r, [a, b, c, d])

    p = IndexedBase('p')
    x = IndexedBase('x')
    y = IndexedBase('y')
    z = IndexedBase('z')

    targets = [
            p[a, b],
            x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b]))
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
def test_removal_of_shallow_interms(spark_ctx):
    """Test removal of shallow intermediates.

    Here we have two intermediates,

    .. math::

        U X V + U Y V


    .. math::

        U X W + U Y W

    and it has been deliberately made such that the multiplication with U should
    be carried out first.  Then after the collection of U, we have a shallow
    intermediate U (X + Y), which is a sum of a single product intermediate.
    This test succeeds when we have two intermediates only without the shallow


    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)
    r_small = Range('R', 0, Rational(1 / 2) * n)

    dumms = symbols('a b c d')
    a, b, c = dumms[:3]
    dumms_small = symbols('e f g h')
    e = dumms_small[0]
    dr.set_dumms(r, dumms)
    dr.set_dumms(r_small, dumms_small)

    # The indexed bases.
    u = IndexedBase('U')
    v = IndexedBase('V')
    w = IndexedBase('W')
    x = IndexedBase('X')
    y = IndexedBase('Y')

    s = IndexedBase('S')
    t = IndexedBase('T')

    # The target.
    s_def = dr.define_einst(s[a, b], u[a, c] * x[c, b] + u[a, c] * y[c, b])
    targets = [
        dr.define_einst(t[a, b], s_def[a, e] * v[e, b]),
        dr.define_einst(t[a, b], s_def[a, e] * w[e, b])

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 4

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)
Exemple #14
def test_matrix_factorization(spark_ctx):
    """Test a basic matrix multiplication factorization problem.

    In this test, there are four matrices involved, X, Y, U, and V.  And they
    are used in two test cases for different scenarios.


    # Basic context setting-up.

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)

    # The indexed bases.
    x = IndexedBase('X')
    y = IndexedBase('Y')
    u = IndexedBase('U')
    v = IndexedBase('V')
    t = IndexedBase('T')

    # Test case 1.
    # The final expression to optimize is mathematically
    # .. math::
    #     (2 X - Y) * (2 U + V)
    # Here, the expression is to be given in its extended form originally, and
    # we test if it can be factorized into something similar to what we have
    # above. Here we have the signs and coefficients to have better code
    # coverage for these cases.  This test case more concentrates on the
    # horizontal complexity in the input.

    # The target.
    target = dr.define_einst(
        t[a, b], 4 * x[a, c] * u[c, b] + 2 * x[a, c] * v[c, b] -
        2 * y[a, c] * u[c, b] - y[a, c] * v[c, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 2 * n**3 + 2 * n**2
    assert leading_cost == 2 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 2 * n**3 + 4 * n**2

    # Test case 2.
    # The final expression to optimize is mathematically
    # .. math::
    #     (X - 2 Y) * U * V
    # Different from the first test case, here we concentrate more on the
    # treatment of depth complexity in the input.  The sum intermediate needs to
    # be factored again.

    # The target.
    target = dr.define_einst(
        t[a, b], x[a, c] * u[c, d] * v[d, b] - 2 * y[a, c] * u[c, d] * v[d, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=True)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n**3 + n**2
    assert leading_cost == 4 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 4 * n**3 + 2 * n**2

    # Test disabling summation optimization.
    res = optimize(targets, strategy=Strategy.BEST)
    assert verify_eval_seq(res, targets, simplify=True)
    new_cost = get_flop_cost(res, ignore_consts=False)
    assert new_cost - cost != 0
Exemple #15
def free_alg(spark_ctx):
    """Initialize the environment for a free algebra."""

    dr = Drudge(spark_ctx)

    r = Range('R')
    dumms = sympify('i, j, k, l, m, n')
    dr.set_dumms(r, dumms)

    s = Range('S')
    s_dumms = symbols('alpha beta')
    dr.set_dumms(s, s_dumms)


    # For testing the Einstein over multiple ranges.
    a1, a2 = symbols('a1 a2')
    dr.add_resolver({a1: (r, s), a2: (r, s)})
    dr.set_name(a1, a2)

    v = Vec('v')

    m = IndexedBase('m')
    dr.set_symm(m, Perm([1, 0], NEG))

    h = IndexedBase('h')
    dr.set_symm(h, Perm([1, 0], NEG | CONJ))

    rho = IndexedBase('rho')
    dr.set_symm(rho, Perm([1, 0, 3, 2]), valence=4)

    dr.set_tensor_method('get_one', lambda x: 1)

    return dr
Exemple #16
def test_optimization_of_common_terms(spark_ctx):
    """Test optimization of common terms in summations.

    In this test, there are just two matrices involved, X, Y.  The target reads

    .. math::

        T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a]

    Ideally, it should be evaluated as,

    .. math::

        I[a, b] = X[a, b] + 2 Y[a, b]
        T[a, b] = I[a, b] - I[b, a]


    .. math::

        I[a, b] = X[a, b] - 2 Y[b, a]
        T[a, b] = I[a, b] - I[b, a]

    Here, in order to emulate real cases where common term reference is in
    interplay with factorization, the X and Y matrices are written as :math:`X =
    S U` and :math:`Y = S V`.


    # Basic context setting-up.

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)

    # The indexed bases.
    s = IndexedBase('S')
    u = IndexedBase('U')
    v = IndexedBase('V')

    x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b])
    y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b])
    t = dr.define_einst(
        IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a])

    targets = [t]
    eval_seq = optimize(targets)
    assert len(eval_seq) == 3

    # Check correctness.
    verify_eval_seq(eval_seq, targets)

    # Check cost.
    cost = get_flop_cost(eval_seq)
    assert cost == 2 * n**3 + 2 * n**2
    cost = get_flop_cost(eval_seq, ignore_consts=False)
    assert cost == 2 * n**3 + 3 * n**2

    # Check the result when the common symmetrization optimization is disabled.
    eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON)
    verify_eval_seq(eval_seq, targets)
    new_cost = get_flop_cost(eval_seq, ignore_consts=True)
    assert new_cost - cost != 0
Exemple #17
def test_matrix_chain(spark_ctx):
    """Test a basic matrix chain multiplication problem.

    Matrix chain multiplication problem is the classical problem that motivated
    the algorithm for single-term optimization in this package.  So here a very
    simple matrix chain multiplication problem with three matrices are used to
    test the factorization facilities.  In this simple test, we will have three
    matrices :math:`x`, :math:`y`, and :math:`z`, which are of shapes
    :math:`m\\times n`, :math:`n \\times l`, and :math:`l \\times m`
    respectively. In the factorization, we are going to set :math:`n = 2 m` and
    :math:`l = 3 m`.

    If we multiply the first two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 l

    Or if we multiply the last two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 n

    In addition to the classical matrix chain product, also tested is the
    trace of their cyclic product.

    .. math::

        t = \\sum_i \\sum_j \\sum_k x_{i, j} y_{j, k} z_{k, i}

    If we first take the product of :math:`Y Z`, the cost will be (two times)
    :math:`n m l + n m`. For first multiplying :math:`X Y` and :math:`Z X`,
    the costs will be (two times) :math:`n m l + m l` and :math:`n m l + n l`


    # Basic context setting-up.

    dr = Drudge(spark_ctx)

    # The sizes.
    m, n, l = symbols('m n l')

    # The ranges.
    m_range = Range('M', 0, m)
    n_range = Range('N', 0, n)
    l_range = Range('L', 0, l)

    dr.set_dumms(m_range, symbols('a b c'))
    dr.set_dumms(n_range, symbols('i j k'))
    dr.set_dumms(l_range, symbols('p q r'))

    # The indexed bases.
    x = IndexedBase('x', shape=(m, n))
    y = IndexedBase('y', shape=(n, l))
    z = IndexedBase('z', shape=(l, m))

    # The costs substitution.
    substs = {n: m * 2, l: m * 3}

    # Actual tests.

    p = dr.names

    target_base = IndexedBase('t')
    target = dr.define_einst(target_base[p.a, p.b],
                             x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b])

    # Perform the factorization.
    targets = [target]
    eval_seq = optimize(targets, substs=substs)
    assert len(eval_seq) == 2

    # Check the correctness.
    assert verify_eval_seq(eval_seq, targets)

    # Check the cost.
    cost = get_flop_cost(eval_seq)
    leading_cost = get_flop_cost(eval_seq, leading=True)
    expected_cost = 2 * l * m * n + 2 * m**2 * n
    assert cost == expected_cost
    assert leading_cost == expected_cost
def three_ranges(spark_ctx):
    """Fixture with three ranges.

    This drudge has three ranges, named M, N, L with sizes m, n, and l,
    respectively.  It also has a substitution dictionary setting n = 2m and l
    = 3m.


    dr = Drudge(spark_ctx)

    # The sizes.
    m, n, l = symbols('m n l')

    # The ranges.
    m_range = Range('M', 0, m)
    n_range = Range('N', 0, n)
    l_range = Range('L', 0, l)

    dr.set_dumms(m_range, symbols('a b c d e f g'))
    dr.set_dumms(n_range, symbols('i j k l m n'))
    dr.set_dumms(l_range, symbols('p q r'))
    dr.set_name(m, n, l)

    dr.substs = {n: m * 2, l: m * 3}

    return dr