Exemple #1
0
    def map_reduction(expr, rec, nresults=1):
        if frozenset(expr.inames) != inames_set:
            return type(expr)(
                    operation=expr.operation,
                    inames=expr.inames,
                    expr=rec(expr.expr),
                    allow_simultaneous=expr.allow_simultaneous)

        if subst_rule_name is None:
            subst_rule_prefix = "red_%s_arg" % "_".join(inames)
            my_subst_rule_name = var_name_gen(subst_rule_prefix)
        else:
            my_subst_rule_name = subst_rule_name

        if my_subst_rule_name in substs:
            raise LoopyError("substitution rule '%s' already exists"
                    % my_subst_rule_name)

        from loopy.kernel.data import SubstitutionRule
        substs[my_subst_rule_name] = SubstitutionRule(
                name=my_subst_rule_name,
                arguments=tuple(inames),
                expression=expr.expr)

        from pymbolic import var
        iname_vars = [var(iname) for iname in inames]

        return type(expr)(
                operation=expr.operation,
                inames=expr.inames,
                expr=var(my_subst_rule_name)(*iname_vars),
                allow_simultaneous=expr.allow_simultaneous)
Exemple #2
0
    def _get_new_substitutions_and_renames(self):
        """This makes a new dictionary of substitutions from the ones
        encountered in mapping all the encountered expressions.
        It tries hard to keep substitution names the same--i.e.
        if all derivative versions of a substitution rule ended
        up with the same mapped version, then this version should
        retain the name that the substitution rule had previously.
        Unfortunately, this can't be done in a single pass, and so
        the routine returns an additional dictionary *subst_renames*
        of renamings to be performed on the processed expressions.

        The returned substitutions already have the rename applied
        to them.

        :returns: (new_substitutions, subst_renames)
        """

        from loopy.kernel.data import SubstitutionRule

        result = {}
        renames = {}

        used_names = set()

        for key, (name, args, body) in six.iteritems(
                self.subst_rule_registry):
            orig_names = self.subst_rule_old_names.get(key, [])

            # If no orig_names are found, then this particular
            # subst rule was never referenced, and so it's fine
            # to leave out.

            if not orig_names:
                continue

            new_name = min(orig_names)
            if new_name in used_names:
                new_name = self.make_unique_var_name(new_name)

            renames[name] = new_name
            used_names.add(new_name)

            result[new_name] = SubstitutionRule(
                    name=new_name,
                    arguments=args,
                    expression=body)

        # {{{ perform renames on new substitutions

        subst_renamer = SubstitutionRuleRenamer(renames)

        renamed_result = {}
        for name, rule in six.iteritems(result):
            renamed_result[name] = rule.copy(
                    expression=subst_renamer(rule.expression))

        # }}}

        return renamed_result, renames
Exemple #3
0
    def _get_new_substitutions_and_renames(self):
        """This makes a new dictionary of substitutions from the ones
        encountered in mapping all the encountered expressions.
        It tries hard to keep substitution names the same--i.e.
        if all derivative versions of a substitution rule ended
        up with the same mapped version, then this version should
        retain the name that the substitution rule had previously.
        Unfortunately, this can't be done in a single pass, and so
        the routine returns an additional dictionary *subst_renames*
        of renamings to be performed on the processed expressions.

        The returned substitutions already have the rename applied
        to them.

        :returns: (new_substitutions, subst_renames)
        """

        from loopy.kernel.data import SubstitutionRule

        orig_name_histogram = {}
        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
            if self.subst_rule_use_count.get(key, 0):
                orig_name_histogram[orig_name] = \
                        orig_name_histogram.get(orig_name, 0) + 1

        result = {}
        renames = {}

        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
            args, body = key

            if self.subst_rule_use_count.get(key, 0):
                if orig_name_histogram[orig_name] == 1 and name != orig_name:
                    renames[name] = orig_name
                    name = orig_name

                result[name] = SubstitutionRule(name=name,
                                                arguments=args,
                                                expression=body)

        # {{{ perform renames on new substitutions

        subst_renamer = SubstitutionRuleRenamer(renames)

        renamed_result = {}
        for name, rule in six.iteritems(result):
            renamed_result[name] = rule.copy(
                expression=subst_renamer(rule.expression))

        # }}}

        return renamed_result, renames
Exemple #4
0
def extract_subst(kernel, subst_name, template, parameters=()):
    """
    :arg subst_name: The name of the substitution rule to be created.
    :arg template: Unification template expression.
    :arg parameters: An iterable of parameters used in
        *template*, or a comma-separated string of the same.

    All targeted subexpressions must match ('unify with') *template*
    The template may contain '*' wildcards that will have to match exactly across all
    unifications.
    """

    if isinstance(template, str):
        from pymbolic import parse
        template = parse(template)

    if isinstance(parameters, str):
        parameters = tuple(s.strip() for s in parameters.split(","))

    var_name_gen = kernel.get_var_name_generator()

    # {{{ replace any wildcards in template with new variables

    def get_unique_var_name():
        based_on = subst_name + "_wc"

        result = var_name_gen(based_on)
        return result

    from loopy.symbolic import WildcardToUniqueVariableMapper
    wc_map = WildcardToUniqueVariableMapper(get_unique_var_name)
    template = wc_map(template)

    # }}}

    # {{{ gather up expressions

    expr_descriptors = []

    from loopy.symbolic import UnidirectionalUnifier
    unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters))

    def gather_exprs(expr, mapper):
        urecs = unif(template, expr)

        if urecs:
            if len(urecs) > 1:
                raise RuntimeError(
                    "ambiguous unification of '%s' with template '%s'" %
                    (expr, template))

            urec, = urecs

            expr_descriptors.append(
                ExprDescriptor(insn=insn,
                               expr=expr,
                               unif_var_dict=dict(
                                   (lhs.name, rhs)
                                   for lhs, rhs in urec.equations)))
        else:
            mapper.fallback_mapper(expr)
            # can't nest, don't recurse

    from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper)
    dfmapper = CallbackMapper(gather_exprs, WalkMapper())

    for insn in kernel.instructions:
        dfmapper(insn.assignees)
        dfmapper(insn.expression)

    for sr in six.itervalues(kernel.substitutions):
        dfmapper(sr.expression)

    # }}}

    if not expr_descriptors:
        raise RuntimeError("no expressions matching '%s'" % template)

    # {{{ substitute rule into instructions

    def replace_exprs(expr, mapper):
        found = False
        for exprd in expr_descriptors:
            if expr is exprd.expr:
                found = True
                break

        if not found:
            return mapper.fallback_mapper(expr)

        args = [exprd.unif_var_dict[arg_name] for arg_name in parameters]

        result = var(subst_name)
        if args:
            result = result(*args)

        return result
        # can't nest, don't recurse

    cbmapper = CallbackMapper(replace_exprs, IdentityMapper())

    new_insns = []

    for insn in kernel.instructions:
        new_insns.append(insn.with_transformed_expressions(cbmapper))

    from loopy.kernel.data import SubstitutionRule
    new_substs = {
        subst_name:
        SubstitutionRule(
            name=subst_name,
            arguments=tuple(parameters),
            expression=template,
        )
    }

    for subst in six.itervalues(kernel.substitutions):
        new_substs[subst.name] = subst.copy(
            expression=cbmapper(subst.expression))

    # }}}

    return kernel.copy(instructions=new_insns, substitutions=new_substs)
Exemple #5
0
def assignment_to_subst(kernel,
                        lhs_name,
                        extra_arguments=(),
                        within=None,
                        force_retain_argument=False):
    """Extract an assignment (to a temporary variable or an argument)
    as a :ref:`substitution-rule`. The temporary may be an array, in
    which case the array indices will become arguments to the substitution
    rule.

    :arg within: a stack match as understood by
        :func:`loopy.match.parse_stack_match`.
    :arg force_retain_argument: If True and if *lhs_name* is an argument, it is
        kept even if it is no longer referenced.

    This operation will change all usage sites
    of *lhs_name* matched by *within*. If there
    are further usage sites of *lhs_name*, then
    the original assignment to *lhs_name* as well
    as the temporary variable is left in place.
    """

    if isinstance(extra_arguments, str):
        extra_arguments = tuple(s.strip() for s in extra_arguments.split(","))

    # {{{ establish the relevant definition of lhs_name for each usage site

    dep_kernel = expand_subst(kernel)
    from loopy.kernel.creation import apply_single_writer_depencency_heuristic
    dep_kernel = apply_single_writer_depencency_heuristic(dep_kernel)

    id_to_insn = dep_kernel.id_to_insn

    def get_relevant_definition_insn_id(usage_insn_id):
        insn = id_to_insn[usage_insn_id]

        def_id = set()
        for dep_id in insn.depends_on:
            dep_insn = id_to_insn[dep_id]
            if lhs_name in dep_insn.write_dependency_names():
                if lhs_name in dep_insn.read_dependency_names():
                    raise LoopyError(
                        "instruction '%s' both reads *and* "
                        "writes '%s'--cannot transcribe to substitution "
                        "rule" % (dep_id, lhs_name))

                def_id.add(dep_id)
            else:
                rec_result = get_relevant_definition_insn_id(dep_id)
                if rec_result is not None:
                    def_id.add(rec_result)

        if len(def_id) > 1:
            raise LoopyError(
                "more than one write to '%s' found in "
                "depdendencies of '%s'--definition cannot be resolved "
                "(writer instructions ids: %s)" %
                (lhs_name, usage_insn_id, ", ".join(def_id)))

        if not def_id:
            return None
        else:
            def_id, = def_id

        return def_id

    usage_to_definition = {}

    for insn in dep_kernel.instructions:
        if lhs_name not in insn.read_dependency_names():
            continue

        def_id = get_relevant_definition_insn_id(insn.id)
        if def_id is None:
            raise LoopyError("no write to '%s' found in dependency tree "
                             "of '%s'--definition cannot be resolved" %
                             (lhs_name, insn.id))

        usage_to_definition[insn.id] = def_id

    definition_insn_ids = set()
    for insn in kernel.instructions:
        if lhs_name in insn.write_dependency_names():
            definition_insn_ids.add(insn.id)

    # }}}

    if not definition_insn_ids:
        raise LoopyError("no assignments to variable '%s' found" % lhs_name)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, kernel.get_var_name_generator())
    tts = AssignmentToSubstChanger(rule_mapping_context, lhs_name,
                                   definition_insn_ids, usage_to_definition,
                                   extra_arguments, within)

    kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel))

    from loopy.kernel.data import SubstitutionRule

    # {{{ create new substitution rules

    new_substs = kernel.substitutions.copy()
    for def_id, subst_name in six.iteritems(
            tts.definition_insn_id_to_subst_name):
        def_insn = kernel.id_to_insn[def_id]

        from loopy.kernel.data import Assignment
        assert isinstance(def_insn, Assignment)

        from pymbolic.primitives import Variable, Subscript
        if isinstance(def_insn.assignee, Subscript):
            indices = def_insn.assignee.index_tuple
        elif isinstance(def_insn.assignee, Variable):
            indices = ()
        else:
            raise LoopyError("Unrecognized LHS type: %s" %
                             type(def_insn.assignee).__name__)

        arguments = []

        for i in indices:
            if not isinstance(i, Variable):
                raise LoopyError("In defining instruction '%s': "
                                 "asignee index '%s' is not a plain variable. "
                                 "Perhaps use loopy.affine_map_inames() "
                                 "to perform substitution." % (def_id, i))

            arguments.append(i.name)

        new_substs[subst_name] = SubstitutionRule(
            name=subst_name,
            arguments=tuple(arguments) + extra_arguments,
            expression=def_insn.expression)

    # }}}

    # {{{ delete temporary variable if possible

    # (copied below if modified)
    new_temp_vars = kernel.temporary_variables
    new_args = kernel.args

    if lhs_name in kernel.temporary_variables:
        if not any(six.itervalues(tts.saw_unmatched_usage_sites)):
            # All usage sites matched--they're now substitution rules.
            # We can get rid of the variable.

            new_temp_vars = new_temp_vars.copy()
            del new_temp_vars[lhs_name]

    if lhs_name in kernel.arg_dict and not force_retain_argument:
        if not any(six.itervalues(tts.saw_unmatched_usage_sites)):
            # All usage sites matched--they're now substitution rules.
            # We can get rid of the argument

            new_args = new_args[:]
            for i in range(len(new_args)):
                if new_args[i].name == lhs_name:
                    del new_args[i]
                    break

    # }}}

    import loopy as lp
    kernel = lp.remove_instructions(
        kernel,
        set(insn_id for insn_id, still_used in six.iteritems(
            tts.saw_unmatched_usage_sites) if not still_used))

    return kernel.copy(
        substitutions=new_substs,
        temporary_variables=new_temp_vars,
        args=new_args,
    )
Exemple #6
0
def extract_subst(kernel, subst_name, template, parameters=(), within=None):
    """
    :arg subst_name: The name of the substitution rule to be created.
    :arg template: Unification template expression.
    :arg parameters: An iterable of parameters used in
        *template*, or a comma-separated string of the same.
    :arg within: An instance of :class:`loopy.match.MatchExpressionBase` or
        :class:`str` as understood by :func:`loopy.match.parse_match`.

    All targeted subexpressions must match ('unify with') *template*
    The template may contain '*' wildcards that will have to match exactly across all
    unifications.
    """

    if isinstance(kernel, TranslationUnit):
        kernel_names = [
            i for i, clbl in kernel.callables_table.items()
            if isinstance(clbl, CallableKernel)
        ]
        if len(kernel_names) != 1:
            raise LoopyError()

        return kernel.with_kernel(
            extract_subst(kernel[kernel_names[0]], subst_name, template,
                          parameters))

    if isinstance(template, str):
        from pymbolic import parse
        template = parse(template)

    if isinstance(parameters, str):
        parameters = tuple(s.strip() for s in parameters.split(","))

    from loopy.match import parse_match
    within = parse_match(within)

    var_name_gen = kernel.get_var_name_generator()

    # {{{ replace any wildcards in template with new variables

    def get_unique_var_name():
        based_on = subst_name + "_wc"

        result = var_name_gen(based_on)
        return result

    from loopy.symbolic import WildcardToUniqueVariableMapper
    wc_map = WildcardToUniqueVariableMapper(get_unique_var_name)
    template = wc_map(template)

    # }}}

    # {{{ gather up expressions

    expr_descriptors = []

    from loopy.symbolic import UnidirectionalUnifier
    unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters))

    def gather_exprs(expr, mapper):
        urecs = unif(template, expr)

        if urecs:
            if len(urecs) > 1:
                raise RuntimeError(
                    "ambiguous unification of '%s' with template '%s'" %
                    (expr, template))

            urec, = urecs

            expr_descriptors.append(
                ExprDescriptor(insn=insn,
                               expr=expr,
                               unif_var_dict={
                                   lhs.name: rhs
                                   for lhs, rhs in urec.equations
                               }))
        else:
            mapper.fallback_mapper(expr)
            # can't nest, don't recurse

    from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper)
    dfmapper = CallbackMapper(gather_exprs, WalkMapper())

    from loopy.kernel.instruction import MultiAssignmentBase
    for insn in kernel.instructions:
        if isinstance(insn, MultiAssignmentBase) and within(kernel, insn):
            dfmapper(insn.assignees)
            dfmapper(insn.expression)

    for sr in kernel.substitutions.values():
        dfmapper(sr.expression)

    # }}}

    if not expr_descriptors:
        raise RuntimeError("no expressions matching '%s'" % template)

    # {{{ substitute rule into instructions

    def replace_exprs(expr, mapper):
        found = False
        for exprd in expr_descriptors:
            if expr is exprd.expr:
                found = True
                break

        if not found:
            return mapper.fallback_mapper(expr)

        args = [exprd.unif_var_dict[arg_name] for arg_name in parameters]

        result = var(subst_name)
        if args:
            result = result(*args)

        return result
        # can't nest, don't recurse

    cbmapper = CallbackMapper(replace_exprs, IdentityMapper())

    new_insns = []

    def transform_assignee(expr):
        # Assignment LHS's cannot be subst rules. Treat them
        # specially.

        import pymbolic.primitives as prim
        if isinstance(expr, tuple):
            return tuple(transform_assignee(expr_i) for expr_i in expr)

        elif isinstance(expr, prim.Subscript):
            return type(expr)(expr.aggregate, cbmapper(expr.index))

        elif isinstance(expr, prim.Variable):
            return expr
        else:
            raise ValueError("assignment LHS not understood")

    for insn in kernel.instructions:
        if within(kernel, insn):
            new_insns.append(
                insn.with_transformed_expressions(
                    cbmapper, assignee_f=transform_assignee))
        else:
            new_insns.append(insn)

    from loopy.kernel.data import SubstitutionRule
    new_substs = {
        subst_name:
        SubstitutionRule(
            name=subst_name,
            arguments=tuple(parameters),
            expression=template,
        )
    }

    for subst in kernel.substitutions.values():
        new_substs[subst.name] = subst.copy(
            expression=cbmapper(subst.expression))

    # }}}

    return kernel.copy(instructions=new_insns, substitutions=new_substs)