def tag_array_axes(knl, ary_names, dim_tags): """ .. versionchanged:: 2016.2 This function was called :func:`tag_data_axes` before version 2016.2. """ from loopy.kernel.tools import ArrayChanger if isinstance(ary_names, str): ary_names = [ary_name.strip() for ary_name in ary_names.split(",")] for ary_name in ary_names: achng = ArrayChanger(knl, ary_name) ary = achng.get() from loopy.kernel.array import parse_array_dim_tags new_dim_tags = parse_array_dim_tags(dim_tags, n_axes=ary.num_user_axes(), use_increasing_target_axes=ary.max_target_axes > 1, dim_names=ary.dim_names) ary = ary.copy(dim_tags=tuple(new_dim_tags)) knl = achng.with_changed_array(ary) return knl
def set_array_axis_names(kernel, ary_names, dim_names): """ .. versionchanged:: 2016.2 This function was called :func:`set_array_dim_names` before version 2016.2. """ from loopy.kernel.tools import ArrayChanger if isinstance(ary_names, str): ary_names = ary_names.split(",") if isinstance(dim_names, str): dim_names = tuple(dim_names.split(",")) for ary_name in ary_names: achng = ArrayChanger(kernel, ary_name) ary = achng.get() ary = ary.copy(dim_names=dim_names) kernel = achng.with_changed_array(ary) return kernel
def tag_array_axes(kernel, ary_names, dim_tags): """ :arg dim_tags: a tuple of :class:`loopy.kernel.array.ArrayDimImplementationTag` or a string that parses to one. See :func:`loopy.kernel.array.parse_array_dim_tags` for a description of the allowed string format. For example, *dim_tags* could be ``"N2,N0,N1"`` to determine that the second axis is the fastest-varying, the last is the next-fastest, and the first is the slowest. .. versionchanged:: 2016.2 This function was called ``tag_data_axes`` before version 2016.2. """ from loopy.kernel.tools import ArrayChanger if isinstance(ary_names, str): ary_names = [ary_name.strip() for ary_name in ary_names.split(",")] for ary_name in ary_names: achng = ArrayChanger(kernel, ary_name) ary = achng.get() from loopy.kernel.array import parse_array_dim_tags new_dim_tags = parse_array_dim_tags( dim_tags, n_axes=ary.num_user_axes(), use_increasing_target_axes=ary.max_target_axes > 1, dim_names=ary.dim_names) ary = ary.copy(dim_tags=tuple(new_dim_tags)) kernel = achng.with_changed_array(ary) return kernel
def split_array_dim(kernel, arrays_and_axes, count, auto_split_inames=True, split_kwargs=None): """ :arg arrays_and_axes: a list of tuples *(array, axis_nr)* indicating that the index in *axis_nr* should be split. The tuples may also be *(array, axis_nr, "F")*, indicating that the index will be split as it would be according to Fortran order. *array* may name a temporary variable or an argument. If *arrays_and_axes* is a :class:`tuple`, it is automatically wrapped in a list, to make single splits easier. :arg count: The group size to use in the split. :arg auto_split_inames: Whether to automatically split inames encountered in the specified indices. :arg split_kwargs: arguments to pass to :func:`loopy.split_inames` Note that splits on the corresponding inames are carried out implicitly. The inames may *not* be split beforehand. (There's no *really* good reason for this--this routine is just not smart enough to deal with this.) """ if count == 1: return kernel if split_kwargs is None: split_kwargs = {} # {{{ process input into array_to_rest # where "rest" is the non-argument-name part of the input tuples # in args_and_axes def normalize_rest(rest): if len(rest) == 1: return (rest[0], "C") elif len(rest) == 2: return rest else: raise RuntimeError("split instruction '%s' not understood" % rest) if isinstance(arrays_and_axes, tuple): arrays_and_axes = [arrays_and_axes] array_to_rest = { tup[0]: normalize_rest(tup[1:]) for tup in arrays_and_axes } if len(arrays_and_axes) != len(array_to_rest): raise RuntimeError("cannot split multiple axes of the same variable") del arrays_and_axes # }}} # {{{ adjust arrays from loopy.kernel.tools import ArrayChanger for array_name, (axis, order) in array_to_rest.items(): achng = ArrayChanger(kernel, array_name) ary = achng.get() from pytools import div_ceil # {{{ adjust shape new_shape = ary.shape if new_shape is not None: new_shape = list(new_shape) axis_len = new_shape[axis] new_shape[axis] = count outer_len = div_ceil(axis_len, count) if order == "F": new_shape.insert(axis + 1, outer_len) elif order == "C": new_shape.insert(axis, outer_len) else: raise RuntimeError("order '%s' not understood" % order) new_shape = tuple(new_shape) # }}} # {{{ adjust dim tags if ary.dim_tags is None: raise RuntimeError("dim_tags of '%s' are not known" % array_name) new_dim_tags = list(ary.dim_tags) old_dim_tag = ary.dim_tags[axis] from loopy.kernel.array import FixedStrideArrayDimTag if not isinstance(old_dim_tag, FixedStrideArrayDimTag): raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" % (axis, array_name)) old_stride = old_dim_tag.stride outer_stride = count * old_stride if order == "F": new_dim_tags.insert(axis + 1, FixedStrideArrayDimTag(outer_stride)) elif order == "C": new_dim_tags.insert(axis, FixedStrideArrayDimTag(outer_stride)) else: raise RuntimeError("order '%s' not understood" % order) new_dim_tags = tuple(new_dim_tags) # }}} # {{{ adjust dim_names new_dim_names = ary.dim_names if new_dim_names is not None: new_dim_names = list(new_dim_names) existing_name = new_dim_names[axis] new_dim_names[axis] = existing_name + "_inner" outer_name = existing_name + "_outer" if order == "F": new_dim_names.insert(axis + 1, outer_name) elif order == "C": new_dim_names.insert(axis, outer_name) else: raise RuntimeError("order '%s' not understood" % order) new_dim_names = tuple(new_dim_names) # }}} kernel = achng.with_changed_array( ary.copy(shape=new_shape, dim_tags=new_dim_tags, dim_names=new_dim_names)) # }}} split_vars = {} var_name_gen = kernel.get_var_name_generator() def split_access_axis(expr): axis_nr, order = array_to_rest[expr.aggregate.name] idx = expr.index if not isinstance(idx, tuple): idx = (idx, ) idx = list(idx) axis_idx = idx[axis_nr] if auto_split_inames: from pymbolic.primitives import Variable if not isinstance(axis_idx, Variable): raise RuntimeError( "found access '%s' in which axis %d is not a " "single variable--cannot split " "(Have you tried to do the split yourself, manually, " "beforehand? If so, you shouldn't.)" % (expr, axis_nr)) split_iname = idx[axis_nr].name assert split_iname in kernel.all_inames() try: outer_iname, inner_iname = split_vars[split_iname] except KeyError: outer_iname = var_name_gen(split_iname + "_outer") inner_iname = var_name_gen(split_iname + "_inner") split_vars[split_iname] = outer_iname, inner_iname inner_index = Variable(inner_iname) outer_index = Variable(outer_iname) else: from loopy.symbolic import simplify_using_aff inner_index = simplify_using_aff(kernel, axis_idx % count) outer_index = simplify_using_aff(kernel, axis_idx // count) idx[axis_nr] = inner_index if order == "F": idx.insert(axis + 1, outer_index) elif order == "C": idx.insert(axis, outer_index) else: raise RuntimeError("order '%s' not understood" % order) return expr.aggregate.index(tuple(idx)) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) aash = ArrayAxisSplitHelper(rule_mapping_context, set(array_to_rest.keys()), split_access_axis) kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel)) if auto_split_inames: from loopy import split_iname for iname, (outer_iname, inner_iname) in split_vars.items(): kernel = split_iname(kernel, iname, count, outer_iname=outer_iname, inner_iname=inner_iname, **split_kwargs) return kernel
def _split_array_axis_inner(kernel, array_name, axis_nr, count, order="C"): if count == 1: return kernel # {{{ adjust arrays from loopy.kernel.tools import ArrayChanger achng = ArrayChanger(kernel, array_name) ary = achng.get() from pytools import div_ceil # {{{ adjust shape new_shape = ary.shape if new_shape is not None: new_shape = list(new_shape) axis_len = new_shape[axis_nr] new_shape[axis_nr] = count outer_len = div_ceil(axis_len, count) if order == "F": new_shape.insert(axis_nr + 1, outer_len) elif order == "C": new_shape.insert(axis_nr, outer_len) else: raise RuntimeError("order '%s' not understood" % order) new_shape = tuple(new_shape) # }}} # {{{ adjust dim tags if ary.dim_tags is None: raise RuntimeError("dim_tags of '%s' are not known" % array_name) new_dim_tags = list(ary.dim_tags) old_dim_tag = ary.dim_tags[axis_nr] from loopy.kernel.array import FixedStrideArrayDimTag if not isinstance(old_dim_tag, FixedStrideArrayDimTag): raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" % (axis_nr, array_name)) old_stride = old_dim_tag.stride outer_stride = count * old_stride if order == "F": new_dim_tags.insert(axis_nr + 1, FixedStrideArrayDimTag(outer_stride)) elif order == "C": new_dim_tags.insert(axis_nr, FixedStrideArrayDimTag(outer_stride)) else: raise RuntimeError("order '%s' not understood" % order) new_dim_tags = tuple(new_dim_tags) # }}} # {{{ adjust dim_names new_dim_names = ary.dim_names if new_dim_names is not None: new_dim_names = list(new_dim_names) existing_name = new_dim_names[axis_nr] new_dim_names[axis_nr] = existing_name + "_inner" outer_name = existing_name + "_outer" if order == "F": new_dim_names.insert(axis_nr + 1, outer_name) elif order == "C": new_dim_names.insert(axis_nr, outer_name) else: raise RuntimeError("order '%s' not understood" % order) new_dim_names = tuple(new_dim_names) # }}} kernel = achng.with_changed_array( ary.copy(shape=new_shape, dim_tags=new_dim_tags, dim_names=new_dim_names)) # }}} var_name_gen = kernel.get_var_name_generator() def split_access_axis(expr): idx = expr.index if not isinstance(idx, tuple): idx = (idx, ) idx = list(idx) axis_idx = idx[axis_nr] from loopy.symbolic import simplify_using_aff inner_index = simplify_using_aff(kernel, axis_idx % count) outer_index = simplify_using_aff(kernel, axis_idx // count) idx[axis_nr] = inner_index if order == "F": idx.insert(axis_nr + 1, outer_index) elif order == "C": idx.insert(axis_nr, outer_index) else: raise RuntimeError("order '%s' not understood" % order) return expr.aggregate.index(tuple(idx)) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) aash = ArrayAxisSplitHelper(rule_mapping_context, {array_name}, split_access_axis) kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel)) return kernel