def test_optimization_handles_scalar_intermediates(spark_ctx): """Test optimization of scalar intermediates scaling other tensors. This is set as a special test primarily since it would entail the same collectible giving residues with different ranges. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e') dr.set_dumms(r, dumms) a, b, c = dumms[:3] dr.add_default_resolver(r) u = IndexedBase('u') eps = IndexedBase('epsilon') t = IndexedBase('t') s = IndexedBase('s') targets = [ dr.define( u, (a, r), (b, r), dr.sum((c, r), 8 * s[a, b] * eps[c] * t[a]) - 8 * s[a, b] * eps[a] * t[a]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_handles_coeffcients(spark_ctx): """Test optimization of scalar intermediates scaled by coefficients. This test comes from PoST theory. It tests the optimization of tensor evaluations with scalar intermediates scaled by a factor. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b = symbols('a b') dr.set_dumms(r, [a, b]) dr.add_default_resolver(r) r = IndexedBase('r') eps = IndexedBase('epsilon') t = IndexedBase('t') targets = [ dr.define(r[a, b], dr.sum(2 * eps[a] * t[a, b]) - 2 * eps[b] * t[a, b]) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_optimization_handles_nonlinear_factors(spark_ctx): """Test optimization of with nonlinear factors. Here a factor is the square of an indexed quantity. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) dumms = symbols('a b c d e f g h') dr.set_dumms(r, dumms) a, b, c, d = dumms[:4] dr.add_default_resolver(r) u = symbols('u') s = IndexedBase('s') targets = [ dr.define( u, dr.sum((a, r), (b, r), (c, r), (d, r), 32 * s[a, c]**2 * s[b, d]**2 + 32 * s[a, c] * s[a, d] * s[b, c] * s[b, d])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)
def test_drs_tensor_def_dispatch(spark_ctx): """Tests the dispatch to drudge for tensor definitions.""" dr = Drudge(spark_ctx) names = dr.names i_symb = Symbol('i') x = IndexedBase('x') rhs = x[i_symb] dr.add_default_resolver(Range('R')) a = DrsSymbol(dr, 'a') i = DrsSymbol(dr, 'i') for lhs in [a, a[i]]: expected = dr.define(lhs, rhs) def_ = lhs <= rhs assert def_ == expected assert not hasattr(names, 'a') assert not hasattr(names, '_a') def_ = lhs.def_as(rhs) assert def_ == expected assert names.a == expected if isinstance(lhs, DrsSymbol): assert names._a == Symbol('a') else: assert names._a == IndexedBase('a') dr.unset_name(def_)
def test_conjugation_optimization(spark_ctx): """Test optimization of expressions containing complex conjugate. """ dr = Drudge(spark_ctx) n = symbols('n') r = Range('r', 0, n) a, b, c, d = symbols('a b c d') dr.set_dumms(r, [a, b, c, d]) dr.add_default_resolver(r) p = IndexedBase('p') x = IndexedBase('x') y = IndexedBase('y') z = IndexedBase('z') targets = [ dr.define_einst( p[a, b], x[a, c] * conjugate(y[c, b]) + x[a, c] * conjugate(z[c, b])) ] eval_seq = optimize(targets) assert verify_eval_seq(eval_seq, targets)