def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i", batch_axes_in_by=frozenset(), copy_outputs=set()): """ :arg batch_axes_in_by: a :class:`set` of axis indices in the variable named *by* that are not part of the differentiation. :return: a string containing the name of a new variable holding the derivative of *var_name* by the desired *diff_context.by_name*, or *None* if no dependency exists. """ from loopy.preprocess import add_default_dependencies knl = add_default_dependencies(knl) if isinstance(diff_outputs, str): diff_outputs = [ dout.strip() for dout in diff_outputs.split(",") if dout.strip() ] by_arg = knl.arg_dict[by] additional_shape = by_arg.shape var_name_gen = knl.get_var_name_generator() # {{{ differentiate instructions diff_context = DifferentiationContext(knl, var_name_gen, by, diff_iname_prefix=diff_iname_prefix, additional_shape=additional_shape) result = {} for dout in diff_outputs: result = diff_context.get_diff_var(dout) for cout in copy_outputs: diff_context.import_output_var(cout) # }}} return diff_context.get_new_kernel(), result
def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i", batch_axes_in_by=frozenset(), copy_outputs=set()): """ :arg batch_axes_in_by: a :class:`set` of axis indices in the variable named *by* that are not part of the differentiation. :return: a string containing the name of a new variable holding the derivative of *var_name* by the desired *diff_context.by_name*, or *None* if no dependency exists. """ from loopy.preprocess import add_default_dependencies knl = add_default_dependencies(knl) if isinstance(diff_outputs, str): diff_outputs = [ dout.strip() for dout in diff_outputs.split(",") if dout.strip()] by_arg = knl.arg_dict[by] additional_shape = by_arg.shape var_name_gen = knl.get_var_name_generator() # {{{ differentiate instructions diff_context = DifferentiationContext( knl, var_name_gen, by, diff_iname_prefix=diff_iname_prefix, additional_shape=additional_shape) result = {} for dout in diff_outputs: result = diff_context.get_diff_var(dout) for cout in copy_outputs: diff_context.import_output_var(cout) # }}} return diff_context.get_new_kernel(), result
def stringify(self, what=None, with_dependencies=False): all_what = set([ "name", "arguments", "domains", "tags", "variables", "rules", "instructions", "Dependencies", "schedule", ]) first_letter_to_what = dict( (w[0], w) for w in all_what) assert len(first_letter_to_what) == len(all_what) if what is None: what = all_what.copy() if not with_dependencies: what.remove("Dependencies") if isinstance(what, str): if "," in what: what = what.split(",") what = set(s.strip() for s in what) else: what = set( first_letter_to_what[w] for w in what) if not (what <= all_what): raise LoopyError("invalid 'what' passed: %s" % ", ".join(what-all_what)) lines = [] from loopy.preprocess import add_default_dependencies kernel = add_default_dependencies(self) sep = 75*"-" if "name" in what: lines.append(sep) lines.append("KERNEL: " + kernel.name) if "arguments" in what: lines.append(sep) lines.append("ARGUMENTS:") for arg_name in sorted(kernel.arg_dict): lines.append(str(kernel.arg_dict[arg_name])) if "domains" in what: lines.append(sep) lines.append("DOMAINS:") for dom, parents in zip(kernel.domains, kernel.all_parents_per_domain()): lines.append(len(parents)*" " + str(dom)) if "tags" in what: lines.append(sep) lines.append("INAME IMPLEMENTATION TAGS:") for iname in sorted(kernel.all_inames()): line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname)) lines.append(line) if "variables" in what and kernel.temporary_variables: lines.append(sep) lines.append("TEMPORARIES:") for tv in sorted(six.itervalues(kernel.temporary_variables), key=lambda tv: tv.name): lines.append(str(tv)) if "rules" in what and kernel.substitutions: lines.append(sep) lines.append("SUBSTIUTION RULES:") for rule_name in sorted(six.iterkeys(kernel.substitutions)): lines.append(str(kernel.substitutions[rule_name])) if "instructions" in what: lines.append(sep) lines.append("INSTRUCTIONS:") loop_list_width = 35 printed_insn_ids = set() Fore = self.options._fore Style = self.options._style def print_insn(insn): if insn.id in printed_insn_ids: return printed_insn_ids.add(insn.id) for dep_id in sorted(insn.depends_on): print_insn(kernel.id_to_insn[dep_id]) if isinstance(insn, lp.MultiAssignmentBase): lhs = ", ".join(str(a) for a in insn.assignees) rhs = str(insn.expression) trailing = [] elif isinstance(insn, lp.CInstruction): lhs = ", ".join(str(a) for a in insn.assignees) rhs = "CODE(%s|%s)" % ( ", ".join(str(x) for x in insn.read_variables), ", ".join("%s=%s" % (name, expr) for name, expr in insn.iname_exprs)) trailing = [" "+l for l in insn.code.split("\n")] loop_list = ",".join(sorted(kernel.insn_inames(insn))) options = [Fore.GREEN+insn.id+Style.RESET_ALL] if insn.priority: options.append("priority=%d" % insn.priority) if insn.tags: options.append("tags=%s" % ":".join(insn.tags)) if isinstance(insn, lp.Assignment) and insn.atomicity: options.append("atomic=%s" % ":".join( str(a) for a in insn.atomicity)) if insn.groups: options.append("groups=%s" % ":".join(insn.groups)) if insn.conflicts_with_groups: options.append( "conflicts=%s" % ":".join(insn.conflicts_with_groups)) if insn.no_sync_with: options.append("no_sync_with=%s" % ":".join(insn.no_sync_with)) if len(loop_list) > loop_list_width: lines.append("[%s]" % loop_list) lines.append("%s%s <- %s # %s" % ( (loop_list_width+2)*" ", Fore.BLUE+lhs+Style.RESET_ALL, Fore.MAGENTA+rhs+Style.RESET_ALL, ", ".join(options))) else: lines.append("[%s]%s%s <- %s # %s" % ( loop_list, " "*(loop_list_width-len(loop_list)), Fore.BLUE + lhs + Style.RESET_ALL, Fore.MAGENTA+rhs+Style.RESET_ALL, ",".join(options))) lines.extend(trailing) if insn.predicates: lines.append(10*" " + "if (%s)" % " && ".join(insn.predicates)) import loopy as lp for insn in kernel.instructions: print_insn(insn) dep_lines = [] for insn in kernel.instructions: if insn.depends_on: dep_lines.append("%s : %s" % (insn.id, ",".join(insn.depends_on))) if "Dependencies" in what and dep_lines: lines.append(sep) lines.append("DEPENDENCIES: " "(use loopy.show_dependency_graph to visualize)") lines.extend(dep_lines) if "schedule" in what and kernel.schedule is not None: lines.append(sep) lines.append("SCHEDULE:") from loopy.schedule import dump_schedule lines.append(dump_schedule(kernel, kernel.schedule)) lines.append(sep) return "\n".join(lines)
def stringify(self, with_dependencies=False): lines = [] from loopy.preprocess import add_default_dependencies kernel = add_default_dependencies(self) sep = 75*"-" lines.append(sep) lines.append("KERNEL: " + kernel.name) lines.append(sep) lines.append("ARGUMENTS:") for arg_name in sorted(kernel.arg_dict): lines.append(str(kernel.arg_dict[arg_name])) lines.append(sep) lines.append("DOMAINS:") for dom, parents in zip(kernel.domains, kernel.all_parents_per_domain()): lines.append(len(parents)*" " + str(dom)) lines.append(sep) lines.append("INAME IMPLEMENTATION TAGS:") for iname in sorted(kernel.all_inames()): line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname)) lines.append(line) if kernel.temporary_variables: lines.append(sep) lines.append("TEMPORARIES:") for tv in sorted(six.itervalues(kernel.temporary_variables), key=lambda tv: tv.name): lines.append(str(tv)) if kernel.substitutions: lines.append(sep) lines.append("SUBSTIUTION RULES:") for rule_name in sorted(six.iterkeys(kernel.substitutions)): lines.append(str(kernel.substitutions[rule_name])) lines.append(sep) lines.append("INSTRUCTIONS:") loop_list_width = 35 printed_insn_ids = set() def print_insn(insn): if insn.id in printed_insn_ids: return printed_insn_ids.add(insn.id) for dep_id in sorted(insn.insn_deps): print_insn(kernel.id_to_insn[dep_id]) if isinstance(insn, lp.Assignment): lhs = str(insn.assignee) rhs = str(insn.expression) trailing = [] elif isinstance(insn, lp.CInstruction): lhs = ", ".join(str(a) for a in insn.assignees) rhs = "CODE(%s|%s)" % ( ", ".join(str(x) for x in insn.read_variables), ", ".join("%s=%s" % (name, expr) for name, expr in insn.iname_exprs)) trailing = [" "+l for l in insn.code.split("\n")] loop_list = ",".join(sorted(kernel.insn_inames(insn))) options = [insn.id] if insn.priority: options.append("priority=%d" % insn.priority) if insn.tags: options.append("tags=%s" % ":".join(insn.tags)) if insn.groups: options.append("groups=%s" % ":".join(insn.groups)) if insn.conflicts_with_groups: options.append("conflicts=%s" % ":".join(insn.conflicts_with_groups)) if len(loop_list) > loop_list_width: lines.append("[%s]" % loop_list) lines.append("%s%s <- %s # %s" % ( (loop_list_width+2)*" ", lhs, rhs, ", ".join(options))) else: lines.append("[%s]%s%s <- %s # %s" % ( loop_list, " "*(loop_list_width-len(loop_list)), lhs, rhs, ",".join(options))) lines.extend(trailing) if insn.predicates: lines.append(10*" " + "if (%s)" % " && ".join(insn.predicates)) import loopy as lp for insn in kernel.instructions: print_insn(insn) dep_lines = [] for insn in kernel.instructions: if insn.insn_deps: dep_lines.append("%s : %s" % (insn.id, ",".join(insn.insn_deps))) if dep_lines: lines.append(sep) lines.append("DEPENDENCIES: " "(use loopy.show_dependency_graph to visualize)") if with_dependencies: lines.extend(dep_lines) lines.append(sep) if kernel.schedule is not None: lines.append("SCHEDULE:") from loopy.schedule import dump_schedule lines.append(dump_schedule(kernel, kernel.schedule)) lines.append(sep) return "\n".join(lines)
def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): """Extract an assignment to a temporary variable as a :ref:`substituiton-rule`. The temporary may be an array, in which case the array indices will become arguments to the substitution rule. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. This operation will change all usage sites of *temp_name* matched by *within*. If there are further usage sites of *temp_name*, then the original assignment to *temp_name* as well as the temporary variable is left in place. """ if isinstance(extra_arguments, str): extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) # {{{ establish the relevant definition of temp_name for each usage site dep_kernel = expand_subst(kernel) from loopy.preprocess import add_default_dependencies dep_kernel = add_default_dependencies(dep_kernel) id_to_insn = dep_kernel.id_to_insn def get_relevant_definition_insn_id(usage_insn_id): insn = id_to_insn[usage_insn_id] def_id = set() for dep_id in insn.insn_deps: dep_insn = id_to_insn[dep_id] if temp_name in dep_insn.write_dependency_names(): if temp_name in dep_insn.read_dependency_names(): raise LoopyError("instruction '%s' both reads *and* " "writes '%s'--cannot transcribe to substitution " "rule" % (dep_id, temp_name)) def_id.add(dep_id) else: rec_result = get_relevant_definition_insn_id(dep_id) if rec_result is not None: def_id.add(rec_result) if len(def_id) > 1: raise LoopyError("more than one write to '%s' found in " "depdendencies of '%s'--definition cannot be resolved " "(writer instructions ids: %s)" % (temp_name, usage_insn_id, ", ".join(def_id))) if not def_id: return None else: def_id, = def_id return def_id usage_to_definition = {} for insn in kernel.instructions: if temp_name not in insn.read_dependency_names(): continue def_id = get_relevant_definition_insn_id(insn.id) if def_id is None: raise LoopyError("no write to '%s' found in dependency tree " "of '%s'--definition cannot be resolved" % (temp_name, insn.id)) usage_to_definition[insn.id] = def_id definition_insn_ids = set() for insn in kernel.instructions: if temp_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) # }}} from loopy.context_matching import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) tts = TemporaryToSubstChanger(rule_mapping_context, temp_name, definition_insn_ids, usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) from loopy.kernel.data import SubstitutionRule # {{{ create new substitution rules new_substs = kernel.substitutions.copy() for def_id, subst_name in six.iteritems(tts.definition_insn_id_to_subst_name): def_insn = kernel.id_to_insn[def_id] (_, indices), = def_insn.assignees_and_indices() arguments = [] from pymbolic.primitives import Variable for i in indices: if not isinstance(i, Variable): raise LoopyError("In defining instruction '%s': " "asignee index '%s' is not a plain variable. " "Perhaps use loopy.affine_map_inames() " "to perform substitution." % (def_id, i)) arguments.append(i.name) new_substs[subst_name] = SubstitutionRule( name=subst_name, arguments=tuple(arguments) + extra_arguments, expression=def_insn.expression) # }}} # {{{ delete temporary variable if possible new_temp_vars = kernel.temporary_variables if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the variable. new_temp_vars = new_temp_vars.copy() del new_temp_vars[temp_name] # }}} import loopy as lp kernel = lp.remove_instructions( kernel, set( insn_id for insn_id, still_used in six.iteritems( tts.saw_unmatched_usage_sites) if not still_used)) return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, )
def get_dot_dependency_graph(kernel, iname_cluster=True, use_insn_id=False): """Return a string in the `dot <http://graphviz.org/>`_ language depicting dependencies among kernel instructions. """ # make sure all automatically added stuff shows up from loopy.preprocess import add_default_dependencies kernel = add_default_dependencies(kernel) if iname_cluster and not kernel.schedule: try: from loopy.schedule import get_one_scheduled_kernel kernel = get_one_scheduled_kernel(kernel) except RuntimeError as e: iname_cluster = False from warnings import warn warn("error encountered during scheduling for dep graph -- " "cannot perform iname clustering: %s(%s)" % (type(e).__name__, e)) dep_graph = {} lines = [] from loopy.kernel.data import MultiAssignmentBase, CInstruction for insn in kernel.instructions: if isinstance(insn, MultiAssignmentBase): op = "%s <- %s" % (insn.assignees, insn.expression) if len(op) > 200: op = op[:200] + "..." elif isinstance(insn, CInstruction): op = "<C instruction %s>" % insn.id else: op = "<instruction %s>" % insn.id if use_insn_id: insn_label = insn.id tooltip = op else: insn_label = op tooltip = insn.id lines.append("\"%s\" [label=\"%s\",shape=\"box\",tooltip=\"%s\"];" % ( insn.id, repr(insn_label)[1:-1], repr(tooltip)[1:-1], )) for dep in insn.depends_on: dep_graph.setdefault(insn.id, set()).add(dep) # {{{ O(n^3) transitive reduction # first, compute transitive closure by fixed point iteration while True: changed_something = False for insn_1 in dep_graph: for insn_2 in dep_graph.get(insn_1, set()).copy(): for insn_3 in dep_graph.get(insn_2, set()).copy(): if insn_3 not in dep_graph.get(insn_1, set()): changed_something = True dep_graph[insn_1].add(insn_3) if not changed_something: break for insn_1 in dep_graph: for insn_2 in dep_graph.get(insn_1, set()).copy(): for insn_3 in dep_graph.get(insn_2, set()).copy(): if insn_3 in dep_graph.get(insn_1, set()): dep_graph[insn_1].remove(insn_3) # }}} for insn_1 in dep_graph: for insn_2 in dep_graph.get(insn_1, set()): lines.append("%s -> %s" % (insn_2, insn_1)) if iname_cluster: from loopy.schedule import EnterLoop, LeaveLoop, RunInstruction, Barrier for sched_item in kernel.schedule: if isinstance(sched_item, EnterLoop): lines.append("subgraph cluster_%s { label=\"%s\"" % (sched_item.iname, sched_item.iname)) elif isinstance(sched_item, LeaveLoop): lines.append("}") elif isinstance(sched_item, RunInstruction): lines.append(sched_item.insn_id) elif isinstance(sched_item, Barrier): pass else: raise LoopyError("schedule item not unterstood: %r" % sched_item) return "digraph %s {\n%s\n}" % ( kernel.name, "\n".join(lines) )
def stringify(self, what=None, with_dependencies=False): all_what = set([ "name", "arguments", "domains", "tags", "variables", "rules", "instructions", "Dependencies", "schedule", ]) first_letter_to_what = dict((w[0], w) for w in all_what) assert len(first_letter_to_what) == len(all_what) if what is None: what = all_what.copy() if not with_dependencies: what.remove("Dependencies") if isinstance(what, str): if "," in what: what = what.split(",") what = set(s.strip() for s in what) else: what = set(first_letter_to_what[w] for w in what) if not (what <= all_what): raise LoopyError("invalid 'what' passed: %s" % ", ".join(what - all_what)) lines = [] from loopy.preprocess import add_default_dependencies kernel = add_default_dependencies(self) sep = 75 * "-" if "name" in what: lines.append(sep) lines.append("KERNEL: " + kernel.name) if "arguments" in what: lines.append(sep) lines.append("ARGUMENTS:") for arg_name in sorted(kernel.arg_dict): lines.append(str(kernel.arg_dict[arg_name])) if "domains" in what: lines.append(sep) lines.append("DOMAINS:") for dom, parents in zip(kernel.domains, kernel.all_parents_per_domain()): lines.append(len(parents) * " " + str(dom)) if "tags" in what: lines.append(sep) lines.append("INAME IMPLEMENTATION TAGS:") for iname in sorted(kernel.all_inames()): line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname)) lines.append(line) if "variables" in what and kernel.temporary_variables: lines.append(sep) lines.append("TEMPORARIES:") for tv in sorted(six.itervalues(kernel.temporary_variables), key=lambda tv: tv.name): lines.append(str(tv)) if "rules" in what and kernel.substitutions: lines.append(sep) lines.append("SUBSTIUTION RULES:") for rule_name in sorted(six.iterkeys(kernel.substitutions)): lines.append(str(kernel.substitutions[rule_name])) if "instructions" in what: lines.append(sep) lines.append("INSTRUCTIONS:") loop_list_width = 35 printed_insn_ids = set() Fore = self.options._fore Style = self.options._style def print_insn(insn): if insn.id in printed_insn_ids: return printed_insn_ids.add(insn.id) for dep_id in sorted(insn.depends_on): print_insn(kernel.id_to_insn[dep_id]) if isinstance(insn, lp.MultiAssignmentBase): lhs = ", ".join(str(a) for a in insn.assignees) rhs = str(insn.expression) trailing = [] elif isinstance(insn, lp.CInstruction): lhs = ", ".join(str(a) for a in insn.assignees) rhs = "CODE(%s|%s)" % (", ".join( str(x) for x in insn.read_variables), ", ".join( "%s=%s" % (name, expr) for name, expr in insn.iname_exprs)) trailing = [" " + l for l in insn.code.split("\n")] elif isinstance(insn, lp.BarrierInstruction): lhs = "" rhs = "... %sbarrier" % insn.kind[0] trailing = [] elif isinstance(insn, lp.NoOpInstruction): lhs = "" rhs = "... nop" trailing = [] else: raise LoopyError("unexpected instruction type: %s" % type(insn).__name__) loop_list = ",".join(sorted(kernel.insn_inames(insn))) options = [Fore.GREEN + insn.id + Style.RESET_ALL] if insn.priority: options.append("priority=%d" % insn.priority) if insn.tags: options.append("tags=%s" % ":".join(insn.tags)) if isinstance(insn, lp.Assignment) and insn.atomicity: options.append("atomic=%s" % ":".join(str(a) for a in insn.atomicity)) if insn.groups: options.append("groups=%s" % ":".join(insn.groups)) if insn.conflicts_with_groups: options.append("conflicts=%s" % ":".join(insn.conflicts_with_groups)) if insn.no_sync_with: options.append("no_sync_with=%s" % ":".join(insn.no_sync_with)) if lhs: core = "%s <- %s" % ( Fore.BLUE + lhs + Style.RESET_ALL, Fore.MAGENTA + rhs + Style.RESET_ALL, ) else: core = Fore.MAGENTA + rhs + Style.RESET_ALL if len(loop_list) > loop_list_width: lines.append("[%s]" % loop_list) lines.append("%s%s # %s" % ( (loop_list_width + 2) * " ", core, ", ".join(options))) else: lines.append("[%s]%s%s # %s" % (loop_list, " " * (loop_list_width - len(loop_list)), core, ",".join(options))) lines.extend(trailing) if insn.predicates: lines.append(10 * " " + "if (%s)" % " && ".join(insn.predicates)) import loopy as lp for insn in kernel.instructions: print_insn(insn) dep_lines = [] for insn in kernel.instructions: if insn.depends_on: dep_lines.append("%s : %s" % (insn.id, ",".join(insn.depends_on))) if "Dependencies" in what and dep_lines: lines.append(sep) lines.append("DEPENDENCIES: " "(use loopy.show_dependency_graph to visualize)") lines.extend(dep_lines) if "schedule" in what and kernel.schedule is not None: lines.append(sep) lines.append("SCHEDULE:") from loopy.schedule import dump_schedule lines.append(dump_schedule(kernel, kernel.schedule)) lines.append(sep) return "\n".join(lines)
def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, force_retain_argument=False): """Extract an assignment (to a temporary variable or an argument) as a :ref:`substitution-rule`. The temporary may be an array, in which case the array indices will become arguments to the substitution rule. :arg within: a stack match as understood by :func:`loopy.match.parse_stack_match`. :arg force_retain_argument: If True and if *lhs_name* is an argument, it is kept even if it is no longer referenced. This operation will change all usage sites of *lhs_name* matched by *within*. If there are further usage sites of *lhs_name*, then the original assignment to *lhs_name* as well as the temporary variable is left in place. """ if isinstance(extra_arguments, str): extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) # {{{ establish the relevant definition of lhs_name for each usage site dep_kernel = expand_subst(kernel) from loopy.preprocess import add_default_dependencies dep_kernel = add_default_dependencies(dep_kernel) id_to_insn = dep_kernel.id_to_insn def get_relevant_definition_insn_id(usage_insn_id): insn = id_to_insn[usage_insn_id] def_id = set() for dep_id in insn.depends_on: dep_insn = id_to_insn[dep_id] if lhs_name in dep_insn.write_dependency_names(): if lhs_name in dep_insn.read_dependency_names(): raise LoopyError( "instruction '%s' both reads *and* " "writes '%s'--cannot transcribe to substitution " "rule" % (dep_id, lhs_name)) def_id.add(dep_id) else: rec_result = get_relevant_definition_insn_id(dep_id) if rec_result is not None: def_id.add(rec_result) if len(def_id) > 1: raise LoopyError( "more than one write to '%s' found in " "depdendencies of '%s'--definition cannot be resolved " "(writer instructions ids: %s)" % (lhs_name, usage_insn_id, ", ".join(def_id))) if not def_id: return None else: def_id, = def_id return def_id usage_to_definition = {} for insn in dep_kernel.instructions: if lhs_name not in insn.read_dependency_names(): continue def_id = get_relevant_definition_insn_id(insn.id) if def_id is None: raise LoopyError("no write to '%s' found in dependency tree " "of '%s'--definition cannot be resolved" % (lhs_name, insn.id)) usage_to_definition[insn.id] = def_id definition_insn_ids = set() for insn in kernel.instructions: if lhs_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) # }}} if not definition_insn_ids: raise LoopyError("no assignments to variable '%s' found" % lhs_name) from loopy.match import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) tts = AssignmentToSubstChanger(rule_mapping_context, lhs_name, definition_insn_ids, usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) from loopy.kernel.data import SubstitutionRule # {{{ create new substitution rules new_substs = kernel.substitutions.copy() for def_id, subst_name in six.iteritems( tts.definition_insn_id_to_subst_name): def_insn = kernel.id_to_insn[def_id] from loopy.kernel.data import Assignment assert isinstance(def_insn, Assignment) from pymbolic.primitives import Variable, Subscript if isinstance(def_insn.assignee, Subscript): indices = def_insn.assignee.index_tuple elif isinstance(def_insn.assignee, Variable): indices = () else: raise LoopyError("Unrecognized LHS type: %s" % type(def_insn.assignee).__name__) arguments = [] for i in indices: if not isinstance(i, Variable): raise LoopyError("In defining instruction '%s': " "asignee index '%s' is not a plain variable. " "Perhaps use loopy.affine_map_inames() " "to perform substitution." % (def_id, i)) arguments.append(i.name) new_substs[subst_name] = SubstitutionRule( name=subst_name, arguments=tuple(arguments) + extra_arguments, expression=def_insn.expression) # }}} # {{{ delete temporary variable if possible # (copied below if modified) new_temp_vars = kernel.temporary_variables new_args = kernel.args if lhs_name in kernel.temporary_variables: if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the variable. new_temp_vars = new_temp_vars.copy() del new_temp_vars[lhs_name] if lhs_name in kernel.arg_dict and not force_retain_argument: if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the argument new_args = new_args[:] for i in range(len(new_args)): if new_args[i].name == lhs_name: del new_args[i] break # }}} import loopy as lp kernel = lp.remove_instructions( kernel, set(insn_id for insn_id, still_used in six.iteritems( tts.saw_unmatched_usage_sites) if not still_used)) return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, args=new_args, )