def map_sum(self, expr): children = [self.rec(child) for child in expr.children] if all(child is orig for child, orig in zip(children, expr.children)): return expr from pymbolic.primitives import flattened_sum return flattened_sum(children)
def map_sum(self, expr, type_context): def base_impl(expr, type_context): return super(ExpressionToCExpressionMapper, self).map_sum(expr, type_context) # I've added 'type_context == "i"' because of the following # idiotic corner case: Code generation for subscripts comes # through here, and it may involve variables that we know # nothing about (offsets and such). If we fall into the allow_complex # branch, we'll try to do type inference on these variables, # and stuff breaks. This band-aid works around that. -AK if not self.allow_complex or type_context == "i": return base_impl(expr, type_context) tgt_dtype = self.infer_type(expr) is_complex = tgt_dtype.is_complex() if not is_complex: return base_impl(expr, type_context) else: tgt_name = self.complex_type_name(tgt_dtype) reals = [] complexes = [] for child in expr.children: if self.infer_type(child).is_complex(): complexes.append(child) else: reals.append(child) real_sum = p.flattened_sum( [self.rec(r, type_context) for r in reals]) c_applied = [ self.rec(c, type_context, tgt_dtype) for c in complexes ] def binary_tree_add(start, end): if start + 1 == end: return c_applied[start] mid = (start + end) // 2 lsum = binary_tree_add(start, mid) rsum = binary_tree_add(mid, end) return var("%s_add" % tgt_name)(lsum, rsum) complex_sum = binary_tree_add(0, len(c_applied)) if real_sum: return var("%s_radd" % tgt_name)(real_sum, complex_sum) else: return complex_sum
def rebuild_optemplate(self): def generate_summands(): for i in self.interiors: if self.quadrature_tag is None: yield FluxOperator(i.flux_expr, self.is_lift)(i.field_expr) else: yield QuadratureFluxOperator(i.flux_expr, self.quadrature_tag)(i.field_expr) for b in self.boundaries: if self.quadrature_tag is None: yield BoundaryFluxOperator(b.flux_expr, b.bpair.tag, self.is_lift)(b.bpair) else: yield QuadratureBoundaryFluxOperator(b.flux_expr, self.quadrature_tag, b.bpair.tag)(b.bpair) from pymbolic.primitives import flattened_sum return flattened_sum(generate_summands())
def map_operator_binding(self, expr): from hedge.optemplate import \ FluxOperatorBase, \ BoundaryPair, OperatorBinding, \ FluxExchangeOperator if isinstance(expr, OperatorBinding): if isinstance(expr.op, FluxOperatorBase): if isinstance(expr.field, BoundaryPair): # we're only worried about internal fluxes return IdentityMapper.map_operator_binding(self, expr) # by now we've narrowed it down to a bound interior flux def func_on_scalar_or_vector(func, arg_fields): # No CSE necessary here--the compiler CSE's these # automatically. from hedge.tools import is_obj_array, make_obj_array if is_obj_array(arg_fields): # arg_fields (as an object array) isn't hashable # --make it so by turning it into a tuple arg_fields = tuple(arg_fields) return make_obj_array([ func(i, arg_fields) for i in range(len(arg_fields)) ]) else: return func(0, (arg_fields, )) from hedge.mesh import TAG_RANK_BOUNDARY def exchange_and_cse(rank): return func_on_scalar_or_vector( lambda i, args: FluxExchangeOperator(i, rank, args), expr.field) from pymbolic.primitives import flattened_sum return flattened_sum([expr] + [ OperatorBinding( expr.op, BoundaryPair(expr.field, exchange_and_cse(rank), TAG_RANK_BOUNDARY(rank))) for rank in self.interacting_ranks ]) else: return IdentityMapper.map_operator_binding(self, expr)
def map_operator_binding(self, expr): from hedge.optemplate import \ FluxOperatorBase, \ BoundaryPair, OperatorBinding, \ FluxExchangeOperator if isinstance(expr, OperatorBinding): if isinstance(expr.op, FluxOperatorBase): if isinstance(expr.field, BoundaryPair): # we're only worried about internal fluxes return IdentityMapper.map_operator_binding(self, expr) # by now we've narrowed it down to a bound interior flux def func_on_scalar_or_vector(func, arg_fields): # No CSE necessary here--the compiler CSE's these # automatically. from hedge.tools import is_obj_array, make_obj_array if is_obj_array(arg_fields): # arg_fields (as an object array) isn't hashable # --make it so by turning it into a tuple arg_fields = tuple(arg_fields) return make_obj_array([ func(i, arg_fields) for i in range(len(arg_fields))]) else: return func(0, (arg_fields,)) from hedge.mesh import TAG_RANK_BOUNDARY def exchange_and_cse(rank): return func_on_scalar_or_vector( lambda i, args: FluxExchangeOperator(i, rank, args), expr.field) from pymbolic.primitives import flattened_sum return flattened_sum([expr] + [OperatorBinding(expr.op, BoundaryPair( expr.field, exchange_and_cse(rank), TAG_RANK_BOUNDARY(rank))) for rank in self.interacting_ranks]) else: return IdentityMapper.map_operator_binding(self, expr)
def rebuild_optemplate(self): def generate_summands(): for i in self.interiors: if self.quadrature_tag is None: yield FluxOperator(i.flux_expr, self.is_lift)(i.field_expr) else: yield QuadratureFluxOperator( i.flux_expr, self.quadrature_tag)(i.field_expr) for b in self.boundaries: if self.quadrature_tag is None: yield BoundaryFluxOperator(b.flux_expr, b.bpair.tag, self.is_lift)(b.bpair) else: yield QuadratureBoundaryFluxOperator( b.flux_expr, self.quadrature_tag, b.bpair.tag)(b.bpair) from pymbolic.primitives import flattened_sum return flattened_sum(generate_summands())
def map_derivative_source(self, expr): rec_operand = self.rec(expr.operand) nablas = [] for d_or_n in self.derivative_collector(rec_operand): if isinstance(d_or_n, prim.NablaComponent): nablas.append(d_or_n) elif isinstance(d_or_n, prim.DerivativeSource): pass else: raise RuntimeError("unexpected result from " "DerivativeSourceAndNablaComponentCollector") n_axes = max(n.ambient_axis for n in nablas) + 1 assert n_axes from pymbolic.primitives import flattened_sum return flattened_sum( self.take_derivative( axis, self.nabla_component_to_unit_vector(expr.nabla_id, axis) (rec_operand)) for axis in range(n_axes))
def map_sum(self, expr): idj = _InnerDerivativeJoiner() def invoke_idj(expr): sub_derivatives = {} result = idj(expr, sub_derivatives) if not sub_derivatives: return expr else: for operator, operands in sub_derivatives.items(): derivatives.setdefault(operator, []).extend(operands) return result derivatives = {} new_children = [invoke_idj(child) for child in expr.children] for operator, operands in derivatives.items(): new_children.insert( 0, operator(sum(self.rec(operand) for operand in operands))) from pymbolic.primitives import flattened_sum return flattened_sum(new_children)
def map_derivative_source(self, expr): rec_operand = self.rec(expr.operand) nablas = [] for d_or_n in self.derivative_collector(rec_operand): if isinstance(d_or_n, prim.NablaComponent): nablas.append(d_or_n) elif isinstance(d_or_n, prim.DerivativeSource): pass else: raise RuntimeError( "unexpected result from " "DerivativeSourceAndNablaComponentCollector") n_axes = max(n.ambient_axis for n in nablas) + 1 assert n_axes from pymbolic.primitives import flattened_sum return flattened_sum( self.take_derivative( axis, self.nabla_component_to_unit_vector(expr.nabla_id, axis)( rec_operand)) for axis in range(n_axes))
def map_sum(self, expr, *args, **kwargs): from pymbolic.primitives import flattened_sum return flattened_sum(tuple( self.rec(child, *args, **kwargs) for child in expr.children))
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): # FIXME: Does not understand subst rules for now if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) if var_name in kernel.temporary_variables: var_descr = kernel.temporary_variables[var_name] elif var_name in kernel.arg_dict: var_descr = kernel.arg_dict[var_name] else: raise NameError("array '%s' was not found" % var_name) # {{{ check/normalize vary_by_axes if isinstance(vary_by_axes, str): vary_by_axes = vary_by_axes.split(",") from loopy.kernel.array import ArrayBase if isinstance(var_descr, ArrayBase): if var_descr.dim_names is not None: name_to_index = dict( (name, idx) for idx, name in enumerate(var_descr.dim_names)) else: name_to_index = {} def map_ax_name_to_index(ax): if isinstance(ax, str): try: return name_to_index[ax] except KeyError: raise LoopyError("axis name '%s' not understood " % ax) else: return ax vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] if ( vary_by_axes and (min(vary_by_axes) < 0 or max(vary_by_axes) > var_descr.num_user_axes())): raise LoopyError("vary_by_axes refers to out-of-bounds axis index") # }}} from pymbolic.mapper.substitutor import make_subst_func from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum, flattened_product, Subscript, Variable) from loopy.symbolic import (get_dependencies, SubstitutionMapper, UnidirectionalUnifier) # {{{ common factor key list maintenance # list of (index_key, common factors found) common_factors = [] def find_unifiable_cf_index(index_key): for i, (key, val) in enumerate(common_factors): unif = UnidirectionalUnifier( lhs_mapping_candidates=get_dependencies(key)) unif_result = unif(key, index_key) if unif_result: assert len(unif_result) == 1 return i, unif_result[0] return None, None def extract_index_key(access_expr): if isinstance(access_expr, Variable): return () elif isinstance(access_expr, Subscript): index = access_expr.index_tuple return tuple(index[ax] for ax in vary_by_axes) else: raise ValueError("unexpected type of access_expr") def is_assignee(insn): return any( lhs == var_name for lhs, sbscript in insn.assignees_and_indices()) def iterate_as(cls, expr): if isinstance(expr, cls): for ch in expr.children: yield ch else: yield expr # }}} # {{{ find common factors from loopy.kernel.data import Assignment for insn in kernel.instructions: if not is_assignee(insn): continue if not isinstance(insn, Assignment): raise LoopyError("'%s' modified by non-expression instruction" % var_name) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: # {{{ doesn't exist yet assert unif_result is None my_common_factors = None for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) if my_common_factors is None: my_common_factors = product_parts else: my_common_factors = my_common_factors & product_parts if my_common_factors is not None: common_factors.append((index_key, my_common_factors)) # }}} else: # {{{ match, filter existing common factors _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) my_common_factors = set( cf for cf in my_common_factors if unif_subst_map(cf) in product_parts) common_factors[cf_index] = (index_key, my_common_factors) # }}} # }}} # {{{ remove common factors new_insns = [] for insn in kernel.instructions: if not isinstance(insn, Assignment) or not is_assignee(insn): new_insns.append(insn) continue (_, index_key), = insn.assignees_and_indices() lhs = insn.assignee rhs = insn.expression if is_zero(rhs): new_insns.append(insn) continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: new_insns.append(insn) continue _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) mapped_my_common_factors = set( unif_subst_map(cf) for cf in my_common_factors) new_sum_terms = [] for term in iterate_as(Sum, rhs): if term == lhs: new_sum_terms.append(term) continue new_sum_terms.append( flattened_product([ part for part in iterate_as(Product, term) if part not in mapped_my_common_factors ])) new_insns.append( insn.copy(expression=flattened_sum(new_sum_terms))) # }}} # {{{ substitute common factors into usage sites def find_substitution(expr): if isinstance(expr, Subscript): v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: return expr if v != var_name: return expr index_key = extract_index_key(expr) cf_index, unif_result = find_unifiable_cf_index(index_key) unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) _, my_common_factors = common_factors[cf_index] if my_common_factors is not None: return flattened_product( [unif_subst_map(cf) for cf in my_common_factors] + [expr]) else: return expr insns = new_insns new_insns = [] subm = SubstitutionMapper(find_substitution) for insn in insns: if not isinstance(insn, Assignment) or is_assignee(insn): new_insns.append(insn) continue new_insns.append(insn.with_transformed_expressions(subm)) # }}} return kernel.copy(instructions=new_insns)
def map_product(self, expr): # {{{ gather NablaComponents and DerivativeSources d_source_nabla_ids_per_child = [] # id to set((child index, axis), ...) nabla_finder = {} for child_idx, rec_child in enumerate(expr.children): nabla_component_ids = set() derivative_source_ids = set() nablas = [] for d_or_n in self.derivative_collector(rec_child): if isinstance(d_or_n, prim.NablaComponent): nabla_component_ids.add(d_or_n.nabla_id) nablas.append(d_or_n) elif isinstance(d_or_n, prim.DerivativeSource): derivative_source_ids.add(d_or_n.nabla_id) else: raise RuntimeError( "unexpected result from " "DerivativeSourceAndNablaComponentCollector") d_source_nabla_ids_per_child.append(derivative_source_ids) for ncomp in nablas: nabla_finder.setdefault(ncomp.nabla_id, set()).add( (child_idx, ncomp.ambient_axis)) # }}} if nabla_finder and not any(d_source_nabla_ids_per_child): raise ValueError( "no derivative source found to resolve in '%s'" "--did you forget to wrap the term that should have its " "derivative taken in 'Derivative()(term)'?" % str(expr)) # a list of lists, the outer level presenting a sum, the inner a product result = [list(expr.children)] for child_idx, (d_source_nabla_ids, child) in enumerate( zip(d_source_nabla_ids_per_child, expr.children)): if not d_source_nabla_ids: continue if len(d_source_nabla_ids) > 1: raise NotImplementedError("more than one DerivativeSource per " "child in a product") nabla_id, = d_source_nabla_ids try: nablas = nabla_finder[nabla_id] except KeyError: continue if self.restrict_to_id is not None and nabla_id != self.restrict_to_id: continue n_axes = max(axis for _, axis in nablas) + 1 new_result = [] for prod_term_list in result: for axis in range(n_axes): new_ptl = prod_term_list[:] dsfinder = self.derivative_source_finder( nabla_id, self, axis) new_ptl[child_idx] = dsfinder(new_ptl[child_idx]) for nabla_child_index, _ in nablas: new_ptl[nabla_child_index] = \ self.nabla_component_to_unit_vector(nabla_id, axis)( new_ptl[nabla_child_index]) new_result.append(new_ptl) result = new_result from pymbolic.primitives import flattened_sum return flattened_sum( type(expr)(tuple( self.rec(prod_term) for prod_term in prod_term_list)) for prod_term_list in result)
def emit_assignment(self, codegen_state, insn): kernel = codegen_state.kernel ecm = codegen_state.expression_to_code_mapper assignee_var_name, = insn.assignee_var_names() lhs_var = codegen_state.kernel.get_var_descriptor(assignee_var_name) lhs_dtype = lhs_var.dtype if insn.atomicity: raise NotImplementedError("atomic ops in ISPC") from loopy.expression import dtype_to_type_context from pymbolic.mapper.stringifier import PREC_NONE rhs_type_context = dtype_to_type_context(kernel.target, lhs_dtype) rhs_code = ecm(insn.expression, prec=PREC_NONE, type_context=rhs_type_context, needed_dtype=lhs_dtype) lhs = insn.assignee # {{{ handle streaming stores if "!streaming_store" in insn.tags: ary = ecm.find_array(lhs) from loopy.kernel.array import get_access_info from pymbolic import evaluate from loopy.symbolic import simplify_using_aff index_tuple = tuple( simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) access_info = get_access_info( kernel.target, ary, index_tuple, lambda expr: evaluate(expr, self.codegen_state.var_subst_map), codegen_state.vectorization_info) from loopy.kernel.data import GlobalArg, TemporaryVariable if not isinstance(ary, (GlobalArg, TemporaryVariable)): raise LoopyError("array type not supported in ISPC: %s" % type(ary).__name) if len(access_info.subscripts) != 1: raise LoopyError("streaming stores must have a subscript") subscript, = access_info.subscripts from pymbolic.primitives import Sum, flattened_sum, Variable if isinstance(subscript, Sum): terms = subscript.children else: terms = (subscript.children, ) new_terms = [] from loopy.kernel.data import LocalIndexTag from loopy.symbolic import get_dependencies saw_l0 = False for term in terms: if (isinstance(term, Variable) and isinstance( kernel.iname_to_tag.get(term.name), LocalIndexTag) and kernel.iname_to_tag.get(term.name).axis == 0): if saw_l0: raise LoopyError("streaming store must have stride 1 " "in local index, got: %s" % subscript) saw_l0 = True continue else: for dep in get_dependencies(term): if (isinstance(kernel.iname_to_tag.get(dep), LocalIndexTag) and kernel.iname_to_tag.get(dep).axis == 0): raise LoopyError( "streaming store must have stride 1 " "in local index, got: %s" % subscript) new_terms.append(term) if not saw_l0: raise LoopyError("streaming store must have stride 1 in " "local index, got: %s" % subscript) if access_info.vector_index is not None: raise LoopyError("streaming store may not use a short-vector " "data type") rhs_has_programindex = any( isinstance(kernel.iname_to_tag.get(dep), LocalIndexTag) and kernel.iname_to_tag.get(dep).axis == 0 for dep in get_dependencies(insn.expression)) if not rhs_has_programindex: rhs_code = "broadcast(%s, 0)" % rhs_code from cgen import Statement return Statement( "streaming_store(%s + %s, %s)" % (access_info.array_name, ecm(flattened_sum(new_terms), PREC_NONE, 'i'), rhs_code)) # }}} from cgen import Assign return Assign(ecm(lhs, prec=PREC_NONE, type_context=None), rhs_code)
def map_sum(self, expr): from pymbolic.primitives import flattened_sum return flattened_sum(self.rec(ch) for ch in expr.children)
def map_sum(self, expr): from pymbolic.primitives import flattened_sum return flattened_sum(tuple(self.rec(child) for child in expr.children))
def map_polynomial(self, expr, enclosing_prec, *args, **kwargs): from pymbolic.primitives import flattened_sum return self.rec(flattened_sum( [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]), enclosing_prec, *args, **kwargs)
def map_product(self, expr): # {{{ gather NablaComponents and DerivativeSources d_source_nabla_ids_per_child = [] # id to set((child index, axis), ...) nabla_finder = {} for child_idx, rec_child in enumerate(expr.children): nabla_component_ids = set() derivative_source_ids = set() nablas = [] for d_or_n in self.derivative_collector(rec_child): if isinstance(d_or_n, prim.NablaComponent): nabla_component_ids.add(d_or_n.nabla_id) nablas.append(d_or_n) elif isinstance(d_or_n, prim.DerivativeSource): derivative_source_ids.add(d_or_n.nabla_id) else: raise RuntimeError("unexpected result from " "DerivativeSourceAndNablaComponentCollector") d_source_nabla_ids_per_child.append(derivative_source_ids) for ncomp in nablas: nabla_finder.setdefault( ncomp.nabla_id, set()).add((child_idx, ncomp.ambient_axis)) # }}} if nabla_finder and not any(d_source_nabla_ids_per_child): raise ValueError("no derivative source found to resolve in '%s'" "--did you forget to wrap the term that should have its " "derivative taken in 'Derivative()(term)'?" % str(expr)) # a list of lists, the outer level presenting a sum, the inner a product result = [list(expr.children)] for child_idx, (d_source_nabla_ids, child) in enumerate( zip(d_source_nabla_ids_per_child, expr.children)): if not d_source_nabla_ids: continue if len(d_source_nabla_ids) > 1: raise NotImplementedError("more than one DerivativeSource per " "child in a product") nabla_id, = d_source_nabla_ids try: nablas = nabla_finder[nabla_id] except KeyError: continue if self.restrict_to_id is not None and nabla_id != self.restrict_to_id: continue n_axes = max(axis for _, axis in nablas) + 1 new_result = [] for prod_term_list in result: for axis in range(n_axes): new_ptl = prod_term_list[:] dsfinder = self.derivative_source_finder(nabla_id, self, axis) new_ptl[child_idx] = dsfinder(new_ptl[child_idx]) for nabla_child_index, _ in nablas: new_ptl[nabla_child_index] = \ self.nabla_component_to_unit_vector(nabla_id, axis)( new_ptl[nabla_child_index]) new_result.append(new_ptl) result = new_result from pymbolic.primitives import flattened_sum return flattened_sum( type(expr)(tuple( self.rec(prod_term) for prod_term in prod_term_list)) for prod_term_list in result)
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): assert isinstance(kernel, LoopKernel) # FIXME: Does not understand subst rules for now if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) if var_name in kernel.temporary_variables: var_descr = kernel.temporary_variables[var_name] elif var_name in kernel.arg_dict: var_descr = kernel.arg_dict[var_name] else: raise NameError("array '%s' was not found" % var_name) # {{{ check/normalize vary_by_axes if isinstance(vary_by_axes, str): vary_by_axes = vary_by_axes.split(",") from loopy.kernel.array import ArrayBase if isinstance(var_descr, ArrayBase): if var_descr.dim_names is not None: name_to_index = { name: idx for idx, name in enumerate(var_descr.dim_names) } else: name_to_index = {} def map_ax_name_to_index(ax): if isinstance(ax, str): try: return name_to_index[ax] except KeyError: raise LoopyError("axis name '%s' not understood " % ax) else: return ax vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] if (vary_by_axes and (min(vary_by_axes) < 0 or max(vary_by_axes) > var_descr.num_user_axes())): raise LoopyError("vary_by_axes refers to out-of-bounds axis index") # }}} from pymbolic.mapper.substitutor import make_subst_func from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum, flattened_product, Subscript, Variable) from loopy.symbolic import (get_dependencies, SubstitutionMapper, UnidirectionalUnifier) # {{{ common factor key list maintenance # list of (index_key, common factors found) common_factors = [] def find_unifiable_cf_index(index_key): for i, (key, _val) in enumerate(common_factors): unif = UnidirectionalUnifier( lhs_mapping_candidates=get_dependencies(key)) unif_result = unif(key, index_key) if unif_result: assert len(unif_result) == 1 return i, unif_result[0] return None, None def extract_index_key(access_expr): if isinstance(access_expr, Variable): return () elif isinstance(access_expr, Subscript): index = access_expr.index_tuple return tuple(index[ax] for ax in vary_by_axes) else: raise ValueError("unexpected type of access_expr") def is_assignee(insn): return var_name in insn.assignee_var_names() def iterate_as(cls, expr): if isinstance(expr, cls): yield from expr.children else: yield expr # }}} # {{{ find common factors from loopy.kernel.data import Assignment for insn in kernel.instructions: if not is_assignee(insn): continue if not isinstance(insn, Assignment): raise LoopyError("'%s' modified by non-single-assignment" % var_name) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: # {{{ doesn't exist yet assert unif_result is None my_common_factors = None for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) if my_common_factors is None: my_common_factors = product_parts else: my_common_factors = my_common_factors & product_parts if my_common_factors is not None: common_factors.append((index_key, my_common_factors)) # }}} else: # {{{ match, filter existing common factors _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper( make_subst_func(unif_result.lmap)) for term in iterate_as(Sum, rhs): if term == lhs: continue for part in iterate_as(Product, term): if var_name in get_dependencies(part): raise LoopyError("unexpected dependency on '%s' " "in RHS of instruction '%s'" % (var_name, insn.id)) product_parts = set(iterate_as(Product, term)) my_common_factors = { cf for cf in my_common_factors if unif_subst_map(cf) in product_parts } common_factors[cf_index] = (index_key, my_common_factors) # }}} # }}} common_factors = [(ik, cf) for ik, cf in common_factors if cf] if not common_factors: raise LoopyError("no common factors found") # {{{ remove common factors new_insns = [] for insn in kernel.instructions: if not isinstance(insn, Assignment) or not is_assignee(insn): new_insns.append(insn) continue index_key = extract_index_key(insn.assignee) lhs = insn.assignee rhs = insn.expression if is_zero(rhs): new_insns.append(insn) continue index_key = extract_index_key(lhs) cf_index, unif_result = find_unifiable_cf_index(index_key) if cf_index is None: new_insns.append(insn) continue _, my_common_factors = common_factors[cf_index] unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap)) mapped_my_common_factors = { unif_subst_map(cf) for cf in my_common_factors } new_sum_terms = [] for term in iterate_as(Sum, rhs): if term == lhs: new_sum_terms.append(term) continue new_sum_terms.append( flattened_product([ part for part in iterate_as(Product, term) if part not in mapped_my_common_factors ])) new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms))) # }}} # {{{ substitute common factors into usage sites def find_substitution(expr): if isinstance(expr, Subscript): v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: return expr if v != var_name: return expr index_key = extract_index_key(expr) cf_index, unif_result = find_unifiable_cf_index(index_key) unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap)) _, my_common_factors = common_factors[cf_index] if my_common_factors is not None: return flattened_product( [unif_subst_map(cf) for cf in my_common_factors] + [expr]) else: return expr insns = new_insns new_insns = [] subm = SubstitutionMapper(find_substitution) for insn in insns: if not isinstance(insn, Assignment) or is_assignee(insn): new_insns.append(insn) continue new_insns.append(insn.with_transformed_expressions(subm)) # }}} return kernel.copy(instructions=new_insns)
def map_polynomial(self, expr, other, *args, **kwargs): from pymbolic.primitives import flattened_sum return type(expr) == type(other) \ and self.rec(flattened_sum([coeff * expr.base**exp for exp, coeff in expr.data[::-1]]), flattened_sum([coeff * expr.base**exp for exp, coeff in other.data[::-1]]), *args, **kwargs)
def map_polynomial(self, expr, enclosing_prec, *args, **kwargs): from pymbolic.primitives import flattened_sum return self.rec( flattened_sum( [coeff * expr.base**exp for exp, coeff in expr.data[::-1]]), enclosing_prec, *args, **kwargs)
def emit_assignment(self, codegen_state, insn): kernel = codegen_state.kernel ecm = codegen_state.expression_to_code_mapper assignee_var_name, = insn.assignee_var_names() lhs_var = codegen_state.kernel.get_var_descriptor(assignee_var_name) lhs_dtype = lhs_var.dtype if insn.atomicity: raise NotImplementedError("atomic ops in ISPC") from loopy.expression import dtype_to_type_context from pymbolic.mapper.stringifier import PREC_NONE rhs_type_context = dtype_to_type_context(kernel.target, lhs_dtype) rhs_code = ecm(insn.expression, prec=PREC_NONE, type_context=rhs_type_context, needed_dtype=lhs_dtype) lhs = insn.assignee # {{{ handle streaming stores if "!streaming_store" in insn.tags: ary = ecm.find_array(lhs) from loopy.kernel.array import get_access_info from pymbolic import evaluate from loopy.symbolic import simplify_using_aff index_tuple = tuple( simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) access_info = get_access_info(kernel.target, ary, index_tuple, lambda expr: evaluate(expr, codegen_state.var_subst_map), codegen_state.vectorization_info) from loopy.kernel.data import ArrayArg, TemporaryVariable if not isinstance(ary, (ArrayArg, TemporaryVariable)): raise LoopyError("array type not supported in ISPC: %s" % type(ary).__name) if len(access_info.subscripts) != 1: raise LoopyError("streaming stores must have a subscript") subscript, = access_info.subscripts from pymbolic.primitives import Sum, flattened_sum, Variable if isinstance(subscript, Sum): terms = subscript.children else: terms = (subscript.children,) new_terms = [] from loopy.kernel.data import LocalIndexTag, filter_iname_tags_by_type from loopy.symbolic import get_dependencies saw_l0 = False for term in terms: if (isinstance(term, Variable) and kernel.iname_tags_of_type(term.name, LocalIndexTag)): tag, = kernel.iname_tags_of_type( term.name, LocalIndexTag, min_num=1, max_num=1) if tag.axis == 0: if saw_l0: raise LoopyError( "streaming store must have stride 1 in " "local index, got: %s" % subscript) saw_l0 = True continue else: for dep in get_dependencies(term): if filter_iname_tags_by_type( kernel.iname_to_tags.get(dep, []), LocalIndexTag): tag, = filter_iname_tags_by_type( kernel.iname_to_tags.get(dep, []), LocalIndexTag, 1) if tag.axis == 0: raise LoopyError( "streaming store must have stride 1 in " "local index, got: %s" % subscript) new_terms.append(term) if not saw_l0: raise LoopyError("streaming store must have stride 1 in " "local index, got: %s" % subscript) if access_info.vector_index is not None: raise LoopyError("streaming store may not use a short-vector " "data type") rhs_has_programindex = any( isinstance(tag, LocalIndexTag) and tag.axis == 0 for tag in kernel.iname_tags(dep) for dep in get_dependencies(insn.expression)) if not rhs_has_programindex: rhs_code = "broadcast(%s, 0)" % rhs_code from cgen import Statement return Statement( "streaming_store(%s + %s, %s)" % ( access_info.array_name, ecm(flattened_sum(new_terms), PREC_NONE, 'i'), rhs_code)) # }}} from cgen import Assign return Assign(ecm(lhs, prec=PREC_NONE, type_context=None), rhs_code)