def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr
Beispiel #2
0
def get_loopy_instructions_as_maxima(kernel, prefix):
    """Sample use for code comparison::

        load("knl-optFalse.mac");
        load("knl-optTrue.mac");

        vname: bessel_j_8;

        un_name : concat(''un_, vname);
        opt_name : concat(''opt_, vname);

        print(ratsimp(ev(un_name - opt_name)));
    """
    from loopy.preprocess import add_boostability_and_automatic_dependencies
    kernel = add_boostability_and_automatic_dependencies(kernel)

    my_variable_names = (avn for insn in kernel.instructions
                         for avn in insn.assignee_var_names())

    from pymbolic import var
    subst_dict = dict((vn, var(prefix + vn)) for vn in my_variable_names)

    mstr = MaximaStringifyMapper()
    from loopy.symbolic import SubstitutionMapper
    from pymbolic.mapper.substitutor import make_subst_func
    substitute = SubstitutionMapper(make_subst_func(subst_dict))

    result = ["ratprint:false;"]

    written_insn_ids = set()

    from loopy.kernel import InstructionBase, Assignment

    def write_insn(insn):
        if not isinstance(insn, InstructionBase):
            insn = kernel.id_to_insn[insn]
        if not isinstance(insn, Assignment):
            raise RuntimeError("non-single-output assignment not supported "
                               "in maxima export")

        for dep in insn.depends_on:
            if dep not in written_insn_ids:
                write_insn(dep)

        aname, = insn.assignee_var_names()
        result.append("%s%s : %s;" %
                      (prefix, aname, mstr(substitute(insn.expression))))

        written_insn_ids.add(insn.id)

    for insn in kernel.instructions:
        if insn.id not in written_insn_ids:
            write_insn(insn)

    return "\n".join(result)
Beispiel #3
0
    def process_expression_for_loopy(self, expr):
        from pymbolic.mapper.substitutor import make_subst_func
        from loopy.symbolic import SubstitutionMapper

        submap = SubstitutionMapper(make_subst_func(self.active_iname_aliases))

        expr = submap(expr)

        subshift = SubscriptIndexBaseShifter(self)
        expr = subshift(expr)

        return expr
Beispiel #4
0
def loopy_substitute(expression: Any,
                     variable_assigments: Mapping[str, Any]) -> Any:
    from loopy.symbolic import SubstitutionMapper
    from pymbolic.mapper.substitutor import make_subst_func

    # {{{ early exit for identity substitution

    if all(
            isinstance(v, prim.Variable) and v.name == k
            for k, v in variable_assigments.items()):
        # Nothing to do here, move on.
        return expression

    # }}}

    return SubstitutionMapper(make_subst_func(variable_assigments))(expression)
Beispiel #5
0
    def map_reduction(self, expr, expn_state):
        within = self.within(expn_state.kernel, expn_state.instruction,
                             expn_state.stack)

        for iname in expr.inames:
            self.iname_to_red_count[iname] = (
                self.iname_to_red_count.get(iname, 0) + 1)
            if not expr.allow_simultaneous:
                self.iname_to_nonsimultaneous_red_count[iname] = (
                    self.iname_to_nonsimultaneous_red_count.get(iname, 0) + 1)

        if within and not expr.allow_simultaneous:
            subst_dict = {}

            from pymbolic import var

            new_inames = []
            for iname in expr.inames:
                if (not (self.inames is None or iname in self.inames)
                        or self.iname_to_red_count[iname] <= 1):
                    new_inames.append(iname)
                    continue

                new_iname = self.rule_mapping_context.make_unique_var_name(
                    iname)
                subst_dict[iname] = var(new_iname)
                self.old_to_new.append((iname, new_iname))
                new_inames.append(new_iname)

            from loopy.symbolic import SubstitutionMapper
            from pymbolic.mapper.substitutor import make_subst_func

            from loopy.symbolic import Reduction
            return Reduction(
                expr.operation, tuple(new_inames),
                self.rec(
                    SubstitutionMapper(make_subst_func(subst_dict))(expr.expr),
                    expn_state), expr.allow_simultaneous)
        else:
            return super(_ReductionInameUniquifier,
                         self).map_reduction(expr, expn_state)
Beispiel #6
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    assert isinstance(kernel, LoopKernel)
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = {
                name: idx
                for idx, name in enumerate(var_descr.dim_names)
            }
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (vary_by_axes
                and (min(vary_by_axes) < 0
                     or max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum,
                                     flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
                                UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, _val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return var_name in insn.assignee_var_names()

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            yield from expr.children
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-single-assignment" %
                             var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = {
                    cf
                    for cf in my_common_factors
                    if unif_subst_map(cf) in product_parts
                }

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    common_factors = [(ik, cf) for ik, cf in common_factors if cf]

    if not common_factors:
        raise LoopyError("no common factors found")

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(insn.assignee)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        mapped_my_common_factors = {
            unif_subst_map(cf)
            for cf in my_common_factors
        }

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                flattened_product([
                    part for part in iterate_as(Product, term)
                    if part not in mapped_my_common_factors
                ]))

        new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                [unif_subst_map(cf) for cf in my_common_factors] + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Beispiel #7
0
    def __init__(self, kernel, domain, sweep_inames, access_descriptors,
                 storage_axis_count):
        self.kernel = kernel
        self.sweep_inames = sweep_inames

        storage_axis_names = self.storage_axis_names = [
            "_loopy_storage_%d" % i for i in range(storage_axis_count)
        ]

        # {{{ duplicate sweep inames

        # The duplication is necessary, otherwise the storage fetch
        # inames remain weirdly tied to the original sweep inames.

        self.primed_sweep_inames = [psin + "'" for psin in sweep_inames]

        from loopy.isl_helpers import duplicate_axes
        dup_sweep_index = domain.space.dim(dim_type.out)
        domain_dup_sweep = duplicate_axes(domain, sweep_inames,
                                          self.primed_sweep_inames)

        self.prime_sweep_inames = SubstitutionMapper(
            make_subst_func({
                sin: var(psin)
                for sin, psin in zip(sweep_inames, self.primed_sweep_inames)
            }))

        # # }}}

        self.stor2sweep = build_global_storage_to_sweep_map(
            kernel, access_descriptors, domain_dup_sweep, dup_sweep_index,
            storage_axis_names, sweep_inames, self.primed_sweep_inames,
            self.prime_sweep_inames)

        storage_base_indices, storage_shape = compute_bounds(
            kernel, domain, self.stor2sweep, self.primed_sweep_inames,
            storage_axis_names)

        # compute augmented domain

        # {{{ filter out unit-length dimensions

        non1_storage_axis_flags = []
        non1_storage_shape = []

        for saxis_len in storage_shape:
            has_length_non1 = saxis_len != 1

            non1_storage_axis_flags.append(has_length_non1)

            if has_length_non1:
                non1_storage_shape.append(saxis_len)

        # }}}

        # {{{ subtract off the base indices
        # add the new, base-0 indices as new in dimensions

        sp = self.stor2sweep.get_space()
        stor_idx = sp.dim(dim_type.out)

        n_stor = storage_axis_count
        nn1_stor = len(non1_storage_shape)

        aug_domain = self.stor2sweep.move_dims(dim_type.out, stor_idx,
                                               dim_type.in_, 0,
                                               n_stor).range()

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes']

        aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor)

        inew = 0
        for i, name in enumerate(storage_axis_names):
            if non1_storage_axis_flags[i]:
                aug_domain = aug_domain.set_dim_name(dim_type.set,
                                                     stor_idx + inew, name)
                inew += 1

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes]

        from loopy.symbolic import aff_from_expr
        for saxis, bi, s in zip(storage_axis_names, storage_base_indices,
                                storage_shape):
            if s != 1:
                cns = isl.Constraint.equality_from_aff(
                    aff_from_expr(aug_domain.get_space(),
                                  var(saxis) - (var(saxis + "'") - bi)))

                aug_domain = aug_domain.add_constraint(cns)

        # }}}

        # eliminate (primed) storage axes with non-zero base indices
        aug_domain = aug_domain.project_out(dim_type.set, stor_idx + nn1_stor,
                                            n_stor)

        # eliminate duplicated sweep_inames
        nsweep = len(sweep_inames)
        aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index,
                                            nsweep)

        self.non1_storage_axis_flags = non1_storage_axis_flags
        self.aug_domain = aug_domain
        self.storage_base_indices = storage_base_indices
        self.non1_storage_shape = non1_storage_shape
Beispiel #8
0
def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
        footprint_subscripts, arg):
    """Track applied iname rewrites, deal with slice specifiers ':'."""

    name_gen = kernel.get_var_name_generator()

    from pymbolic.primitives import Variable

    if footprint_subscripts is None:
        return kernel, rule_name, sweep_inames, []

    if not isinstance(footprint_subscripts, (list, tuple)):
        footprint_subscripts = [footprint_subscripts]

    inames_to_be_removed = []

    new_footprint_subscripts = []
    for fsub in footprint_subscripts:
        if isinstance(fsub, str):
            from loopy.symbolic import parse
            fsub = parse(fsub)

        if not isinstance(fsub, tuple):
            fsub = (fsub,)

        if len(fsub) != arg.num_user_axes():
            raise ValueError("sweep index '%s' has the wrong number of dimensions"
                    % str(fsub))

        for subst_map in kernel.applied_iname_rewrites:
            from loopy.symbolic import SubstitutionMapper
            from pymbolic.mapper.substitutor import make_subst_func
            fsub = SubstitutionMapper(make_subst_func(subst_map))(fsub)

        from loopy.symbolic import get_dependencies
        fsub_dependencies = get_dependencies(fsub)

        new_fsub = []
        for axis_nr, fsub_axis in enumerate(fsub):
            from pymbolic.primitives import Slice
            if isinstance(fsub_axis, Slice):
                if fsub_axis.children != (None,):
                    raise NotImplementedError("add_prefetch only "
                            "supports full slices")

                axis_name = name_gen(
                        based_on="%s_fetch_axis_%d" % (arg.name, axis_nr))

                kernel = _add_kernel_axis(kernel, axis_name, 0, arg.shape[axis_nr],
                        frozenset(sweep_inames) | fsub_dependencies)
                sweep_inames = sweep_inames + [axis_name]

                inames_to_be_removed.append(axis_name)
                new_fsub.append(Variable(axis_name))

            else:
                new_fsub.append(fsub_axis)

        new_footprint_subscripts.append(tuple(new_fsub))
        del new_fsub

    footprint_subscripts = new_footprint_subscripts
    del new_footprint_subscripts

    subst_use = [Variable(rule_name)(*si) for si in footprint_subscripts]
    return kernel, subst_use, sweep_inames, inames_to_be_removed
Beispiel #9
0
    def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
                           lhs_expr, rhs_expr, lhs_dtype, rhs_type_context):
        from pymbolic.mapper.stringifier import PREC_NONE

        # FIXME: Could detect operations, generate atomic_{add,...} when
        # appropriate.

        if isinstance(lhs_dtype, NumpyType) and lhs_dtype.numpy_dtype in [
                np.int32, np.int64, np.float32, np.float64
        ]:
            from cgen import Block, DoWhile, Assign
            from loopy.target.c import POD
            old_val_var = codegen_state.var_name_generator("loopy_old_val")
            new_val_var = codegen_state.var_name_generator("loopy_new_val")

            from loopy.kernel.data import TemporaryVariable, AddressSpace
            ecm = codegen_state.expression_to_code_mapper.with_assignments({
                old_val_var:
                TemporaryVariable(old_val_var, lhs_dtype),
                new_val_var:
                TemporaryVariable(new_val_var, lhs_dtype),
            })

            lhs_expr_code = ecm(lhs_expr, prec=PREC_NONE, type_context=None)

            from pymbolic.mapper.substitutor import make_subst_func
            from pymbolic import var
            from loopy.symbolic import SubstitutionMapper

            subst = SubstitutionMapper(
                make_subst_func({lhs_expr: var(old_val_var)}))
            rhs_expr_code = ecm(subst(rhs_expr),
                                prec=PREC_NONE,
                                type_context=rhs_type_context,
                                needed_dtype=lhs_dtype)

            if lhs_dtype.numpy_dtype.itemsize == 4:
                func_name = "atomic_cmpxchg"
            elif lhs_dtype.numpy_dtype.itemsize == 8:
                func_name = "atom_cmpxchg"
            else:
                raise LoopyError("unexpected atomic size")

            cast_str = ""
            old_val = old_val_var
            new_val = new_val_var

            if lhs_dtype.numpy_dtype.kind == "f":
                if lhs_dtype.numpy_dtype == np.float32:
                    ctype = "int"
                elif lhs_dtype.numpy_dtype == np.float64:
                    ctype = "long"
                else:
                    assert False

                from loopy.kernel.data import (TemporaryVariable, ArrayArg)
                if (isinstance(lhs_var, ArrayArg)
                        and lhs_var.address_space == AddressSpace.GLOBAL):
                    var_kind = "__global"
                elif (isinstance(lhs_var, ArrayArg)
                      and lhs_var.address_space == AddressSpace.LOCAL):
                    var_kind = "__local"
                elif (isinstance(lhs_var, TemporaryVariable)
                      and lhs_var.address_space == AddressSpace.LOCAL):
                    var_kind = "__local"
                elif (isinstance(lhs_var, TemporaryVariable)
                      and lhs_var.address_space == AddressSpace.GLOBAL):
                    var_kind = "__global"
                else:
                    raise LoopyError("unexpected kind of variable '%s' in "
                                     "atomic operation: " %
                                     (lhs_var.name, type(lhs_var).__name__))

                old_val = "*(%s *) &" % ctype + old_val
                new_val = "*(%s *) &" % ctype + new_val
                cast_str = "(%s %s *) " % (var_kind, ctype)

            return Block([
                POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                    old_val_var),
                POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                    new_val_var),
                DoWhile(
                    "%(func_name)s("
                    "%(cast_str)s&(%(lhs_expr)s), "
                    "%(old_val)s, "
                    "%(new_val)s"
                    ") != %(old_val)s" % {
                        "func_name": func_name,
                        "cast_str": cast_str,
                        "lhs_expr": lhs_expr_code,
                        "old_val": old_val,
                        "new_val": new_val,
                    },
                    Block([
                        Assign(old_val_var, lhs_expr_code),
                        Assign(new_val_var, rhs_expr_code),
                    ]))
            ])
        else:
            raise NotImplementedError("atomic update for '%s'" % lhs_dtype)
Beispiel #10
0
def _fix_parameter(kernel, name, value, within=None):
    def process_set(s):
        var_dict = s.get_var_dict()

        try:
            dt, idx = var_dict[name]
        except KeyError:
            return s

        value_aff = isl.Aff.zero_on_domain(s.space) + value

        from loopy.isl_helpers import iname_rel_aff
        name_equal_value_aff = iname_rel_aff(s.space, name, "==", value_aff)

        s = (s.add_constraint(
            isl.Constraint.equality_from_aff(
                name_equal_value_aff)).project_out(dt, idx, 1))

        return s

    new_domains = [process_set(dom) for dom in kernel.domains]

    from pymbolic.mapper.substitutor import make_subst_func
    subst_func = make_subst_func({name: value})

    from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper
    subst_map = SubstitutionMapper(subst_func)
    ev_map = PartialEvaluationMapper()

    def map_expr(expr):
        return ev_map(subst_map(expr))

    from loopy.kernel.array import ArrayBase
    new_args = []
    for arg in kernel.args:
        if arg.name == name:
            # remove from argument list
            continue

        if not isinstance(arg, ArrayBase):
            new_args.append(arg)
        else:
            new_args.append(arg.map_exprs(map_expr))

    new_temp_vars = {}
    for tv in kernel.temporary_variables.values():
        new_temp_vars[tv.name] = tv.map_exprs(map_expr)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, kernel.get_var_name_generator())
    esubst_map = RuleAwareSubstitutionMapper(rule_mapping_context,
                                             subst_func,
                                             within=within)
    return (rule_mapping_context.finish_kernel(
        esubst_map.map_kernel(
            kernel,
            within=within,
            # overwritten below, no need to map
            map_tvs=False,
            map_args=False)).copy(
                domains=new_domains,
                args=new_args,
                temporary_variables=new_temp_vars,
                assumptions=process_set(kernel.assumptions),
            ))
Beispiel #11
0
    def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
                           lhs_expr, rhs_expr, lhs_dtype, rhs_type_context):

        from pymbolic.primitives import Sum
        from cgen import Statement
        from pymbolic.mapper.stringifier import PREC_NONE

        if isinstance(lhs_dtype, NumpyType) and lhs_dtype.numpy_dtype in [
                np.int32, np.int64, np.float32, np.float64
        ]:
            # atomicAdd
            if isinstance(rhs_expr, Sum):
                ecm = self.get_expression_to_code_mapper(codegen_state)

                new_rhs_expr = Sum(
                    tuple(c for c in rhs_expr.children if c != lhs_expr))
                lhs_expr_code = ecm(lhs_expr)
                rhs_expr_code = ecm(new_rhs_expr)

                return Statement("atomicAdd(&{}, {})".format(
                    lhs_expr_code, rhs_expr_code))
            else:
                from cgen import Block, DoWhile, Assign
                from loopy.target.c import POD
                old_val_var = codegen_state.var_name_generator("loopy_old_val")
                new_val_var = codegen_state.var_name_generator("loopy_new_val")

                from loopy.kernel.data import TemporaryVariable
                ecm = codegen_state.expression_to_code_mapper.with_assignments(
                    {
                        old_val_var: TemporaryVariable(old_val_var, lhs_dtype),
                        new_val_var: TemporaryVariable(new_val_var, lhs_dtype),
                    })

                lhs_expr_code = ecm(lhs_expr,
                                    prec=PREC_NONE,
                                    type_context=None)

                from pymbolic.mapper.substitutor import make_subst_func
                from pymbolic import var
                from loopy.symbolic import SubstitutionMapper

                subst = SubstitutionMapper(
                    make_subst_func({lhs_expr: var(old_val_var)}))
                rhs_expr_code = ecm(subst(rhs_expr),
                                    prec=PREC_NONE,
                                    type_context=rhs_type_context,
                                    needed_dtype=lhs_dtype)

                cast_str = ""
                old_val = old_val_var
                new_val = new_val_var

                if lhs_dtype.numpy_dtype.kind == "f":
                    if lhs_dtype.numpy_dtype == np.float32:
                        ctype = "int"
                    elif lhs_dtype.numpy_dtype == np.float64:
                        ctype = "long"
                    else:
                        raise AssertionError()

                    old_val = "*(%s *) &" % ctype + old_val
                    new_val = "*(%s *) &" % ctype + new_val
                    cast_str = "(%s *) " % (ctype)

                return Block([
                    POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                        old_val_var),
                    POD(self, NumpyType(lhs_dtype.dtype, target=self.target),
                        new_val_var),
                    DoWhile(
                        "atomicCAS("
                        "%(cast_str)s&(%(lhs_expr)s), "
                        "%(old_val)s, "
                        "%(new_val)s"
                        ") != %(old_val)s" % {
                            "cast_str": cast_str,
                            "lhs_expr": lhs_expr_code,
                            "old_val": old_val,
                            "new_val": new_val,
                        },
                        Block([
                            Assign(old_val_var, lhs_expr_code),
                            Assign(new_val_var, rhs_expr_code),
                        ]))
                ])
        else:
            raise NotImplementedError("atomic update for '%s'" % lhs_dtype)
Beispiel #12
0
def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
        store_expression=None, within=None, default_tag="l.auto",
        temporary_scope=None, temporary_is_local=None,
        fetch_bounding_box=False):
    """Replace accesses to *var_name* with ones to a temporary, which is
    created and acts as a buffer. To perform this transformation, the access
    footprint to *var_name* is determined and a temporary of a suitable
    :class:`loopy.AddressSpace` and shape is created.

    By default, the value of the buffered cells in *var_name* are read prior to
    any (read/write) use, and the modified values are written out after use has
    concluded, but for special use cases (e.g. additive accumulation), the
    behavior can be modified using *init_expression* and *store_expression*.

    :arg buffer_inames: The inames across which the buffer should be usable--i.e.
        all possible values of these inames will be covered by the buffer footprint.
        A tuple of inames or a comma-separated string.
    :arg init_expression: Either *None* (indicating the prior value of the buffered
        array should be read) or an expression optionally involving the
        variable 'base' (which references the associated location in the array
        being buffered).
    :arg store_expression: Either *None*, *False*, or an expression involving
        variables 'base' and 'buffer' (without array indices).
        (*None* indicates that a default storage instruction should be used,
        *False* indicates that no storing of the temporary should occur
        at all.)
    :arg within: If not None, limit the action of the transformation to
        matching contexts.  See :func:`loopy.match.parse_stack_match`
        for syntax.
    :arg temporary_scope: If given, override the choice of
        :class:`AddressSpace` for the created temporary.
    :arg default_tag: The default :ref:`iname-tags` to be assigned to the
        inames used for fetching and storing
    :arg fetch_bounding_box: If the access footprint is non-convex
        (resulting in an error), setting this argument to *True* will force a
        rectangular (and hence convex) superset of the footprint to be
        fetched.
    """

    # {{{ unify temporary_scope / temporary_is_local

    from loopy.kernel.data import AddressSpace
    if temporary_is_local is not None:
        from warnings import warn
        warn("temporary_is_local is deprecated. Use temporary_scope instead",
                DeprecationWarning, stacklevel=2)

        if temporary_scope is not None:
            raise LoopyError("may not specify both temporary_is_local and "
                    "temporary_scope")

        if temporary_is_local:
            temporary_scope = AddressSpace.LOCAL
        else:
            temporary_scope = AddressSpace.PRIVATE

    del temporary_is_local

    # }}}

    # {{{ process arguments

    if isinstance(init_expression, str):
        from loopy.symbolic import parse
        init_expression = parse(init_expression)

    if isinstance(store_expression, str):
        from loopy.symbolic import parse
        store_expression = parse(store_expression)

    if isinstance(buffer_inames, str):
        buffer_inames = [s.strip()
                for s in buffer_inames.split(",") if s.strip()]

    for iname in buffer_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname"
                    % iname)

    buffer_inames = list(buffer_inames)
    buffer_inames_set = frozenset(buffer_inames)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    if var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    elif var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    else:
        raise ValueError("variable '%s' not found" % var_name)

    from loopy.kernel.data import ArrayBase
    if isinstance(var_descr, ArrayBase):
        var_shape = var_descr.shape
    else:
        var_shape = ()

    if temporary_scope is None:
        import loopy as lp
        temporary_scope = lp.auto

    # }}}

    # {{{ caching

    from loopy import CACHING_ENABLED

    from loopy.preprocess import prepare_for_caching
    key_kernel = prepare_for_caching(kernel)
    cache_key = (key_kernel, var_name, tuple(buffer_inames),
            PymbolicExpressionHashWrapper(init_expression),
            PymbolicExpressionHashWrapper(store_expression), within,
            default_tag, temporary_scope, fetch_bounding_box)

    if CACHING_ENABLED:
        try:
            result = buffer_array_cache[cache_key]
            logger.info("%s: buffer_array cache hit" % kernel.name)
            return result
        except KeyError:
            pass

    # }}}

    var_name_gen = kernel.get_var_name_generator()
    within_inames = set()

    access_descriptors = []
    for insn in kernel.instructions:
        if not within(kernel, insn.id, ()):
            continue

        from pymbolic.primitives import Variable, Subscript
        from loopy.symbolic import LinearSubscript

        for assignee in insn.assignees:
            if isinstance(assignee, Variable):
                assignee_name = assignee.name
                index = ()

            elif isinstance(assignee, Subscript):
                assignee_name = assignee.aggregate.name
                index = assignee.index_tuple

            elif isinstance(assignee, LinearSubscript):
                if assignee.aggregate.name == var_name:
                    raise LoopyError("buffer_array may not be applied in the "
                            "presence of linear write indexing into '%s'" % var_name)

            else:
                raise LoopyError("invalid lvalue '%s'" % assignee)

            if assignee_name == var_name:
                within_inames.update(
                        (get_dependencies(index) & kernel.all_inames())
                        - buffer_inames_set)
                access_descriptors.append(
                        AccessDescriptor(
                            identifier=insn.id,
                            storage_axis_exprs=index))

    # {{{ find fetch/store inames

    init_inames = []
    store_inames = []
    new_iname_to_tag = {}

    for i in range(len(var_shape)):
        dim_name = str(i)
        if isinstance(var_descr, ArrayBase) and var_descr.dim_names is not None:
            dim_name = var_descr.dim_names[i]

        init_iname = var_name_gen(f"{var_name}_init_{dim_name}")
        store_iname = var_name_gen(f"{var_name}_store_{dim_name}")

        new_iname_to_tag[init_iname] = default_tag
        new_iname_to_tag[store_iname] = default_tag

        init_inames.append(init_iname)
        store_inames.append(store_iname)

    # }}}

    # {{{ modify loop domain

    non1_init_inames = []
    non1_store_inames = []

    if var_shape:
        # {{{ find domain to be changed

        from loopy.kernel.tools import DomainChanger
        domch = DomainChanger(kernel, buffer_inames_set | within_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in buffer_inames_set:
                if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                    raise RuntimeError("buffer iname '%s' is not 'at home' in the "
                            "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, buffer_inames,
                access_descriptors, len(var_shape))

        for i in range(len(var_shape)):
            if abm.non1_storage_axis_flags[i]:
                non1_init_inames.append(init_inames[i])
                non1_store_inames.append(store_inames[i])
            else:
                del new_iname_to_tag[init_inames[i]]
                del new_iname_to_tag[store_inames[i]]

        new_domain = domch.domain
        new_domain = abm.augment_domain_with_sweep(
                    new_domain, non1_init_inames,
                    boxify_sweep=fetch_bounding_box)
        new_domain = abm.augment_domain_with_sweep(
                    new_domain, non1_store_inames,
                    boxify_sweep=fetch_bounding_box)
        new_kernel_domains = domch.get_domains_with(new_domain)
        del new_domain

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        abm = NoOpArrayToBufferMap()

    # }}}

    # {{{ set up temp variable

    import loopy as lp

    buf_var_name = var_name_gen(based_on=var_name+"_buf")

    new_temporary_variables = kernel.temporary_variables.copy()
    temp_var = lp.TemporaryVariable(
            name=buf_var_name,
            dtype=var_descr.dtype,
            base_indices=(0,)*len(abm.non1_storage_shape),
            shape=tuple(abm.non1_storage_shape),
            address_space=temporary_scope)

    new_temporary_variables[buf_var_name] = temp_var

    # }}}

    new_insns = []

    buf_var = var(buf_var_name)

    # {{{ generate init instruction

    buf_var_init = buf_var
    if non1_init_inames:
        buf_var_init = buf_var_init.index(
                tuple(var(iname) for iname in non1_init_inames))

    init_base = var(var_name)

    init_subscript = []
    init_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_init_inames[init_iname_idx])
                init_iname_idx += 1
            init_subscript.append(ax_subscript)

    if init_subscript:
        init_base = init_base.index(tuple(init_subscript))

    if init_expression is None:
        init_expression = init_base
    else:
        init_expression = init_expression
        init_expression = SubstitutionMapper(
                make_subst_func({
                    "base": init_base,
                    }))(init_expression)

    init_insn_id = kernel.make_unique_instruction_id(based_on="init_"+var_name)
    from loopy.kernel.data import Assignment
    init_instruction = Assignment(id=init_insn_id,
                assignee=buf_var_init,
                expression=init_expression,
                within_inames=(
                    frozenset(within_inames)
                    | frozenset(non1_init_inames)),
                depends_on=frozenset(),
                depends_on_is_final=True)

    # }}}

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())
    aar = ArrayAccessReplacer(rule_mapping_context, var_name,
            within, abm, buf_var)
    kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel))

    did_write = False
    for insn_id in aar.modified_insn_ids:
        insn = kernel.id_to_insn[insn_id]
        if buf_var_name in insn.assignee_var_names():
            did_write = True

    # {{{ add init_insn_id to depends_on

    new_insns = []

    def none_to_empty_set(s):
        if s is None:
            return frozenset()
        else:
            return s

    for insn in kernel.instructions:
        if insn.id in aar.modified_insn_ids:
            new_insns.append(
                    insn.copy(
                        depends_on=(
                            none_to_empty_set(insn.depends_on)
                            | frozenset([init_insn_id]))))
        else:
            new_insns.append(insn)

    # }}}

    # {{{ generate store instruction

    buf_var_store = buf_var
    if non1_store_inames:
        buf_var_store = buf_var_store.index(
                tuple(var(iname) for iname in non1_store_inames))

    store_subscript = []
    store_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_store_inames[store_iname_idx])
                store_iname_idx += 1
            store_subscript.append(ax_subscript)

    store_target = var(var_name)
    if store_subscript:
        store_target = store_target.index(tuple(store_subscript))

    if store_expression is None:
        store_expression = buf_var_store
    else:
        store_expression = SubstitutionMapper(
                make_subst_func({
                    "base": store_target,
                    "buffer": buf_var_store,
                    }))(store_expression)

    if store_expression is not False:
        from loopy.kernel.data import Assignment
        store_instruction = Assignment(
                    id=kernel.make_unique_instruction_id(based_on="store_"+var_name),
                    depends_on=frozenset(aar.modified_insn_ids),
                    no_sync_with=frozenset([(init_insn_id, "any")]),
                    assignee=store_target,
                    expression=store_expression,
                    within_inames=(
                        frozenset(within_inames)
                        | frozenset(non1_store_inames)))
    else:
        did_write = False

    # }}}

    new_insns.append(init_instruction)
    if did_write:
        new_insns.append(store_instruction)
    else:
        for iname in store_inames:
            del new_iname_to_tag[iname]

    kernel = kernel.copy(
            domains=new_kernel_domains,
            instructions=new_insns,
            temporary_variables=new_temporary_variables)

    from loopy import tag_inames
    kernel = tag_inames(kernel, new_iname_to_tag)

    from loopy.kernel.tools import assign_automatic_axes
    kernel = assign_automatic_axes(kernel)

    if CACHING_ENABLED:
        from loopy.preprocess import prepare_for_caching
        buffer_array_cache.store_if_not_present(
                cache_key, prepare_for_caching(kernel))

    return kernel
def pack_and_unpack_args_for_call_for_single_kernel(kernel,
                                                    callables_table,
                                                    call_name,
                                                    args_to_pack=None,
                                                    args_to_unpack=None):
    """
    Returns a a copy of *kernel* with instructions appended to copy the
    arguments in *args* to match the alignment expected by the *call_name* in
    the kernel. The arguments are copied back to *args* with the appropriate
    data layout.

    :arg call_name: An instance of :class:`str` denoting the function call in
        the *kernel*.
    :arg args_to_unpack: A list of the arguments as instances of :class:`str` which
        must be packed. If set *None*, it is interpreted that all the array
        arguments would be packed.
    :arg args_to_unpack: A list of the arguments as instances of :class:`str`
        which must be unpacked. If set *None*, it is interpreted that
        all the array arguments should be unpacked.
    """
    assert isinstance(kernel, LoopKernel)
    new_domains = []
    new_tmps = kernel.temporary_variables.copy()
    old_insn_to_new_insns = {}

    for insn in kernel.instructions:
        if not isinstance(insn, CallInstruction):
            # pack and unpack call only be done for CallInstructions.
            continue
        if insn.expression.function.name not in callables_table:
            continue

        in_knl_callable = callables_table[insn.expression.function.name]

        if in_knl_callable.name != call_name:
            # not the function we're looking for.
            continue
        in_knl_callable = in_knl_callable.with_packing_for_args()

        vng = kernel.get_var_name_generator()
        ing = kernel.get_instruction_id_generator()

        parameters = insn.expression.parameters
        if args_to_pack is None:
            args_to_pack = [
                par.subscript.aggregate.name
                for par in parameters + insn.assignees
                if isinstance(par, SubArrayRef) and (par.swept_inames)
            ]
        if args_to_unpack is None:
            args_to_unpack = [
                par.subscript.aggregate.name
                for par in parameters + insn.assignees
                if isinstance(par, SubArrayRef) and (par.swept_inames)
            ]

        # {{{ sanity checks for args

        assert isinstance(args_to_pack, list)
        assert isinstance(args_to_unpack, list)

        for arg in args_to_pack:
            found_sub_array_ref = False

            for par in parameters + insn.assignees:
                # checking that the given args is a sub array ref
                if isinstance(par,
                              SubArrayRef) and (par.subscript.aggregate.name
                                                == arg):
                    found_sub_array_ref = True
                    break
            if not found_sub_array_ref:
                raise LoopyError(
                    "No match found for packing arg '%s' of call '%s' "
                    "at insn '%s'." % (arg, call_name, insn.id))
        for arg in args_to_unpack:
            if arg not in args_to_pack:
                raise LoopyError("Argument %s should be packed in order to be "
                                 "unpacked." % arg)

        # }}}

        packing_insns = []
        unpacking_insns = []

        # {{{ handling ilp tags

        from loopy.kernel.data import IlpBaseTag, VectorizeTag
        import islpy as isl
        from pymbolic import var

        dim_type = isl.dim_type.set
        ilp_inames = {
            iname
            for iname in insn.within_inames if all(
                isinstance(tag, (IlpBaseTag, VectorizeTag))
                for tag in kernel.iname_to_tags.get(iname, []))
        }
        new_ilp_inames = set()
        ilp_inames_map = {}
        for iname in ilp_inames:
            new_iname_name = vng(iname + "_ilp")
            ilp_inames_map[var(iname)] = var(new_iname_name)
            new_ilp_inames.add(new_iname_name)
        for iname in ilp_inames:
            new_domain = kernel.get_inames_domain(iname).copy()
            for i in range(new_domain.n_dim()):
                old_iname = new_domain.get_dim_name(dim_type, i)
                if old_iname in ilp_inames:
                    new_domain = new_domain.set_dim_name(
                        dim_type, i, ilp_inames_map[var(old_iname)].name)
            new_domains.append(new_domain)

        # }}}

        from pymbolic.mapper.substitutor import make_subst_func
        from loopy.symbolic import SubstitutionMapper

        # dict to store the new assignees and parameters, the mapping pattern
        # from arg_id to parameters is identical to InKernelCallable.arg_id_to_dtype
        id_to_parameters = tuple(enumerate(parameters)) + tuple(
            (-i - 1, assignee) for i, assignee in enumerate(insn.assignees))
        new_id_to_parameters = {}

        for arg_id, p in id_to_parameters:
            if isinstance(p, SubArrayRef) and (p.subscript.aggregate.name
                                               in args_to_pack):
                new_pack_inames = ilp_inames_map.copy(
                )  # packing-specific inames
                new_unpack_inames = ilp_inames_map.copy(
                )  # unpacking-specific iname

                new_pack_inames = {
                    iname: var(vng(iname.name + "_pack"))
                    for iname in p.swept_inames
                }
                new_unpack_inames = {
                    iname: var(vng(iname.name + "_unpack"))
                    for iname in p.swept_inames
                }

                # Updating the domains corresponding to the new inames.
                for iname in p.swept_inames:
                    new_domain_pack = kernel.get_inames_domain(
                        iname.name).copy()
                    new_domain_unpack = kernel.get_inames_domain(
                        iname.name).copy()
                    for i in range(new_domain_pack.n_dim()):
                        old_iname = new_domain_pack.get_dim_name(dim_type, i)
                        if var(old_iname) in new_pack_inames:
                            new_domain_pack = new_domain_pack.set_dim_name(
                                dim_type, i,
                                new_pack_inames[var(old_iname)].name)
                            new_domain_unpack = new_domain_unpack.set_dim_name(
                                dim_type, i,
                                new_unpack_inames[var(old_iname)].name)
                    new_domains.append(new_domain_pack)
                    new_domains.append(new_domain_unpack)

                arg = p.subscript.aggregate.name
                pack_name = vng(arg + "_pack")

                from loopy.kernel.data import (TemporaryVariable,
                                               temp_var_scope)

                if arg in kernel.arg_dict:
                    arg_in_caller = kernel.arg_dict[arg]
                else:
                    arg_in_caller = kernel.temporary_variables[arg]

                pack_tmp = TemporaryVariable(
                    name=pack_name,
                    dtype=arg_in_caller.dtype,
                    dim_tags=in_knl_callable.arg_id_to_descr[arg_id].dim_tags,
                    shape=in_knl_callable.arg_id_to_descr[arg_id].shape,
                    scope=temp_var_scope.PRIVATE,
                )

                new_tmps[pack_name] = pack_tmp

                from loopy import Assignment
                pack_subst_mapper = SubstitutionMapper(
                    make_subst_func(new_pack_inames))
                unpack_subst_mapper = SubstitutionMapper(
                    make_subst_func(new_unpack_inames))

                # {{{ getting the lhs for packing and rhs for unpacking

                from loopy.isl_helpers import simplify_via_aff, make_slab

                flatten_index = simplify_via_aff(
                    sum(dim_tag.stride * idx for dim_tag, idx in zip(
                        arg_in_caller.dim_tags, p.subscript.index_tuple)))

                new_indices = []
                for dim_tag in in_knl_callable.arg_id_to_descr[
                        arg_id].dim_tags:
                    ind = flatten_index // dim_tag.stride
                    flatten_index -= (dim_tag.stride * ind)
                    new_indices.append(ind)

                new_indices = tuple(simplify_via_aff(i) for i in new_indices)

                pack_lhs_assignee = pack_subst_mapper(
                    var(pack_name).index(new_indices))
                unpack_rhs = unpack_subst_mapper(
                    var(pack_name).index(new_indices))

                # }}}

                packing_insns.append(
                    Assignment(
                        assignee=pack_lhs_assignee,
                        expression=pack_subst_mapper.map_subscript(
                            p.subscript),
                        within_inames=insn.within_inames - ilp_inames
                        | {new_pack_inames[i].name
                           for i in p.swept_inames} | (new_ilp_inames),
                        depends_on=insn.depends_on,
                        id=ing(insn.id + "_pack"),
                        depends_on_is_final=True))

                if p.subscript.aggregate.name in args_to_unpack:
                    unpacking_insns.append(
                        Assignment(
                            expression=unpack_rhs,
                            assignee=unpack_subst_mapper.map_subscript(
                                p.subscript),
                            within_inames=insn.within_inames - ilp_inames | {
                                new_unpack_inames[i].name
                                for i in p.swept_inames
                            } | (new_ilp_inames),
                            id=ing(insn.id + "_unpack"),
                            depends_on=frozenset([insn.id]),
                            depends_on_is_final=True))

                # {{{ creating the sweep inames for the new sub array refs

                updated_swept_inames = []

                for _ in in_knl_callable.arg_id_to_descr[arg_id].shape:
                    updated_swept_inames.append(var(vng("i_packsweep_" + arg)))

                ctx = kernel.isl_context
                space = isl.Space.create_from_names(
                    ctx, set=[iname.name for iname in updated_swept_inames])
                iname_set = isl.BasicSet.universe(space)
                for iname, axis_length in zip(
                        updated_swept_inames,
                        in_knl_callable.arg_id_to_descr[arg_id].shape):
                    iname_set = iname_set & make_slab(space, iname.name, 0,
                                                      axis_length)
                new_domains = new_domains + [iname_set]

                # }}}

                new_id_to_parameters[arg_id] = SubArrayRef(
                    tuple(updated_swept_inames),
                    (var(pack_name).index(tuple(updated_swept_inames))))
            else:
                new_id_to_parameters[arg_id] = p

        if packing_insns:
            subst_mapper = SubstitutionMapper(make_subst_func(ilp_inames_map))
            new_call_insn = insn.with_transformed_expressions(subst_mapper)
            new_params = tuple(
                subst_mapper(new_id_to_parameters[i])
                for i, _ in enumerate(parameters))
            new_assignees = tuple(
                subst_mapper(new_id_to_parameters[-i - 1])
                for i, _ in enumerate(insn.assignees))
            new_call_insn = new_call_insn.copy(
                depends_on=new_call_insn.depends_on
                | {pack.id
                   for pack in packing_insns},
                within_inames=new_call_insn.within_inames - ilp_inames |
                (new_ilp_inames),
                expression=new_call_insn.expression.function(*new_params),
                assignees=new_assignees)
            old_insn_to_new_insns[insn.id] = (packing_insns + [new_call_insn] +
                                              unpacking_insns)

    if old_insn_to_new_insns:
        new_instructions = []
        for insn in kernel.instructions:
            if insn.id in old_insn_to_new_insns:
                # Replacing the current instruction with the group of
                # instructions including the packing and unpacking instructions
                new_instructions.extend(old_insn_to_new_insns[insn.id])
            else:
                # for the instructions that depend on the call instruction that
                # are to be packed and unpacked, we need to add the complete
                # instruction block as a dependency for them.
                new_depends_on = insn.depends_on
                if insn.depends_on & set(old_insn_to_new_insns):
                    # need to add the unpack instructions on dependencies.
                    for old_insn_id in insn.depends_on & set(
                            old_insn_to_new_insns):
                        new_depends_on |= frozenset(
                            i.id for i in old_insn_to_new_insns[old_insn_id])
                new_instructions.append(insn.copy(depends_on=new_depends_on))
        kernel = kernel.copy(domains=kernel.domains + new_domains,
                             instructions=new_instructions,
                             temporary_variables=new_tmps)

    return kernel
Beispiel #14
0
def extract_subst(kernel, subst_name, template, parameters=()):
    """
    :arg subst_name: The name of the substitution rule to be created.
    :arg template: Unification template expression.
    :arg parameters: An iterable of parameters used in
        *template*, or a comma-separated string of the same.

    All targeted subexpressions must match ('unify with') *template*
    The template may contain '*' wildcards that will have to match exactly across all
    unifications.
    """

    if isinstance(template, str):
        from pymbolic import parse
        template = parse(template)

    if isinstance(parameters, str):
        parameters = tuple(
                s.strip() for s in parameters.split(","))

    var_name_gen = kernel.get_var_name_generator()

    # {{{ replace any wildcards in template with new variables

    def get_unique_var_name():
        based_on = subst_name+"_wc"

        result = var_name_gen(based_on)
        return result

    from loopy.symbolic import WildcardToUniqueVariableMapper
    wc_map = WildcardToUniqueVariableMapper(get_unique_var_name)
    template = wc_map(template)

    # }}}

    # {{{ deal with iname deps of template that are not independent_inames

    # (We call these 'matching_vars', because they have to match exactly in
    # every CSE. As above, they might need to be renamed to make them unique
    # within the kernel.)

    matching_vars = []
    old_to_new = {}

    for iname in (get_dependencies(template)
            - set(parameters)
            - kernel.non_iname_variable_names()):
        if iname in kernel.all_inames():
            # need to rename to be unique
            new_iname = var_name_gen(iname)
            old_to_new[iname] = var(new_iname)
            matching_vars.append(new_iname)
        else:
            matching_vars.append(iname)

    if old_to_new:
        template = (
                SubstitutionMapper(make_subst_func(old_to_new))
                (template))

    # }}}

    # {{{ gather up expressions

    expr_descriptors = []

    from loopy.symbolic import UnidirectionalUnifier
    unif = UnidirectionalUnifier(
            lhs_mapping_candidates=set(parameters) | set(matching_vars))

    def gather_exprs(expr, mapper):
        urecs = unif(template, expr)

        if urecs:
            if len(urecs) > 1:
                raise RuntimeError("ambiguous unification of '%s' with template '%s'"
                        % (expr, template))

            urec, = urecs

            expr_descriptors.append(
                    ExprDescriptor(
                        insn=insn,
                        expr=expr,
                        unif_var_dict=dict((lhs.name, rhs)
                            for lhs, rhs in urec.equations)))
        else:
            mapper.fallback_mapper(expr)
            # can't nest, don't recurse

    from loopy.symbolic import (
            CallbackMapper, WalkMapper, IdentityMapper)
    dfmapper = CallbackMapper(gather_exprs, WalkMapper())

    for insn in kernel.instructions:
        dfmapper(insn.expression)

    for sr in six.itervalues(kernel.substitutions):
        dfmapper(sr.expression)

    # }}}

    if not expr_descriptors:
        raise RuntimeError("no expressions matching '%s'" % template)

    # {{{ substitute rule into instructions

    def replace_exprs(expr, mapper):
        found = False
        for exprd in expr_descriptors:
            if expr is exprd.expr:
                found = True
                break

        if not found:
            return mapper.fallback_mapper(expr)

        args = [exprd.unif_var_dict[arg_name]
                for arg_name in parameters]

        result = var(subst_name)
        if args:
            result = result(*args)

        return result
        # can't nest, don't recurse

    cbmapper = CallbackMapper(replace_exprs, IdentityMapper())

    new_insns = []

    for insn in kernel.instructions:
        new_expr = cbmapper(insn.expression)
        new_insns.append(insn.copy(expression=new_expr))

    from loopy.kernel.data import SubstitutionRule
    new_substs = {
            subst_name: SubstitutionRule(
                name=subst_name,
                arguments=tuple(parameters),
                expression=template,
                )}

    for subst in six.itervalues(kernel.substitutions):
        new_substs[subst.name] = subst.copy(
                expression=cbmapper(subst.expression))

    # }}}

    return kernel.copy(
            instructions=new_insns,
            substitutions=new_substs)
Beispiel #15
0
def fix_parameters(kernel, within=None, **value_dict):
    """Fix the values of the arguments to specific constants.

    *value_dict* consists of *name*/*value* pairs, where *name* will be fixed
    to be *value*. *name* may refer to :ref:`domain-parameters` or
    :ref:`arguments`.
    """

    if not value_dict:
        return kernel

    def process_set_one_param(s, name, value):
        var_dict = s.get_var_dict()

        try:
            dt, idx = var_dict[name]
        except KeyError:
            return s

        value_aff = isl.Aff.zero_on_domain(s.space) + value

        from loopy.isl_helpers import iname_rel_aff
        name_equal_value_aff = iname_rel_aff(s.space, name, "==", value_aff)

        s = (s.add_constraint(
            isl.Constraint.equality_from_aff(
                name_equal_value_aff)).project_out(dt, idx, 1))

        return s

    def process_set(s):
        for name, value in value_dict.items():
            s = process_set_one_param(s, name, value)
        return s

    new_domains = [process_set(dom) for dom in kernel.domains]

    from pymbolic.mapper.substitutor import make_subst_func
    subst_func = make_subst_func(value_dict)

    from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper
    subst_map = SubstitutionMapper(subst_func)
    ev_map = PartialEvaluationMapper()

    def map_expr(expr):
        return ev_map(subst_map(expr))

    from loopy.kernel.array import ArrayBase
    new_args = []
    for arg in kernel.args:
        if arg.name in value_dict.keys():
            # remove from argument list
            continue

        if not isinstance(arg, ArrayBase):
            new_args.append(arg)
        else:
            new_args.append(arg.map_exprs(map_expr))

    new_temp_vars = {}
    for tv in kernel.temporary_variables.values():
        new_temp_vars[tv.name] = tv.map_exprs(map_expr)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, kernel.get_var_name_generator())
    esubst_map = RuleAwareSubstitutionMapper(rule_mapping_context,
                                             subst_func,
                                             within=within)
    return (rule_mapping_context.finish_kernel(
        esubst_map.map_kernel(kernel, within=within)).copy(
            domains=new_domains,
            args=new_args,
            temporary_variables=new_temp_vars,
            assumptions=process_set(kernel.assumptions),
        ))