Exemplo n.º 1
0
def generate_arg_setup(gen, kernel, implemented_data_info, options):
    import loopy as lp

    from loopy.kernel.data import KernelArgument
    from loopy.kernel.array import ArrayBase
    from loopy.symbolic import StringifyMapper
    from pymbolic import var

    gen("# {{{ set up array arguments")
    gen("")

    if not options.no_numpy:
        gen("_lpy_encountered_numpy = False")
        gen("_lpy_encountered_dev = False")
        gen("")

    args = []

    strify = StringifyMapper()

    expect_no_more_arguments = False

    for arg_idx, arg in enumerate(implemented_data_info):
        is_written = arg.base_name in kernel.get_written_variables()
        kernel_arg = kernel.impl_arg_to_arg.get(arg.name)

        if not issubclass(arg.arg_class, KernelArgument):
            expect_no_more_arguments = True
            continue

        if expect_no_more_arguments:
            raise LoopyError("Further arguments encountered after arg info "
                             "describing a global temporary variable")

        if not issubclass(arg.arg_class, ArrayBase):
            args.append(arg.name)
            continue

        gen("# {{{ process %s" % arg.name)
        gen("")

        if not options.no_numpy:
            gen("if isinstance(%s, _lpy_np.ndarray):" % arg.name)
            with Indentation(gen):
                gen("# synchronous, nothing to worry about")
                gen("%s = _lpy_cl_array.to_device("
                    "queue, %s, allocator=allocator)" % (arg.name, arg.name))
                gen("_lpy_encountered_numpy = True")
            gen("elif %s is not None:" % arg.name)
            with Indentation(gen):
                gen("_lpy_encountered_dev = True")

            gen("")

        if not options.skip_arg_checks and not is_written:
            gen("if %s is None:" % arg.name)
            with Indentation(gen):
                gen("raise RuntimeError(\"input argument '%s' must "
                    "be supplied\")" % arg.name)
                gen("")

        if (is_written and arg.arg_class is lp.ImageArg
                and not options.skip_arg_checks):
            gen("if %s is None:" % arg.name)
            with Indentation(gen):
                gen("raise RuntimeError(\"written image '%s' must "
                    "be supplied\")" % arg.name)
                gen("")

        if is_written and arg.shape is None and not options.skip_arg_checks:
            gen("if %s is None:" % arg.name)
            with Indentation(gen):
                gen("raise RuntimeError(\"written argument '%s' has "
                    "unknown shape and must be supplied\")" % arg.name)
                gen("")

        possibly_made_by_loopy = False

        # {{{ allocate written arrays, if needed

        if is_written and arg.arg_class in [lp.GlobalArg, lp.ConstantArg] \
                and arg.shape is not None \
                and all(si is not None for si in arg.shape):

            if not isinstance(arg.dtype, NumpyType):
                raise LoopyError("do not know how to pass arg of type '%s'" %
                                 arg.dtype)

            possibly_made_by_loopy = True
            gen("_lpy_made_by_loopy = False")
            gen("")

            gen("if %s is None:" % arg.name)
            with Indentation(gen):
                num_axes = len(arg.strides)
                for i in range(num_axes):
                    gen("_lpy_shape_%d = %s" % (i, strify(arg.unvec_shape[i])))

                itemsize = kernel_arg.dtype.numpy_dtype.itemsize
                for i in range(num_axes):
                    gen("_lpy_strides_%d = %s" %
                        (i, strify(itemsize * arg.unvec_strides[i])))

                if not options.skip_arg_checks:
                    for i in range(num_axes):
                        gen("assert _lpy_strides_%d > 0, "
                            "\"'%s' has negative stride in axis %d\"" %
                            (i, arg.name, i))

                sym_strides = tuple(
                    var("_lpy_strides_%d" % i) for i in range(num_axes))
                sym_shape = tuple(
                    var("_lpy_shape_%d" % i) for i in range(num_axes))

                alloc_size_expr = (
                    sum(astrd * (alen - 1)
                        for alen, astrd in zip(sym_shape, sym_strides)) +
                    itemsize)

                gen("_lpy_alloc_size = %s" % strify(alloc_size_expr))
                gen("%(name)s = _lpy_cl_array.Array(queue, %(shape)s, "
                    "%(dtype)s, strides=%(strides)s, "
                    "data=allocator(_lpy_alloc_size), allocator=allocator)" %
                    dict(name=arg.name,
                         shape=strify(sym_shape),
                         strides=strify(sym_strides),
                         dtype=python_dtype_str(kernel_arg.dtype.numpy_dtype)))

                if not options.skip_arg_checks:
                    for i in range(num_axes):
                        gen("del _lpy_shape_%d" % i)
                        gen("del _lpy_strides_%d" % i)
                    gen("del _lpy_alloc_size")
                    gen("")

                gen("_lpy_made_by_loopy = True")
                gen("")

        # }}}

        # {{{ argument checking

        if arg.arg_class in [lp.GlobalArg, lp.ConstantArg] \
                and not options.skip_arg_checks:
            if possibly_made_by_loopy:
                gen("if not _lpy_made_by_loopy:")
            else:
                gen("if True:")

            with Indentation(gen):
                gen("if %s.dtype != %s:" %
                    (arg.name, python_dtype_str(kernel_arg.dtype.numpy_dtype)))
                with Indentation(gen):
                    gen("raise TypeError(\"dtype mismatch on argument '%s' "
                        "(got: %%s, expected: %s)\" %% %s.dtype)" %
                        (arg.name, arg.dtype, arg.name))

                # {{{ generate shape checking code

                def strify_allowing_none(shape_axis):
                    if shape_axis is None:
                        return "None"
                    else:
                        return strify(shape_axis)

                def strify_tuple(t):
                    if len(t) == 0:
                        return "()"
                    else:
                        return "(%s,)" % ", ".join(
                            strify_allowing_none(sa) for sa in t)

                shape_mismatch_msg = (
                    "raise TypeError(\"shape mismatch on argument '%s' "
                    "(got: %%s, expected: %%s)\" "
                    "%% (%s.shape, %s))" %
                    (arg.name, arg.name, strify_tuple(arg.unvec_shape)))

                if kernel_arg.shape is None:
                    pass

                elif any(shape_axis is None
                         for shape_axis in kernel_arg.shape):
                    gen("if len(%s.shape) != %s:" %
                        (arg.name, len(arg.unvec_shape)))
                    with Indentation(gen):
                        gen(shape_mismatch_msg)

                    for i, shape_axis in enumerate(arg.unvec_shape):
                        if shape_axis is None:
                            continue

                        gen("if %s.shape[%d] != %s:" %
                            (arg.name, i, strify(shape_axis)))
                        with Indentation(gen):
                            gen(shape_mismatch_msg)

                else:  # not None, no Nones in tuple
                    gen("if %s.shape != %s:" %
                        (arg.name, strify(arg.unvec_shape)))
                    with Indentation(gen):
                        gen(shape_mismatch_msg)

                # }}}

                if arg.unvec_strides and kernel_arg.dim_tags:
                    itemsize = kernel_arg.dtype.numpy_dtype.itemsize
                    sym_strides = tuple(itemsize * s_i
                                        for s_i in arg.unvec_strides)
                    gen("if %s.strides != %s:" %
                        (arg.name, strify(sym_strides)))
                    with Indentation(gen):
                        gen("raise TypeError(\"strides mismatch on "
                            "argument '%s' (got: %%s, expected: %%s)\" "
                            "%% (%s.strides, %s))" %
                            (arg.name, arg.name, strify(sym_strides)))

                if not arg.allows_offset:
                    gen("if %s.offset:" % arg.name)
                    with Indentation(gen):
                        gen("raise ValueError(\"Argument '%s' does not "
                            "allow arrays with offsets. Try passing "
                            "default_offset=loopy.auto to make_kernel()."
                            "\")" % arg.name)
                        gen("")

        # }}}

        if possibly_made_by_loopy and not options.skip_arg_checks:
            gen("del _lpy_made_by_loopy")
            gen("")

        if arg.arg_class in [lp.GlobalArg, lp.ConstantArg]:
            args.append("%s.base_data" % arg.name)
        else:
            args.append("%s" % arg.name)

        gen("")

        gen("# }}}")
        gen("")

    gen("# }}}")
    gen("")

    return args
Exemplo n.º 2
0
def generate_integer_arg_finding_from_shapes(gen, kernel,
                                             implemented_data_info):
    # a mapping from integer argument names to a list of tuples
    # (arg_name, expression), where expression is a
    # unary function of kernel.arg_dict[arg_name]
    # returning the desired integer argument.
    iarg_to_sources = {}

    from loopy.kernel.data import GlobalArg
    from loopy.symbolic import DependencyMapper, StringifyMapper
    dep_map = DependencyMapper()

    from pymbolic import var
    for arg in implemented_data_info:
        if arg.arg_class is GlobalArg:
            sym_shape = var(arg.name).attr("shape")
            for axis_nr, shape_i in enumerate(arg.shape):
                if shape_i is None:
                    continue

                deps = dep_map(shape_i)

                if len(deps) == 1:
                    integer_arg_var, = deps

                    if kernel.arg_dict[
                            integer_arg_var.name].dtype.is_integral():
                        from pymbolic.algorithm import solve_affine_equations_for
                        try:
                            # friggin' overkill :)
                            iarg_expr = solve_affine_equations_for(
                                [integer_arg_var.name], [
                                    (shape_i, sym_shape.index(axis_nr))
                                ])[integer_arg_var]
                        except Exception as e:
                            #from traceback import print_exc
                            #print_exc()

                            # went wrong? oh well
                            from warnings import warn
                            warn(
                                "Unable to generate code to automatically "
                                "find '%s' from the shape of '%s':\n%s" %
                                (integer_arg_var.name, arg.name, str(e)),
                                ParameterFinderWarning)
                        else:
                            iarg_to_sources.setdefault(integer_arg_var.name, []) \
                                    .append((arg.name, iarg_expr))

    gen("# {{{ find integer arguments from shapes")
    gen("")

    for iarg_name, sources in six.iteritems(iarg_to_sources):
        gen("if %s is None:" % iarg_name)
        with Indentation(gen):
            if_stmt = "if"
            for arg_name, value_expr in sources:
                gen("%s %s is not None:" % (if_stmt, arg_name))
                with Indentation(gen):
                    gen("%s = %s" % (iarg_name, StringifyMapper()(value_expr)))

                if_stmt = "elif"

        gen("")

    gen("# }}}")
    gen("")
Exemplo n.º 3
0
    def generate_arg_setup(
            self, gen, kernel, implemented_data_info, options):
        import loopy as lp

        from loopy.kernel.data import KernelArgument
        from loopy.kernel.array import ArrayBase
        from loopy.symbolic import StringifyMapper
        from loopy.types import NumpyType

        gen("# {{{ set up array arguments")
        gen("")

        if not options.no_numpy:
            gen("_lpy_encountered_numpy = False")
            gen("_lpy_encountered_dev = False")
            gen("")

        args = []

        strify = StringifyMapper()

        expect_no_more_arguments = False

        for arg in implemented_data_info:
            is_written = arg.base_name in kernel.get_written_variables()
            kernel_arg = kernel.impl_arg_to_arg.get(arg.name)

            if not issubclass(arg.arg_class, KernelArgument):
                expect_no_more_arguments = True
                continue

            if expect_no_more_arguments:
                raise LoopyError("Further arguments encountered after arg info "
                        "describing a global temporary variable")

            if not issubclass(arg.arg_class, ArrayBase):
                args.append(arg.name)
                continue

            gen("# {{{ process %s" % arg.name)
            gen("")

            if not options.no_numpy:
                self.handle_non_numpy_arg(gen, arg)

            if not options.skip_arg_checks and not is_written:
                gen("if %s is None:" % arg.name)
                with Indentation(gen):
                    gen("raise RuntimeError(\"input argument '%s' must "
                            'be supplied")' % arg.name)
                    gen("")

            if (is_written
                    and arg.arg_class is lp.ImageArg
                    and not options.skip_arg_checks):
                gen("if %s is None:" % arg.name)
                with Indentation(gen):
                    gen("raise RuntimeError(\"written image '%s' must "
                            'be supplied")' % arg.name)
                    gen("")

            if is_written and arg.shape is None and not options.skip_arg_checks:
                gen("if %s is None:" % arg.name)
                with Indentation(gen):
                    gen("raise RuntimeError(\"written argument '%s' has "
                            'unknown shape and must be supplied")' % arg.name)
                    gen("")

            possibly_made_by_loopy = False

            # {{{ allocate written arrays, if needed

            if is_written and arg.arg_class in [lp.ArrayArg, lp.ConstantArg] \
                    and arg.shape is not None \
                    and all(si is not None for si in arg.shape):

                if not isinstance(arg.dtype, NumpyType):
                    raise LoopyError("do not know how to pass arg of type '%s'"
                            % arg.dtype)

                possibly_made_by_loopy = True
                gen("_lpy_made_by_loopy = False")
                gen("")

                gen("if %s is None:" % arg.name)
                with Indentation(gen):
                    self.handle_alloc(
                        gen, arg, kernel_arg, strify, options.skip_arg_checks)
                    gen("_lpy_made_by_loopy = True")
                    gen("")

            # }}}

            # {{{ argument checking

            if arg.arg_class in [lp.ArrayArg, lp.ConstantArg] \
                    and not options.skip_arg_checks:
                if possibly_made_by_loopy:
                    gen("if not _lpy_made_by_loopy:")
                else:
                    gen("if True:")

                with Indentation(gen):
                    gen("if %s.dtype != %s:"
                            % (arg.name, self.python_dtype_str(
                                gen, kernel_arg.dtype.numpy_dtype)))
                    with Indentation(gen):
                        gen("raise TypeError(\"dtype mismatch on argument '%s' "
                                '(got: %%s, expected: %s)" %% %s.dtype)'
                                % (arg.name, arg.dtype, arg.name))

                    # {{{ generate shape checking code

                    def strify_allowing_none(shape_axis):
                        if shape_axis is None:
                            return "None"
                        else:
                            return strify(shape_axis)

                    def strify_tuple(t):
                        if len(t) == 0:
                            return "()"
                        else:
                            return "(%s,)" % ", ".join(
                                    strify_allowing_none(sa)
                                    for sa in t)

                    shape_mismatch_msg = (
                            "raise TypeError(\"shape mismatch on argument '%s' "
                            '(got: %%s, expected: %%s)" '
                            "%% (%s.shape, %s))"
                            % (arg.name, arg.name, strify_tuple(arg.unvec_shape)))

                    if kernel_arg.shape is None:
                        pass

                    elif any(shape_axis is None for shape_axis in kernel_arg.shape):
                        gen("if len(%s.shape) != %s:"
                                % (arg.name, len(arg.unvec_shape)))
                        with Indentation(gen):
                            gen(shape_mismatch_msg)

                        for i, shape_axis in enumerate(arg.unvec_shape):
                            if shape_axis is None:
                                continue

                            gen("if %s.shape[%d] != %s:"
                                    % (arg.name, i, strify(shape_axis)))
                            with Indentation(gen):
                                gen(shape_mismatch_msg)

                    else:  # not None, no Nones in tuple
                        gen("if %s.shape != %s:"
                                % (arg.name, strify(arg.unvec_shape)))
                        with Indentation(gen):
                            gen(shape_mismatch_msg)

                    # }}}

                    if arg.unvec_strides and kernel_arg.dim_tags:
                        itemsize = kernel_arg.dtype.numpy_dtype.itemsize
                        sym_strides = tuple(
                                itemsize*s_i for s_i in arg.unvec_strides)

                        ndim = len(arg.unvec_shape)
                        shape = ["_lpy_shape_%d" % i for i in range(ndim)]
                        strides = ["_lpy_stride_%d" % i for i in range(ndim)]

                        gen("({},) = {}.shape".format(", ".join(shape), arg.name))
                        gen("({},) = {}.strides".format(
                            ", ".join(strides), arg.name))

                        gen("if not (%s):"
                                % self.get_strides_check_expr(
                                    shape, strides,
                                    (strify(s) for s in sym_strides)))
                        with Indentation(gen):
                            gen("_lpy_got = tuple(stride "
                                    "for (dim, stride) in zip(%s.shape, %s.strides) "
                                    "if dim > 1)"
                                    % (arg.name, arg.name))
                            gen("_lpy_expected = tuple(stride "
                                    "for (dim, stride) in zip(%s.shape, %s) "
                                    "if dim > 1)"
                                    % (arg.name, strify_tuple(sym_strides)))

                            gen('raise TypeError("strides mismatch on '
                                    "argument '%s' "
                                    "(after removing unit length dims, "
                                    'got: %%s, expected: %%s)" '
                                    "%% (_lpy_got, _lpy_expected))"
                                    % arg.name)

                    if not arg.allows_offset:
                        gen("if hasattr({}, 'offset') and {}.offset:".format(
                                arg.name, arg.name))
                        with Indentation(gen):
                            gen("raise ValueError(\"Argument '%s' does not "
                                    "allow arrays with offsets. Try passing "
                                    "default_offset=loopy.auto to make_kernel()."
                                    '")' % arg.name)
                            gen("")

            # }}}

            if possibly_made_by_loopy and not options.skip_arg_checks:
                gen("del _lpy_made_by_loopy")
                gen("")

            if arg.arg_class in [lp.ArrayArg, lp.ConstantArg]:
                args.append(self.get_arg_pass(arg))
            else:
                args.append("%s" % arg.name)

            gen("")

            gen("# }}}")
            gen("")

        gen("# }}}")
        gen("")

        return args