예제 #1
0
파일: data.py 프로젝트: shwina/loopy
def tag_array_axes(knl, ary_names, dim_tags):
    """
    .. versionchanged:: 2016.2

        This function was called :func:`tag_data_axes` before version 2016.2.
    """

    from loopy.kernel.tools import ArrayChanger

    if isinstance(ary_names, str):
        ary_names = [ary_name.strip() for ary_name in ary_names.split(",")]

    for ary_name in ary_names:
        achng = ArrayChanger(knl, ary_name)
        ary = achng.get()

        from loopy.kernel.array import parse_array_dim_tags
        new_dim_tags = parse_array_dim_tags(dim_tags,
                n_axes=ary.num_user_axes(),
                use_increasing_target_axes=ary.max_target_axes > 1,
                dim_names=ary.dim_names)

        ary = ary.copy(dim_tags=tuple(new_dim_tags))

        knl = achng.with_changed_array(ary)

    return knl
예제 #2
0
파일: data.py 프로젝트: cmsquared/loopy
def tag_array_axes(knl, ary_names, dim_tags):
    """
    .. versionchanged:: 2016.2

        This function was called :func:`tag_data_axes` before version 2016.2.
    """

    from loopy.kernel.tools import ArrayChanger

    if isinstance(ary_names, str):
        ary_names = [ary_name.strip() for ary_name in ary_names.split(",")]

    for ary_name in ary_names:
        achng = ArrayChanger(knl, ary_name)
        ary = achng.get()

        from loopy.kernel.array import parse_array_dim_tags
        new_dim_tags = parse_array_dim_tags(dim_tags,
                n_axes=ary.num_user_axes(),
                use_increasing_target_axes=ary.max_target_axes > 1,
                dim_names=ary.dim_names)

        ary = ary.copy(dim_tags=tuple(new_dim_tags))

        knl = achng.with_changed_array(ary)

    return knl
예제 #3
0
파일: data.py 프로젝트: navjotk/loopy
def set_array_dim_names(kernel, ary_names, dim_names):
    from loopy.kernel.tools import ArrayChanger
    if isinstance(ary_names, str):
        ary_names = ary_names.split(",")

    if isinstance(dim_names, str):
        dim_names = tuple(dim_names.split(","))

    for ary_name in ary_names:
        achng = ArrayChanger(kernel, ary_name)
        ary = achng.get()

        ary = ary.copy(dim_names=dim_names)

        kernel = achng.with_changed_array(ary)

    return kernel
예제 #4
0
파일: data.py 프로젝트: navjotk/loopy
def tag_data_axes(knl, ary_names, dim_tags):
    from loopy.kernel.tools import ArrayChanger

    if isinstance(ary_names, str):
        ary_names = ary_names.split(",")

    for ary_name in ary_names:
        achng = ArrayChanger(knl, ary_name)
        ary = achng.get()

        from loopy.kernel.array import parse_array_dim_tags
        new_dim_tags = parse_array_dim_tags(dim_tags,
                n_axes=ary.num_user_axes(),
                use_increasing_target_axes=ary.max_target_axes > 1)

        ary = ary.copy(dim_tags=tuple(new_dim_tags))

        knl = achng.with_changed_array(ary)

    return knl
예제 #5
0
파일: data.py 프로젝트: shwina/loopy
def set_array_axis_names(kernel, ary_names, dim_names):
    """
    .. versionchanged:: 2016.2

        This function was called :func:`set_array_dim_names` before version 2016.2.
    """
    from loopy.kernel.tools import ArrayChanger
    if isinstance(ary_names, str):
        ary_names = ary_names.split(",")

    if isinstance(dim_names, str):
        dim_names = tuple(dim_names.split(","))

    for ary_name in ary_names:
        achng = ArrayChanger(kernel, ary_name)
        ary = achng.get()

        ary = ary.copy(dim_names=dim_names)

        kernel = achng.with_changed_array(ary)

    return kernel
예제 #6
0
파일: data.py 프로젝트: cmsquared/loopy
def set_array_axis_names(kernel, ary_names, dim_names):
    """
    .. versionchanged:: 2016.2

        This function was called :func:`set_array_dim_names` before version 2016.2.
    """
    from loopy.kernel.tools import ArrayChanger
    if isinstance(ary_names, str):
        ary_names = ary_names.split(",")

    if isinstance(dim_names, str):
        dim_names = tuple(dim_names.split(","))

    for ary_name in ary_names:
        achng = ArrayChanger(kernel, ary_name)
        ary = achng.get()

        ary = ary.copy(dim_names=dim_names)

        kernel = achng.with_changed_array(ary)

    return kernel
예제 #7
0
def tag_array_axes(kernel, ary_names, dim_tags):
    """
    :arg dim_tags: a tuple of
        :class:`loopy.kernel.array.ArrayDimImplementationTag` or a string that
        parses to one. See :func:`loopy.kernel.array.parse_array_dim_tags` for a
        description of the allowed string format.

        For example, *dim_tags* could be ``"N2,N0,N1"`` to determine
        that the second axis is the fastest-varying, the last is
        the next-fastest, and the first is the slowest.

    .. versionchanged:: 2016.2

        This function was called ``tag_data_axes`` before version 2016.2.
    """

    from loopy.kernel.tools import ArrayChanger

    if isinstance(ary_names, str):
        ary_names = [ary_name.strip() for ary_name in ary_names.split(",")]

    for ary_name in ary_names:
        achng = ArrayChanger(kernel, ary_name)
        ary = achng.get()

        from loopy.kernel.array import parse_array_dim_tags
        new_dim_tags = parse_array_dim_tags(
            dim_tags,
            n_axes=ary.num_user_axes(),
            use_increasing_target_axes=ary.max_target_axes > 1,
            dim_names=ary.dim_names)

        ary = ary.copy(dim_tags=tuple(new_dim_tags))

        kernel = achng.with_changed_array(ary)

    return kernel
예제 #8
0
파일: padding.py 프로젝트: inducer/loopy
def split_array_dim(kernel, arrays_and_axes, count, auto_split_inames=True,
        split_kwargs=None):
    """
    :arg arrays_and_axes: a list of tuples *(array, axis_nr)* indicating
        that the index in *axis_nr* should be split. The tuples may
        also be *(array, axis_nr, "F")*, indicating that the index will
        be split as it would be according to Fortran order.

        *array* may name a temporary variable or an argument.

        If *arrays_and_axes* is a :class:`tuple`, it is automatically
        wrapped in a list, to make single splits easier.

    :arg count: The group size to use in the split.
    :arg auto_split_inames: Whether to automatically split inames
        encountered in the specified indices.
    :arg split_kwargs: arguments to pass to :func:`loopy.split_inames`

    Note that splits on the corresponding inames are carried out implicitly.
    The inames may *not* be split beforehand. (There's no *really* good reason
    for this--this routine is just not smart enough to deal with this.)
    """

    if count == 1:
        return kernel

    if split_kwargs is None:
        split_kwargs = {}

    # {{{ process input into array_to_rest

    # where "rest" is the non-argument-name part of the input tuples
    # in args_and_axes
    def normalize_rest(rest):
        if len(rest) == 1:
            return (rest[0], "C")
        elif len(rest) == 2:
            return rest
        else:
            raise RuntimeError("split instruction '%s' not understood" % rest)

    if isinstance(arrays_and_axes, tuple):
        arrays_and_axes = [arrays_and_axes]

    array_to_rest = dict(
            (tup[0], normalize_rest(tup[1:])) for tup in arrays_and_axes)

    if len(arrays_and_axes) != len(array_to_rest):
        raise RuntimeError("cannot split multiple axes of the same variable")

    del arrays_and_axes

    # }}}

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    for array_name, (axis, order) in six.iteritems(array_to_rest):
        achng = ArrayChanger(kernel, array_name)
        ary = achng.get()

        from pytools import div_ceil

        # {{{ adjust shape

        new_shape = ary.shape
        if new_shape is not None:
            new_shape = list(new_shape)
            axis_len = new_shape[axis]
            new_shape[axis] = count
            outer_len = div_ceil(axis_len, count)

            if order == "F":
                new_shape.insert(axis+1, outer_len)
            elif order == "C":
                new_shape.insert(axis, outer_len)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_shape = tuple(new_shape)

        # }}}

        # {{{ adjust dim tags

        if ary.dim_tags is None:
            raise RuntimeError("dim_tags of '%s' are not known" % array_name)
        new_dim_tags = list(ary.dim_tags)

        old_dim_tag = ary.dim_tags[axis]

        from loopy.kernel.array import FixedStrideArrayDimTag
        if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
            raise RuntimeError("axis %d of '%s' is not tagged fixed-stride"
                    % (axis, array_name))

        old_stride = old_dim_tag.stride
        outer_stride = count*old_stride

        if order == "F":
            new_dim_tags.insert(axis+1, FixedStrideArrayDimTag(outer_stride))
        elif order == "C":
            new_dim_tags.insert(axis, FixedStrideArrayDimTag(outer_stride))
        else:
            raise RuntimeError("order '%s' not understood" % order)

        new_dim_tags = tuple(new_dim_tags)

        # }}}

        # {{{ adjust dim_names

        new_dim_names = ary.dim_names
        if new_dim_names is not None:
            new_dim_names = list(new_dim_names)
            existing_name = new_dim_names[axis]
            new_dim_names[axis] = existing_name + "_inner"
            outer_name = existing_name + "_outer"

            if order == "F":
                new_dim_names.insert(axis+1, outer_name)
            elif order == "C":
                new_dim_names.insert(axis, outer_name)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_dim_names = tuple(new_dim_names)

        # }}}

        kernel = achng.with_changed_array(ary.copy(
            shape=new_shape, dim_tags=new_dim_tags, dim_names=new_dim_names))

    # }}}

    split_vars = {}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        axis_nr, order = array_to_rest[expr.aggregate.name]

        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx,)
        idx = list(idx)

        axis_idx = idx[axis_nr]

        if auto_split_inames:
            from pymbolic.primitives import Variable
            if not isinstance(axis_idx, Variable):
                raise RuntimeError("found access '%s' in which axis %d is not a "
                        "single variable--cannot split "
                        "(Have you tried to do the split yourself, manually, "
                        "beforehand? If so, you shouldn't.)"
                        % (expr, axis_nr))

            split_iname = idx[axis_nr].name
            assert split_iname in kernel.all_inames()

            try:
                outer_iname, inner_iname = split_vars[split_iname]
            except KeyError:
                outer_iname = var_name_gen(split_iname+"_outer")
                inner_iname = var_name_gen(split_iname+"_inner")
                split_vars[split_iname] = outer_iname, inner_iname

            inner_index = Variable(inner_iname)
            outer_index = Variable(outer_iname)

        else:
            from loopy.symbolic import simplify_using_aff
            inner_index = simplify_using_aff(kernel, axis_idx % count)
            outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis+1, outer_index)
        elif order == "C":
            idx.insert(axis, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context,
            set(six.iterkeys(array_to_rest)), split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    if auto_split_inames:
        from loopy import split_iname
        for iname, (outer_iname, inner_iname) in six.iteritems(split_vars):
            kernel = split_iname(kernel, iname, count,
                    outer_iname=outer_iname, inner_iname=inner_iname,
                    **split_kwargs)

    return kernel
예제 #9
0
파일: padding.py 프로젝트: inducer/loopy
def _split_array_axis_inner(kernel, array_name, axis_nr, count, order="C"):
    if count == 1:
        return kernel

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    achng = ArrayChanger(kernel, array_name)
    ary = achng.get()

    from pytools import div_ceil

    # {{{ adjust shape

    new_shape = ary.shape
    if new_shape is not None:
        new_shape = list(new_shape)
        axis_len = new_shape[axis_nr]
        new_shape[axis_nr] = count
        outer_len = div_ceil(axis_len, count)

        if order == "F":
            new_shape.insert(axis_nr+1, outer_len)
        elif order == "C":
            new_shape.insert(axis_nr, outer_len)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_shape = tuple(new_shape)

    # }}}

    # {{{ adjust dim tags

    if ary.dim_tags is None:
        raise RuntimeError("dim_tags of '%s' are not known" % array_name)
    new_dim_tags = list(ary.dim_tags)

    old_dim_tag = ary.dim_tags[axis_nr]

    from loopy.kernel.array import FixedStrideArrayDimTag
    if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
        raise RuntimeError("axis %d of '%s' is not tagged fixed-stride"
                % (axis_nr, array_name))

    old_stride = old_dim_tag.stride
    outer_stride = count*old_stride

    if order == "F":
        new_dim_tags.insert(axis_nr+1, FixedStrideArrayDimTag(outer_stride))
    elif order == "C":
        new_dim_tags.insert(axis_nr, FixedStrideArrayDimTag(outer_stride))
    else:
        raise RuntimeError("order '%s' not understood" % order)

    new_dim_tags = tuple(new_dim_tags)

    # }}}

    # {{{ adjust dim_names

    new_dim_names = ary.dim_names
    if new_dim_names is not None:
        new_dim_names = list(new_dim_names)
        existing_name = new_dim_names[axis_nr]
        new_dim_names[axis_nr] = existing_name + "_inner"
        outer_name = existing_name + "_outer"

        if order == "F":
            new_dim_names.insert(axis_nr+1, outer_name)
        elif order == "C":
            new_dim_names.insert(axis_nr, outer_name)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_dim_names = tuple(new_dim_names)

    # }}}

    kernel = achng.with_changed_array(ary.copy(
        shape=new_shape, dim_tags=new_dim_tags, dim_names=new_dim_names))

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx,)
        idx = list(idx)

        axis_idx = idx[axis_nr]

        from loopy.symbolic import simplify_using_aff
        inner_index = simplify_using_aff(kernel, axis_idx % count)
        outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis_nr+1, outer_index)
        elif order == "C":
            idx.insert(axis_nr, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context,
            set([array_name]), split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    return kernel
예제 #10
0
def split_array_dim(kernel,
                    arrays_and_axes,
                    count,
                    auto_split_inames=True,
                    split_kwargs=None):
    """
    :arg arrays_and_axes: a list of tuples *(array, axis_nr)* indicating
        that the index in *axis_nr* should be split. The tuples may
        also be *(array, axis_nr, "F")*, indicating that the index will
        be split as it would be according to Fortran order.

        *array* may name a temporary variable or an argument.

        If *arrays_and_axes* is a :class:`tuple`, it is automatically
        wrapped in a list, to make single splits easier.

    :arg count: The group size to use in the split.
    :arg auto_split_inames: Whether to automatically split inames
        encountered in the specified indices.
    :arg split_kwargs: arguments to pass to :func:`loopy.split_inames`

    Note that splits on the corresponding inames are carried out implicitly.
    The inames may *not* be split beforehand. (There's no *really* good reason
    for this--this routine is just not smart enough to deal with this.)
    """

    if count == 1:
        return kernel

    if split_kwargs is None:
        split_kwargs = {}

    # {{{ process input into array_to_rest

    # where "rest" is the non-argument-name part of the input tuples
    # in args_and_axes
    def normalize_rest(rest):
        if len(rest) == 1:
            return (rest[0], "C")
        elif len(rest) == 2:
            return rest
        else:
            raise RuntimeError("split instruction '%s' not understood" % rest)

    if isinstance(arrays_and_axes, tuple):
        arrays_and_axes = [arrays_and_axes]

    array_to_rest = {
        tup[0]: normalize_rest(tup[1:])
        for tup in arrays_and_axes
    }

    if len(arrays_and_axes) != len(array_to_rest):
        raise RuntimeError("cannot split multiple axes of the same variable")

    del arrays_and_axes

    # }}}

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    for array_name, (axis, order) in array_to_rest.items():
        achng = ArrayChanger(kernel, array_name)
        ary = achng.get()

        from pytools import div_ceil

        # {{{ adjust shape

        new_shape = ary.shape
        if new_shape is not None:
            new_shape = list(new_shape)
            axis_len = new_shape[axis]
            new_shape[axis] = count
            outer_len = div_ceil(axis_len, count)

            if order == "F":
                new_shape.insert(axis + 1, outer_len)
            elif order == "C":
                new_shape.insert(axis, outer_len)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_shape = tuple(new_shape)

        # }}}

        # {{{ adjust dim tags

        if ary.dim_tags is None:
            raise RuntimeError("dim_tags of '%s' are not known" % array_name)
        new_dim_tags = list(ary.dim_tags)

        old_dim_tag = ary.dim_tags[axis]

        from loopy.kernel.array import FixedStrideArrayDimTag
        if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
            raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" %
                               (axis, array_name))

        old_stride = old_dim_tag.stride
        outer_stride = count * old_stride

        if order == "F":
            new_dim_tags.insert(axis + 1, FixedStrideArrayDimTag(outer_stride))
        elif order == "C":
            new_dim_tags.insert(axis, FixedStrideArrayDimTag(outer_stride))
        else:
            raise RuntimeError("order '%s' not understood" % order)

        new_dim_tags = tuple(new_dim_tags)

        # }}}

        # {{{ adjust dim_names

        new_dim_names = ary.dim_names
        if new_dim_names is not None:
            new_dim_names = list(new_dim_names)
            existing_name = new_dim_names[axis]
            new_dim_names[axis] = existing_name + "_inner"
            outer_name = existing_name + "_outer"

            if order == "F":
                new_dim_names.insert(axis + 1, outer_name)
            elif order == "C":
                new_dim_names.insert(axis, outer_name)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_dim_names = tuple(new_dim_names)

        # }}}

        kernel = achng.with_changed_array(
            ary.copy(shape=new_shape,
                     dim_tags=new_dim_tags,
                     dim_names=new_dim_names))

    # }}}

    split_vars = {}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        axis_nr, order = array_to_rest[expr.aggregate.name]

        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx, )
        idx = list(idx)

        axis_idx = idx[axis_nr]

        if auto_split_inames:
            from pymbolic.primitives import Variable
            if not isinstance(axis_idx, Variable):
                raise RuntimeError(
                    "found access '%s' in which axis %d is not a "
                    "single variable--cannot split "
                    "(Have you tried to do the split yourself, manually, "
                    "beforehand? If so, you shouldn't.)" % (expr, axis_nr))

            split_iname = idx[axis_nr].name
            assert split_iname in kernel.all_inames()

            try:
                outer_iname, inner_iname = split_vars[split_iname]
            except KeyError:
                outer_iname = var_name_gen(split_iname + "_outer")
                inner_iname = var_name_gen(split_iname + "_inner")
                split_vars[split_iname] = outer_iname, inner_iname

            inner_index = Variable(inner_iname)
            outer_index = Variable(outer_iname)

        else:
            from loopy.symbolic import simplify_using_aff
            inner_index = simplify_using_aff(kernel, axis_idx % count)
            outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis + 1, outer_index)
        elif order == "C":
            idx.insert(axis, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context,
                                set(array_to_rest.keys()), split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    if auto_split_inames:
        from loopy import split_iname
        for iname, (outer_iname, inner_iname) in split_vars.items():
            kernel = split_iname(kernel,
                                 iname,
                                 count,
                                 outer_iname=outer_iname,
                                 inner_iname=inner_iname,
                                 **split_kwargs)

    return kernel
예제 #11
0
def _split_array_axis_inner(kernel, array_name, axis_nr, count, order="C"):
    if count == 1:
        return kernel

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    achng = ArrayChanger(kernel, array_name)
    ary = achng.get()

    from pytools import div_ceil

    # {{{ adjust shape

    new_shape = ary.shape
    if new_shape is not None:
        new_shape = list(new_shape)
        axis_len = new_shape[axis_nr]
        new_shape[axis_nr] = count
        outer_len = div_ceil(axis_len, count)

        if order == "F":
            new_shape.insert(axis_nr + 1, outer_len)
        elif order == "C":
            new_shape.insert(axis_nr, outer_len)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_shape = tuple(new_shape)

    # }}}

    # {{{ adjust dim tags

    if ary.dim_tags is None:
        raise RuntimeError("dim_tags of '%s' are not known" % array_name)
    new_dim_tags = list(ary.dim_tags)

    old_dim_tag = ary.dim_tags[axis_nr]

    from loopy.kernel.array import FixedStrideArrayDimTag
    if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
        raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" %
                           (axis_nr, array_name))

    old_stride = old_dim_tag.stride
    outer_stride = count * old_stride

    if order == "F":
        new_dim_tags.insert(axis_nr + 1, FixedStrideArrayDimTag(outer_stride))
    elif order == "C":
        new_dim_tags.insert(axis_nr, FixedStrideArrayDimTag(outer_stride))
    else:
        raise RuntimeError("order '%s' not understood" % order)

    new_dim_tags = tuple(new_dim_tags)

    # }}}

    # {{{ adjust dim_names

    new_dim_names = ary.dim_names
    if new_dim_names is not None:
        new_dim_names = list(new_dim_names)
        existing_name = new_dim_names[axis_nr]
        new_dim_names[axis_nr] = existing_name + "_inner"
        outer_name = existing_name + "_outer"

        if order == "F":
            new_dim_names.insert(axis_nr + 1, outer_name)
        elif order == "C":
            new_dim_names.insert(axis_nr, outer_name)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_dim_names = tuple(new_dim_names)

    # }}}

    kernel = achng.with_changed_array(
        ary.copy(shape=new_shape,
                 dim_tags=new_dim_tags,
                 dim_names=new_dim_names))

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx, )
        idx = list(idx)

        axis_idx = idx[axis_nr]

        from loopy.symbolic import simplify_using_aff
        inner_index = simplify_using_aff(kernel, axis_idx % count)
        outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis_nr + 1, outer_index)
        elif order == "C":
            idx.insert(axis_nr, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context, {array_name},
                                split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    return kernel