def _split_backward_equations(self, backward_assignments, subexp_symgen): all_subexpressions = [] split_main_assignments = [] known_coeffs_dict = dict() for asm, stencil_dir in zip(backward_assignments, self.stencil): split_asm = self._split_backward_equations_recursive( asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict) split_main_assignments.append(split_asm) ac = AssignmentCollection(split_main_assignments, subexpressions=all_subexpressions, subexpression_symbol_generator=subexp_symgen) ac.topological_sort(sort_main_assignments=False) return ac
def test_assignment_collection(): ac = AssignmentCollection([Assignment(z, x + y)], [], subexpression_symbol_generator=symbol_gen) lhs = ac.add_subexpression(t) assert lhs == sp.Symbol("a_0") ac.subexpressions.append(Assignment(t, 3)) ac.topological_sort(sort_main_assignments=False, sort_subexpressions=True) assert ac.subexpressions[0].lhs == t assert ac.new_with_inserted_subexpression(sp.Symbol("not_defined")) == ac ac_inserted = ac.new_with_inserted_subexpression(t) ac_inserted2 = ac.new_without_subexpressions({lhs}) assert all(a == b for a, b in zip(ac_inserted.all_assignments, ac_inserted2.all_assignments)) print(ac_inserted) assert ac_inserted.subexpressions[0] == Assignment(lhs, 3) assert 'a_0' in str(ac_inserted) assert '<table' in ac_inserted._repr_html_()