Exemplo n.º 1
0
def synthesize_idis_for_extra_args(kernel, schedule_index):
    """
    :returns: A list of :class:`loopy.codegen.ImplementedDataInfo`
    """
    sched_item = kernel.schedule[schedule_index]

    from loopy.codegen import ImplementedDataInfo
    from loopy.kernel.data import InameArg, temp_var_scope

    assert isinstance(sched_item, CallKernel)

    idis = []

    for arg in sched_item.extra_args:
        temporary = kernel.temporary_variables[arg]
        assert temporary.scope == temp_var_scope.GLOBAL
        idis.extend(
            temporary.decl_info(kernel.target, index_dtype=kernel.index_dtype))

    for iname in sched_item.extra_inames:
        idis.append(
            ImplementedDataInfo(target=kernel.target,
                                name=iname,
                                dtype=kernel.index_dtype,
                                arg_class=InameArg,
                                is_written=False))

    return idis
Exemplo n.º 2
0
        def gen_decls(name_suffix, shape, strides, unvec_shape, unvec_strides,
                      stride_arg_axes, dtype, user_index):
            """
            :arg unvec_shape: shape tuple
                that accounts for :class:`loopy.kernel.array.VectorArrayDimTag`
                in a scalar manner
            :arg unvec_strides: strides tuple
                that accounts for :class:`loopy.kernel.array.VectorArrayDimTag`
                in a scalar manner
            :arg stride_arg_axes: a tuple *(user_axis, impl_axis, unvec_impl_axis)*
            :arg user_index: A tuple representing a (user-facing)
                multi-dimensional subscript. This is filled in with
                concrete integers when known (such as for separate-array
                dim tags), and with *None* where the index won't be
                known until run time.
            """

            if dtype is None:
                dtype = self.dtype

            user_axis = len(user_index)

            num_user_axes = self.num_user_axes(require_answer=False)

            if num_user_axes is None or user_axis >= num_user_axes:
                # {{{ recursion base case

                full_name = self.name + name_suffix

                stride_args = []
                strides = list(strides)
                unvec_strides = list(unvec_strides)

                # generate stride arguments, yielded later to keep array first
                for stride_user_axis, stride_impl_axis, stride_unvec_impl_axis \
                        in stride_arg_axes:
                    stride_name = full_name + "_stride%d" % stride_user_axis

                    from pymbolic import var
                    strides[stride_impl_axis] = \
                            unvec_strides[stride_unvec_impl_axis] = \
                            var(stride_name)

                    stride_args.append(
                        ImplementedDataInfo(
                            target=target,
                            name=stride_name,
                            dtype=index_dtype,
                            arg_class=ValueArg,
                            stride_for_name_and_axis=(full_name,
                                                      stride_impl_axis),
                            is_written=False))

                yield ImplementedDataInfo(target=target,
                                          name=full_name,
                                          base_name=self.name,
                                          arg_class=type(self),
                                          dtype=dtype,
                                          shape=shape,
                                          strides=tuple(strides),
                                          unvec_shape=unvec_shape,
                                          unvec_strides=tuple(unvec_strides),
                                          allows_offset=bool(self.offset),
                                          is_written=is_written)

                import loopy as lp

                if self.offset is lp.auto:
                    offset_name = full_name + "_offset"
                    yield ImplementedDataInfo(target=target,
                                              name=offset_name,
                                              dtype=index_dtype,
                                              arg_class=ValueArg,
                                              offset_for_name=full_name,
                                              is_written=False)

                yield from stride_args

                # }}}

                return

            dim_tag = self.dim_tags[user_axis]

            if isinstance(dim_tag, FixedStrideArrayDimTag):
                if array_shape is None:
                    new_shape_axis = None
                else:
                    new_shape_axis = array_shape[user_axis]

                import loopy as lp
                if dim_tag.stride is lp.auto:
                    new_stride_arg_axes = stride_arg_axes \
                            + ((user_axis, len(strides), len(unvec_strides)),)

                    # repaired above when final array name is known
                    # (and stride argument is created)
                    new_stride_axis = None
                else:
                    new_stride_arg_axes = stride_arg_axes
                    new_stride_axis = dim_tag.stride

                yield from gen_decls(name_suffix, shape + (new_shape_axis, ),
                                     strides + (new_stride_axis, ),
                                     unvec_shape + (new_shape_axis, ),
                                     unvec_strides + (new_stride_axis, ),
                                     new_stride_arg_axes, dtype,
                                     user_index + (None, ))

            elif isinstance(dim_tag, SeparateArrayArrayDimTag):
                shape_i = array_shape[user_axis]
                if not is_integer(shape_i):
                    raise LoopyError("shape of '%s' has non-constant "
                                     "integer axis %d (0-based)" %
                                     (self.name, user_axis))

                for i in range(shape_i):
                    yield from gen_decls(name_suffix + "_s%d" % i, shape,
                                         strides, unvec_shape, unvec_strides,
                                         stride_arg_axes, dtype,
                                         user_index + (i, ))

            elif isinstance(dim_tag, VectorArrayDimTag):
                shape_i = array_shape[user_axis]
                if not is_integer(shape_i):
                    raise LoopyError("shape of '%s' has non-constant "
                                     "integer axis %d (0-based)" %
                                     (self.name, user_axis))

                yield from gen_decls(
                    name_suffix,
                    shape,
                    strides,
                    unvec_shape + (shape_i, ),
                    # vectors always have stride 1
                    unvec_strides + (1, ),
                    stride_arg_axes,
                    target.vector_dtype(dtype, shape_i),
                    user_index + (None, ))

            else:
                raise LoopyError(
                    "unsupported array dim implementation tag '%s' "
                    "in array '%s'" % (dim_tag, self.name))