def map_reduction(expr, rec, nresults=1): if frozenset(expr.inames) != inames_set: return type(expr)( operation=expr.operation, inames=expr.inames, expr=rec(expr.expr), allow_simultaneous=expr.allow_simultaneous) if subst_rule_name is None: subst_rule_prefix = "red_%s_arg" % "_".join(inames) my_subst_rule_name = var_name_gen(subst_rule_prefix) else: my_subst_rule_name = subst_rule_name if my_subst_rule_name in substs: raise LoopyError("substitution rule '%s' already exists" % my_subst_rule_name) from loopy.kernel.data import SubstitutionRule substs[my_subst_rule_name] = SubstitutionRule( name=my_subst_rule_name, arguments=tuple(inames), expression=expr.expr) from pymbolic import var iname_vars = [var(iname) for iname in inames] return type(expr)( operation=expr.operation, inames=expr.inames, expr=var(my_subst_rule_name)(*iname_vars), allow_simultaneous=expr.allow_simultaneous)
def _get_new_substitutions_and_renames(self): """This makes a new dictionary of substitutions from the ones encountered in mapping all the encountered expressions. It tries hard to keep substitution names the same--i.e. if all derivative versions of a substitution rule ended up with the same mapped version, then this version should retain the name that the substitution rule had previously. Unfortunately, this can't be done in a single pass, and so the routine returns an additional dictionary *subst_renames* of renamings to be performed on the processed expressions. The returned substitutions already have the rename applied to them. :returns: (new_substitutions, subst_renames) """ from loopy.kernel.data import SubstitutionRule result = {} renames = {} used_names = set() for key, (name, args, body) in six.iteritems( self.subst_rule_registry): orig_names = self.subst_rule_old_names.get(key, []) # If no orig_names are found, then this particular # subst rule was never referenced, and so it's fine # to leave out. if not orig_names: continue new_name = min(orig_names) if new_name in used_names: new_name = self.make_unique_var_name(new_name) renames[name] = new_name used_names.add(new_name) result[new_name] = SubstitutionRule( name=new_name, arguments=args, expression=body) # {{{ perform renames on new substitutions subst_renamer = SubstitutionRuleRenamer(renames) renamed_result = {} for name, rule in six.iteritems(result): renamed_result[name] = rule.copy( expression=subst_renamer(rule.expression)) # }}} return renamed_result, renames
def _get_new_substitutions_and_renames(self): """This makes a new dictionary of substitutions from the ones encountered in mapping all the encountered expressions. It tries hard to keep substitution names the same--i.e. if all derivative versions of a substitution rule ended up with the same mapped version, then this version should retain the name that the substitution rule had previously. Unfortunately, this can't be done in a single pass, and so the routine returns an additional dictionary *subst_renames* of renamings to be performed on the processed expressions. The returned substitutions already have the rename applied to them. :returns: (new_substitutions, subst_renames) """ from loopy.kernel.data import SubstitutionRule orig_name_histogram = {} for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): if self.subst_rule_use_count.get(key, 0): orig_name_histogram[orig_name] = \ orig_name_histogram.get(orig_name, 0) + 1 result = {} renames = {} for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): args, body = key if self.subst_rule_use_count.get(key, 0): if orig_name_histogram[orig_name] == 1 and name != orig_name: renames[name] = orig_name name = orig_name result[name] = SubstitutionRule(name=name, arguments=args, expression=body) # {{{ perform renames on new substitutions subst_renamer = SubstitutionRuleRenamer(renames) renamed_result = {} for name, rule in six.iteritems(result): renamed_result[name] = rule.copy( expression=subst_renamer(rule.expression)) # }}} return renamed_result, renames
def extract_subst(kernel, subst_name, template, parameters=()): """ :arg subst_name: The name of the substitution rule to be created. :arg template: Unification template expression. :arg parameters: An iterable of parameters used in *template*, or a comma-separated string of the same. All targeted subexpressions must match ('unify with') *template* The template may contain '*' wildcards that will have to match exactly across all unifications. """ if isinstance(template, str): from pymbolic import parse template = parse(template) if isinstance(parameters, str): parameters = tuple(s.strip() for s in parameters.split(",")) var_name_gen = kernel.get_var_name_generator() # {{{ replace any wildcards in template with new variables def get_unique_var_name(): based_on = subst_name + "_wc" result = var_name_gen(based_on) return result from loopy.symbolic import WildcardToUniqueVariableMapper wc_map = WildcardToUniqueVariableMapper(get_unique_var_name) template = wc_map(template) # }}} # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) if urecs: if len(urecs) > 1: raise RuntimeError( "ambiguous unification of '%s' with template '%s'" % (expr, template)) urec, = urecs expr_descriptors.append( ExprDescriptor(insn=insn, expr=expr, unif_var_dict=dict( (lhs.name, rhs) for lhs, rhs in urec.equations))) else: mapper.fallback_mapper(expr) # can't nest, don't recurse from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper) dfmapper = CallbackMapper(gather_exprs, WalkMapper()) for insn in kernel.instructions: dfmapper(insn.assignees) dfmapper(insn.expression) for sr in six.itervalues(kernel.substitutions): dfmapper(sr.expression) # }}} if not expr_descriptors: raise RuntimeError("no expressions matching '%s'" % template) # {{{ substitute rule into instructions def replace_exprs(expr, mapper): found = False for exprd in expr_descriptors: if expr is exprd.expr: found = True break if not found: return mapper.fallback_mapper(expr) args = [exprd.unif_var_dict[arg_name] for arg_name in parameters] result = var(subst_name) if args: result = result(*args) return result # can't nest, don't recurse cbmapper = CallbackMapper(replace_exprs, IdentityMapper()) new_insns = [] for insn in kernel.instructions: new_insns.append(insn.with_transformed_expressions(cbmapper)) from loopy.kernel.data import SubstitutionRule new_substs = { subst_name: SubstitutionRule( name=subst_name, arguments=tuple(parameters), expression=template, ) } for subst in six.itervalues(kernel.substitutions): new_substs[subst.name] = subst.copy( expression=cbmapper(subst.expression)) # }}} return kernel.copy(instructions=new_insns, substitutions=new_substs)
def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, force_retain_argument=False): """Extract an assignment (to a temporary variable or an argument) as a :ref:`substitution-rule`. The temporary may be an array, in which case the array indices will become arguments to the substitution rule. :arg within: a stack match as understood by :func:`loopy.match.parse_stack_match`. :arg force_retain_argument: If True and if *lhs_name* is an argument, it is kept even if it is no longer referenced. This operation will change all usage sites of *lhs_name* matched by *within*. If there are further usage sites of *lhs_name*, then the original assignment to *lhs_name* as well as the temporary variable is left in place. """ if isinstance(extra_arguments, str): extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) # {{{ establish the relevant definition of lhs_name for each usage site dep_kernel = expand_subst(kernel) from loopy.kernel.creation import apply_single_writer_depencency_heuristic dep_kernel = apply_single_writer_depencency_heuristic(dep_kernel) id_to_insn = dep_kernel.id_to_insn def get_relevant_definition_insn_id(usage_insn_id): insn = id_to_insn[usage_insn_id] def_id = set() for dep_id in insn.depends_on: dep_insn = id_to_insn[dep_id] if lhs_name in dep_insn.write_dependency_names(): if lhs_name in dep_insn.read_dependency_names(): raise LoopyError( "instruction '%s' both reads *and* " "writes '%s'--cannot transcribe to substitution " "rule" % (dep_id, lhs_name)) def_id.add(dep_id) else: rec_result = get_relevant_definition_insn_id(dep_id) if rec_result is not None: def_id.add(rec_result) if len(def_id) > 1: raise LoopyError( "more than one write to '%s' found in " "depdendencies of '%s'--definition cannot be resolved " "(writer instructions ids: %s)" % (lhs_name, usage_insn_id, ", ".join(def_id))) if not def_id: return None else: def_id, = def_id return def_id usage_to_definition = {} for insn in dep_kernel.instructions: if lhs_name not in insn.read_dependency_names(): continue def_id = get_relevant_definition_insn_id(insn.id) if def_id is None: raise LoopyError("no write to '%s' found in dependency tree " "of '%s'--definition cannot be resolved" % (lhs_name, insn.id)) usage_to_definition[insn.id] = def_id definition_insn_ids = set() for insn in kernel.instructions: if lhs_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) # }}} if not definition_insn_ids: raise LoopyError("no assignments to variable '%s' found" % lhs_name) from loopy.match import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) tts = AssignmentToSubstChanger(rule_mapping_context, lhs_name, definition_insn_ids, usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) from loopy.kernel.data import SubstitutionRule # {{{ create new substitution rules new_substs = kernel.substitutions.copy() for def_id, subst_name in six.iteritems( tts.definition_insn_id_to_subst_name): def_insn = kernel.id_to_insn[def_id] from loopy.kernel.data import Assignment assert isinstance(def_insn, Assignment) from pymbolic.primitives import Variable, Subscript if isinstance(def_insn.assignee, Subscript): indices = def_insn.assignee.index_tuple elif isinstance(def_insn.assignee, Variable): indices = () else: raise LoopyError("Unrecognized LHS type: %s" % type(def_insn.assignee).__name__) arguments = [] for i in indices: if not isinstance(i, Variable): raise LoopyError("In defining instruction '%s': " "asignee index '%s' is not a plain variable. " "Perhaps use loopy.affine_map_inames() " "to perform substitution." % (def_id, i)) arguments.append(i.name) new_substs[subst_name] = SubstitutionRule( name=subst_name, arguments=tuple(arguments) + extra_arguments, expression=def_insn.expression) # }}} # {{{ delete temporary variable if possible # (copied below if modified) new_temp_vars = kernel.temporary_variables new_args = kernel.args if lhs_name in kernel.temporary_variables: if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the variable. new_temp_vars = new_temp_vars.copy() del new_temp_vars[lhs_name] if lhs_name in kernel.arg_dict and not force_retain_argument: if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the argument new_args = new_args[:] for i in range(len(new_args)): if new_args[i].name == lhs_name: del new_args[i] break # }}} import loopy as lp kernel = lp.remove_instructions( kernel, set(insn_id for insn_id, still_used in six.iteritems( tts.saw_unmatched_usage_sites) if not still_used)) return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, args=new_args, )
def extract_subst(kernel, subst_name, template, parameters=(), within=None): """ :arg subst_name: The name of the substitution rule to be created. :arg template: Unification template expression. :arg parameters: An iterable of parameters used in *template*, or a comma-separated string of the same. :arg within: An instance of :class:`loopy.match.MatchExpressionBase` or :class:`str` as understood by :func:`loopy.match.parse_match`. All targeted subexpressions must match ('unify with') *template* The template may contain '*' wildcards that will have to match exactly across all unifications. """ if isinstance(kernel, TranslationUnit): kernel_names = [ i for i, clbl in kernel.callables_table.items() if isinstance(clbl, CallableKernel) ] if len(kernel_names) != 1: raise LoopyError() return kernel.with_kernel( extract_subst(kernel[kernel_names[0]], subst_name, template, parameters)) if isinstance(template, str): from pymbolic import parse template = parse(template) if isinstance(parameters, str): parameters = tuple(s.strip() for s in parameters.split(",")) from loopy.match import parse_match within = parse_match(within) var_name_gen = kernel.get_var_name_generator() # {{{ replace any wildcards in template with new variables def get_unique_var_name(): based_on = subst_name + "_wc" result = var_name_gen(based_on) return result from loopy.symbolic import WildcardToUniqueVariableMapper wc_map = WildcardToUniqueVariableMapper(get_unique_var_name) template = wc_map(template) # }}} # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) if urecs: if len(urecs) > 1: raise RuntimeError( "ambiguous unification of '%s' with template '%s'" % (expr, template)) urec, = urecs expr_descriptors.append( ExprDescriptor(insn=insn, expr=expr, unif_var_dict={ lhs.name: rhs for lhs, rhs in urec.equations })) else: mapper.fallback_mapper(expr) # can't nest, don't recurse from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper) dfmapper = CallbackMapper(gather_exprs, WalkMapper()) from loopy.kernel.instruction import MultiAssignmentBase for insn in kernel.instructions: if isinstance(insn, MultiAssignmentBase) and within(kernel, insn): dfmapper(insn.assignees) dfmapper(insn.expression) for sr in kernel.substitutions.values(): dfmapper(sr.expression) # }}} if not expr_descriptors: raise RuntimeError("no expressions matching '%s'" % template) # {{{ substitute rule into instructions def replace_exprs(expr, mapper): found = False for exprd in expr_descriptors: if expr is exprd.expr: found = True break if not found: return mapper.fallback_mapper(expr) args = [exprd.unif_var_dict[arg_name] for arg_name in parameters] result = var(subst_name) if args: result = result(*args) return result # can't nest, don't recurse cbmapper = CallbackMapper(replace_exprs, IdentityMapper()) new_insns = [] def transform_assignee(expr): # Assignment LHS's cannot be subst rules. Treat them # specially. import pymbolic.primitives as prim if isinstance(expr, tuple): return tuple(transform_assignee(expr_i) for expr_i in expr) elif isinstance(expr, prim.Subscript): return type(expr)(expr.aggregate, cbmapper(expr.index)) elif isinstance(expr, prim.Variable): return expr else: raise ValueError("assignment LHS not understood") for insn in kernel.instructions: if within(kernel, insn): new_insns.append( insn.with_transformed_expressions( cbmapper, assignee_f=transform_assignee)) else: new_insns.append(insn) from loopy.kernel.data import SubstitutionRule new_substs = { subst_name: SubstitutionRule( name=subst_name, arguments=tuple(parameters), expression=template, ) } for subst in kernel.substitutions.values(): new_substs[subst.name] = subst.copy( expression=cbmapper(subst.expression)) # }}} return kernel.copy(instructions=new_insns, substitutions=new_substs)