def test_optimization_handles_scalar_intermediates(spark_ctx): """Test optimization of scalar intermediates scaling other tensors. This is set as a special test primarily since it would entail the same collectible giving residues with different ranges. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e') dr.set_dumms(r, dumms) a, b, c = dumms[:3] dr.add_default_resolver(r) u = IndexedBase('u') eps = IndexedBase('epsilon') t = IndexedBase('t') s = IndexedBase('s') targets = [ dr.define( u, (a, r), (b, r), dr.sum((c, r), 8 * s[a, b] * eps[c] * t[a]) - 8 * s[a, b] * eps[a] * t[a]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def free_alg(spark_ctx): """Initialize the environment for a free algebra.""" dr = Drudge(spark_ctx) r = Range('R') dumms = sympify('i, j, k, l, m, n') dr.set_dumms(r, dumms) s = Range('S') s_dumms = symbols('alpha beta') dr.set_dumms(s, s_dumms) dr.add_resolver_for_dumms() # For testing the Einstein over multiple ranges. a1, a2 = symbols('a1 a2') dr.add_resolver({a1: (r, s), a2: (r, s)}) dr.set_name(a1, a2) v = Vec('v') dr.set_name(v) m = IndexedBase('m') dr.set_symm(m, Perm([1, 0], NEG)) h = IndexedBase('h') dr.set_symm(h, Perm([1, 0], NEG | CONJ)) rho = IndexedBase('rho') dr.set_symm(rho, Perm([1, 0, 3, 2]), valence=4) dr.set_tensor_method('get_one', lambda x: 1) return dr
def test_optimization_handles_coeffcients(spark_ctx): """Test optimization of scalar intermediates scaled by coefficients. This test comes from PoST theory. It tests the optimization of tensor evaluations with scalar intermediates scaled by a factor. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b = symbols('a b') dr.set_dumms(r, [a, b]) dr.add_default_resolver(r) r = IndexedBase('r') eps = IndexedBase('epsilon') t = IndexedBase('t') targets = [ dr.define(r[a, b], dr.sum(2 * eps[a] * t[a, b]) - 2 * eps[b] * t[a, b]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_handles_nonlinear_factors(spark_ctx): """Test optimization of with nonlinear factors. Here a factor is the square of an indexed quantity. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e f g h') dr.set_dumms(r, dumms) a, b, c, d = dumms[:4] dr.add_default_resolver(r) u = symbols('u') s = IndexedBase('s') targets = [ dr.define( u, dr.sum((a, r), (b, r), (c, r), (d, r), 32 * s[a, c]**2 * s[b, d]**2 + 32 * s[a, c] * s[a, d] * s[b, c] * s[b, d])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_basic_handling_range_with_variable_bounds(spark_ctx): """Test the treatment of ranges with variable bounds. Here we use a simple example that slightly resembles the angular momentum handling in quantum physics. Here we concentrate on basic operations of dummy resetting and mapping of scalar functions. """ dr = Drudge(spark_ctx) j1, j2 = symbols('j1 j2') m1, m2 = symbols('m1, m2') j_max = symbols('j_max') j = Range('j', 0, j_max) m = Range('m') dr.set_dumms(j, [j1, j2]) dr.set_dumms(m, [m1, m2]) v = Vec('v') x = IndexedBase('x') tensor = dr.sum((j2, j), (m2, m[0, j2]), x[j2, m2] * v[j2, m2]) reset = tensor.reset_dumms() assert reset.n_terms == 1 term = reset.local_terms[0] assert len(term.sums) == 2 if term.sums[0][1].label == 'j': j_sum, m_sum = term.sums else: m_sum, j_sum = term.sums assert j_sum[0] == j1 assert j_sum[1].args == j.args assert m_sum[0] == m1 assert m_sum[1].label == 'm' assert m_sum[1].lower == 0 assert m_sum[1].upper == j1 # Important! assert term.amp == x[j1, m1] assert term.vecs == (v[j1, m1], ) # Test that functions can be mapped to the bounds. repled = reset.map2scalars(lambda x: x.xreplace({j_max: 10}), skip_ranges=False) assert repled.n_terms == 1 term = repled.local_terms[0] checked = False for _, i in term.sums: if i.label == 'j': assert i.lower == 0 assert i.upper == 10 checked = True continue assert checked
def simple_drudge(spark_ctx): """Form a simple drudge with some basic information. """ dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() return dr
def test_disconnected_outer_product_factorization(spark_ctx): """Test optimization of expressions with disconnected outer products. """ # Basic context setting-up. dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d, e = dumms[:5] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. u = IndexedBase('U') x = IndexedBase('X') y = IndexedBase('Y') z = IndexedBase('Z') t = IndexedBase('T') # The target. target = dr.define_einst( t[a, b], u[a, b] * z[c, e] * x[e, c] + u[a, b] * z[c, e] * y[e, c] ) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 4 * n ** 2 assert leading_cost == 4 * n ** 2
def three_ranges(spark_ctx): """Fixture with three ranges. This drudge has three ranges, named M, N, L with sizes m, n, and l, respectively. It also has a substitution dictionary setting n = 2m and l = 3m. """ dr = Drudge(spark_ctx) # The sizes. m, n, l = symbols('m n l') # The ranges. m_range = Range('M', 0, m) n_range = Range('N', 0, n) l_range = Range('L', 0, l) dr.set_dumms(m_range, symbols('a b c d e f g')) dr.set_dumms(n_range, symbols('i j k l m n')) dr.set_dumms(l_range, symbols('p q r')) dr.add_resolver_for_dumms() dr.set_name(m, n, l) dr.substs = {n: m * 2, l: m * 3} return dr
def test_conjugation_optimization(spark_ctx): """Test optimization of expressions containing complex conjugate. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b, c, d = symbols('a b c d') dr.set_dumms(r, [a, b, c, d]) dr.add_default_resolver(r) p = IndexedBase('p') x = IndexedBase('x') y = IndexedBase('y') z = IndexedBase('z') targets = [ dr.define_einst( p[a, b], x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_sums_can_be_expanded(spark_ctx): """Test the summation expansion facility. Here we have essentially a direct product of two ranges and expand it. The usage here also includes some preliminary steps typical in the usage paradigm. """ dr = Drudge(spark_ctx) comp = Range('P') r1, r2 = symbols('r1, r2') dr.set_dumms(comp, [r1, r2]) a = IndexedBase('a') v = Vec('v') # A simple thing written in terms of composite indices. orig = dr.sum((r1, comp), (r2, comp), a[r1] * a[r2] * v[r1] * v[r2]) # Rewrite the expression in terms of components. Here, r1 should be # construed as a simple Wild. rewritten = orig.subst_all([(a[r1], a[x(r1), y(r1)]), (v[r1], v[x(r1), y(r1)])]) # Expand the summation over r. x_dim = Range('X') y_dim = Range('Y') x1, x2 = symbols('x1 x2') dr.set_dumms(x_dim, [x1, x2]) y1, y2 = symbols('y1 y2') dr.set_dumms(y_dim, [y1, y2]) res = rewritten.expand_sums( comp, lambda r: [(Symbol(str(r).replace('r', 'x')), x_dim, x(r)), (Symbol(str(r).replace('r', 'y')), y_dim, y(r))]) assert (res - dr.sum( (x1, x_dim), (y1, y_dim), (x2, x_dim), (y2, y_dim), a[x1, y1] * a[x2, y2] * v[x1, y1] * v[x2, y2])).simplify() == 0
def test_matrix_chain(spark_ctx): """Test a basic matrix chain multiplication problem. Matrix chain multiplication problem is the classical problem that motivated the algorithm for single-term optimization in this package. So here a very simple matrix chain multiplication problem with three matrices are used to test the factorization facilities. In this simple test, we will have three matrices :math:`x`, :math:`y`, and :math:`z`, which are of shapes :math:`m\\times n`, :math:`n \\times l`, and :math:`l \\times m` respectively. In the factorization, we are going to set :math:`n = 2 m` and :math:`l = 3 m`. If we multiply the first two matrices first, the cost will be (two times) .. math:: m n l + m^2 l Or if we multiply the last two matrices first, the cost will be (two times) .. math:: m n l + m^2 n In addition to the classical matrix chain product, also tested is the trace of their cyclic product. .. math:: t = \\sum_i \\sum_j \\sum_k x_{i, j} y_{j, k} z_{k, i} If we first take the product of :math:`Y Z`, the cost will be (two times) :math:`n m l + n m`. For first multiplying :math:`X Y` and :math:`Z X`, the costs will be (two times) :math:`n m l + m l` and :math:`n m l + n l` respectively. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) # The sizes. m, n, l = symbols('m n l') # The ranges. m_range = Range('M', 0, m) n_range = Range('N', 0, n) l_range = Range('L', 0, l) dr.set_dumms(m_range, symbols('a b c')) dr.set_dumms(n_range, symbols('i j k')) dr.set_dumms(l_range, symbols('p q r')) dr.add_resolver_for_dumms() # The indexed bases. x = IndexedBase('x', shape=(m, n)) y = IndexedBase('y', shape=(n, l)) z = IndexedBase('z', shape=(l, m)) # The costs substitution. substs = {n: m * 2, l: m * 3} # # Actual tests. # p = dr.names target_base = IndexedBase('t') target = dr.define_einst(target_base[p.a, p.b], x[p.a, p.i] * y[p.i, p.p] * z[p.p, p.b]) # Perform the factorization. targets = [target] eval_seq = optimize(targets, substs=substs) assert len(eval_seq) == 2 # Check the correctness. assert verify_eval_seq(eval_seq, targets) # Check the cost. cost = get_flop_cost(eval_seq) leading_cost = get_flop_cost(eval_seq, leading=True) expected_cost = 2 * l * m * n + 2 * m**2 * n assert cost == expected_cost assert leading_cost == expected_cost
def test_removal_of_shallow_interms(spark_ctx): """Test removal of shallow intermediates. Here we have two intermediates, .. math:: U X V + U Y V and .. math:: U X W + U Y W and it has been deliberately made such that the multiplication with U should be carried out first. Then after the collection of U, we have a shallow intermediate U (X + Y), which is a sum of a single product intermediate. This test succeeds when we have two intermediates only without the shallow ones. """ # Basic context setting-up. dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) r_small = Range('R', 0, Rational(1 / 2) * n) dumms = symbols('a b c d') a, b, c = dumms[:3] dumms_small = symbols('e f g h') e = dumms_small[0] dr.set_dumms(r, dumms) dr.set_dumms(r_small, dumms_small) dr.add_resolver_for_dumms() # The indexed bases. u = IndexedBase('U') v = IndexedBase('V') w = IndexedBase('W') x = IndexedBase('X') y = IndexedBase('Y') s = IndexedBase('S') t = IndexedBase('T') # The target. s_def = dr.define_einst(s[a, b], u[a, c] * x[c, b] + u[a, c] * y[c, b]) targets = [ dr.define_einst(t[a, b], s_def[a, e] * v[e, b]), dr.define_einst(t[a, b], s_def[a, e] * w[e, b]) ] # The actual optimization. res = optimize(targets) assert len(res) == 4 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False)
def test_optimization_of_common_terms(spark_ctx): """Test optimization of common terms in summations. In this test, there are just two matrices involved, X, Y. The target reads .. math:: T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a] Ideally, it should be evaluated as, .. math:: I[a, b] = X[a, b] + 2 Y[a, b] T[a, b] = I[a, b] - I[b, a] or, .. math:: I[a, b] = X[a, b] - 2 Y[b, a] T[a, b] = I[a, b] - I[b, a] Here, in order to emulate real cases where common term reference is in interplay with factorization, the X and Y matrices are written as :math:`X = S U` and :math:`Y = S V`. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d = dumms[:4] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. s = IndexedBase('S') u = IndexedBase('U') v = IndexedBase('V') x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b]) y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b]) t = dr.define_einst( IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a]) targets = [t] eval_seq = optimize(targets) assert len(eval_seq) == 3 # Check correctness. verify_eval_seq(eval_seq, targets) # Check cost. cost = get_flop_cost(eval_seq) assert cost == 2 * n**3 + 2 * n**2 cost = get_flop_cost(eval_seq, ignore_consts=False) assert cost == 2 * n**3 + 3 * n**2 # Check the result when the common symmetrization optimization is disabled. eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON) verify_eval_seq(eval_seq, targets) new_cost = get_flop_cost(eval_seq, ignore_consts=True) assert new_cost - cost != 0
def test_matrix_factorization(spark_ctx): """Test a basic matrix multiplication factorization problem. In this test, there are four matrices involved, X, Y, U, and V. And they are used in two test cases for different scenarios. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d = dumms[:4] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. x = IndexedBase('X') y = IndexedBase('Y') u = IndexedBase('U') v = IndexedBase('V') t = IndexedBase('T') # # Test case 1. # # The final expression to optimize is mathematically # # .. math:: # # (2 X - Y) * (2 U + V) # # Here, the expression is to be given in its extended form originally, and # we test if it can be factorized into something similar to what we have # above. Here we have the signs and coefficients to have better code # coverage for these cases. This test case more concentrates on the # horizontal complexity in the input. # # The target. target = dr.define_einst( t[a, b], 4 * x[a, c] * u[c, b] + 2 * x[a, c] * v[c, b] - 2 * y[a, c] * u[c, b] - y[a, c] * v[c, b]) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=False) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 2 * n**3 + 2 * n**2 assert leading_cost == 2 * n**3 cost = get_flop_cost(res, ignore_consts=False) assert cost == 2 * n**3 + 4 * n**2 # # Test case 2. # # The final expression to optimize is mathematically # # .. math:: # # (X - 2 Y) * U * V # # Different from the first test case, here we concentrate more on the # treatment of depth complexity in the input. The sum intermediate needs to # be factored again. # # The target. target = dr.define_einst( t[a, b], x[a, c] * u[c, d] * v[d, b] - 2 * y[a, c] * u[c, d] * v[d, b]) targets = [target] # The actual optimization. res = optimize(targets) assert len(res) == 3 # Test the correctness. assert verify_eval_seq(res, targets, simplify=True) # Test the cost. cost = get_flop_cost(res) leading_cost = get_flop_cost(res, leading=True) assert cost == 4 * n**3 + n**2 assert leading_cost == 4 * n**3 cost = get_flop_cost(res, ignore_consts=False) assert cost == 4 * n**3 + 2 * n**2 # Test disabling summation optimization. res = optimize(targets, strategy=Strategy.BEST) assert verify_eval_seq(res, targets, simplify=True) new_cost = get_flop_cost(res, ignore_consts=False) assert new_cost - cost != 0