Beispiel #1
0
def pick_off_constants(expr):
    """
    :return: a tuple ``(constant, non_constant)`` that contains
        separates out nodes constant multipliers from any other
        nodes in *expr*
    """

    if isinstance(expr, pp.Product):
        constants = []
        non_constants = []

        for child in expr.children:
            if isinstance(child, pp.Product):
                sub_const, sub_expr = pick_off_constants(child)
                constants.append(sub_const)
                non_constants.append(sub_expr)
            elif pp.is_constant(child) or isinstance(child, p.Parameter):
                constants.append(child)
            else:
                non_constants.append(child)

        return (pp.flattened_product(constants),
                pp.flattened_product(non_constants))

    else:
        return 1, expr
Beispiel #2
0
    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr
Beispiel #3
0
    def get_temporary_decl(self, codegen_state, sched_index, temp_var,
                           decl_info):
        from loopy.target.c import POD  # uses the correct complex type
        temp_var_decl = POD(self, decl_info.dtype, decl_info.name)

        shape = decl_info.shape

        if temp_var.scope == temp_var_scope.PRIVATE:
            # FIXME: This is a pretty coarse way of deciding what
            # private temporaries get duplicated. Refine? (See also
            # above in expr to code mapper)
            _, lsize = codegen_state.kernel.get_grid_size_upper_bounds_as_exprs(
            )
            shape = lsize + shape

        if shape:
            from cgen import ArrayOf
            ecm = self.get_expression_to_code_mapper(codegen_state)
            temp_var_decl = ArrayOf(
                temp_var_decl,
                ecm(p.flattened_product(shape),
                    prec=PREC_NONE,
                    type_context="i"))

        return temp_var_decl
    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr
Beispiel #5
0
def find_inner_deriv_and_coeff(expr):
    if is_derivative_binding(expr):
        return 1, expr
    elif isinstance(expr, pp.Product):
        factors = get_flat_factors(expr)

        derivatives = []
        nonderivatives = []
        for f in factors:
            if is_derivative_binding(f):
                derivatives.append(f)
            else:
                nonderivatives.append(f)

        if len(derivatives) > 1:
            raise ValueError("multiplied second derivatives in '%s'"
                    % expr)

        if not derivatives:
            # We'll only get called if there *is* a second derivative.
            # That we can't find it by picking apart the top-level
            # product is bad news.

            raise ValueError("second derivative inside nonlinearity "
                    "in '%s'" % expr)

        derivative, = derivatives

        return pp.flattened_product(nonderivatives), derivative
    else:
        raise ValueError("unexpected node type '%s' inside "
                "second derivative in '%s'"
                % (type(expr).__name__, expr))
Beispiel #6
0
    def map_product(self, expr, derivatives):
        from grudge.symbolic.tools import is_scalar
        from pytools import partition
        scalars, nonscalars = partition(is_scalar, expr.children)

        if len(nonscalars) != 1:
            return DerivativeJoiner()(expr)
        else:
            from pymbolic import flattened_product
            factor = flattened_product(scalars)
            nonscalar, = nonscalars

            sub_derivatives = {}
            nonscalar = self.rec(nonscalar, sub_derivatives)

            def do_map(expr):
                if is_scalar(expr):
                    return expr
                else:
                    return self.rec(expr, derivatives)

            for operator, operands in sub_derivatives.items():
                for operand in operands:
                    derivatives.setdefault(operator,
                                           []).append(factor * operand)

            return factor * nonscalar
Beispiel #7
0
    def map_product(self, expr):
        if len(expr.children) == 0:
            return expr

        from pymbolic.primitives import flattened_product, Product

        first = expr.children[0]
        if isinstance(first, op.Operator):
            prod = flattened_product(expr.children[1:])
            if isinstance(prod, Product) and len(prod.children) > 1:
                from warnings import warn
                warn("Binding '%s' to more than one "
                     "operand in a product is ambiguous - "
                     "use the parenthesized form instead." % first)
            return sym.OperatorBinding(first, self.rec(prod))
        else:
            return self.rec(first) * self.rec(
                flattened_product(expr.children[1:]))
Beispiel #8
0
    def map_product(self, expr, type_context):
        def base_impl(expr, type_context):
            return super(ExpressionToCExpressionMapper,
                         self).map_product(expr, type_context)

        # I've added 'type_context == "i"' because of the following
        # idiotic corner case: Code generation for subscripts comes
        # through here, and it may involve variables that we know
        # nothing about (offsets and such). If we fall into the allow_complex
        # branch, we'll try to do type inference on these variables,
        # and stuff breaks. This band-aid works around that. -AK
        if not self.allow_complex or type_context == "i":
            return base_impl(expr, type_context)

        tgt_dtype = self.infer_type(expr)
        is_complex = tgt_dtype.is_complex()

        if not is_complex:
            return base_impl(expr, type_context)
        else:
            tgt_name = self.complex_type_name(tgt_dtype)

            reals = []
            complexes = []
            for child in expr.children:
                if self.infer_type(child).is_complex():
                    complexes.append(child)
                else:
                    reals.append(child)

            real_prd = p.flattened_product(
                [self.rec(r, type_context) for r in reals])

            c_applied = [
                self.rec(c, type_context, tgt_dtype) for c in complexes
            ]

            def binary_tree_mul(start, end):
                if start + 1 == end:
                    return c_applied[start]
                mid = (start + end) // 2
                lsum = binary_tree_mul(start, mid)
                rsum = binary_tree_mul(mid, end)
                return var("%s_mul" % tgt_name)(lsum, rsum)

            complex_prd = binary_tree_mul(0, len(complexes))

            if real_prd:
                return var("%s_rmul" % tgt_name)(real_prd, complex_prd)
            else:
                return complex_prd
Beispiel #9
0
    def get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info):
        temp_var_decl = POD(self, decl_info.dtype, decl_info.name)

        if temp_var.read_only:
            from cgen import Const
            temp_var_decl = Const(temp_var_decl)

        if decl_info.shape:
            from cgen import ArrayOf
            ecm = self.get_expression_to_code_mapper(codegen_state)
            temp_var_decl = ArrayOf(temp_var_decl,
                    ecm(p.flattened_product(decl_info.shape),
                        prec=PREC_NONE, type_context="i"))

        return temp_var_decl
Beispiel #10
0
    def get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info):
        temp_var_decl = POD(self, decl_info.dtype, decl_info.name)

        if temp_var.read_only:
            from cgen import Const
            temp_var_decl = Const(temp_var_decl)

        if decl_info.shape:
            from cgen import ArrayOf
            ecm = self.get_expression_to_code_mapper(codegen_state)
            temp_var_decl = ArrayOf(temp_var_decl,
                    ecm(p.flattened_product(decl_info.shape),
                        prec=PREC_NONE, type_context="i"))

        if temp_var.alignment:
            from cgen import AlignedAttribute
            temp_var_decl = AlignedAttribute(temp_var.alignment, temp_var_decl)

        return temp_var_decl
Beispiel #11
0
    def get_temporary_decl(self, codegen_state, sched_index, temp_var, decl_info):
        from loopy.target.c import POD  # uses the correct complex type
        temp_var_decl = POD(self, decl_info.dtype, decl_info.name)

        shape = decl_info.shape

        if temp_var.address_space == AddressSpace.PRIVATE:
            # FIXME: This is a pretty coarse way of deciding what
            # private temporaries get duplicated. Refine? (See also
            # above in expr to code mapper)
            _, lsize = codegen_state.kernel.get_grid_size_upper_bounds_as_exprs()
            shape = lsize + shape

        if shape:
            from cgen import ArrayOf
            ecm = self.get_expression_to_code_mapper(codegen_state)
            temp_var_decl = ArrayOf(
                    temp_var_decl,
                    ecm(p.flattened_product(shape),
                        prec=PREC_NONE, type_context="i"))

        return temp_var_decl
Beispiel #12
0
 def map_product(self, expr, *args, **kwargs):
     from pymbolic.primitives import flattened_product
     return flattened_product(tuple(
         self.rec(child, *args, **kwargs) for child in expr.children))
Beispiel #13
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = dict(
                    (name, idx)
                    for idx, name in enumerate(var_descr.dim_names))
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (
                vary_by_axes
                and
                (min(vary_by_axes) < 0
                or
                max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero,
            flattened_sum, flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
            UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                    lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return any(
                lhs == var_name
                for lhs, sbscript in insn.assignees_and_indices())

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            for ch in expr.children:
                yield ch
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-expression instruction"
                    % var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                    make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = set(
                        cf for cf in my_common_factors
                        if unif_subst_map(cf) in product_parts)

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        (_, index_key), = insn.assignees_and_indices()

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        mapped_my_common_factors = set(
                unif_subst_map(cf)
                for cf in my_common_factors)

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                    flattened_product([
                        part
                        for part in iterate_as(Product, term)
                        if part not in mapped_my_common_factors
                        ]))

        new_insns.append(
                insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Beispiel #14
0
 def map_product(self, expr, *args, **kwargs):
     from pymbolic.primitives import flattened_product
     return flattened_product(tuple(
         self.rec(child, *args, **kwargs) for child in expr.children))
Beispiel #15
0
    def map_product(self, expr):
        from pymbolic.primitives import flattened_product

        return flattened_product(self.rec(ch) for ch in expr.children)
Beispiel #16
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    assert isinstance(kernel, LoopKernel)
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = {
                name: idx
                for idx, name in enumerate(var_descr.dim_names)
            }
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (vary_by_axes
                and (min(vary_by_axes) < 0
                     or max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum,
                                     flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
                                UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, _val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return var_name in insn.assignee_var_names()

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            yield from expr.children
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-single-assignment" %
                             var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = {
                    cf
                    for cf in my_common_factors
                    if unif_subst_map(cf) in product_parts
                }

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    common_factors = [(ik, cf) for ik, cf in common_factors if cf]

    if not common_factors:
        raise LoopyError("no common factors found")

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(insn.assignee)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        mapped_my_common_factors = {
            unif_subst_map(cf)
            for cf in my_common_factors
        }

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                flattened_product([
                    part for part in iterate_as(Product, term)
                    if part not in mapped_my_common_factors
                ]))

        new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                [unif_subst_map(cf) for cf in my_common_factors] + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Beispiel #17
0
 def map_product(self, expr):
     from pymbolic.primitives import flattened_product
     return flattened_product(self.rec(ch) for ch in expr.children)