Exemplo n.º 1
0
def precompute(kernel, subst_use, sweep_inames=[], within=None,
        storage_axes=None, temporary_name=None, precompute_inames=None,
        storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
        fetch_bounding_box=False, temporary_is_local=None,
        insn_id=None):
    """Precompute the expression described in the substitution rule determined by
    *subst_use* and store it in a temporary array. A precomputation needs two
    things to operate, a list of *sweep_inames* (order irrelevant) and an
    ordered list of *storage_axes* (whose order will describe the axis ordering
    of the temporary array).

    :arg subst_use: Describes what to prefetch.

    The following objects may be given for *subst_use*:

    * The name of the substitution rule.

    * The tagged name ("name$tag") of the substitution rule.

    * A list of invocations of the substitution rule.
      This list of invocations, when swept across *sweep_inames*, then serves
      to define the footprint of the precomputation.

      Invocations may be tagged ("name$tag") to filter out a subset of the
      usage sites of the substitution rule. (Namely those usage sites that
      use the same tagged name.)

      Invocations may be given as a string or as a
      :class:`pymbolic.primitives.Expression` object.

      If only one invocation is to be given, then the only entry of the list
      may be given directly.

    If the list of invocations generating the footprint is not given,
    all (tag-matching, if desired) usage sites of the substitution rule
    are used to determine the footprint.

    The following cases can arise for each sweep axis:

    * The axis is an iname that occurs within arguments specified at
      usage sites of the substitution rule. This case is assumed covered
      by the storage axes provided for the argument.

    * The axis is an iname that occurs within the *value* of the rule, but not
      within its arguments. A new, dedicated storage axis is allocated for
      such an axis.

    :arg sweep_inames: A :class:`list` of inames and/or rule argument
        names to be swept.
        May also equivalently be a comma-separated string.
    :arg storage_axes: A :class:`list` of inames and/or rule argument
        names/indices to be used as storage axes.
        May also equivalently be a comma-separated string.
    :arg within: a stack match as understood by
        :func:`loopy.context_matching.parse_stack_match`.
    :arg temporary_name:
        The temporary variable name to use for storing the precomputed data.
        If it does not exist, it will be created. If it does exist, its properties
        (such as size, type) are checked (and updated, if possible) to match
        its use.
    :arg precompute_inames:
        If the specified inames do not already exist, they will be
        created. If they do already exist, their loop domain is verified
        against the one required for this precomputation.
    :arg insn_id: The ID of the instruction performing the precomputation.

    If `storage_axes` is not specified, it defaults to the arrangement
    `<direct sweep axes><arguments>` with the direct sweep axes being the
    slower-varying indices.

    Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are
    eliminated.
    """

    # {{{ check, standardize arguments

    if isinstance(sweep_inames, str):
        sweep_inames = [iname.strip() for iname in sweep_inames.split(",")]

    for iname in sweep_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname"
                    % iname)

    sweep_inames = list(sweep_inames)
    sweep_inames_set = frozenset(sweep_inames)

    if isinstance(storage_axes, str):
        storage_axes = [ax.strip() for ax in storage_axes.split(",")]

    if isinstance(precompute_inames, str):
        precompute_inames = [iname.strip() for iname in precompute_inames.split(",")]

    if isinstance(subst_use, str):
        subst_use = [subst_use]

    footprint_generators = None

    subst_name = None
    subst_tag = None

    from pymbolic.primitives import Variable, Call
    from loopy.symbolic import parse, TaggedVariable

    for use in subst_use:
        if isinstance(use, str):
            use = parse(use)

        if isinstance(use, Call):
            if footprint_generators is None:
                footprint_generators = []

            footprint_generators.append(use)
            subst_name_as_expr = use.function
        else:
            subst_name_as_expr = use

        if isinstance(subst_name_as_expr, TaggedVariable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = subst_name_as_expr.tag
        elif isinstance(subst_name_as_expr, Variable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = None
        else:
            raise ValueError("unexpected type of subst_name")

        if (subst_name, subst_tag) == (None, None):
            subst_name, subst_tag = new_subst_name, new_subst_tag
        else:
            if (subst_name, subst_tag) != (new_subst_name, new_subst_tag):
                raise ValueError("not all uses in subst_use agree "
                        "on rule name and tag")

    from loopy.context_matching import parse_stack_match
    within = parse_stack_match(within)

    from loopy.kernel.data import parse_tag
    default_tag = parse_tag(default_tag)

    subst = kernel.substitutions[subst_name]
    c_subst_name = subst_name.replace(".", "_")

    # }}}

    # {{{ process invocations in footprint generators, start access_descriptors

    if footprint_generators:
        from pymbolic.primitives import Variable, Call

        access_descriptors = []

        for fpg in footprint_generators:
            if isinstance(fpg, Variable):
                args = ()
            elif isinstance(fpg, Call):
                args = fpg.parameters
            else:
                raise ValueError("footprint generator must "
                        "be substitution rule invocation")

            access_descriptors.append(
                    RuleAccessDescriptor(
                        identifier=access_descriptor_id(args, None),
                        args=args
                        ))

    # }}}

    # {{{ gather up invocations in kernel code, finish access_descriptors

    if not footprint_generators:
        rule_mapping_context = SubstitutionRuleMappingContext(
                kernel.substitutions, kernel.get_var_name_generator())
        invg = RuleInvocationGatherer(
                rule_mapping_context, kernel, subst_name, subst_tag, within)
        del rule_mapping_context

        import loopy as lp
        for insn in kernel.instructions:
            if isinstance(insn, lp.ExpressionInstruction):
                invg(insn.assignee, kernel, insn)
                invg(insn.expression, kernel, insn)

        access_descriptors = invg.access_descriptors
        if not access_descriptors:
            raise RuntimeError("no invocations of '%s' found" % subst_name)

    # }}}

    # {{{ find inames used in arguments

    expanding_usage_arg_deps = set()

    for accdesc in access_descriptors:
        for arg in accdesc.args:
            expanding_usage_arg_deps.update(
                    get_dependencies(arg) & kernel.all_inames())

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    # {{{ use given / find new storage_axes

    # extra axes made necessary because they don't occur in the arguments
    extra_storage_axes = sweep_inames_set - expanding_usage_arg_deps

    from loopy.symbolic import SubstitutionRuleExpander
    submap = SubstitutionRuleExpander(kernel.substitutions)

    value_inames = get_dependencies(
            submap(subst.expression)
            ) & kernel.all_inames()
    if value_inames - expanding_usage_arg_deps < extra_storage_axes:
        raise RuntimeError("unreferenced sweep inames specified: "
                + ", ".join(extra_storage_axes
                    - value_inames - expanding_usage_arg_deps))

    new_iname_to_tag = {}

    if storage_axes is None:
        storage_axes = (
                list(extra_storage_axes)
                + list(range(len(subst.arguments))))

    prior_storage_axis_name_dict = {}

    storage_axis_names = []
    storage_axis_sources = []  # number for arg#, or iname

    # {{{ check for pre-existing precompute_inames

    if precompute_inames is not None:
        preexisting_precompute_inames = (
                set(precompute_inames) & kernel.all_inames())
    else:
        preexisting_precompute_inames = set()

    # }}}

    for i, saxis in enumerate(storage_axes):
        tag_lookup_saxis = saxis

        if saxis in subst.arguments:
            saxis = subst.arguments.index(saxis)

        storage_axis_sources.append(saxis)

        if isinstance(saxis, int):
            # argument index
            name = old_name = subst.arguments[saxis]
        else:
            old_name = saxis
            name = "%s_%s" % (c_subst_name, old_name)

        if (precompute_inames is not None
                and i < len(precompute_inames)
                and precompute_inames[i]):
            name = precompute_inames[i]
            tag_lookup_saxis = name
            if (name not in preexisting_precompute_inames
                    and var_name_gen.is_name_conflicting(name)):
                raise RuntimeError("new storage axis name '%s' "
                        "conflicts with existing name" % name)
        else:
            name = var_name_gen(name)

        storage_axis_names.append(name)
        if name not in preexisting_precompute_inames:
            new_iname_to_tag[name] = storage_axis_to_tag.get(
                    tag_lookup_saxis, default_tag)

        prior_storage_axis_name_dict[name] = old_name

    del storage_axis_to_tag
    del storage_axes
    del precompute_inames

    # }}}

    # {{{ fill out access_descriptors[...].storage_axis_exprs

    access_descriptors = [
            accdesc.copy(
                storage_axis_exprs=storage_axis_exprs(
                    storage_axis_sources, accdesc.args))
            for accdesc in access_descriptors]

    # }}}

    expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps)
    assert expanding_inames <= kernel.all_inames()

    if storage_axis_names:
        # {{{ find domain to be changed

        change_inames = expanding_inames | preexisting_precompute_inames

        from loopy.kernel.tools import DomainChanger
        domch = DomainChanger(kernel, change_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in sweep_inames_set:
                if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                    raise RuntimeError("sweep iname '%s' is not 'at home' in the "
                            "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames,
                access_descriptors, len(storage_axis_names))

        non1_storage_axis_names = []
        for i, saxis in enumerate(storage_axis_names):
            if abm.non1_storage_axis_flags[i]:
                non1_storage_axis_names.append(saxis)
            else:
                del new_iname_to_tag[saxis]

                if saxis in preexisting_precompute_inames:
                    raise LoopyError("precompute axis %d (1-based) was "
                            "eliminated as "
                            "having length 1 but also mapped to existing "
                            "iname '%s'" % (i+1, saxis))

        mod_domain = domch.domain

        # {{{ modify the domain, taking into account preexisting inames

        # inames may already exist in mod_domain, add them primed to start
        primed_non1_saxis_names = [
                iname+"'" for iname in non1_storage_axis_names]

        mod_domain = abm.augment_domain_with_sweep(
            domch.domain, primed_non1_saxis_names,
            boxify_sweep=fetch_bounding_box)

        check_domain = mod_domain

        for i, saxis in enumerate(non1_storage_axis_names):
            var_dict = mod_domain.get_var_dict(isl.dim_type.set)

            if saxis in preexisting_precompute_inames:
                # add equality constraint between existing and new variable

                dt, dim_idx = var_dict[saxis]
                saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)

                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)

                mod_domain = mod_domain.add_constraint(
                        isl.Constraint.inequality_from_aff(new_var_aff - saxis_aff))

                # project out the new one
                mod_domain = mod_domain.project_out(dt, dim_idx, 1)

            else:
                # remove the prime from the new variable
                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis)

        # {{{ check that we got the desired domain

        check_domain = check_domain.project_out_except(
                primed_non1_saxis_names, [isl.dim_type.set])

        mod_check_domain = mod_domain

        # re-add the prime from the new variable
        var_dict = mod_check_domain.get_var_dict(isl.dim_type.set)

        for saxis in non1_storage_axis_names:
            dt, dim_idx = var_dict[saxis]
            mod_check_domain = mod_check_domain.set_dim_name(dt, dim_idx, saxis+"'")

        mod_check_domain = mod_check_domain.project_out_except(
                primed_non1_saxis_names, [isl.dim_type.set])

        mod_check_domain, check_domain = isl.align_two(
                mod_check_domain, check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError("domain of preexisting inames does not match "
                    "domain needed for precompute")

        # }}}

        # {{{ check that we didn't shrink the original domain

        # project out the new names from the modified domain
        orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set))
        mod_check_domain = mod_domain.project_out_except(
                orig_domain_inames, [isl.dim_type.set])

        check_domain = domch.domain

        mod_check_domain, check_domain = isl.align_two(
                mod_check_domain, check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError("original domain got shrunk by applying the precompute")

        # }}}

        # }}}

        new_kernel_domains = domch.get_domains_with(mod_domain)

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        non1_storage_axis_names = []
        abm = NoOpArrayToBufferMap()

    kernel = kernel.copy(domains=new_kernel_domains)

    # {{{ set up compute insn

    if temporary_name is None:
        temporary_name = var_name_gen(based_on=c_subst_name)

    assignee = var(temporary_name)

    if non1_storage_axis_names:
        assignee = assignee.index(
                tuple(var(iname) for iname in non1_storage_axis_names))

    # {{{ process substitutions on compute instruction

    storage_axis_subst_dict = {}

    for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices):
        if arg_name in non1_storage_axis_names:
            arg = var(arg_name)
        else:
            arg = 0

        storage_axis_subst_dict[
                prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())

    from loopy.context_matching import parse_stack_match
    expr_subst_map = RuleAwareSubstitutionMapper(
            rule_mapping_context,
            make_subst_func(storage_axis_subst_dict),
            within=parse_stack_match(None))

    compute_expression = expr_subst_map(subst.expression, kernel, None)

    # }}}

    from loopy.kernel.data import ExpressionInstruction
    if insn_id is None:
        insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name)

    compute_insn = ExpressionInstruction(
            id=insn_id,
            assignee=assignee,
            expression=compute_expression)

    # }}}

    # {{{ substitute rule into expressions in kernel (if within footprint)

    invr = RuleInvocationReplacer(rule_mapping_context,
            subst_name, subst_tag, within,
            access_descriptors, abm,
            storage_axis_names, storage_axis_sources,
            non1_storage_axis_names,
            temporary_name)

    kernel = invr.map_kernel(kernel)
    kernel = kernel.copy(
            instructions=[compute_insn] + kernel.instructions)
    kernel = rule_mapping_context.finish_kernel(kernel)

    # }}}

    # {{{ set up temp variable

    import loopy as lp
    if dtype is None:
        dtype = lp.auto
    else:
        dtype = np.dtype(dtype)

    import loopy as lp

    if temporary_is_local is None:
        temporary_is_local = lp.auto

    new_temp_shape = tuple(abm.non1_storage_shape)

    new_temporary_variables = kernel.temporary_variables.copy()
    if temporary_name not in new_temporary_variables:
        temp_var = lp.TemporaryVariable(
                name=temporary_name,
                dtype=dtype,
                base_indices=(0,)*len(new_temp_shape),
                shape=tuple(abm.non1_storage_shape),
                is_local=temporary_is_local)

    else:
        temp_var = new_temporary_variables[temporary_name]

        # {{{ check and adapt existing temporary

        if temp_var.dtype is lp.auto:
            pass
        elif temp_var.dtype is not lp.auto and dtype is lp.auto:
            dtype = temp_var.dtype
        elif temp_var.dtype is not lp.auto and dtype is not lp.auto:
            if temp_var.dtype != dtype:
                raise LoopyError("Existing and new dtype of temporary '%s' "
                        "do not match (existing: %s, new: %s)"
                        % (temporary_name, temp_var.dtype, dtype))

        temp_var = temp_var.copy(dtype=dtype)

        if len(temp_var.shape) != len(new_temp_shape):
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching number of dimensions "
                    % (temporary_name,
                        len(temp_var.shape), len(new_temp_shape)))

        if temp_var.base_indices != (0,) * len(new_temp_shape):
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching number of dimensions "
                    % (temporary_name,
                        len(temp_var.shape), len(new_temp_shape)))

        new_temp_shape = tuple(
                max(i, ex_i)
                for i, ex_i in zip(new_temp_shape, temp_var.shape))

        temp_var = temp_var.copy(shape=new_temp_shape)

        if temporary_is_local == temp_var.is_local:
            pass
        elif temporary_is_local is lp.auto:
            temporary_is_local = temp_var.is_local
        elif temp_var.is_local is lp.auto:
            pass
        else:
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching values of 'is_local'"
                    % (temporary_name,
                        temp_var.is_local, temporary_is_local))

        temp_var = temp_var.copy(is_local=temporary_is_local)

        # }}}

    new_temporary_variables[temporary_name] = temp_var

    kernel = kernel.copy(
            temporary_variables=new_temporary_variables)

    # }}}

    from loopy import tag_inames
    return tag_inames(kernel, new_iname_to_tag)
Exemplo n.º 2
0
def buffer_array(
    kernel,
    var_name,
    buffer_inames,
    init_expression=None,
    store_expression=None,
    within=None,
    default_tag="l.auto",
    temporary_is_local=None,
    fetch_bounding_box=False,
):
    """
    :arg init_expression: Either *None* (indicating the prior value of the buffered
        array should be read) or an expression optionally involving the
        variable 'base' (which references the associated location in the array
        being buffered).
    :arg store_expression: Either *None* or an expression involving
        variables 'base' and 'buffer' (without array indices).
    """

    # {{{ process arguments

    if isinstance(init_expression, str):
        from loopy.symbolic import parse

        init_expression = parse(init_expression)

    if isinstance(store_expression, str):
        from loopy.symbolic import parse

        store_expression = parse(store_expression)

    if isinstance(buffer_inames, str):
        buffer_inames = [s.strip() for s in buffer_inames.split(",") if s.strip()]

    for iname in buffer_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname" % iname)

    buffer_inames = list(buffer_inames)
    buffer_inames_set = frozenset(buffer_inames)

    from loopy.context_matching import parse_stack_match

    within = parse_stack_match(within)

    if var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    elif var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    else:
        raise ValueError("variable '%s' not found" % var_name)

    from loopy.kernel.data import ArrayBase

    if isinstance(var_descr, ArrayBase):
        var_shape = var_descr.shape
    else:
        var_shape = ()

    if temporary_is_local is None:
        import loopy as lp

        temporary_is_local = lp.auto

    # }}}

    var_name_gen = kernel.get_var_name_generator()
    within_inames = set()

    access_descriptors = []
    for insn in kernel.instructions:
        if not within(kernel, insn.id, ()):
            continue

        for assignee, index in insn.assignees_and_indices():
            if assignee == var_name:
                within_inames.update((get_dependencies(index) & kernel.all_inames()) - buffer_inames_set)
                access_descriptors.append(AccessDescriptor(identifier=insn.id, storage_axis_exprs=index))

    # {{{ find fetch/store inames

    init_inames = []
    store_inames = []
    new_iname_to_tag = {}

    for i in range(len(var_shape)):
        init_iname = var_name_gen("%s_init_%d" % (var_name, i))
        store_iname = var_name_gen("%s_store_%d" % (var_name, i))

        new_iname_to_tag[init_iname] = default_tag
        new_iname_to_tag[store_iname] = default_tag

        init_inames.append(init_iname)
        store_inames.append(store_iname)

    # }}}

    # {{{ modify loop domain

    non1_init_inames = []
    non1_store_inames = []

    if var_shape:
        # {{{ find domain to be changed

        from loopy.kernel.tools import DomainChanger

        domch = DomainChanger(kernel, buffer_inames_set | within_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in buffer_inames_set:
                if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                    raise RuntimeError("buffer iname '%s' is not 'at home' in the " "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, buffer_inames, access_descriptors, len(var_shape))

        for i in range(len(var_shape)):
            if abm.non1_storage_axis_flags[i]:
                non1_init_inames.append(init_inames[i])
                non1_store_inames.append(store_inames[i])
            else:
                del new_iname_to_tag[init_inames[i]]
                del new_iname_to_tag[store_inames[i]]

        new_domain = domch.domain
        new_domain = abm.augment_domain_with_sweep(new_domain, non1_init_inames, boxify_sweep=fetch_bounding_box)
        new_domain = abm.augment_domain_with_sweep(new_domain, non1_store_inames, boxify_sweep=fetch_bounding_box)
        new_kernel_domains = domch.get_domains_with(new_domain)
        del new_domain

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        abm = NoOpArrayToBufferMap()

    # }}}

    # {{{ set up temp variable

    import loopy as lp

    buf_var_name = var_name_gen(based_on=var_name + "_buf")

    new_temporary_variables = kernel.temporary_variables.copy()
    temp_var = lp.TemporaryVariable(
        name=buf_var_name,
        dtype=var_descr.dtype,
        base_indices=(0,) * len(abm.non1_storage_shape),
        shape=tuple(abm.non1_storage_shape),
        is_local=temporary_is_local,
    )

    new_temporary_variables[buf_var_name] = temp_var

    # }}}

    new_insns = []

    buf_var = var(buf_var_name)

    # {{{ generate init instruction

    buf_var_init = buf_var
    if non1_init_inames:
        buf_var_init = buf_var_init.index(tuple(var(iname) for iname in non1_init_inames))

    init_base = var(var_name)

    init_subscript = []
    init_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_init_inames[init_iname_idx])
                init_iname_idx += 1
            init_subscript.append(ax_subscript)

    if init_subscript:
        init_base = init_base.index(tuple(init_subscript))

    if init_expression is None:
        init_expression = init_base
    else:
        init_expression = init_expression
        init_expression = SubstitutionMapper(make_subst_func({"base": init_base}))(init_expression)

    init_insn_id = kernel.make_unique_instruction_id(based_on="init_" + var_name)
    from loopy.kernel.data import ExpressionInstruction

    init_instruction = ExpressionInstruction(
        id=init_insn_id,
        assignee=buf_var_init,
        expression=init_expression,
        forced_iname_deps=frozenset(within_inames),
        insn_deps=frozenset(),
        insn_deps_is_final=True,
    )

    # }}}

    rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions, kernel.get_var_name_generator())
    aar = ArrayAccessReplacer(rule_mapping_context, var_name, within, abm, buf_var)
    kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel))

    did_write = False
    for insn_id in aar.modified_insn_ids:
        insn = kernel.id_to_insn[insn_id]
        if any(assignee_name == buf_var_name for assignee_name, _ in insn.assignees_and_indices()):
            did_write = True

    # {{{ add init_insn_id to insn_deps

    new_insns = []

    def none_to_empty_set(s):
        if s is None:
            return frozenset()
        else:
            return s

    for insn in kernel.instructions:
        if insn.id in aar.modified_insn_ids:
            new_insns.append(insn.copy(insn_deps=(none_to_empty_set(insn.insn_deps) | frozenset([init_insn_id]))))
        else:
            new_insns.append(insn)

    # }}}

    # {{{ generate store instruction

    buf_var_store = buf_var
    if non1_store_inames:
        buf_var_store = buf_var_store.index(tuple(var(iname) for iname in non1_store_inames))

    store_subscript = []
    store_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_store_inames[store_iname_idx])
                store_iname_idx += 1
            store_subscript.append(ax_subscript)

    store_target = var(var_name)
    if store_subscript:
        store_target = store_target.index(tuple(store_subscript))

    if store_expression is None:
        store_expression = buf_var_store
    else:
        store_expression = SubstitutionMapper(make_subst_func({"base": store_target, "buffer": buf_var_store}))(
            store_expression
        )

    from loopy.kernel.data import ExpressionInstruction

    store_instruction = ExpressionInstruction(
        id=kernel.make_unique_instruction_id(based_on="store_" + var_name),
        insn_deps=frozenset(aar.modified_insn_ids),
        assignee=store_target,
        expression=store_expression,
        forced_iname_deps=frozenset(within_inames),
    )

    # }}}

    new_insns.append(init_instruction)
    if did_write:
        new_insns.append(store_instruction)

    kernel = kernel.copy(
        domains=new_kernel_domains, instructions=new_insns, temporary_variables=new_temporary_variables
    )

    from loopy import tag_inames

    kernel = tag_inames(kernel, new_iname_to_tag)

    return kernel