def add(condition, dimensions, as_else_block=False): nonlocal last_conditional if staggered_field.index_dimensions == 1: assignments = [ Assignment(staggered_field(d), expressions[d]) for d in dimensions ] a_coll = AssignmentCollection(assignments, list(subexpressions)) a_coll = a_coll.new_filtered( [staggered_field(d) for d in dimensions]) elif staggered_field.index_dimensions == 2: assert staggered_field.has_fixed_index_shape assignments = [ Assignment(staggered_field(d, i), expr) for d in dimensions for i, expr in enumerate(expressions[d]) ] a_coll = AssignmentCollection(assignments, list(subexpressions)) a_coll = a_coll.new_filtered([ staggered_field(d, i) for i in range(staggered_field.index_shape[1]) for d in dimensions ]) sp_assignments = [ SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments ] if as_else_block and last_conditional: new_cond = Conditional(condition, Block(sp_assignments)) last_conditional.false_block = Block([new_cond]) last_conditional = new_cond else: last_conditional = Conditional(condition, Block(sp_assignments)) final_assignments.append(last_conditional)
def __init__( self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {}, simplification_hints: Optional[Dict[str, Any]] = None, subexpression_symbol_generator: Iterator[sp.Symbol] = None ) -> None: if isinstance(main_assignments, Dict): main_assignments = [ Assignment(k, v) for k, v in main_assignments.items() ] if isinstance(subexpressions, Dict): subexpressions = [ Assignment(k, v) for k, v in subexpressions.items() ] self.main_assignments = main_assignments self.subexpressions = subexpressions if simplification_hints is None: simplification_hints = {} self.simplification_hints = simplification_hints if subexpression_symbol_generator is None: self.subexpression_symbol_generator = SymbolGen() else: self.subexpression_symbol_generator = subexpression_symbol_generator
def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" new_subexpressions = [] subs_dict = None for se in self.subexpressions: if se.lhs == symbol: subs_dict = {se.lhs: se.rhs} else: new_subexpressions.append(se) if subs_dict is None: return self new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] return self.copy(new_eqs, new_subexpressions)
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [ Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list ]
def new_with_substitutions( self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, substitute_on_lhs: bool = True, sort_topologically: bool = True) -> 'AssignmentCollection': """Returns new object, where terms are substituted according to the passed substitution dict. Args: substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments sort_topologically: if subexpressions are added as substitutions and this parameters is true, the subexpressions are sorted topologically after insertion Returns: New AssignmentCollection where substitutions have been applied, self is not altered. """ transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) if add_substitutions_as_subexpressions: transformed_subexpressions = [ Assignment(b, a) for a, b in substitutions.items() ] + transformed_subexpressions if sort_topologically: transformed_subexpressions = sort_assignments_topologically( transformed_subexpressions) return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" own_definitions = set([e.lhs for e in self.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments]) assert len(own_definitions.intersection(other_definitions)) == 0, \ "Cannot merge collections, since both define the same symbols" own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} substitution_dict = {} processed_other_subexpression_equations = [] for other_subexpression_eq in other.subexpressions: if other_subexpression_eq.lhs in own_subexpression_symbols: if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: continue # exact the same subexpression equation exists already else: # different definition - a new name has to be introduced new_lhs = next(self.subexpression_symbol_generator) new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) processed_other_subexpression_equations.append(new_eq) substitution_dict[other_subexpression_eq.lhs] = new_lhs else: processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] return self.copy(self.main_assignments + processed_other_main_assignments, self.subexpressions + processed_other_subexpression_equations)
def sympy_cse(ac, **kwargs): """Searches for common subexpressions inside the assignment collection. Searches is done in both the existing subexpressions as well as the assignments themselves. It uses the sympy subexpression detection to do this. Return a new assignment collection with the additional subexpressions found """ symbol_gen = ac.subexpression_symbol_generator all_assignments = [ e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment) ] other_objects = [ e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment) ] replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs) replacement_eqs = [Assignment(*r) for r in replacements] modified_subexpressions = new_eq[:len(ac.subexpressions)] modified_update_equations = new_eq[len(ac.subexpressions):] new_subexpressions = sort_assignments_topologically( other_objects + replacement_eqs + modified_subexpressions) return ac.copy(modified_update_equations, new_subexpressions)
def to_placeholder_function(expr, name): """Replaces an expression by a sympy function. - replacing an expression with just a symbol would lead to problem when calculating derivatives - placeholder functions get rid of this problem Examples: >>> x, t = sp.symbols("x, t") >>> temperature = x**2 + t**4 # some 'complicated' dependency >>> temperature_placeholder = to_placeholder_function(temperature, 'T') >>> diffusivity = temperature_placeholder + 42 * t >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative _dT_dt + 42 >>> result, subexpr = remove_placeholder_functions(diffusivity) >>> result T + 42*t >>> subexpr [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)] """ symbols = list(expr.atoms(sp.Symbol)) symbols.sort(key=lambda e: e.name) derivative_symbols = [ sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols ] derivatives = [sp.diff(expr, s) for s in symbols] assignments = [Assignment(sp.Symbol(name), expr)] assignments += [ Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) if not is_constant(derivative) ] def fdiff(_, index): result = derivatives[index - 1] return result if is_constant(result) else derivative_symbols[index - 1] func = type( name, (sp.Function, PlaceholderFunction), { 'fdiff': fdiff, 'value': sp.Symbol(name), 'subexpressions': assignments, 'nargs': len(symbols) }) return func(*symbols)
def add_subexpressions_for_constants(ac): """Extracts constant factors to subexpressions in the given assignment collection. SymPy will exclude common factors from a sum only if they are symbols. This simplification can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence, the number of multiplications is reduced and in some cases, more common subexpressions can be found. """ constants_to_subexp_dict = defaultdict( lambda: next(ac.subexpression_symbol_generator)) def visit(expr): args = list(expr.args) if len(args) == 0: return expr if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul): for i, arg in enumerate(args): if is_constant(arg) and abs(arg) != 1: if arg < 0: args[i] = -constants_to_subexp_dict[-arg] else: args[i] = constants_to_subexp_dict[arg] return expr.func(*(visit(a) for a in args)) main_assignments = [ Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments ] subexpressions = [ Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions ] symbols_to_collect = set(constants_to_subexp_dict.values()) main_assignments = [ Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments ] subexpressions = [ Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions ] subexpressions = [ Assignment(symb, c) for c, symb in constants_to_subexp_dict.items() ] + subexpressions return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function""" return [ Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list ]
def subexpression_substitution_in_main_assignments(ac): """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] for s in ac.main_assignments: new_rhs = s.rhs for sub_expr in ac.subexpressions: new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) result.append(Assignment(s.lhs, new_rhs)) return ac.copy(result)
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], positive: Optional[bool] = None, replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ). This makes the term longer - simplify usually is undoing these - however this transformation can be done to find more common sub-expressions Args: expr: input expression search_symbols: symbols that are searched for for example, given [x,y,z] terms like x*y, x*z, z*y are replaced positive: there are two ways to do this substitution, either with term (x+y)**2 or (x-y)**2 . if positive=True the first version is done, if positive=False the second version is done, if positive=None the sign is determined by the sign of the mixed term that is replaced replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol and the replacement equation is added to the list """ mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set() if expr.is_Mul: distinct_search_symbols = set() nr_of_search_terms = 0 other_factors = sp.Integer(1) for t in expr.args: if t in search_symbols: nr_of_search_terms += 1 distinct_search_symbols.add(t) else: other_factors *= t if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2: u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name) if positive is None: other_factors_without_symbols = other_factors for s in other_factors.atoms(sp.Symbol): other_factors_without_symbols = other_factors_without_symbols.subs(s, 1) positive = other_factors_without_symbols.is_positive assert positive is not None sign = 1 if positive else -1 if replace_mixed is not None: new_symbol_str = 'P' if positive else 'M' mixed_symbol_name = u.name + new_symbol_str + v.name mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", "")) if mixed_symbol not in mixed_symbols_replaced: mixed_symbols_replaced.add(mixed_symbol) replace_mixed.append(Assignment(mixed_symbol, u + sign * v)) else: mixed_symbol = u + sign * v return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2) param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args] result = expr.func(*param_list, evaluate=False) if param_list else expr return result
def subexpression_substitution_in_existing_subexpressions(ac): """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outer_ctr, s in enumerate(ac.subexpressions): new_rhs = s.rhs for inner_ctr in range(outer_ctr): sub_expr = ac.subexpressions[inner_ctr] new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs) result.append(Assignment(s.lhs, new_rhs)) return ac.copy(ac.main_assignments, result)
def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, streaming_pattern='pull', timestep=Timestep.BOTH): method = collision_rule.method accessor = get_accessor(streaming_pattern, timestep) loads = [ Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil)) ] stores = [ Assignment(a, b) for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols) ] result = collision_rule.copy() result.subexpressions = loads + result.subexpressions result.main_assignments += stores for direction, boundary in boundary_spec.items(): cond = boundary_conditional(boundary, direction, streaming_pattern, timestep, method, field) result.main_assignments.append(cond) if 'split_groups' in result.simplification_hints: substitutions = { b: a for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols) } new_split_groups = [] for split_group in result.simplification_hints['split_groups']: new_split_groups.append( [fast_subs(e, substitutions) for e in split_group]) result.simplification_hints['split_groups'] = new_split_groups return result
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): elements = [BoundaryOffsetInfo(stencil)] dir_symbol = TypedSymbol("dir", np.int64) elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) config = CreateKernelConfig(index_fields=[index_field], target=target, **kernel_creation_args) return create_kernel(elements, config=config)
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target='cpu', openmp=True): elements = [BoundaryOffsetInfo(stencil)] index_arr_dtype = index_field.dtype.numpy_dtype dir_symbol = TypedSymbol("dir", index_arr_dtype.fields['dir'][0]) elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) return create_indexed_kernel(elements, [index_field], target=target, cpu_openmp=openmp)
def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: """Adds a subexpression to current collection. Args: rhs: right hand side of new subexpression lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. topological_sort: sort the subexpressions topologically after insertion, to make sure that definition of a symbol comes before its usage. If False, subexpression is appended. Returns: left hand side symbol (which could have been generated) """ if lhs is None: lhs = next(self.subexpression_symbol_generator) eq = Assignment(lhs, rhs) self.subexpressions.append(eq) if topological_sort: self.topological_sort(sort_subexpressions=True, sort_main_assignments=False) return lhs
def apply_sympy_optimisations(assignments): """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation) and applies the default sympy optimisations. See sympy.codegen.rewriting """ # Evaluates all constant terms evaluate_constant_terms = ReplaceOptim( lambda e: hasattr(e, 'is_constant') and e.is_constant and not e. is_integer, lambda p: p.evalf(17)) sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) assignments = [ Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) if hasattr( a, 'lhs') else a for a in assignments ] assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] for a in chain.from_iterable(assignments_nodes): a.optimize(sympy_optimisations) return assignments
def set_main_assignments_from_dict(self, main_assignments_dict): self.main_assignments = [ Assignment(k, v) for k, v in main_assignments_dict.items() ]
def set_sub_expressions_from_dict(self, sub_expressions_dict): self.subexpressions = [ Assignment(k, v) for k, v in sub_expressions_dict.items() ]
def assignment_adder(lhs, rhs): assignments.append(Assignment(lhs, rhs))