Beispiel #1
0
def diff_kernel(knl,
                diff_outputs,
                by,
                diff_iname_prefix="diff_i",
                batch_axes_in_by=frozenset(),
                copy_outputs=set()):
    """

    :arg batch_axes_in_by: a :class:`set` of axis indices in the variable named *by*
        that are not part of the differentiation.
    :return: a string containing the name of a new variable
        holding the derivative of *var_name* by the desired
        *diff_context.by_name*, or *None* if no dependency exists.
    """

    from loopy.preprocess import add_default_dependencies
    knl = add_default_dependencies(knl)

    if isinstance(diff_outputs, str):
        diff_outputs = [
            dout.strip() for dout in diff_outputs.split(",") if dout.strip()
        ]

    by_arg = knl.arg_dict[by]
    additional_shape = by_arg.shape

    var_name_gen = knl.get_var_name_generator()

    # {{{ differentiate instructions

    diff_context = DifferentiationContext(knl,
                                          var_name_gen,
                                          by,
                                          diff_iname_prefix=diff_iname_prefix,
                                          additional_shape=additional_shape)

    result = {}
    for dout in diff_outputs:
        result = diff_context.get_diff_var(dout)

    for cout in copy_outputs:
        diff_context.import_output_var(cout)

    # }}}

    return diff_context.get_new_kernel(), result
Beispiel #2
0
def diff_kernel(knl, diff_outputs, by, diff_iname_prefix="diff_i",
        batch_axes_in_by=frozenset(), copy_outputs=set()):
    """

    :arg batch_axes_in_by: a :class:`set` of axis indices in the variable named *by*
        that are not part of the differentiation.
    :return: a string containing the name of a new variable
        holding the derivative of *var_name* by the desired
        *diff_context.by_name*, or *None* if no dependency exists.
    """

    from loopy.preprocess import add_default_dependencies
    knl = add_default_dependencies(knl)

    if isinstance(diff_outputs, str):
        diff_outputs = [
                dout.strip() for dout in diff_outputs.split(",")
                if dout.strip()]

    by_arg = knl.arg_dict[by]
    additional_shape = by_arg.shape

    var_name_gen = knl.get_var_name_generator()

    # {{{ differentiate instructions

    diff_context = DifferentiationContext(
            knl, var_name_gen, by, diff_iname_prefix=diff_iname_prefix,
            additional_shape=additional_shape)

    result = {}
    for dout in diff_outputs:
        result = diff_context.get_diff_var(dout)

    for cout in copy_outputs:
        diff_context.import_output_var(cout)

    # }}}

    return diff_context.get_new_kernel(), result
Beispiel #3
0
    def stringify(self, what=None, with_dependencies=False):
        all_what = set([
            "name",
            "arguments",
            "domains",
            "tags",
            "variables",
            "rules",
            "instructions",
            "Dependencies",
            "schedule",
            ])

        first_letter_to_what = dict(
                (w[0], w) for w in all_what)
        assert len(first_letter_to_what) == len(all_what)

        if what is None:
            what = all_what.copy()
            if not with_dependencies:
                what.remove("Dependencies")

        if isinstance(what, str):
            if "," in what:
                what = what.split(",")
                what = set(s.strip() for s in what)
            else:
                what = set(
                        first_letter_to_what[w]
                        for w in what)

        if not (what <= all_what):
            raise LoopyError("invalid 'what' passed: %s"
                    % ", ".join(what-all_what))

        lines = []

        from loopy.preprocess import add_default_dependencies
        kernel = add_default_dependencies(self)

        sep = 75*"-"

        if "name" in what:
            lines.append(sep)
            lines.append("KERNEL: " + kernel.name)

        if "arguments" in what:
            lines.append(sep)
            lines.append("ARGUMENTS:")
            for arg_name in sorted(kernel.arg_dict):
                lines.append(str(kernel.arg_dict[arg_name]))

        if "domains" in what:
            lines.append(sep)
            lines.append("DOMAINS:")
            for dom, parents in zip(kernel.domains, kernel.all_parents_per_domain()):
                lines.append(len(parents)*"  " + str(dom))

        if "tags" in what:
            lines.append(sep)
            lines.append("INAME IMPLEMENTATION TAGS:")
            for iname in sorted(kernel.all_inames()):
                line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname))
                lines.append(line)

        if "variables" in what and kernel.temporary_variables:
            lines.append(sep)
            lines.append("TEMPORARIES:")
            for tv in sorted(six.itervalues(kernel.temporary_variables),
                    key=lambda tv: tv.name):
                lines.append(str(tv))

        if "rules" in what and kernel.substitutions:
            lines.append(sep)
            lines.append("SUBSTIUTION RULES:")
            for rule_name in sorted(six.iterkeys(kernel.substitutions)):
                lines.append(str(kernel.substitutions[rule_name]))

        if "instructions" in what:
            lines.append(sep)
            lines.append("INSTRUCTIONS:")
            loop_list_width = 35

            printed_insn_ids = set()

            Fore = self.options._fore
            Style = self.options._style

            def print_insn(insn):
                if insn.id in printed_insn_ids:
                    return
                printed_insn_ids.add(insn.id)

                for dep_id in sorted(insn.depends_on):
                    print_insn(kernel.id_to_insn[dep_id])

                if isinstance(insn, lp.MultiAssignmentBase):
                    lhs = ", ".join(str(a) for a in insn.assignees)
                    rhs = str(insn.expression)
                    trailing = []
                elif isinstance(insn, lp.CInstruction):
                    lhs = ", ".join(str(a) for a in insn.assignees)
                    rhs = "CODE(%s|%s)" % (
                            ", ".join(str(x) for x in insn.read_variables),
                            ", ".join("%s=%s" % (name, expr)
                                for name, expr in insn.iname_exprs))

                    trailing = ["    "+l for l in insn.code.split("\n")]

                loop_list = ",".join(sorted(kernel.insn_inames(insn)))

                options = [Fore.GREEN+insn.id+Style.RESET_ALL]
                if insn.priority:
                    options.append("priority=%d" % insn.priority)
                if insn.tags:
                    options.append("tags=%s" % ":".join(insn.tags))
                if isinstance(insn, lp.Assignment) and insn.atomicity:
                    options.append("atomic=%s" % ":".join(
                        str(a) for a in insn.atomicity))
                if insn.groups:
                    options.append("groups=%s" % ":".join(insn.groups))
                if insn.conflicts_with_groups:
                    options.append(
                            "conflicts=%s" % ":".join(insn.conflicts_with_groups))
                if insn.no_sync_with:
                    options.append("no_sync_with=%s" % ":".join(insn.no_sync_with))

                if len(loop_list) > loop_list_width:
                    lines.append("[%s]" % loop_list)
                    lines.append("%s%s <- %s   # %s" % (
                        (loop_list_width+2)*" ", Fore.BLUE+lhs+Style.RESET_ALL,
                        Fore.MAGENTA+rhs+Style.RESET_ALL,
                        ", ".join(options)))
                else:
                    lines.append("[%s]%s%s <- %s   # %s" % (
                        loop_list, " "*(loop_list_width-len(loop_list)),
                        Fore.BLUE + lhs + Style.RESET_ALL,
                        Fore.MAGENTA+rhs+Style.RESET_ALL,
                        ",".join(options)))

                lines.extend(trailing)

                if insn.predicates:
                    lines.append(10*" " + "if (%s)" % " && ".join(insn.predicates))

            import loopy as lp
            for insn in kernel.instructions:
                print_insn(insn)

        dep_lines = []
        for insn in kernel.instructions:
            if insn.depends_on:
                dep_lines.append("%s : %s" % (insn.id, ",".join(insn.depends_on)))

        if "Dependencies" in what and dep_lines:
            lines.append(sep)
            lines.append("DEPENDENCIES: "
                    "(use loopy.show_dependency_graph to visualize)")
            lines.extend(dep_lines)

        if "schedule" in what and kernel.schedule is not None:
            lines.append(sep)
            lines.append("SCHEDULE:")
            from loopy.schedule import dump_schedule
            lines.append(dump_schedule(kernel, kernel.schedule))

        lines.append(sep)

        return "\n".join(lines)
Beispiel #4
0
    def stringify(self, with_dependencies=False):
        lines = []

        from loopy.preprocess import add_default_dependencies
        kernel = add_default_dependencies(self)

        sep = 75*"-"
        lines.append(sep)
        lines.append("KERNEL: " + kernel.name)
        lines.append(sep)
        lines.append("ARGUMENTS:")
        for arg_name in sorted(kernel.arg_dict):
            lines.append(str(kernel.arg_dict[arg_name]))
        lines.append(sep)
        lines.append("DOMAINS:")
        for dom, parents in zip(kernel.domains, kernel.all_parents_per_domain()):
            lines.append(len(parents)*"  " + str(dom))

        lines.append(sep)
        lines.append("INAME IMPLEMENTATION TAGS:")
        for iname in sorted(kernel.all_inames()):
            line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname))
            lines.append(line)

        if kernel.temporary_variables:
            lines.append(sep)
            lines.append("TEMPORARIES:")
            for tv in sorted(six.itervalues(kernel.temporary_variables),
                    key=lambda tv: tv.name):
                lines.append(str(tv))

        if kernel.substitutions:
            lines.append(sep)
            lines.append("SUBSTIUTION RULES:")
            for rule_name in sorted(six.iterkeys(kernel.substitutions)):
                lines.append(str(kernel.substitutions[rule_name]))

        lines.append(sep)
        lines.append("INSTRUCTIONS:")
        loop_list_width = 35

        printed_insn_ids = set()

        def print_insn(insn):
            if insn.id in printed_insn_ids:
                return
            printed_insn_ids.add(insn.id)

            for dep_id in sorted(insn.insn_deps):
                print_insn(kernel.id_to_insn[dep_id])

            if isinstance(insn, lp.Assignment):
                lhs = str(insn.assignee)
                rhs = str(insn.expression)
                trailing = []
            elif isinstance(insn, lp.CInstruction):
                lhs = ", ".join(str(a) for a in insn.assignees)
                rhs = "CODE(%s|%s)" % (
                        ", ".join(str(x) for x in insn.read_variables),
                        ", ".join("%s=%s" % (name, expr)
                            for name, expr in insn.iname_exprs))

                trailing = ["    "+l for l in insn.code.split("\n")]

            loop_list = ",".join(sorted(kernel.insn_inames(insn)))

            options = [insn.id]
            if insn.priority:
                options.append("priority=%d" % insn.priority)
            if insn.tags:
                options.append("tags=%s" % ":".join(insn.tags))
            if insn.groups:
                options.append("groups=%s" % ":".join(insn.groups))
            if insn.conflicts_with_groups:
                options.append("conflicts=%s" % ":".join(insn.conflicts_with_groups))

            if len(loop_list) > loop_list_width:
                lines.append("[%s]" % loop_list)
                lines.append("%s%s <- %s   # %s" % (
                    (loop_list_width+2)*" ", lhs,
                    rhs, ", ".join(options)))
            else:
                lines.append("[%s]%s%s <- %s   # %s" % (
                    loop_list, " "*(loop_list_width-len(loop_list)),
                    lhs, rhs, ",".join(options)))

            lines.extend(trailing)

            if insn.predicates:
                lines.append(10*" " + "if (%s)" % " && ".join(insn.predicates))

        import loopy as lp
        for insn in kernel.instructions:
            print_insn(insn)

        dep_lines = []
        for insn in kernel.instructions:
            if insn.insn_deps:
                dep_lines.append("%s : %s" % (insn.id, ",".join(insn.insn_deps)))
        if dep_lines:
            lines.append(sep)
            lines.append("DEPENDENCIES: "
                    "(use loopy.show_dependency_graph to visualize)")
            if with_dependencies:
                lines.extend(dep_lines)

        lines.append(sep)

        if kernel.schedule is not None:
            lines.append("SCHEDULE:")
            from loopy.schedule import dump_schedule
            lines.append(dump_schedule(kernel, kernel.schedule))
            lines.append(sep)

        return "\n".join(lines)
Beispiel #5
0
def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None):
    """Extract an assignment to a temporary variable
    as a :ref:`substituiton-rule`. The temporary may be an array, in
    which case the array indices will become arguments to the substitution
    rule.

    :arg within: a stack match as understood by
        :func:`loopy.context_matching.parse_stack_match`.

    This operation will change all usage sites
    of *temp_name* matched by *within*. If there
    are further usage sites of *temp_name*, then
    the original assignment to *temp_name* as well
    as the temporary variable is left in place.
    """

    if isinstance(extra_arguments, str):
        extra_arguments = tuple(s.strip() for s in extra_arguments.split(","))

    # {{{ establish the relevant definition of temp_name for each usage site

    dep_kernel = expand_subst(kernel)
    from loopy.preprocess import add_default_dependencies
    dep_kernel = add_default_dependencies(dep_kernel)

    id_to_insn = dep_kernel.id_to_insn

    def get_relevant_definition_insn_id(usage_insn_id):
        insn = id_to_insn[usage_insn_id]

        def_id = set()
        for dep_id in insn.insn_deps:
            dep_insn = id_to_insn[dep_id]
            if temp_name in dep_insn.write_dependency_names():
                if temp_name in dep_insn.read_dependency_names():
                    raise LoopyError("instruction '%s' both reads *and* "
                            "writes '%s'--cannot transcribe to substitution "
                            "rule" % (dep_id, temp_name))

                def_id.add(dep_id)
            else:
                rec_result = get_relevant_definition_insn_id(dep_id)
                if rec_result is not None:
                    def_id.add(rec_result)

        if len(def_id) > 1:
            raise LoopyError("more than one write to '%s' found in "
                    "depdendencies of '%s'--definition cannot be resolved "
                    "(writer instructions ids: %s)"
                    % (temp_name, usage_insn_id, ", ".join(def_id)))

        if not def_id:
            return None
        else:
            def_id, = def_id

        return def_id

    usage_to_definition = {}

    for insn in kernel.instructions:
        if temp_name not in insn.read_dependency_names():
            continue

        def_id = get_relevant_definition_insn_id(insn.id)
        if def_id is None:
            raise LoopyError("no write to '%s' found in dependency tree "
                    "of '%s'--definition cannot be resolved"
                    % (temp_name, insn.id))

        usage_to_definition[insn.id] = def_id

    definition_insn_ids = set()
    for insn in kernel.instructions:
        if temp_name in insn.write_dependency_names():
            definition_insn_ids.add(insn.id)

    # }}}

    from loopy.context_matching import parse_stack_match
    within = parse_stack_match(within)

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())
    tts = TemporaryToSubstChanger(rule_mapping_context,
            temp_name, definition_insn_ids,
            usage_to_definition, extra_arguments, within)

    kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel))

    from loopy.kernel.data import SubstitutionRule

    # {{{ create new substitution rules

    new_substs = kernel.substitutions.copy()
    for def_id, subst_name in six.iteritems(tts.definition_insn_id_to_subst_name):
        def_insn = kernel.id_to_insn[def_id]

        (_, indices), = def_insn.assignees_and_indices()

        arguments = []

        from pymbolic.primitives import Variable
        for i in indices:
            if not isinstance(i, Variable):
                raise LoopyError("In defining instruction '%s': "
                        "asignee index '%s' is not a plain variable. "
                        "Perhaps use loopy.affine_map_inames() "
                        "to perform substitution." % (def_id, i))

            arguments.append(i.name)

        new_substs[subst_name] = SubstitutionRule(
                name=subst_name,
                arguments=tuple(arguments) + extra_arguments,
                expression=def_insn.expression)

    # }}}

    # {{{ delete temporary variable if possible

    new_temp_vars = kernel.temporary_variables
    if not any(six.itervalues(tts.saw_unmatched_usage_sites)):
        # All usage sites matched--they're now substitution rules.
        # We can get rid of the variable.

        new_temp_vars = new_temp_vars.copy()
        del new_temp_vars[temp_name]

    # }}}

    import loopy as lp
    kernel = lp.remove_instructions(
            kernel,
            set(
                insn_id
                for insn_id, still_used in six.iteritems(
                    tts.saw_unmatched_usage_sites)
                if not still_used))

    return kernel.copy(
            substitutions=new_substs,
            temporary_variables=new_temp_vars,
            )
Beispiel #6
0
def get_dot_dependency_graph(kernel, iname_cluster=True, use_insn_id=False):
    """Return a string in the `dot <http://graphviz.org/>`_ language depicting
    dependencies among kernel instructions.
    """

    # make sure all automatically added stuff shows up
    from loopy.preprocess import add_default_dependencies
    kernel = add_default_dependencies(kernel)

    if iname_cluster and not kernel.schedule:
        try:
            from loopy.schedule import get_one_scheduled_kernel
            kernel = get_one_scheduled_kernel(kernel)
        except RuntimeError as e:
            iname_cluster = False
            from warnings import warn
            warn("error encountered during scheduling for dep graph -- "
                    "cannot perform iname clustering: %s(%s)"
                    % (type(e).__name__, e))

    dep_graph = {}
    lines = []

    from loopy.kernel.data import MultiAssignmentBase, CInstruction

    for insn in kernel.instructions:
        if isinstance(insn, MultiAssignmentBase):
            op = "%s <- %s" % (insn.assignees, insn.expression)
            if len(op) > 200:
                op = op[:200] + "..."

        elif isinstance(insn, CInstruction):
            op = "<C instruction %s>" % insn.id
        else:
            op = "<instruction %s>" % insn.id

        if use_insn_id:
            insn_label = insn.id
            tooltip = op
        else:
            insn_label = op
            tooltip = insn.id

        lines.append("\"%s\" [label=\"%s\",shape=\"box\",tooltip=\"%s\"];"
                % (
                    insn.id,
                    repr(insn_label)[1:-1],
                    repr(tooltip)[1:-1],
                    ))
        for dep in insn.depends_on:
            dep_graph.setdefault(insn.id, set()).add(dep)

    # {{{ O(n^3) transitive reduction

    # first, compute transitive closure by fixed point iteration
    while True:
        changed_something = False

        for insn_1 in dep_graph:
            for insn_2 in dep_graph.get(insn_1, set()).copy():
                for insn_3 in dep_graph.get(insn_2, set()).copy():
                    if insn_3 not in dep_graph.get(insn_1, set()):
                        changed_something = True
                        dep_graph[insn_1].add(insn_3)

        if not changed_something:
            break

    for insn_1 in dep_graph:
        for insn_2 in dep_graph.get(insn_1, set()).copy():
            for insn_3 in dep_graph.get(insn_2, set()).copy():
                if insn_3 in dep_graph.get(insn_1, set()):
                    dep_graph[insn_1].remove(insn_3)

    # }}}

    for insn_1 in dep_graph:
        for insn_2 in dep_graph.get(insn_1, set()):
            lines.append("%s -> %s" % (insn_2, insn_1))

    if iname_cluster:
        from loopy.schedule import EnterLoop, LeaveLoop, RunInstruction, Barrier

        for sched_item in kernel.schedule:
            if isinstance(sched_item, EnterLoop):
                lines.append("subgraph cluster_%s { label=\"%s\""
                        % (sched_item.iname, sched_item.iname))
            elif isinstance(sched_item, LeaveLoop):
                lines.append("}")
            elif isinstance(sched_item, RunInstruction):
                lines.append(sched_item.insn_id)
            elif isinstance(sched_item, Barrier):
                pass
            else:
                raise LoopyError("schedule item not unterstood: %r" % sched_item)

    return "digraph %s {\n%s\n}" % (
            kernel.name,
            "\n".join(lines)
            )
Beispiel #7
0
    def stringify(self, what=None, with_dependencies=False):
        all_what = set([
            "name",
            "arguments",
            "domains",
            "tags",
            "variables",
            "rules",
            "instructions",
            "Dependencies",
            "schedule",
        ])

        first_letter_to_what = dict((w[0], w) for w in all_what)
        assert len(first_letter_to_what) == len(all_what)

        if what is None:
            what = all_what.copy()
            if not with_dependencies:
                what.remove("Dependencies")

        if isinstance(what, str):
            if "," in what:
                what = what.split(",")
                what = set(s.strip() for s in what)
            else:
                what = set(first_letter_to_what[w] for w in what)

        if not (what <= all_what):
            raise LoopyError("invalid 'what' passed: %s" %
                             ", ".join(what - all_what))

        lines = []

        from loopy.preprocess import add_default_dependencies
        kernel = add_default_dependencies(self)

        sep = 75 * "-"

        if "name" in what:
            lines.append(sep)
            lines.append("KERNEL: " + kernel.name)

        if "arguments" in what:
            lines.append(sep)
            lines.append("ARGUMENTS:")
            for arg_name in sorted(kernel.arg_dict):
                lines.append(str(kernel.arg_dict[arg_name]))

        if "domains" in what:
            lines.append(sep)
            lines.append("DOMAINS:")
            for dom, parents in zip(kernel.domains,
                                    kernel.all_parents_per_domain()):
                lines.append(len(parents) * "  " + str(dom))

        if "tags" in what:
            lines.append(sep)
            lines.append("INAME IMPLEMENTATION TAGS:")
            for iname in sorted(kernel.all_inames()):
                line = "%s: %s" % (iname, kernel.iname_to_tag.get(iname))
                lines.append(line)

        if "variables" in what and kernel.temporary_variables:
            lines.append(sep)
            lines.append("TEMPORARIES:")
            for tv in sorted(six.itervalues(kernel.temporary_variables),
                             key=lambda tv: tv.name):
                lines.append(str(tv))

        if "rules" in what and kernel.substitutions:
            lines.append(sep)
            lines.append("SUBSTIUTION RULES:")
            for rule_name in sorted(six.iterkeys(kernel.substitutions)):
                lines.append(str(kernel.substitutions[rule_name]))

        if "instructions" in what:
            lines.append(sep)
            lines.append("INSTRUCTIONS:")
            loop_list_width = 35

            printed_insn_ids = set()

            Fore = self.options._fore
            Style = self.options._style

            def print_insn(insn):
                if insn.id in printed_insn_ids:
                    return
                printed_insn_ids.add(insn.id)

                for dep_id in sorted(insn.depends_on):
                    print_insn(kernel.id_to_insn[dep_id])

                if isinstance(insn, lp.MultiAssignmentBase):
                    lhs = ", ".join(str(a) for a in insn.assignees)
                    rhs = str(insn.expression)
                    trailing = []
                elif isinstance(insn, lp.CInstruction):
                    lhs = ", ".join(str(a) for a in insn.assignees)
                    rhs = "CODE(%s|%s)" % (", ".join(
                        str(x) for x in insn.read_variables), ", ".join(
                            "%s=%s" % (name, expr)
                            for name, expr in insn.iname_exprs))

                    trailing = ["    " + l for l in insn.code.split("\n")]
                elif isinstance(insn, lp.BarrierInstruction):
                    lhs = ""
                    rhs = "... %sbarrier" % insn.kind[0]
                    trailing = []

                elif isinstance(insn, lp.NoOpInstruction):
                    lhs = ""
                    rhs = "... nop"
                    trailing = []

                else:
                    raise LoopyError("unexpected instruction type: %s" %
                                     type(insn).__name__)

                loop_list = ",".join(sorted(kernel.insn_inames(insn)))

                options = [Fore.GREEN + insn.id + Style.RESET_ALL]
                if insn.priority:
                    options.append("priority=%d" % insn.priority)
                if insn.tags:
                    options.append("tags=%s" % ":".join(insn.tags))
                if isinstance(insn, lp.Assignment) and insn.atomicity:
                    options.append("atomic=%s" %
                                   ":".join(str(a) for a in insn.atomicity))
                if insn.groups:
                    options.append("groups=%s" % ":".join(insn.groups))
                if insn.conflicts_with_groups:
                    options.append("conflicts=%s" %
                                   ":".join(insn.conflicts_with_groups))
                if insn.no_sync_with:
                    options.append("no_sync_with=%s" %
                                   ":".join(insn.no_sync_with))

                if lhs:
                    core = "%s <- %s" % (
                        Fore.BLUE + lhs + Style.RESET_ALL,
                        Fore.MAGENTA + rhs + Style.RESET_ALL,
                    )
                else:
                    core = Fore.MAGENTA + rhs + Style.RESET_ALL

                if len(loop_list) > loop_list_width:
                    lines.append("[%s]" % loop_list)
                    lines.append("%s%s   # %s" % (
                        (loop_list_width + 2) * " ", core, ", ".join(options)))
                else:
                    lines.append("[%s]%s%s   # %s" %
                                 (loop_list, " " *
                                  (loop_list_width - len(loop_list)), core,
                                  ",".join(options)))

                lines.extend(trailing)

                if insn.predicates:
                    lines.append(10 * " " +
                                 "if (%s)" % " && ".join(insn.predicates))

            import loopy as lp
            for insn in kernel.instructions:
                print_insn(insn)

        dep_lines = []
        for insn in kernel.instructions:
            if insn.depends_on:
                dep_lines.append("%s : %s" %
                                 (insn.id, ",".join(insn.depends_on)))

        if "Dependencies" in what and dep_lines:
            lines.append(sep)
            lines.append("DEPENDENCIES: "
                         "(use loopy.show_dependency_graph to visualize)")
            lines.extend(dep_lines)

        if "schedule" in what and kernel.schedule is not None:
            lines.append(sep)
            lines.append("SCHEDULE:")
            from loopy.schedule import dump_schedule
            lines.append(dump_schedule(kernel, kernel.schedule))

        lines.append(sep)

        return "\n".join(lines)
Beispiel #8
0
def assignment_to_subst(kernel,
                        lhs_name,
                        extra_arguments=(),
                        within=None,
                        force_retain_argument=False):
    """Extract an assignment (to a temporary variable or an argument)
    as a :ref:`substitution-rule`. The temporary may be an array, in
    which case the array indices will become arguments to the substitution
    rule.

    :arg within: a stack match as understood by
        :func:`loopy.match.parse_stack_match`.
    :arg force_retain_argument: If True and if *lhs_name* is an argument, it is
        kept even if it is no longer referenced.

    This operation will change all usage sites
    of *lhs_name* matched by *within*. If there
    are further usage sites of *lhs_name*, then
    the original assignment to *lhs_name* as well
    as the temporary variable is left in place.
    """

    if isinstance(extra_arguments, str):
        extra_arguments = tuple(s.strip() for s in extra_arguments.split(","))

    # {{{ establish the relevant definition of lhs_name for each usage site

    dep_kernel = expand_subst(kernel)
    from loopy.preprocess import add_default_dependencies
    dep_kernel = add_default_dependencies(dep_kernel)

    id_to_insn = dep_kernel.id_to_insn

    def get_relevant_definition_insn_id(usage_insn_id):
        insn = id_to_insn[usage_insn_id]

        def_id = set()
        for dep_id in insn.depends_on:
            dep_insn = id_to_insn[dep_id]
            if lhs_name in dep_insn.write_dependency_names():
                if lhs_name in dep_insn.read_dependency_names():
                    raise LoopyError(
                        "instruction '%s' both reads *and* "
                        "writes '%s'--cannot transcribe to substitution "
                        "rule" % (dep_id, lhs_name))

                def_id.add(dep_id)
            else:
                rec_result = get_relevant_definition_insn_id(dep_id)
                if rec_result is not None:
                    def_id.add(rec_result)

        if len(def_id) > 1:
            raise LoopyError(
                "more than one write to '%s' found in "
                "depdendencies of '%s'--definition cannot be resolved "
                "(writer instructions ids: %s)" %
                (lhs_name, usage_insn_id, ", ".join(def_id)))

        if not def_id:
            return None
        else:
            def_id, = def_id

        return def_id

    usage_to_definition = {}

    for insn in dep_kernel.instructions:
        if lhs_name not in insn.read_dependency_names():
            continue

        def_id = get_relevant_definition_insn_id(insn.id)
        if def_id is None:
            raise LoopyError("no write to '%s' found in dependency tree "
                             "of '%s'--definition cannot be resolved" %
                             (lhs_name, insn.id))

        usage_to_definition[insn.id] = def_id

    definition_insn_ids = set()
    for insn in kernel.instructions:
        if lhs_name in insn.write_dependency_names():
            definition_insn_ids.add(insn.id)

    # }}}

    if not definition_insn_ids:
        raise LoopyError("no assignments to variable '%s' found" % lhs_name)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, kernel.get_var_name_generator())
    tts = AssignmentToSubstChanger(rule_mapping_context, lhs_name,
                                   definition_insn_ids, usage_to_definition,
                                   extra_arguments, within)

    kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel))

    from loopy.kernel.data import SubstitutionRule

    # {{{ create new substitution rules

    new_substs = kernel.substitutions.copy()
    for def_id, subst_name in six.iteritems(
            tts.definition_insn_id_to_subst_name):
        def_insn = kernel.id_to_insn[def_id]

        from loopy.kernel.data import Assignment
        assert isinstance(def_insn, Assignment)

        from pymbolic.primitives import Variable, Subscript
        if isinstance(def_insn.assignee, Subscript):
            indices = def_insn.assignee.index_tuple
        elif isinstance(def_insn.assignee, Variable):
            indices = ()
        else:
            raise LoopyError("Unrecognized LHS type: %s" %
                             type(def_insn.assignee).__name__)

        arguments = []

        for i in indices:
            if not isinstance(i, Variable):
                raise LoopyError("In defining instruction '%s': "
                                 "asignee index '%s' is not a plain variable. "
                                 "Perhaps use loopy.affine_map_inames() "
                                 "to perform substitution." % (def_id, i))

            arguments.append(i.name)

        new_substs[subst_name] = SubstitutionRule(
            name=subst_name,
            arguments=tuple(arguments) + extra_arguments,
            expression=def_insn.expression)

    # }}}

    # {{{ delete temporary variable if possible

    # (copied below if modified)
    new_temp_vars = kernel.temporary_variables
    new_args = kernel.args

    if lhs_name in kernel.temporary_variables:
        if not any(six.itervalues(tts.saw_unmatched_usage_sites)):
            # All usage sites matched--they're now substitution rules.
            # We can get rid of the variable.

            new_temp_vars = new_temp_vars.copy()
            del new_temp_vars[lhs_name]

    if lhs_name in kernel.arg_dict and not force_retain_argument:
        if not any(six.itervalues(tts.saw_unmatched_usage_sites)):
            # All usage sites matched--they're now substitution rules.
            # We can get rid of the argument

            new_args = new_args[:]
            for i in range(len(new_args)):
                if new_args[i].name == lhs_name:
                    del new_args[i]
                    break

    # }}}

    import loopy as lp
    kernel = lp.remove_instructions(
        kernel,
        set(insn_id for insn_id, still_used in six.iteritems(
            tts.saw_unmatched_usage_sites) if not still_used))

    return kernel.copy(
        substitutions=new_substs,
        temporary_variables=new_temp_vars,
        args=new_args,
    )