Example #1
0
def free_alg(spark_ctx):
    """Initialize the environment for a free algebra."""

    dr = Drudge(spark_ctx)

    r = Range('R')
    dumms = sympify('i, j, k, l, m, n')
    dr.set_dumms(r, dumms)

    s = Range('S')
    s_dumms = symbols('alpha beta')
    dr.set_dumms(s, s_dumms)

    dr.add_resolver_for_dumms()

    # For testing the Einstein over multiple ranges.
    a1, a2 = symbols('a1 a2')
    dr.add_resolver({a1: (r, s), a2: (r, s)})
    dr.set_name(a1, a2)

    v = Vec('v')
    dr.set_name(v)

    m = IndexedBase('m')
    dr.set_symm(m, Perm([1, 0], NEG))

    h = IndexedBase('h')
    dr.set_symm(h, Perm([1, 0], NEG | CONJ))

    rho = IndexedBase('rho')
    dr.set_symm(rho, Perm([1, 0, 3, 2]), valence=4)

    dr.set_tensor_method('get_one', lambda x: 1)

    return dr
Example #2
0
def test_varsh_872_5(nuclear: NuclearBogoliubovDrudge):
    """Test simplification based on the rule in Varshalovich 8.7.2 Eq (5).
    """
    dr = nuclear
    a, alpha, b, beta, b_prm, beta_prm = symbols(
        'a alpha b beta bprm betaprm', integer=True
    )
    c, gamma = symbols('c gamma', integer=True)
    sums = [
        (alpha, Range('m', -a, a + 1)),
        (gamma, Range('M', -c, c + 1))
    ]
    amp = CG(a, alpha, b, beta, c, gamma) * CG(
        a, alpha, b_prm, beta_prm, c, gamma
    )

    expected = (
            KroneckerDelta(b, b_prm) * KroneckerDelta(beta, beta_prm)
            * (2 * c + 1) / (2 * b + 1)
    )
    for sums_i in [sums, reversed(sums)]:
        tensor = dr.sum(*sums_i, amp)
        res = tensor.deep_simplify().merge()
        assert res.n_terms == 1
        term = res.local_terms[0]
        assert len(term.sums) == 0
        assert len(term.vecs) == 0
        assert (term.amp - expected).simplify() == 0
Example #3
0
def test_bounds_operations_for_ranges():
    """Properties about bounds for ranges.
    """

    l, u = sympify('l, u')
    r = Range('R')
    assert not r.bounded

    magic = 10
    d = {r: magic}
    assert d[r] == magic

    assert Range('S') not in d

    # Add bounds should not change the identity of the range.
    r_lu = r[l, u]
    assert r_lu.bounded
    assert r_lu.lower == l
    assert r_lu.upper == u
    assert d[r_lu] == 10

    for r_02 in [r_lu[0, 2], r_lu.map(lambda x: x.xreplace({l: 0, u: 2}))]:
        assert d[r_02] == 10
        assert r_02.bounded
        assert r_02.lower == 0
        assert r_02.upper == 2
        assert r_02.size == 2
def three_ranges(spark_ctx):
    """Fixture with three ranges.

    This drudge has three ranges, named M, N, L with sizes m, n, and l,
    respectively.  It also has a substitution dictionary setting n = 2m and l
    = 3m.

    """

    dr = Drudge(spark_ctx)

    # The sizes.
    m, n, l = symbols('m n l')

    # The ranges.
    m_range = Range('M', 0, m)
    n_range = Range('N', 0, n)
    l_range = Range('L', 0, l)

    dr.set_dumms(m_range, symbols('a b c d e f g'))
    dr.set_dumms(n_range, symbols('i j k l m n'))
    dr.set_dumms(l_range, symbols('p q r'))
    dr.add_resolver_for_dumms()
    dr.set_name(m, n, l)

    dr.substs = {n: m * 2, l: m * 3}

    return dr
Example #5
0
def test_basic_handling_range_with_variable_bounds(spark_ctx):
    """Test the treatment of ranges with variable bounds.

    Here we use a simple example that slightly resembles the angular momentum
    handling in quantum physics.  Here we concentrate on basic operations of
    dummy resetting and mapping of scalar functions.
    """

    dr = Drudge(spark_ctx)

    j1, j2 = symbols('j1 j2')
    m1, m2 = symbols('m1, m2')
    j_max = symbols('j_max')
    j = Range('j', 0, j_max)
    m = Range('m')
    dr.set_dumms(j, [j1, j2])
    dr.set_dumms(m, [m1, m2])

    v = Vec('v')
    x = IndexedBase('x')
    tensor = dr.sum((j2, j), (m2, m[0, j2]), x[j2, m2] * v[j2, m2])

    reset = tensor.reset_dumms()
    assert reset.n_terms == 1
    term = reset.local_terms[0]
    assert len(term.sums) == 2
    if term.sums[0][1].label == 'j':
        j_sum, m_sum = term.sums
    else:
        m_sum, j_sum = term.sums
    assert j_sum[0] == j1
    assert j_sum[1].args == j.args
    assert m_sum[0] == m1
    assert m_sum[1].label == 'm'
    assert m_sum[1].lower == 0
    assert m_sum[1].upper == j1  # Important!
    assert term.amp == x[j1, m1]
    assert term.vecs == (v[j1, m1], )

    # Test that functions can be mapped to the bounds.
    repled = reset.map2scalars(lambda x: x.xreplace({j_max: 10}),
                               skip_ranges=False)
    assert repled.n_terms == 1
    term = repled.local_terms[0]
    checked = False
    for _, i in term.sums:
        if i.label == 'j':
            assert i.lower == 0
            assert i.upper == 10
            checked = True
        continue
    assert checked
Example #6
0
def test_simple_terms_can_be_canonicalized():
    """Test the canonicalization of very simple terms.

    In this test, all the terms has very simple appearance.  So rather than
    testing the canonicalization really canonicalizes all the equivalent
    forms, here we check if the canonicalized form is the most intuitive form
    that we expect.
    """

    l = Range('L')
    x = IndexedBase('x')
    i, j = sympify('i, j')

    # A term without the vector part, canonicalization without symmetry.
    term = sum_term([(j, l), (i, l)], x[i, j])[0]
    res = term.canon()
    expected = sum_term([(i, l), (j, l)], x[i, j])[0]
    assert res == expected

    # A term without the vector part, canonicalization with symmetry.
    m = Range('M')
    term = sum_term([(j, m), (i, l)], x[j, i])[0]
    for neg, conj in itertools.product([IDENT, NEG], [IDENT, CONJ]):
        acc = neg | conj
        group = Group([Perm([1, 0], acc)])
        res = term.canon(symms={x: group})
        expected_amp = x[i, j]
        if neg == NEG:
            expected_amp *= -1
        if conj == CONJ:
            expected_amp = conjugate(expected_amp)
        expected = sum_term([(i, l), (j, m)], expected_amp)[0]
        assert res == expected
        continue

    # In the absence of symmetry, the two indices should not be permuted.
    res = term.canon()
    expected = sum_term([(i, l), (j, m)], x[j, i])[0]
    assert res == expected

    # Now we add vectors to the terms.
    v = Vec('v')
    term = sum_term([(i, l), (j, l)], v[i] * v[j])[0]

    # Without anything added, it should already be in the canonical form.
    assert term.canon() == term

    # When we flip the colour of the vectors, we should get something different.
    res = term.canon(vec_colour=lambda idx, vec, term: -idx)
    expected = sum_term([(j, l), (i, l)], v[i] * v[j])[0]
    assert res == expected
Example #7
0
def test_diag_tight_binding_hamiltonian(spark_ctx):
    """Test automatic diagonalization of the tight-binding Hamiltonian.

    The primary target of this test is the simplification of amplitude
    summations.
    """

    n = Symbol('N', integer=True)
    dr = GenMBDrudge(spark_ctx,
                     orb=((Range('L', 0,
                                 n), symbols('x y z x1 x2', integer=True)), ))

    # The reciprocal space range and dummies.
    k, q = symbols('k q', integer=True)
    dr.set_dumms(Range('R', 0, n), [k, q])

    p = dr.names
    h = Symbol('h')  # Hopping neighbours.
    delta = Symbol('Delta')
    c_dag = p.c_dag
    c_ = p.c_
    a = p.L_dumms[0]

    # Hamiltonian in the real-space
    real_ham = dr.sum((a, p.L), (h, 1, -1),
                      delta * c_dag[a + h] * c_[a]).simplify()
    assert real_ham.n_terms == 2

    # Unitary fourier transform.
    cr_def = (c_dag[a],
              dr.sum((k, p.R),
                     (1 / sqrt(n)) * exp(-I * 2 * pi * k * a / n) * c_dag[k]))
    an_def = (c_[a],
              dr.sum((k, p.R),
                     (1 / sqrt(n)) * exp(I * 2 * pi * k * a / n) * c_[k]))
    rec_ham = real_ham.subst_all([cr_def, an_def])
    res = rec_ham.simplify()

    assert res.n_terms == 1
    res_term = res.local_terms[0]
    assert len(res_term.sums) == 1
    dumm = res_term.sums[0][0]
    assert res_term.sums[0][1] == p.R
    # Here we mostly check the Hamiltonian has been diagonalized.
    assert len(res_term.vecs) == 2
    for i in res_term.vecs:
        assert len(i.indices) == 2
        assert i.indices[1] == dumm
Example #8
0
def test_su2_on_1d_heisenberg_model(spark_ctx):
    """Test the SU2 drudge on 1D Heisenberg model with abstract lattice indices.

    This test also acts as the test for the default resolver.
    """

    dr = SU2LatticeDrudge(spark_ctx)
    l = Range('L')
    dr.set_dumms(l, symbols('i j k l m n'))
    dr.add_default_resolver(l)

    p = dr.names
    j_z = p.J_
    j_p = p.J_p
    j_m = p.J_m
    i = p.i
    half = Rational(1, 2)

    coupling = Symbol('J')
    ham = dr.sum(
        (i, l),
        j_z[i] * j_z[i + 1] +
        j_p[i] * j_m[i + 1] / 2 + j_m[i] * j_p[i + 1] / 2
    ) * coupling

    s_sq = dr.sum(
        (i, l),
        j_z[i] * j_z[i] + half * j_p[i] * j_m[i] + half * j_m[i] * j_p[i]
    )

    comm = (ham | s_sq).simplify()
    assert comm == 0
Example #9
0
def test_drs_tensor_def_dispatch(spark_ctx):
    """Tests the dispatch to drudge for tensor definitions."""

    dr = Drudge(spark_ctx)
    names = dr.names

    i_symb = Symbol('i')
    x = IndexedBase('x')
    rhs = x[i_symb]

    dr.add_default_resolver(Range('R'))

    a = DrsSymbol(dr, 'a')
    i = DrsSymbol(dr, 'i')
    for lhs in [a, a[i]]:
        expected = dr.define(lhs, rhs)

        def_ = lhs <= rhs
        assert def_ == expected
        assert not hasattr(names, 'a')
        assert not hasattr(names, '_a')

        def_ = lhs.def_as(rhs)
        assert def_ == expected
        assert names.a == expected
        if isinstance(lhs, DrsSymbol):
            assert names._a == Symbol('a')
        else:
            assert names._a == IndexedBase('a')
        dr.unset_name(def_)
Example #10
0
def test_varsh_872_4(nuclear: NuclearBogoliubovDrudge):
    """Test simplification based on Varshalovich 8.7.2 Eq (4)."""
    dr = nuclear
    c, gamma, c_prm, gamma_prm = symbols('c gamma cprm gammaprm', integer=True)
    a, alpha, b, beta = symbols('a alpha b beta', integer=True)

    m_range = Range('m')
    sums = [
        (alpha, m_range[-a, a + 1]), (beta, m_range[-b, b + 1])
    ]
    amp = CG(a, alpha, b, beta, c, gamma) * CG(
        a, alpha, b, beta, c_prm, gamma_prm
    )

    # Make sure that the pattern matching works in any way the summations are
    # written.
    for sums_i in [sums, reversed(sums)]:
        tensor = dr.sum(*sums_i, amp)
        res = tensor.simplify_am()
        assert res.n_terms == 1
        term = res.local_terms[0]
        assert len(term.sums) == 0
        assert term.amp == KroneckerDelta(
            c, c_prm
        ) * KroneckerDelta(gamma, gamma_prm)
Example #11
0
def test_optimization_handles_nonlinear_factors(spark_ctx):
    """Test optimization of with nonlinear factors.

    Here a factor is the square of an indexed quantity.
    """

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    dumms = symbols('a b c d e f g h')
    dr.set_dumms(r, dumms)
    a, b, c, d = dumms[:4]
    dr.add_default_resolver(r)

    u = symbols('u')
    s = IndexedBase('s')

    targets = [
        dr.define(
            u,
            dr.sum((a, r), (b, r), (c, r), (d, r),
                   32 * s[a, c]**2 * s[b, d]**2 +
                   32 * s[a, c] * s[a, d] * s[b, c] * s[b, d]))
    ]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Example #12
0
def test_varsh_911_8(nuclear: NuclearBogoliubovDrudge):
    """Test simplification based on the rule in Varshalovich 9.1.1 Eq (8).
    """
    dr = nuclear
    j, m, j12, m12, j2, m2, j1, m1, j_prm, m_prm, j23, m23, j3, m3 = symbols(
        'j m j12 m12 j2 m2 j1 m1 jprm mprm j23 m23 j3 m3', integer=True
    )
    m_range = Range('m')
    sums = [(m_i, m_range[-j_i, j_i + 1]) for m_i, j_i in [
        (m1, j1), (m2, j2), (m3, j3), (m12, j12), (m23, j23)
    ]]
    amp = CG(j12, m12, j3, m3, j, m) * CG(j1, m1, j2, m2, j12, m12) * CG(
        j1, m1, j23, m23, j_prm, m_prm
    ) * CG(j2, m2, j3, m3, j23, m23)

    expected = (
            KroneckerDelta(j, j_prm) * KroneckerDelta(m, m_prm)
            * (-1) ** (j1 + j2 + j3 + j)
            * sqrt(2 * j12 + 1) * sqrt(2 * j23 + 1)
            * Wigner6j(j1, j2, j12, j3, j, j23)
    )

    # For performance reason, just test a random arrangement of the summations.
    random.shuffle(sums)
    tensor = dr.sum(*sums, amp)
    res = tensor.deep_simplify().merge()
    assert res.n_terms == 1
    term = res.local_terms[0]
    assert len(term.sums) == 0
    assert len(term.vecs) == 0
    assert (term.amp - expected).simplify() == 0
Example #13
0
def test_optimization_handles_coeffcients(spark_ctx):
    """Test optimization of scalar intermediates scaled by coefficients.

    This test comes from PoST theory.  It tests the optimization of tensor
    evaluations with scalar intermediates scaled by a factor.
    """

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    a, b = symbols('a b')
    dr.set_dumms(r, [a, b])
    dr.add_default_resolver(r)

    r = IndexedBase('r')
    eps = IndexedBase('epsilon')
    t = IndexedBase('t')

    targets = [
        dr.define(r[a, b],
                  dr.sum(2 * eps[a] * t[a, b]) - 2 * eps[b] * t[a, b])
    ]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Example #14
0
def test_optimization_handles_scalar_intermediates(spark_ctx):
    """Test optimization of scalar intermediates scaling other tensors.

    This is set as a special test primarily since it would entail the same
    collectible giving residues with different ranges.
    """

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    dumms = symbols('a b c d e')
    dr.set_dumms(r, dumms)
    a, b, c = dumms[:3]
    dr.add_default_resolver(r)

    u = IndexedBase('u')
    eps = IndexedBase('epsilon')
    t = IndexedBase('t')
    s = IndexedBase('s')

    targets = [
        dr.define(
            u, (a, r), (b, r),
            dr.sum((c, r), 8 * s[a, b] * eps[c] * t[a]) -
            8 * s[a, b] * eps[a] * t[a])
    ]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Example #15
0
def test_range_has_basic_operations():
    """Test the basic operations on ranges."""

    a_symb = sympify('a')
    b_symb = sympify('b')

    bound0 = Range('B', 'a', 'b')
    bound1 = Range('B', a_symb, b_symb)
    symb0 = Range('S')
    symb1 = Range('S')

    assert bound0 == bound1
    assert hash(bound0) == hash(bound1)
    assert symb0 == symb1
    assert hash(symb0) == hash(symb1)

    assert bound0 != symb0
    assert hash(bound0) != hash(symb0)

    assert bound0.label == 'B'
    assert bound0.lower == a_symb
    assert bound0.upper == b_symb
    assert bound0.args == (bound1.label, bound1.lower, bound1.upper)
    assert bound0.size == b_symb - a_symb
    assert bound0.replace_label('B1') == Range('B1', a_symb, b_symb)

    assert symb0.label == 'S'
    assert symb0.lower is None
    assert symb0.upper is None
    assert len(symb0.args) == 1
    assert symb0.args[0] == symb1.label
Example #16
0
def test_trivial_sums_can_be_simplified(free_alg):
    """Test the simplification facility for trivial sums."""
    dr = free_alg
    r = Range('D', 0, 2)

    a, b = symbols('a b')
    tensor = dr.sum(1) + dr.sum((a, r), 1) + dr.sum((a, r), (b, r), 1)
    res = tensor.simplify()
    assert res == dr.sum(7)
Example #17
0
def test_amp_sums_can_be_simplified(free_alg):
    """Test the simplification facility for more complex amplitude sums."""
    dr = free_alg
    v = dr.names.v
    n, i, j = symbols('n i j')
    x = IndexedBase('x')
    r = Range('D', 0, n)

    tensor = dr.sum((i, r), (j, r), i**2 * x[j] * v[j])
    res = tensor.simplify_sums()
    assert res == dr.sum((j, r), (n**3 / 3 - n**2 / 2 + n / 6) * x[j] * v[j])
Example #18
0
def test_sums_can_be_expanded(spark_ctx):
    """Test the summation expansion facility.

    Here we have essentially a direct product of two ranges and expand it.  The
    usage here also includes some preliminary steps typical in the usage
    paradigm.
    """

    dr = Drudge(spark_ctx)

    comp = Range('P')
    r1, r2 = symbols('r1, r2')
    dr.set_dumms(comp, [r1, r2])

    a = IndexedBase('a')
    v = Vec('v')

    # A simple thing written in terms of composite indices.
    orig = dr.sum((r1, comp), (r2, comp), a[r1] * a[r2] * v[r1] * v[r2])

    # Rewrite the expression in terms of components.  Here, r1 should be
    # construed as a simple Wild.
    rewritten = orig.subst_all([(a[r1], a[x(r1), y(r1)]),
                                (v[r1], v[x(r1), y(r1)])])

    # Expand the summation over r.
    x_dim = Range('X')
    y_dim = Range('Y')
    x1, x2 = symbols('x1 x2')
    dr.set_dumms(x_dim, [x1, x2])
    y1, y2 = symbols('y1 y2')
    dr.set_dumms(y_dim, [y1, y2])

    res = rewritten.expand_sums(
        comp, lambda r: [(Symbol(str(r).replace('r', 'x')), x_dim, x(r)),
                         (Symbol(str(r).replace('r', 'y')), y_dim, y(r))])

    assert (res - dr.sum(
        (x1, x_dim), (y1, y_dim), (x2, x_dim), (y2, y_dim),
        a[x1, y1] * a[x2, y2] * v[x1, y1] * v[x2, y2])).simplify() == 0
Example #19
0
def simple_drudge(spark_ctx):
    """Form a simple drudge with some basic information.
    """

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    return dr
Example #20
0
def test_wigner3j_sum_to_wigner6j(nuclear: NuclearBogoliubovDrudge):
    """Test simplification of sum of product of four 3j's to a 6j.

    This test tries to simplify the original LHS of the equation from the
    Wolfram website.
    """

    dr = nuclear
    j1, j2, j3, jprm3, j4, j5, j6 = symbols(
        'j1 j2 j3 jprm3 j4 j5 j6', integer=True
    )
    m1, m2, m3, mprm3, m4, m5, m6 = symbols(
        'm1 m2 m3 mprm3 m4 m5 m6', integer=True
    )

    m_range = Range('m')
    sums = [(m_i, m_range[-j_i, j_i + 1]) for m_i, j_i in [
        (m1, j1), (m2, j2), (m4, j4), (m5, j5), (m6, j6)
    ]]

    phase = (-1) ** (
            j1 + j2 + j4 + j5 + j6 - m1 - m2 - m4 - m5 - m6
    )
    amp = (
            Wigner3j(j2, m2, j3, -m3, j1, m1)
            * Wigner3j(j1, -m1, j5, m5, j6, m6)
            * Wigner3j(j5, -m5, jprm3, mprm3, j4, m4)
            * Wigner3j(j4, -m4, j2, -m2, j6, -m6)
    )

    expected = (
            ((-1) ** (j3 - m3) / (2 * j3 + 1))
            * KroneckerDelta(j3, jprm3) * KroneckerDelta(m3, mprm3)
            * Wigner6j(j1, j2, j3, j4, j5, j6)
    ).expand().simplify()

    # For performance reason, just test a random arrangement of the summations.
    random.shuffle(sums)
    tensor = dr.sum(*sums, phase * amp)
    res = tensor.deep_simplify().merge()
    assert res.n_terms == 1
    term = res.local_terms[0]
    assert len(term.sums) == 0
    assert len(term.vecs) == 0
    assert (term.amp - expected).simplify() == 0
Example #21
0
def test_drudge_has_names(free_alg):
    """Test the name archive for drudge objects.

    Here selected names are tested to makes sure all the code are covered.
    """

    p = free_alg.names

    # Range and dummy related.
    assert p.R == Range('R')
    assert len(p.R_dumms) == 6
    assert p.R_dumms[0] == p.i
    assert p.R_dumms[-1] == p.n

    # Vector bases.
    assert p.v == Vec('v')

    # Scalar bases.
    assert p.m == IndexedBase('m')
Example #22
0
def test_disconnected_outer_product_factorization(spark_ctx):
    """Test optimization of expressions with disconnected outer products.
    """

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d, e = dumms[:5]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    u = IndexedBase('U')
    x = IndexedBase('X')
    y = IndexedBase('Y')
    z = IndexedBase('Z')
    t = IndexedBase('T')

    # The target.
    target = dr.define_einst(
        t[a, b],
        u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c]
    )
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n ** 2
    assert leading_cost == 4 * n ** 2
Example #23
0
def test_numbers_can_substitute_scalars(free_alg, full_balance):
    """Test substituting scalars with numbers."""

    dr = free_alg
    p = dr.names

    x = IndexedBase('x')
    y = IndexedBase('y')
    r = Range('D', 0, 2)
    i, j, k, l = symbols('i j k l')
    dr.set_dumms(r, [i, j, k, l])
    v = p.v

    orig = dr.sum((i, r), x[i]**2 * x[j] * y[k] * v[l])

    res = orig.subst(x[i], 0, full_balance=full_balance).simplify()
    assert res == 0
    res = orig.subst(x[j], 1, full_balance=full_balance).simplify()
    assert res == dr.sum(2 * y[k] * v[l])
    res = orig.subst(x[k], 2, full_balance=full_balance).simplify()
    assert res == dr.sum(16 * y[k] * v[l])
Example #24
0
def test_symb_resolvers():
    """Test the functionality of symbol resolvers in strict mode."""
    r = Range('R')
    a, b = symbols('a b')

    strict, normal = [
        functools.partial(try_resolve_range,
                          sums_dict={},
                          resolvers=[SymbResolver([(r, [a])], strict=i)])
        for i in [True, False]
    ]

    # Strict mode.
    assert strict(a) == r
    assert strict(b) is None
    assert strict(a + 1) is None

    # Normal mode.
    assert normal(a) == r
    assert normal(b) is None
    assert normal(a + 1) == r
Example #25
0
def test_conjugation_optimization(spark_ctx):
    """Test optimization of expressions containing complex conjugate.
    """

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    a, b, c, d = symbols('a b c d')
    dr.set_dumms(r, [a, b, c, d])
    dr.add_default_resolver(r)

    p = IndexedBase('p')
    x = IndexedBase('x')
    y = IndexedBase('y')
    z = IndexedBase('z')

    targets = [
        dr.define_einst(
            p[a, b],
            x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b]))
    ]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Example #26
0
def mprod():
    """A fixture for a term looking like a matrix product.

    This can be used to test some basic operations on terms.
    """

    i, j, k = sympify('i, j, k')
    n = sympify('n')
    l = Range('L', 1, n)
    a = IndexedBase('a', shape=(n, n))
    b = IndexedBase('b', shape=(n, n))
    v = Vec('v')

    prod = sum_term([(i, l), (j, l), (k, l)], a[i, j] * b[j, k] * v[i] * v[k])

    assert len(prod) == 1
    return prod[0], types.SimpleNamespace(i=i,
                                          j=j,
                                          k=k,
                                          l=l,
                                          a=a,
                                          b=b,
                                          v=v,
                                          n=n)
Example #27
0
def test_matrix_chain(spark_ctx):
    """Test a basic matrix chain multiplication problem.

    Matrix chain multiplication problem is the classical problem that motivated
    the algorithm for single-term optimization in this package.  So here a very
    simple matrix chain multiplication problem with three matrices are used to
    test the factorization facilities.  In this simple test, we will have three
    matrices :math:`x`, :math:`y`, and :math:`z`, which are of shapes
    :math:`m\\times n`, :math:`n \\times l`, and :math:`l \\times m`
    respectively. In the factorization, we are going to set :math:`n = 2 m` and
    :math:`l = 3 m`.

    If we multiply the first two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 l

    Or if we multiply the last two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 n

    In addition to the classical matrix chain product, also tested is the
    trace of their cyclic product.

    .. math::

        t = \\sum_i \\sum_j \\sum_k x_{i, j} y_{j, k} z_{k, i}

    If we first take the product of :math:`Y Z`, the cost will be (two times)
    :math:`n m l + n m`. For first multiplying :math:`X Y` and :math:`Z X`,
    the costs will be (two times) :math:`n m l + m l` and :math:`n m l + n l`
    respectively.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    # The sizes.
    m, n, l = symbols('m n l')

    # The ranges.
    m_range = Range('M', 0, m)
    n_range = Range('N', 0, n)
    l_range = Range('L', 0, l)

    dr.set_dumms(m_range, symbols('a b c'))
    dr.set_dumms(n_range, symbols('i j k'))
    dr.set_dumms(l_range, symbols('p q r'))
    dr.add_resolver_for_dumms()

    # The indexed bases.
    x = IndexedBase('x', shape=(m, n))
    y = IndexedBase('y', shape=(n, l))
    z = IndexedBase('z', shape=(l, m))

    # The costs substitution.
    substs = {n: m * 2, l: m * 3}

    #
    # Actual tests.
    #

    p = dr.names

    target_base = IndexedBase('t')
    target = dr.define_einst(target_base[p.a, p.b],
                             x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b])

    # Perform the factorization.
    targets = [target]
    eval_seq = optimize(targets, substs=substs)
    assert len(eval_seq) == 2

    # Check the correctness.
    assert verify_eval_seq(eval_seq, targets)

    # Check the cost.
    cost = get_flop_cost(eval_seq)
    leading_cost = get_flop_cost(eval_seq, leading=True)
    expected_cost = 2 * l * m * n + 2 * m**2 * n
    assert cost == expected_cost
    assert leading_cost == expected_cost
Example #28
0
def test_matrix_factorization(spark_ctx):
    """Test a basic matrix multiplication factorization problem.

    In this test, there are four matrices involved, X, Y, U, and V.  And they
    are used in two test cases for different scenarios.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    x = IndexedBase('X')
    y = IndexedBase('Y')
    u = IndexedBase('U')
    v = IndexedBase('V')
    t = IndexedBase('T')

    #
    # Test case 1.
    #
    # The final expression to optimize is mathematically
    #
    # .. math::
    #
    #     (2 X - Y) * (2 U + V)
    #
    # Here, the expression is to be given in its extended form originally, and
    # we test if it can be factorized into something similar to what we have
    # above. Here we have the signs and coefficients to have better code
    # coverage for these cases.  This test case more concentrates on the
    # horizontal complexity in the input.
    #

    # The target.
    target = dr.define_einst(
        t[a, b], 4 * x[a, c] * u[c, b] + 2 * x[a, c] * v[c, b] -
        2 * y[a, c] * u[c, b] - y[a, c] * v[c, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 2 * n**3 + 2 * n**2
    assert leading_cost == 2 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 2 * n**3 + 4 * n**2

    #
    # Test case 2.
    #
    # The final expression to optimize is mathematically
    #
    # .. math::
    #
    #     (X - 2 Y) * U * V
    #
    # Different from the first test case, here we concentrate more on the
    # treatment of depth complexity in the input.  The sum intermediate needs to
    # be factored again.
    #

    # The target.
    target = dr.define_einst(
        t[a, b], x[a, c] * u[c, d] * v[d, b] - 2 * y[a, c] * u[c, d] * v[d, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=True)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n**3 + n**2
    assert leading_cost == 4 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 4 * n**3 + 2 * n**2

    # Test disabling summation optimization.
    res = optimize(targets, strategy=Strategy.BEST)
    assert verify_eval_seq(res, targets, simplify=True)
    new_cost = get_flop_cost(res, ignore_consts=False)
    assert new_cost - cost != 0
Example #29
0
def test_removal_of_shallow_interms(spark_ctx):
    """Test removal of shallow intermediates.

    Here we have two intermediates,

    .. math::

        U X V + U Y V

    and

    .. math::

        U X W + U Y W

    and it has been deliberately made such that the multiplication with U should
    be carried out first.  Then after the collection of U, we have a shallow
    intermediate U (X + Y), which is a sum of a single product intermediate.
    This test succeeds when we have two intermediates only without the shallow
    ones.

    """

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)
    r_small = Range('R', 0, Rational(1 / 2) * n)

    dumms = symbols('a b c d')
    a, b, c = dumms[:3]
    dumms_small = symbols('e f g h')
    e = dumms_small[0]
    dr.set_dumms(r, dumms)
    dr.set_dumms(r_small, dumms_small)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    u = IndexedBase('U')
    v = IndexedBase('V')
    w = IndexedBase('W')
    x = IndexedBase('X')
    y = IndexedBase('Y')

    s = IndexedBase('S')
    t = IndexedBase('T')

    # The target.
    s_def = dr.define_einst(s[a, b], u[a, c] * x[c, b] + u[a, c] * y[c, b])
    targets = [
        dr.define_einst(t[a, b], s_def[a, e] * v[e, b]),
        dr.define_einst(t[a, b], s_def[a, e] * w[e, b])
    ]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 4

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)
Example #30
0
def test_optimization_of_common_terms(spark_ctx):
    """Test optimization of common terms in summations.

    In this test, there are just two matrices involved, X, Y.  The target reads

    .. math::

        T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a]

    Ideally, it should be evaluated as,

    .. math::

        I[a, b] = X[a, b] + 2 Y[a, b]
        T[a, b] = I[a, b] - I[b, a]

    or,

    .. math::

        I[a, b] = X[a, b] - 2 Y[b, a]
        T[a, b] = I[a, b] - I[b, a]

    Here, in order to emulate real cases where common term reference is in
    interplay with factorization, the X and Y matrices are written as :math:`X =
    S U` and :math:`Y = S V`.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    s = IndexedBase('S')
    u = IndexedBase('U')
    v = IndexedBase('V')

    x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b])
    y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b])
    t = dr.define_einst(
        IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a])

    targets = [t]
    eval_seq = optimize(targets)
    assert len(eval_seq) == 3

    # Check correctness.
    verify_eval_seq(eval_seq, targets)

    # Check cost.
    cost = get_flop_cost(eval_seq)
    assert cost == 2 * n**3 + 2 * n**2
    cost = get_flop_cost(eval_seq, ignore_consts=False)
    assert cost == 2 * n**3 + 3 * n**2

    # Check the result when the common symmetrization optimization is disabled.
    eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON)
    verify_eval_seq(eval_seq, targets)
    new_cost = get_flop_cost(eval_seq, ignore_consts=True)
    assert new_cost - cost != 0