def test_disconnected_outer_product_factorization(spark_ctx): """Test optimization of expressions with disconnected outer products. """ # Basic context setting-up. dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d, e = dumms[:5] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. u = IndexedBase('U') x = IndexedBase('X') y = IndexedBase('Y') z = IndexedBase('Z') t = IndexedBase('T') # The target. target = dr.define_einst( t[a, b], u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c] ) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 4 * n ** 2 assert leading_cost == 4 * n ** 2
def test_conjugation_optimization(spark_ctx): """Test optimization of expressions containing complex conjugate. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b, c, d = symbols('a b c d') dr.set_dumms(r, [a, b, c, d]) dr.add_default_resolver(r) p = IndexedBase('p') x = IndexedBase('x') y = IndexedBase('y') z = IndexedBase('z') targets = [ dr.define_einst( p[a, b], x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_matrix_chain(spark_ctx): """Test a basic matrix chain multiplication problem. Matrix chain multiplication problem is the classical problem that motivated the algorithm for single-term optimization in this package. So here a very simple matrix chain multiplication problem with three matrices are used to test the factorization facilities. In this simple test, we will have three matrices :math:`x`, :math:`y`, and :math:`z`, which are of shapes :math:`m\\times n`, :math:`n \\times l`, and :math:`l \\times m` respectively. In the factorization, we are going to set :math:`n = 2 m` and :math:`l = 3 m`. If we multiply the first two matrices first, the cost will be (two times) .. math:: m n l + m^2 l Or if we multiply the last two matrices first, the cost will be (two times) .. math:: m n l + m^2 n In addition to the classical matrix chain product, also tested is the trace of their cyclic product. .. math:: t = \\sum_i \\sum_j \\sum_k x_{i, j} y_{j, k} z_{k, i} If we first take the product of :math:`Y Z`, the cost will be (two times) :math:`n m l + n m`. For first multiplying :math:`X Y` and :math:`Z X`, the costs will be (two times) :math:`n m l + m l` and :math:`n m l + n l` respectively. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) # The sizes. m, n, l = symbols('m n l') # The ranges. m_range = Range('M', 0, m) n_range = Range('N', 0, n) l_range = Range('L', 0, l) dr.set_dumms(m_range, symbols('a b c')) dr.set_dumms(n_range, symbols('i j k')) dr.set_dumms(l_range, symbols('p q r')) dr.add_resolver_for_dumms() # The indexed bases. x = IndexedBase('x', shape=(m, n)) y = IndexedBase('y', shape=(n, l)) z = IndexedBase('z', shape=(l, m)) # The costs substitution. substs = {n: m * 2, l: m * 3} # # Actual tests. # p = dr.names target_base = IndexedBase('t') target = dr.define_einst(target_base[p.a, p.b], x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b]) # Perform the factorization. targets = [target] eval_seq = optimize(targets, substs=substs) assert len(eval_seq) == 2 # Check the correctness. assert verify_eval_seq(eval_seq, targets) # Check the cost. cost = get_flop_cost(eval_seq) leading_cost = get_flop_cost(eval_seq, leading=True) expected_cost = 2 * l * m * n + 2 * m**2 * n assert cost == expected_cost assert leading_cost == expected_cost
def test_removal_of_shallow_interms(spark_ctx): """Test removal of shallow intermediates. Here we have two intermediates, .. math:: U X V + U Y V and .. math:: U X W + U Y W and it has been deliberately made such that the multiplication with U should be carried out first. Then after the collection of U, we have a shallow intermediate U (X + Y), which is a sum of a single product intermediate. This test succeeds when we have two intermediates only without the shallow ones. """ # Basic context setting-up. dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) r_small = Range('R', 0, Rational(1 / 2) * n) dumms = symbols('a b c d') a, b, c = dumms[:3] dumms_small = symbols('e f g h') e = dumms_small[0] dr.set_dumms(r, dumms) dr.set_dumms(r_small, dumms_small) dr.add_resolver_for_dumms() # The indexed bases. u = IndexedBase('U') v = IndexedBase('V') w = IndexedBase('W') x = IndexedBase('X') y = IndexedBase('Y') s = IndexedBase('S') t = IndexedBase('T') # The target. s_def = dr.define_einst(s[a, b], u[a, c] * x[c, b] + u[a, c] * y[c, b]) targets = [ dr.define_einst(t[a, b], s_def[a, e] * v[e, b]), dr.define_einst(t[a, b], s_def[a, e] * w[e, b]) ] # The actual optimization. res = optimize(targets) assert len(res) == 4 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False)
def test_optimization_of_common_terms(spark_ctx): """Test optimization of common terms in summations. In this test, there are just two matrices involved, X, Y. The target reads .. math:: T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a] Ideally, it should be evaluated as, .. math:: I[a, b] = X[a, b] + 2 Y[a, b] T[a, b] = I[a, b] - I[b, a] or, .. math:: I[a, b] = X[a, b] - 2 Y[b, a] T[a, b] = I[a, b] - I[b, a] Here, in order to emulate real cases where common term reference is in interplay with factorization, the X and Y matrices are written as :math:`X = S U` and :math:`Y = S V`. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d = dumms[:4] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. s = IndexedBase('S') u = IndexedBase('U') v = IndexedBase('V') x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b]) y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b]) t = dr.define_einst( IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a]) targets = [t] eval_seq = optimize(targets) assert len(eval_seq) == 3 # Check correctness. verify_eval_seq(eval_seq, targets) # Check cost. cost = get_flop_cost(eval_seq) assert cost == 2 * n**3 + 2 * n**2 cost = get_flop_cost(eval_seq, ignore_consts=False) assert cost == 2 * n**3 + 3 * n**2 # Check the result when the common symmetrization optimization is disabled. eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON) verify_eval_seq(eval_seq, targets) new_cost = get_flop_cost(eval_seq, ignore_consts=True) assert new_cost - cost != 0
def test_matrix_factorization(spark_ctx): """Test a basic matrix multiplication factorization problem. In this test, there are four matrices involved, X, Y, U, and V. And they are used in two test cases for different scenarios. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d = dumms[:4] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. x = IndexedBase('X') y = IndexedBase('Y') u = IndexedBase('U') v = IndexedBase('V') t = IndexedBase('T') # # Test case 1. # # The final expression to optimize is mathematically # # .. math:: # # (2 X - Y) * (2 U + V) # # Here, the expression is to be given in its extended form originally, and # we test if it can be factorized into something similar to what we have # above. Here we have the signs and coefficients to have better code # coverage for these cases. This test case more concentrates on the # horizontal complexity in the input. # # The target. target = dr.define_einst( t[a, b], 4 * x[a, c] * u[c, b] + 2 * x[a, c] * v[c, b] - 2 * y[a, c] * u[c, b] - y[a, c] * v[c, b]) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 2 * n**3 + 2 * n**2 assert leading_cost == 2 * n**3 cost = get_flop_cost(res, ignore_consts=False) assert cost == 2 * n**3 + 4 * n**2 # # Test case 2. # # The final expression to optimize is mathematically # # .. math:: # # (X - 2 Y) * U * V # # Different from the first test case, here we concentrate more on the # treatment of depth complexity in the input. The sum intermediate needs to # be factored again. # # The target. target = dr.define_einst( t[a, b], x[a, c] * u[c, d] * v[d, b] - 2 * y[a, c] * u[c, d] * v[d, b]) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=True) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 4 * n**3 + n**2 assert leading_cost == 4 * n**3 cost = get_flop_cost(res, ignore_consts=False) assert cost == 4 * n**3 + 2 * n**2 # Test disabling summation optimization. res = optimize(targets, strategy=Strategy.BEST) assert verify_eval_seq(res, targets, simplify=True) new_cost = get_flop_cost(res, ignore_consts=False) assert new_cost - cost != 0