Ejemplo n.º 1
0
    def map_reduction(self, expr, expn_state):
        if set(expr.inames) & set(expn_state.arg_context):
            # FIXME
            raise NotImplementedError()

        if (self.inames <= set(expr.inames) and self.within(
                expn_state.kernel, expn_state.instruction, expn_state.stack)):
            leftover_inames = set(expr.inames) - self.inames

            from loopy.symbolic import Reduction
            if self.direction == "in":
                return Reduction(
                    expr.operation, tuple(leftover_inames),
                    Reduction(expr.operation, tuple(self.inames),
                              self.rec(expr.expr, expn_state),
                              expr.allow_simultaneous),
                    expr.allow_simultaneous)
            elif self.direction == "out":
                return Reduction(
                    expr.operation, tuple(self.inames),
                    Reduction(expr.operation, tuple(leftover_inames),
                              self.rec(expr.expr, expn_state),
                              expr.allow_simultaneous))
            else:
                assert False
        else:
            return super(_ReductionSplitter,
                         self).map_reduction(expr, expn_state)
Ejemplo n.º 2
0
    def map_reduction(self, expr: Reduction, *args: Any,
                      **kwargs: Any) -> Reduction:
        new_inames = []
        for iname in expr.inames:
            new_iname = self.rec(prim.Variable(iname), *args, **kwargs)
            if not isinstance(new_iname, prim.Variable):
                raise ValueError(f"reduction iname {iname} can only be renamed"
                                 " to another iname")
            new_inames.append(new_iname.name)

        return Reduction(expr.operation,
                         tuple(new_inames),
                         self.rec(expr.expr, *args, **kwargs),
                         allow_simultaneous=expr.allow_simultaneous)
Ejemplo n.º 3
0
    def map_reduction(self, expr, expn_state):
        if (set(expr.inames) & self.old_inames_set and self.within(
                expn_state.kernel, expn_state.instruction, expn_state.stack)):
            new_inames = tuple(
                self.old_to_new.get(iname, iname) if iname not in
                expn_state.arg_context else iname for iname in expr.inames)

            from loopy.symbolic import Reduction
            return Reduction(expr.operation, new_inames,
                             self.rec(expr.expr, expn_state),
                             expr.allow_simultaneous)
        else:
            return super(_InameDuplicator,
                         self).map_reduction(expr, expn_state)
Ejemplo n.º 4
0
    def map_reduction(self, expr, expn_state):
        if (self.split_iname in expr.inames
                and self.split_iname not in expn_state.arg_context
                and self.within(expn_state.kernel, expn_state.instruction,
                                expn_state.stack)):
            new_inames = list(expr.inames)
            new_inames.remove(self.split_iname)
            new_inames.extend([self.outer_iname, self.inner_iname])

            from loopy.symbolic import Reduction
            return Reduction(expr.operation, tuple(new_inames),
                             self.rec(expr.expr, expn_state),
                             expr.allow_simultaneous)
        else:
            return super(_InameSplitter, self).map_reduction(expr, expn_state)
Ejemplo n.º 5
0
    def map_reduction(self, expr: Reduction) -> Reduction:
        new_inames = []
        for iname in expr.inames:
            new_iname = self.subst_func(iname)
            if new_iname is None:
                new_iname = prim.Variable(iname)
            else:
                if not isinstance(new_iname, prim.Variable):
                    raise ValueError(
                        f"reduction iname {iname} can only be renamed"
                        " to another iname")
            new_inames.append(new_iname.name)

        return Reduction(expr.operation,
                         tuple(new_inames),
                         self.rec(expr.expr),
                         allow_simultaneous=expr.allow_simultaneous)
Ejemplo n.º 6
0
    def map_reduction(self, expr, expn_state):
        within = self.within(expn_state.kernel, expn_state.instruction,
                             expn_state.stack)

        for iname in expr.inames:
            self.iname_to_red_count[iname] = (
                self.iname_to_red_count.get(iname, 0) + 1)
            if not expr.allow_simultaneous:
                self.iname_to_nonsimultaneous_red_count[iname] = (
                    self.iname_to_nonsimultaneous_red_count.get(iname, 0) + 1)

        if within and not expr.allow_simultaneous:
            subst_dict = {}

            from pymbolic import var

            new_inames = []
            for iname in expr.inames:
                if (not (self.inames is None or iname in self.inames)
                        or self.iname_to_red_count[iname] <= 1):
                    new_inames.append(iname)
                    continue

                new_iname = self.rule_mapping_context.make_unique_var_name(
                    iname)
                subst_dict[iname] = var(new_iname)
                self.old_to_new.append((iname, new_iname))
                new_inames.append(new_iname)

            from loopy.symbolic import SubstitutionMapper
            from pymbolic.mapper.substitutor import make_subst_func

            from loopy.symbolic import Reduction
            return Reduction(
                expr.operation, tuple(new_inames),
                self.rec(
                    SubstitutionMapper(make_subst_func(subst_dict))(expr.expr),
                    expn_state), expr.allow_simultaneous)
        else:
            return super(_ReductionInameUniquifier,
                         self).map_reduction(expr, expn_state)
Ejemplo n.º 7
0
    def map_reduction(self, expr, expn_state):
        expr_inames = set(expr.inames)
        overlap = (self.join_inames
                   & expr_inames - set(expn_state.arg_context))
        if overlap and self.within(expn_state.kernel, expn_state.instruction,
                                   expn_state.stack):
            if overlap != expr_inames:
                raise LoopyError(
                    "Cannot join inames '%s' if there is a reduction "
                    "that does not use all of the inames being joined. "
                    "(Found one with just '%s'.)" %
                    (", ".join(self.joined_inames), ", ".join(expr_inames)))

            new_inames = expr_inames - self.joined_inames
            new_inames.add(self.new_iname)

            from loopy.symbolic import Reduction
            return Reduction(expr.operation, tuple(new_inames),
                             self.rec(expr.expr, expn_state),
                             expr.allow_simultaneous)
        else:
            return super(_InameJoiner, self).map_reduction(expr, expn_state)
Ejemplo n.º 8
0
def make_einsum(spec, arg_names, **knl_creation_kwargs):
    r"""Returns a :class:`LoopKernel` for evaluating array-based
    operations using Einstein summation convention.

    :param spec: a string denoting the subscripts for
        summation as a comma-separated list of subscript labels.
        This follows the usual :func:`numpy.einsum` convention.
        Note that the explicit indicator `->` for the precise output
        form is required.
    :param arg_names: a sequence of string types denoting
        the names of the array operands.
    :param \**knl_creation_kwargs: keyword arguments for kernel creation.
        See :func:`make_kernel` for a list of acceptable keyword
        parameters.

    .. note::

        No attempt is being made to reduce the complexity
        of the resulting expression. This should be dealt with
        as part of a separate transformation.
    """
    arg_spec, out_spec = spec.split("->")
    arg_specs = arg_spec.split(",")

    if len(arg_names) != len(arg_specs):
        raise ValueError(
            f"Number of arg names ({arg_names}) should match the number "
            f"of arg specs: {arg_specs}. Length of arg names is {len(arg_names)}; "
            f"expecting {len(arg_specs)} arg names."
        )

    out_indices = set(out_spec)
    if len(out_indices) != len(out_spec):
        raise ValueError(
            f"Output subscripts '{out_spec}' does not contain all unique indices."
        )

    all_indices = {
        idx
        for argsp in arg_specs
        for idx in argsp} | out_indices

    sum_indices = all_indices - out_indices

    from pymbolic import var
    lhs = var("out")[tuple(var(i) for i in out_spec)]

    rhs = 1
    for arg_name, argsp in zip(arg_names, arg_specs):
        rhs = rhs * var(arg_name)[tuple(var(i) for i in argsp)]

    if sum_indices:
        rhs = Reduction("sum", tuple(var(idx) for idx in sum_indices), rhs)

    constraints = " and ".join(
        "0 <= %s < N%s" % (idx, idx)
        for idx in all_indices
        )

    if "name" not in knl_creation_kwargs:
        knl_creation_kwargs["name"] = "einsum%dto%d_kernel" % (
                len(all_indices), len(out_indices))

    return make_kernel("{[%s]: %s}" % (",".join(all_indices), constraints),
                       [Assignment(lhs, rhs)],
                       **knl_creation_kwargs)