Exemple #1
0
def test_disconnected_outer_product_factorization(spark_ctx):
    """Test optimization of expressions with disconnected outer products.
    """

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d, e = dumms[:5]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    u = IndexedBase('U')
    x = IndexedBase('X')
    y = IndexedBase('Y')
    z = IndexedBase('Z')
    t = IndexedBase('T')

    # The target.
    target = dr.define_einst(
        t[a, b],
        u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c]
    )
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n ** 2
    assert leading_cost == 4 * n ** 2
Exemple #2
0
def test_conjugation_optimization(spark_ctx):
    """Test optimization of expressions containing complex conjugate.
    """

    dr = Drudge(spark_ctx)

    n = symbols('n')
    r = Range('r', 0, n)
    a, b, c, d = symbols('a b c d')
    dr.set_dumms(r, [a, b, c, d])
    dr.add_default_resolver(r)

    p = IndexedBase('p')
    x = IndexedBase('x')
    y = IndexedBase('y')
    z = IndexedBase('z')

    targets = [
        dr.define_einst(
            p[a, b],
            x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b]))
    ]
    eval_seq = optimize(targets)
    assert verify_eval_seq(eval_seq, targets)
Exemple #3
0
def test_matrix_chain(spark_ctx):
    """Test a basic matrix chain multiplication problem.

    Matrix chain multiplication problem is the classical problem that motivated
    the algorithm for single-term optimization in this package.  So here a very
    simple matrix chain multiplication problem with three matrices are used to
    test the factorization facilities.  In this simple test, we will have three
    matrices :math:`x`, :math:`y`, and :math:`z`, which are of shapes
    :math:`m\\times n`, :math:`n \\times l`, and :math:`l \\times m`
    respectively. In the factorization, we are going to set :math:`n = 2 m` and
    :math:`l = 3 m`.

    If we multiply the first two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 l

    Or if we multiply the last two matrices first, the cost will be (two times)

    .. math::

        m n l + m^2 n

    In addition to the classical matrix chain product, also tested is the
    trace of their cyclic product.

    .. math::

        t = \\sum_i \\sum_j \\sum_k x_{i, j} y_{j, k} z_{k, i}

    If we first take the product of :math:`Y Z`, the cost will be (two times)
    :math:`n m l + n m`. For first multiplying :math:`X Y` and :math:`Z X`,
    the costs will be (two times) :math:`n m l + m l` and :math:`n m l + n l`
    respectively.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    # The sizes.
    m, n, l = symbols('m n l')

    # The ranges.
    m_range = Range('M', 0, m)
    n_range = Range('N', 0, n)
    l_range = Range('L', 0, l)

    dr.set_dumms(m_range, symbols('a b c'))
    dr.set_dumms(n_range, symbols('i j k'))
    dr.set_dumms(l_range, symbols('p q r'))
    dr.add_resolver_for_dumms()

    # The indexed bases.
    x = IndexedBase('x', shape=(m, n))
    y = IndexedBase('y', shape=(n, l))
    z = IndexedBase('z', shape=(l, m))

    # The costs substitution.
    substs = {n: m * 2, l: m * 3}

    #
    # Actual tests.
    #

    p = dr.names

    target_base = IndexedBase('t')
    target = dr.define_einst(target_base[p.a, p.b],
                             x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b])

    # Perform the factorization.
    targets = [target]
    eval_seq = optimize(targets, substs=substs)
    assert len(eval_seq) == 2

    # Check the correctness.
    assert verify_eval_seq(eval_seq, targets)

    # Check the cost.
    cost = get_flop_cost(eval_seq)
    leading_cost = get_flop_cost(eval_seq, leading=True)
    expected_cost = 2 * l * m * n + 2 * m**2 * n
    assert cost == expected_cost
    assert leading_cost == expected_cost
def test_removal_of_shallow_interms(spark_ctx):
    """Test removal of shallow intermediates.

    Here we have two intermediates,

    .. math::

        U X V + U Y V

    and

    .. math::

        U X W + U Y W

    and it has been deliberately made such that the multiplication with U should
    be carried out first.  Then after the collection of U, we have a shallow
    intermediate U (X + Y), which is a sum of a single product intermediate.
    This test succeeds when we have two intermediates only without the shallow
    ones.

    """

    # Basic context setting-up.
    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)
    r_small = Range('R', 0, Rational(1 / 2) * n)

    dumms = symbols('a b c d')
    a, b, c = dumms[:3]
    dumms_small = symbols('e f g h')
    e = dumms_small[0]
    dr.set_dumms(r, dumms)
    dr.set_dumms(r_small, dumms_small)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    u = IndexedBase('U')
    v = IndexedBase('V')
    w = IndexedBase('W')
    x = IndexedBase('X')
    y = IndexedBase('Y')

    s = IndexedBase('S')
    t = IndexedBase('T')

    # The target.
    s_def = dr.define_einst(s[a, b], u[a, c] * x[c, b] + u[a, c] * y[c, b])
    targets = [
        dr.define_einst(t[a, b], s_def[a, e] * v[e, b]),
        dr.define_einst(t[a, b], s_def[a, e] * w[e, b])
    ]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 4

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)
Exemple #5
0
def test_optimization_of_common_terms(spark_ctx):
    """Test optimization of common terms in summations.

    In this test, there are just two matrices involved, X, Y.  The target reads

    .. math::

        T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a]

    Ideally, it should be evaluated as,

    .. math::

        I[a, b] = X[a, b] + 2 Y[a, b]
        T[a, b] = I[a, b] - I[b, a]

    or,

    .. math::

        I[a, b] = X[a, b] - 2 Y[b, a]
        T[a, b] = I[a, b] - I[b, a]

    Here, in order to emulate real cases where common term reference is in
    interplay with factorization, the X and Y matrices are written as :math:`X =
    S U` and :math:`Y = S V`.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    s = IndexedBase('S')
    u = IndexedBase('U')
    v = IndexedBase('V')

    x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b])
    y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b])
    t = dr.define_einst(
        IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a])

    targets = [t]
    eval_seq = optimize(targets)
    assert len(eval_seq) == 3

    # Check correctness.
    verify_eval_seq(eval_seq, targets)

    # Check cost.
    cost = get_flop_cost(eval_seq)
    assert cost == 2 * n**3 + 2 * n**2
    cost = get_flop_cost(eval_seq, ignore_consts=False)
    assert cost == 2 * n**3 + 3 * n**2

    # Check the result when the common symmetrization optimization is disabled.
    eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON)
    verify_eval_seq(eval_seq, targets)
    new_cost = get_flop_cost(eval_seq, ignore_consts=True)
    assert new_cost - cost != 0
def test_matrix_factorization(spark_ctx):
    """Test a basic matrix multiplication factorization problem.

    In this test, there are four matrices involved, X, Y, U, and V.  And they
    are used in two test cases for different scenarios.

    """

    #
    # Basic context setting-up.
    #

    dr = Drudge(spark_ctx)

    n = Symbol('n')
    r = Range('R', 0, n)

    dumms = symbols('a b c d e f g')
    a, b, c, d = dumms[:4]
    dr.set_dumms(r, dumms)
    dr.add_resolver_for_dumms()

    # The indexed bases.
    x = IndexedBase('X')
    y = IndexedBase('Y')
    u = IndexedBase('U')
    v = IndexedBase('V')
    t = IndexedBase('T')

    #
    # Test case 1.
    #
    # The final expression to optimize is mathematically
    #
    # .. math::
    #
    #     (2 X - Y) * (2 U + V)
    #
    # Here, the expression is to be given in its extended form originally, and
    # we test if it can be factorized into something similar to what we have
    # above. Here we have the signs and coefficients to have better code
    # coverage for these cases.  This test case more concentrates on the
    # horizontal complexity in the input.
    #

    # The target.
    target = dr.define_einst(
        t[a, b], 4 * x[a, c] * u[c, b] + 2 * x[a, c] * v[c, b] -
        2 * y[a, c] * u[c, b] - y[a, c] * v[c, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=False)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 2 * n**3 + 2 * n**2
    assert leading_cost == 2 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 2 * n**3 + 4 * n**2

    #
    # Test case 2.
    #
    # The final expression to optimize is mathematically
    #
    # .. math::
    #
    #     (X - 2 Y) * U * V
    #
    # Different from the first test case, here we concentrate more on the
    # treatment of depth complexity in the input.  The sum intermediate needs to
    # be factored again.
    #

    # The target.
    target = dr.define_einst(
        t[a, b], x[a, c] * u[c, d] * v[d, b] - 2 * y[a, c] * u[c, d] * v[d, b])
    targets = [target]

    # The actual optimization.
    res = optimize(targets)
    assert len(res) == 3

    # Test the correctness.
    assert verify_eval_seq(res, targets, simplify=True)

    # Test the cost.
    cost = get_flop_cost(res)
    leading_cost = get_flop_cost(res, leading=True)
    assert cost == 4 * n**3 + n**2
    assert leading_cost == 4 * n**3
    cost = get_flop_cost(res, ignore_consts=False)
    assert cost == 4 * n**3 + 2 * n**2

    # Test disabling summation optimization.
    res = optimize(targets, strategy=Strategy.BEST)
    assert verify_eval_seq(res, targets, simplify=True)
    new_cost = get_flop_cost(res, ignore_consts=False)
    assert new_cost - cost != 0