def pick_off_constants(expr): """ :return: a tuple ``(constant, non_constant)`` that contains separates out nodes constant multipliers from any other nodes in *expr* """ if isinstance(expr, pp.Product): constants = [] non_constants = [] for child in expr.children: if isinstance(child, pp.Product): sub_const, sub_expr = pick_off_constants(child) constants.append(sub_const) non_constants.append(sub_expr) elif pp.is_constant(child) or isinstance(child, p.Parameter): constants.append(child) else: non_constants.append(child) return (pp.flattened_product(constants), pp.flattened_product(non_constants)) else: return 1, expr
def find_substitution(expr): if isinstance(expr, Subscript): v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: return expr if v != var_name: return expr index_key = extract_index_key(expr) cf_index, unif_result = find_unifiable_cf_index(index_key) unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) _, my_common_factors = common_factors[cf_index] if my_common_factors is not None: return flattened_product( [unif_subst_map(cf) for cf in my_common_factors] + [expr]) else: return expr
def get_temporary_decl(self, codegen_state, sched_index, temp_var, decl_info): from loopy.target.c import POD # uses the correct complex type temp_var_decl = POD(self, decl_info.dtype, decl_info.name) shape = decl_info.shape if temp_var.scope == temp_var_scope.PRIVATE: # FIXME: This is a pretty coarse way of deciding what # private temporaries get duplicated. Refine? (See also # above in expr to code mapper) _, lsize = codegen_state.kernel.get_grid_size_upper_bounds_as_exprs( ) shape = lsize + shape if shape: from cgen import ArrayOf ecm = self.get_expression_to_code_mapper(codegen_state) temp_var_decl = ArrayOf( temp_var_decl, ecm(p.flattened_product(shape), prec=PREC_NONE, type_context="i")) return temp_var_decl
def find_inner_deriv_and_coeff(expr): if is_derivative_binding(expr): return 1, expr elif isinstance(expr, pp.Product): factors = get_flat_factors(expr) derivatives = [] nonderivatives = [] for f in factors: if is_derivative_binding(f): derivatives.append(f) else: nonderivatives.append(f) if len(derivatives) > 1: raise ValueError("multiplied second derivatives in '%s'" % expr) if not derivatives: # We'll only get called if there *is* a second derivative. # That we can't find it by picking apart the top-level # product is bad news. raise ValueError("second derivative inside nonlinearity " "in '%s'" % expr) derivative, = derivatives return pp.flattened_product(nonderivatives), derivative else: raise ValueError("unexpected node type '%s' inside " "second derivative in '%s'" % (type(expr).__name__, expr))
def map_product(self, expr, derivatives): from grudge.symbolic.tools import is_scalar from pytools import partition scalars, nonscalars = partition(is_scalar, expr.children) if len(nonscalars) != 1: return DerivativeJoiner()(expr) else: from pymbolic import flattened_product factor = flattened_product(scalars) nonscalar, = nonscalars sub_derivatives = {} nonscalar = self.rec(nonscalar, sub_derivatives) def do_map(expr): if is_scalar(expr): return expr else: return self.rec(expr, derivatives) for operator, operands in sub_derivatives.items(): for operand in operands: derivatives.setdefault(operator, []).append(factor * operand) return factor * nonscalar
def map_product(self, expr): if len(expr.children) == 0: return expr from pymbolic.primitives import flattened_product, Product first = expr.children[0] if isinstance(first, op.Operator): prod = flattened_product(expr.children[1:]) if isinstance(prod, Product) and len(prod.children) > 1: from warnings import warn warn("Binding '%s' to more than one " "operand in a product is ambiguous - " "use the parenthesized form instead." % first) return sym.OperatorBinding(first, self.rec(prod)) else: return self.rec(first) * self.rec( flattened_product(expr.children[1:]))
def map_product(self, expr, type_context): def base_impl(expr, type_context): return super(ExpressionToCExpressionMapper, self).map_product(expr, type_context) # I've added 'type_context == "i"' because of the following # idiotic corner case: Code generation for subscripts comes # through here, and it may involve variables that we know # nothing about (offsets and such). If we fall into the allow_complex # branch, we'll try to do type inference on these variables, # and stuff breaks. This band-aid works around that. -AK if not self.allow_complex or type_context == "i": return base_impl(expr, type_context) tgt_dtype = self.infer_type(expr) is_complex = tgt_dtype.is_complex() if not is_complex: return base_impl(expr, type_context) else: tgt_name = self.complex_type_name(tgt_dtype) reals = [] complexes = [] for child in expr.children: if self.infer_type(child).is_complex(): complexes.append(child) else: reals.append(child) real_prd = p.flattened_product( [self.rec(r, type_context) for r in reals]) c_applied = [ self.rec(c, type_context, tgt_dtype) for c in complexes ] def binary_tree_mul(start, end): if start + 1 == end: return c_applied[start] mid = (start + end) // 2 lsum = binary_tree_mul(start, mid) rsum = binary_tree_mul(mid, end) return var("%s_mul" % tgt_name)(lsum, rsum) complex_prd = binary_tree_mul(0, len(complexes)) if real_prd: return var("%s_rmul" % tgt_name)(real_prd, complex_prd) else: return complex_prd
def get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info): temp_var_decl = POD(self, decl_info.dtype, decl_info.name) if temp_var.read_only: from cgen import Const temp_var_decl = Const(temp_var_decl) if decl_info.shape: from cgen import ArrayOf ecm = self.get_expression_to_code_mapper(codegen_state) temp_var_decl = ArrayOf(temp_var_decl, ecm(p.flattened_product(decl_info.shape), prec=PREC_NONE, type_context="i")) return temp_var_decl
def get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info): temp_var_decl = POD(self, decl_info.dtype, decl_info.name) if temp_var.read_only: from cgen import Const temp_var_decl = Const(temp_var_decl) if decl_info.shape: from cgen import ArrayOf ecm = self.get_expression_to_code_mapper(codegen_state) temp_var_decl = ArrayOf(temp_var_decl, ecm(p.flattened_product(decl_info.shape), prec=PREC_NONE, type_context="i")) if temp_var.alignment: from cgen import AlignedAttribute temp_var_decl = AlignedAttribute(temp_var.alignment, temp_var_decl) return temp_var_decl
def get_temporary_decl(self, codegen_state, sched_index, temp_var, decl_info): from loopy.target.c import POD # uses the correct complex type temp_var_decl = POD(self, decl_info.dtype, decl_info.name) shape = decl_info.shape if temp_var.address_space == AddressSpace.PRIVATE: # FIXME: This is a pretty coarse way of deciding what # private temporaries get duplicated. Refine? (See also # above in expr to code mapper) _, lsize = codegen_state.kernel.get_grid_size_upper_bounds_as_exprs() shape = lsize + shape if shape: from cgen import ArrayOf ecm = self.get_expression_to_code_mapper(codegen_state) temp_var_decl = ArrayOf( temp_var_decl, ecm(p.flattened_product(shape), prec=PREC_NONE, type_context="i")) return temp_var_decl
def map_product(self, expr, *args, **kwargs): from pymbolic.primitives import flattened_product return flattened_product(tuple( self.rec(child, *args, **kwargs) for child in expr.children))
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): # FIXME: Does not understand subst rules for now if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) if var_name in kernel.temporary_variables: var_descr = kernel.temporary_variables[var_name] elif var_name in kernel.arg_dict: var_descr = kernel.arg_dict[var_name] else: raise NameError("array '%s' was not found" % var_name) # {{{ check/normalize vary_by_axes if isinstance(vary_by_axes, str): vary_by_axes = vary_by_axes.split(",") from loopy.kernel.array import ArrayBase if isinstance(var_descr, ArrayBase): if var_descr.dim_names is not None: name_to_index = dict( (name, idx) for idx, name in enumerate(var_descr.dim_names)) else: name_to_index = {} def map_ax_name_to_index(ax): if isinstance(ax, str): try: return name_to_index[ax] except KeyError: raise LoopyError("axis name '%s' not understood " % ax) else: return ax vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] if ( vary_by_axes and (min(vary_by_axes) < 0 or max(vary_by_axes) > var_descr.num_user_axes())): raise LoopyError("vary_by_axes refers to out-of-bounds axis index") # }}} from pymbolic.mapper.substitutor import make_subst_func from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum, flattened_product, Subscript, Variable) from loopy.symbolic import (get_dependencies, SubstitutionMapper, UnidirectionalUnifier) # {{{ common factor key list maintenance # list of (index_key, common factors found) common_factors = [] def find_unifiable_cf_index(index_key): for i, (key, val) in enumerate(common_factors): unif = UnidirectionalUnifier( lhs_mapping_candidates=get_dependencies(key)) unif_result = unif(key, index_key) if unif_result: assert len(unif_result) == 1 return i, unif_result[0] return None, None def extract_index_key(access_expr): if isinstance(access_expr, Variable): return () elif isinstance(access_expr, Subscript): index = access_expr.index_tuple return tuple(index[ax] for ax in vary_by_axes) else: raise ValueError("unexpected type of access_expr") def is_assignee(insn): return any( lhs == var_name for lhs, sbscript in insn.assignees_and_indices()) def iterate_as(cls, expr): if isinstance(expr, cls): for ch in expr.children: yield ch else: yield expr # }}} # {{{ find common factors from loopy.kernel.data import Assignment for insn in kernel.instructions: if not is_assignee(insn): continue if not isinstance(insn, Assignment): raise LoopyError("'%s' modified by non-expression instruction" % var_name) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: # {{{ doesn't exist yet assert unif_result is None my_common_factors = None for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) if my_common_factors is None: my_common_factors = product_parts else: my_common_factors = my_common_factors & product_parts if my_common_factors is not None: common_factors.append((index_key, my_common_factors)) # }}} else: # {{{ match, filter existing common factors _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) my_common_factors = set( cf for cf in my_common_factors if unif_subst_map(cf) in product_parts) common_factors[cf_index] = (index_key, my_common_factors) # }}} # }}} # {{{ remove common factors new_insns = [] for insn in kernel.instructions: if not isinstance(insn, Assignment) or not is_assignee(insn): new_insns.append(insn) continue (_, index_key), = insn.assignees_and_indices() lhs = insn.assignee rhs = insn.expression if is_zero(rhs): new_insns.append(insn) continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: new_insns.append(insn) continue _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) mapped_my_common_factors = set( unif_subst_map(cf) for cf in my_common_factors) new_sum_terms = [] for term in iterate_as(Sum, rhs): if term == lhs: new_sum_terms.append(term) continue new_sum_terms.append( flattened_product([ part for part in iterate_as(Product, term) if part not in mapped_my_common_factors ])) new_insns.append( insn.copy(expression=flattened_sum(new_sum_terms))) # }}} # {{{ substitute common factors into usage sites def find_substitution(expr): if isinstance(expr, Subscript): v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: return expr if v != var_name: return expr index_key = extract_index_key(expr) cf_index, unif_result = find_unifiable_cf_index(index_key) unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) _, my_common_factors = common_factors[cf_index] if my_common_factors is not None: return flattened_product( [unif_subst_map(cf) for cf in my_common_factors] + [expr]) else: return expr insns = new_insns new_insns = [] subm = SubstitutionMapper(find_substitution) for insn in insns: if not isinstance(insn, Assignment) or is_assignee(insn): new_insns.append(insn) continue new_insns.append(insn.with_transformed_expressions(subm)) # }}} return kernel.copy(instructions=new_insns)
def map_product(self, expr): from pymbolic.primitives import flattened_product return flattened_product(self.rec(ch) for ch in expr.children)
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): assert isinstance(kernel, LoopKernel) # FIXME: Does not understand subst rules for now if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) if var_name in kernel.temporary_variables: var_descr = kernel.temporary_variables[var_name] elif var_name in kernel.arg_dict: var_descr = kernel.arg_dict[var_name] else: raise NameError("array '%s' was not found" % var_name) # {{{ check/normalize vary_by_axes if isinstance(vary_by_axes, str): vary_by_axes = vary_by_axes.split(",") from loopy.kernel.array import ArrayBase if isinstance(var_descr, ArrayBase): if var_descr.dim_names is not None: name_to_index = { name: idx for idx, name in enumerate(var_descr.dim_names) } else: name_to_index = {} def map_ax_name_to_index(ax): if isinstance(ax, str): try: return name_to_index[ax] except KeyError: raise LoopyError("axis name '%s' not understood " % ax) else: return ax vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] if (vary_by_axes and (min(vary_by_axes) < 0 or max(vary_by_axes) > var_descr.num_user_axes())): raise LoopyError("vary_by_axes refers to out-of-bounds axis index") # }}} from pymbolic.mapper.substitutor import make_subst_func from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum, flattened_product, Subscript, Variable) from loopy.symbolic import (get_dependencies, SubstitutionMapper, UnidirectionalUnifier) # {{{ common factor key list maintenance # list of (index_key, common factors found) common_factors = [] def find_unifiable_cf_index(index_key): for i, (key, _val) in enumerate(common_factors): unif = UnidirectionalUnifier( lhs_mapping_candidates=get_dependencies(key)) unif_result = unif(key, index_key) if unif_result: assert len(unif_result) == 1 return i, unif_result[0] return None, None def extract_index_key(access_expr): if isinstance(access_expr, Variable): return () elif isinstance(access_expr, Subscript): index = access_expr.index_tuple return tuple(index[ax] for ax in vary_by_axes) else: raise ValueError("unexpected type of access_expr") def is_assignee(insn): return var_name in insn.assignee_var_names() def iterate_as(cls, expr): if isinstance(expr, cls): yield from expr.children else: yield expr # }}} # {{{ find common factors from loopy.kernel.data import Assignment for insn in kernel.instructions: if not is_assignee(insn): continue if not isinstance(insn, Assignment): raise LoopyError("'%s' modified by non-single-assignment" % var_name) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: # {{{ doesn't exist yet assert unif_result is None my_common_factors = None for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) if my_common_factors is None: my_common_factors = product_parts else: my_common_factors = my_common_factors & product_parts if my_common_factors is not None: common_factors.append((index_key, my_common_factors)) # }}} else: # {{{ match, filter existing common factors _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) my_common_factors = { cf for cf in my_common_factors if unif_subst_map(cf) in product_parts } common_factors[cf_index] = (index_key, my_common_factors) # }}} # }}} common_factors = [(ik, cf) for ik, cf in common_factors if cf] if not common_factors: raise LoopyError("no common factors found") # {{{ remove common factors new_insns = [] for insn in kernel.instructions: if not isinstance(insn, Assignment) or not is_assignee(insn): new_insns.append(insn) continue index_key = extract_index_key(insn.assignee) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): new_insns.append(insn) continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: new_insns.append(insn) continue _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap)) mapped_my_common_factors = { unif_subst_map(cf) for cf in my_common_factors } new_sum_terms = [] for term in iterate_as(Sum, rhs): if term == lhs: new_sum_terms.append(term) continue new_sum_terms.append( flattened_product([ part for part in iterate_as(Product, term) if part not in mapped_my_common_factors ])) new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms))) # }}} # {{{ substitute common factors into usage sites def find_substitution(expr): if isinstance(expr, Subscript): v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: return expr if v != var_name: return expr index_key = extract_index_key(expr) cf_index, unif_result = find_unifiable_cf_index(index_key) unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap)) _, my_common_factors = common_factors[cf_index] if my_common_factors is not None: return flattened_product( [unif_subst_map(cf) for cf in my_common_factors] + [expr]) else: return expr insns = new_insns new_insns = [] subm = SubstitutionMapper(find_substitution) for insn in insns: if not isinstance(insn, Assignment) or is_assignee(insn): new_insns.append(insn) continue new_insns.append(insn.with_transformed_expressions(subm)) # }}} return kernel.copy(instructions=new_insns)