예제 #1
0
def add_padding(kernel, variable, axis, align_bytes):
    arg_to_idx = {arg.name: i for i, arg in enumerate(kernel.args)}
    arg_idx = arg_to_idx[variable]

    new_args = kernel.args[:]
    arg = new_args[arg_idx]

    if arg.dim_tags is None:
        raise RuntimeError("cannot add padding--dim_tags of '%s' "
                           "are not known" % variable)

    new_dim_tags = list(arg.dim_tags)
    dim_tag = new_dim_tags[axis]

    from loopy.kernel.array import FixedStrideArrayDimTag
    if not isinstance(dim_tag, FixedStrideArrayDimTag):
        raise RuntimeError("cannot find padding multiple--"
                           "axis %d of '%s' is not tagged fixed-stride" %
                           (axis, variable))

    stride = dim_tag.stride
    if not isinstance(stride, int):
        raise RuntimeError("cannot find split granularity--stride is not a "
                           "known integer")

    from pytools import div_ceil
    new_dim_tags[axis] = FixedStrideArrayDimTag(
        div_ceil(stride, align_bytes) * align_bytes)

    new_args[arg_idx] = arg.copy(dim_tags=tuple(new_dim_tags))

    return kernel.copy(args=new_args)
예제 #2
0
    def with_descrs(self, arg_id_to_descr, callables_table):
        from loopy.kernel.function_interface import ArrayArgDescriptor
        from loopy.kernel.array import FixedStrideArrayDimTag
        new_arg_id_to_descr = arg_id_to_descr.copy()
        for i, des in arg_id_to_descr.items():
            # petsc takes 1D arrays as arguments
            if isinstance(des, ArrayArgDescriptor):
                dim_tags = tuple(FixedStrideArrayDimTag(stride=int(numpy.prod(des.shape[i+1:])),
                                                        layout_nesting_level=len(des.shape)-i-1)
                                 for i in range(len(des.shape)))
                new_arg_id_to_descr[i] = des.copy(dim_tags=dim_tags)

        return (self.copy(arg_id_to_descr=new_arg_id_to_descr),
                callables_table)
예제 #3
0
def test_rename_argument_with_auto_stride(ctx_factory):
    from loopy.kernel.array import FixedStrideArrayDimTag

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
            "{[i]: 0<=i<10}",
            """
            y[i] = x[i]
            """, [lp.GlobalArg("x", dtype=float,
                               shape=lp.auto,
                               dim_tags=[FixedStrideArrayDimTag(lp.auto)]), ...])

    knl = lp.rename_argument(knl, "x", "x_new")

    code_str = lp.generate_code_v2(knl).device_code()
    assert code_str.find("double const *__restrict__ x_new,") != -1
    assert code_str.find("double const *__restrict__ x,") == -1

    evt, (out, ) = knl(queue, x_new=np.random.rand(10))
예제 #4
0
def split_array_dim(kernel,
                    arrays_and_axes,
                    count,
                    auto_split_inames=True,
                    split_kwargs=None):
    """
    :arg arrays_and_axes: a list of tuples *(array, axis_nr)* indicating
        that the index in *axis_nr* should be split. The tuples may
        also be *(array, axis_nr, "F")*, indicating that the index will
        be split as it would be according to Fortran order.

        *array* may name a temporary variable or an argument.

        If *arrays_and_axes* is a :class:`tuple`, it is automatically
        wrapped in a list, to make single splits easier.

    :arg count: The group size to use in the split.
    :arg auto_split_inames: Whether to automatically split inames
        encountered in the specified indices.
    :arg split_kwargs: arguments to pass to :func:`loopy.split_inames`

    Note that splits on the corresponding inames are carried out implicitly.
    The inames may *not* be split beforehand. (There's no *really* good reason
    for this--this routine is just not smart enough to deal with this.)
    """

    if count == 1:
        return kernel

    if split_kwargs is None:
        split_kwargs = {}

    # {{{ process input into array_to_rest

    # where "rest" is the non-argument-name part of the input tuples
    # in args_and_axes
    def normalize_rest(rest):
        if len(rest) == 1:
            return (rest[0], "C")
        elif len(rest) == 2:
            return rest
        else:
            raise RuntimeError("split instruction '%s' not understood" % rest)

    if isinstance(arrays_and_axes, tuple):
        arrays_and_axes = [arrays_and_axes]

    array_to_rest = {
        tup[0]: normalize_rest(tup[1:])
        for tup in arrays_and_axes
    }

    if len(arrays_and_axes) != len(array_to_rest):
        raise RuntimeError("cannot split multiple axes of the same variable")

    del arrays_and_axes

    # }}}

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    for array_name, (axis, order) in array_to_rest.items():
        achng = ArrayChanger(kernel, array_name)
        ary = achng.get()

        from pytools import div_ceil

        # {{{ adjust shape

        new_shape = ary.shape
        if new_shape is not None:
            new_shape = list(new_shape)
            axis_len = new_shape[axis]
            new_shape[axis] = count
            outer_len = div_ceil(axis_len, count)

            if order == "F":
                new_shape.insert(axis + 1, outer_len)
            elif order == "C":
                new_shape.insert(axis, outer_len)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_shape = tuple(new_shape)

        # }}}

        # {{{ adjust dim tags

        if ary.dim_tags is None:
            raise RuntimeError("dim_tags of '%s' are not known" % array_name)
        new_dim_tags = list(ary.dim_tags)

        old_dim_tag = ary.dim_tags[axis]

        from loopy.kernel.array import FixedStrideArrayDimTag
        if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
            raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" %
                               (axis, array_name))

        old_stride = old_dim_tag.stride
        outer_stride = count * old_stride

        if order == "F":
            new_dim_tags.insert(axis + 1, FixedStrideArrayDimTag(outer_stride))
        elif order == "C":
            new_dim_tags.insert(axis, FixedStrideArrayDimTag(outer_stride))
        else:
            raise RuntimeError("order '%s' not understood" % order)

        new_dim_tags = tuple(new_dim_tags)

        # }}}

        # {{{ adjust dim_names

        new_dim_names = ary.dim_names
        if new_dim_names is not None:
            new_dim_names = list(new_dim_names)
            existing_name = new_dim_names[axis]
            new_dim_names[axis] = existing_name + "_inner"
            outer_name = existing_name + "_outer"

            if order == "F":
                new_dim_names.insert(axis + 1, outer_name)
            elif order == "C":
                new_dim_names.insert(axis, outer_name)
            else:
                raise RuntimeError("order '%s' not understood" % order)
            new_dim_names = tuple(new_dim_names)

        # }}}

        kernel = achng.with_changed_array(
            ary.copy(shape=new_shape,
                     dim_tags=new_dim_tags,
                     dim_names=new_dim_names))

    # }}}

    split_vars = {}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        axis_nr, order = array_to_rest[expr.aggregate.name]

        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx, )
        idx = list(idx)

        axis_idx = idx[axis_nr]

        if auto_split_inames:
            from pymbolic.primitives import Variable
            if not isinstance(axis_idx, Variable):
                raise RuntimeError(
                    "found access '%s' in which axis %d is not a "
                    "single variable--cannot split "
                    "(Have you tried to do the split yourself, manually, "
                    "beforehand? If so, you shouldn't.)" % (expr, axis_nr))

            split_iname = idx[axis_nr].name
            assert split_iname in kernel.all_inames()

            try:
                outer_iname, inner_iname = split_vars[split_iname]
            except KeyError:
                outer_iname = var_name_gen(split_iname + "_outer")
                inner_iname = var_name_gen(split_iname + "_inner")
                split_vars[split_iname] = outer_iname, inner_iname

            inner_index = Variable(inner_iname)
            outer_index = Variable(outer_iname)

        else:
            from loopy.symbolic import simplify_using_aff
            inner_index = simplify_using_aff(kernel, axis_idx % count)
            outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis + 1, outer_index)
        elif order == "C":
            idx.insert(axis, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context,
                                set(array_to_rest.keys()), split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    if auto_split_inames:
        from loopy import split_iname
        for iname, (outer_iname, inner_iname) in split_vars.items():
            kernel = split_iname(kernel,
                                 iname,
                                 count,
                                 outer_iname=outer_iname,
                                 inner_iname=inner_iname,
                                 **split_kwargs)

    return kernel
예제 #5
0
def _split_array_axis_inner(kernel, array_name, axis_nr, count, order="C"):
    if count == 1:
        return kernel

    # {{{ adjust arrays

    from loopy.kernel.tools import ArrayChanger

    achng = ArrayChanger(kernel, array_name)
    ary = achng.get()

    from pytools import div_ceil

    # {{{ adjust shape

    new_shape = ary.shape
    if new_shape is not None:
        new_shape = list(new_shape)
        axis_len = new_shape[axis_nr]
        new_shape[axis_nr] = count
        outer_len = div_ceil(axis_len, count)

        if order == "F":
            new_shape.insert(axis_nr + 1, outer_len)
        elif order == "C":
            new_shape.insert(axis_nr, outer_len)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_shape = tuple(new_shape)

    # }}}

    # {{{ adjust dim tags

    if ary.dim_tags is None:
        raise RuntimeError("dim_tags of '%s' are not known" % array_name)
    new_dim_tags = list(ary.dim_tags)

    old_dim_tag = ary.dim_tags[axis_nr]

    from loopy.kernel.array import FixedStrideArrayDimTag
    if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
        raise RuntimeError("axis %d of '%s' is not tagged fixed-stride" %
                           (axis_nr, array_name))

    old_stride = old_dim_tag.stride
    outer_stride = count * old_stride

    if order == "F":
        new_dim_tags.insert(axis_nr + 1, FixedStrideArrayDimTag(outer_stride))
    elif order == "C":
        new_dim_tags.insert(axis_nr, FixedStrideArrayDimTag(outer_stride))
    else:
        raise RuntimeError("order '%s' not understood" % order)

    new_dim_tags = tuple(new_dim_tags)

    # }}}

    # {{{ adjust dim_names

    new_dim_names = ary.dim_names
    if new_dim_names is not None:
        new_dim_names = list(new_dim_names)
        existing_name = new_dim_names[axis_nr]
        new_dim_names[axis_nr] = existing_name + "_inner"
        outer_name = existing_name + "_outer"

        if order == "F":
            new_dim_names.insert(axis_nr + 1, outer_name)
        elif order == "C":
            new_dim_names.insert(axis_nr, outer_name)
        else:
            raise RuntimeError("order '%s' not understood" % order)
        new_dim_names = tuple(new_dim_names)

    # }}}

    kernel = achng.with_changed_array(
        ary.copy(shape=new_shape,
                 dim_tags=new_dim_tags,
                 dim_names=new_dim_names))

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    def split_access_axis(expr):
        idx = expr.index
        if not isinstance(idx, tuple):
            idx = (idx, )
        idx = list(idx)

        axis_idx = idx[axis_nr]

        from loopy.symbolic import simplify_using_aff
        inner_index = simplify_using_aff(kernel, axis_idx % count)
        outer_index = simplify_using_aff(kernel, axis_idx // count)

        idx[axis_nr] = inner_index

        if order == "F":
            idx.insert(axis_nr + 1, outer_index)
        elif order == "C":
            idx.insert(axis_nr, outer_index)
        else:
            raise RuntimeError("order '%s' not understood" % order)

        return expr.aggregate.index(tuple(idx))

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, var_name_gen)
    aash = ArrayAxisSplitHelper(rule_mapping_context, {array_name},
                                split_access_axis)
    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))

    return kernel