def test_optimization_handles_scalar_intermediates(spark_ctx): """Test optimization of scalar intermediates scaling other tensors. This is set as a special test primarily since it would entail the same collectible giving residues with different ranges. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e') dr.set_dumms(r, dumms) a, b, c = dumms[:3] dr.add_default_resolver(r) u = IndexedBase('u') eps = IndexedBase('epsilon') t = IndexedBase('t') s = IndexedBase('s') targets = [ dr.define( u, (a, r), (b, r), dr.sum((c, r), 8 * s[a, b] * eps[c] * t[a]) - 8 * s[a, b] * eps[a] * t[a]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_handles_coeffcients(spark_ctx): """Test optimization of scalar intermediates scaled by coefficients. This test comes from PoST theory. It tests the optimization of tensor evaluations with scalar intermediates scaled by a factor. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b = symbols('a b') dr.set_dumms(r, [a, b]) dr.add_default_resolver(r) r = IndexedBase('r') eps = IndexedBase('epsilon') t = IndexedBase('t') targets = [ dr.define(r[a, b], dr.sum(2 * eps[a] * t[a, b]) - 2 * eps[b] * t[a, b]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_handles_nonlinear_factors(spark_ctx): """Test optimization of with nonlinear factors. Here a factor is the square of an indexed quantity. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e f g h') dr.set_dumms(r, dumms) a, b, c, d = dumms[:4] dr.add_default_resolver(r) u = symbols('u') s = IndexedBase('s') targets = [ dr.define( u, dr.sum((a, r), (b, r), (c, r), (d, r), 32 * s[a, c]**2 * s[b, d]**2 + 32 * s[a, c] * s[a, d] * s[b, c] * s[b, d])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_drs_tensor_def_dispatch(spark_ctx): """Tests the dispatch to drudge for tensor definitions.""" dr = Drudge(spark_ctx) names = dr.names i_symb = Symbol('i') x = IndexedBase('x') rhs = x[i_symb] dr.add_default_resolver(Range('R')) a = DrsSymbol(dr, 'a') i = DrsSymbol(dr, 'i') for lhs in [a, a[i]]: expected = dr.define(lhs, rhs) def_ = lhs <= rhs assert def_ == expected assert not hasattr(names, 'a') assert not hasattr(names, '_a') def_ = lhs.def_as(rhs) assert def_ == expected assert names.a == expected if isinstance(lhs, DrsSymbol): assert names._a == Symbol('a') else: assert names._a == IndexedBase('a') dr.unset_name(def_)
def test_simple_scalar_optimization(spark_ctx): """Test optimization of a simple scalar. There is not much optimization that can be done for simple scalars. But we need to ensure that we get correct result here. """ dr = Drudge(spark_ctx) a, b, r = symbols('a b r') targets = [dr.define(r, a * b)] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_of_common_terms(spark_ctx): """Test optimization of common terms in summations. In this test, there are just two matrices involved, X, Y. The target reads .. math:: T[a, b] = X[a, b] - X[b, a] + 2 Y[a, b] - 2 Y[b, a] Ideally, it should be evaluated as, .. math:: I[a, b] = X[a, b] + 2 Y[a, b] T[a, b] = I[a, b] - I[b, a] or, .. math:: I[a, b] = X[a, b] - 2 Y[b, a] T[a, b] = I[a, b] - I[b, a] Here, in order to emulate real cases where common term reference is in interplay with factorization, the X and Y matrices are written as :math:`X = S U` and :math:`Y = S V`. """ # # Basic context setting-up. # dr = Drudge(spark_ctx) n = Symbol('n') r = Range('R', 0, n) dumms = symbols('a b c d e f g') a, b, c, d = dumms[:4] dr.set_dumms(r, dumms) dr.add_resolver_for_dumms() # The indexed bases. s = IndexedBase('S') u = IndexedBase('U') v = IndexedBase('V') x = dr.define(IndexedBase('X')[a, b], s[a, c] * u[c, b]) y = dr.define(IndexedBase('Y')[a, b], s[a, c] * v[c, b]) t = dr.define_einst( IndexedBase('t')[a, b], x[a, b] - x[b, a] + 2 * y[a, b] - 2 * y[b, a]) targets = [t] eval_seq = optimize(targets) assert len(eval_seq) == 3 # Check correctness. verify_eval_seq(eval_seq, targets) # Check cost. cost = get_flop_cost(eval_seq) assert cost == 2 * n**3 + 2 * n**2 cost = get_flop_cost(eval_seq, ignore_consts=False) assert cost == 2 * n**3 + 3 * n**2 # Check the result when the common symmetrization optimization is disabled. eval_seq = optimize(targets, strategy=Strategy.DEFAULT & ~Strategy.COMMON) verify_eval_seq(eval_seq, targets) new_cost = get_flop_cost(eval_seq, ignore_consts=True) assert new_cost - cost != 0