Пример #1
0
    def map_sum(self, expr):
        children = [self.rec(child) for child in expr.children]
        if all(child is orig for child, orig in zip(children, expr.children)):
            return expr

        from pymbolic.primitives import flattened_sum
        return flattened_sum(children)
Пример #2
0
    def map_sum(self, expr, type_context):
        def base_impl(expr, type_context):
            return super(ExpressionToCExpressionMapper,
                         self).map_sum(expr, type_context)

        # I've added 'type_context == "i"' because of the following
        # idiotic corner case: Code generation for subscripts comes
        # through here, and it may involve variables that we know
        # nothing about (offsets and such). If we fall into the allow_complex
        # branch, we'll try to do type inference on these variables,
        # and stuff breaks. This band-aid works around that. -AK
        if not self.allow_complex or type_context == "i":
            return base_impl(expr, type_context)

        tgt_dtype = self.infer_type(expr)
        is_complex = tgt_dtype.is_complex()

        if not is_complex:
            return base_impl(expr, type_context)
        else:
            tgt_name = self.complex_type_name(tgt_dtype)

            reals = []
            complexes = []
            for child in expr.children:
                if self.infer_type(child).is_complex():
                    complexes.append(child)
                else:
                    reals.append(child)

            real_sum = p.flattened_sum(
                [self.rec(r, type_context) for r in reals])

            c_applied = [
                self.rec(c, type_context, tgt_dtype) for c in complexes
            ]

            def binary_tree_add(start, end):
                if start + 1 == end:
                    return c_applied[start]
                mid = (start + end) // 2
                lsum = binary_tree_add(start, mid)
                rsum = binary_tree_add(mid, end)
                return var("%s_add" % tgt_name)(lsum, rsum)

            complex_sum = binary_tree_add(0, len(c_applied))

            if real_sum:
                return var("%s_radd" % tgt_name)(real_sum, complex_sum)
            else:
                return complex_sum
Пример #3
0
    def rebuild_optemplate(self):
        def generate_summands():
            for i in self.interiors:
                if self.quadrature_tag is None:
                    yield FluxOperator(i.flux_expr, self.is_lift)(i.field_expr)
                else:
                    yield QuadratureFluxOperator(i.flux_expr, self.quadrature_tag)(i.field_expr)
            for b in self.boundaries:
                if self.quadrature_tag is None:
                    yield BoundaryFluxOperator(b.flux_expr, b.bpair.tag, self.is_lift)(b.bpair)
                else:
                    yield QuadratureBoundaryFluxOperator(b.flux_expr, self.quadrature_tag, b.bpair.tag)(b.bpair)

        from pymbolic.primitives import flattened_sum

        return flattened_sum(generate_summands())
Пример #4
0
    def map_operator_binding(self, expr):
        from hedge.optemplate import \
                FluxOperatorBase, \
                BoundaryPair, OperatorBinding, \
                FluxExchangeOperator

        if isinstance(expr, OperatorBinding):
            if isinstance(expr.op, FluxOperatorBase):
                if isinstance(expr.field, BoundaryPair):
                    # we're only worried about internal fluxes
                    return IdentityMapper.map_operator_binding(self, expr)

                # by now we've narrowed it down to a bound interior flux

                def func_on_scalar_or_vector(func, arg_fields):
                    # No CSE necessary here--the compiler CSE's these
                    # automatically.

                    from hedge.tools import is_obj_array, make_obj_array
                    if is_obj_array(arg_fields):
                        # arg_fields (as an object array) isn't hashable
                        # --make it so by turning it into a tuple
                        arg_fields = tuple(arg_fields)

                        return make_obj_array([
                            func(i, arg_fields) for i in range(len(arg_fields))
                        ])
                    else:
                        return func(0, (arg_fields, ))

                from hedge.mesh import TAG_RANK_BOUNDARY

                def exchange_and_cse(rank):
                    return func_on_scalar_or_vector(
                        lambda i, args: FluxExchangeOperator(i, rank, args),
                        expr.field)

                from pymbolic.primitives import flattened_sum
                return flattened_sum([expr] + [
                    OperatorBinding(
                        expr.op,
                        BoundaryPair(expr.field, exchange_and_cse(rank),
                                     TAG_RANK_BOUNDARY(rank)))
                    for rank in self.interacting_ranks
                ])
            else:
                return IdentityMapper.map_operator_binding(self, expr)
Пример #5
0
    def map_operator_binding(self, expr):
        from hedge.optemplate import \
                FluxOperatorBase, \
                BoundaryPair, OperatorBinding, \
                FluxExchangeOperator

        if isinstance(expr, OperatorBinding):
            if isinstance(expr.op, FluxOperatorBase):
                if isinstance(expr.field, BoundaryPair):
                    # we're only worried about internal fluxes
                    return IdentityMapper.map_operator_binding(self, expr)

                # by now we've narrowed it down to a bound interior flux

                def func_on_scalar_or_vector(func, arg_fields):
                    # No CSE necessary here--the compiler CSE's these
                    # automatically.

                    from hedge.tools import is_obj_array, make_obj_array
                    if is_obj_array(arg_fields):
                        # arg_fields (as an object array) isn't hashable
                        # --make it so by turning it into a tuple
                        arg_fields = tuple(arg_fields)

                        return make_obj_array([
                            func(i, arg_fields)
                            for i in range(len(arg_fields))])
                    else:
                        return func(0, (arg_fields,))

                from hedge.mesh import TAG_RANK_BOUNDARY

                def exchange_and_cse(rank):
                    return func_on_scalar_or_vector(
                            lambda i, args: FluxExchangeOperator(i, rank, args),
                            expr.field)

                from pymbolic.primitives import flattened_sum
                return flattened_sum([expr]
                    + [OperatorBinding(expr.op, BoundaryPair(
                        expr.field,
                        exchange_and_cse(rank),
                        TAG_RANK_BOUNDARY(rank)))
                        for rank in self.interacting_ranks])
            else:
                return IdentityMapper.map_operator_binding(self, expr)
Пример #6
0
    def rebuild_optemplate(self):
        def generate_summands():
            for i in self.interiors:
                if self.quadrature_tag is None:
                    yield FluxOperator(i.flux_expr, self.is_lift)(i.field_expr)
                else:
                    yield QuadratureFluxOperator(
                        i.flux_expr, self.quadrature_tag)(i.field_expr)
            for b in self.boundaries:
                if self.quadrature_tag is None:
                    yield BoundaryFluxOperator(b.flux_expr, b.bpair.tag,
                                               self.is_lift)(b.bpair)
                else:
                    yield QuadratureBoundaryFluxOperator(
                        b.flux_expr, self.quadrature_tag, b.bpair.tag)(b.bpair)

        from pymbolic.primitives import flattened_sum
        return flattened_sum(generate_summands())
Пример #7
0
    def map_derivative_source(self, expr):
        rec_operand = self.rec(expr.operand)

        nablas = []
        for d_or_n in self.derivative_collector(rec_operand):
            if isinstance(d_or_n, prim.NablaComponent):
                nablas.append(d_or_n)
            elif isinstance(d_or_n, prim.DerivativeSource):
                pass
            else:
                raise RuntimeError("unexpected result from "
                        "DerivativeSourceAndNablaComponentCollector")

        n_axes = max(n.ambient_axis for n in nablas) + 1
        assert n_axes

        from pymbolic.primitives import flattened_sum
        return flattened_sum(
                self.take_derivative(
                    axis,
                    self.nabla_component_to_unit_vector(expr.nabla_id, axis)
                    (rec_operand))
                for axis in range(n_axes))
Пример #8
0
    def map_sum(self, expr):
        idj = _InnerDerivativeJoiner()

        def invoke_idj(expr):
            sub_derivatives = {}
            result = idj(expr, sub_derivatives)
            if not sub_derivatives:
                return expr
            else:
                for operator, operands in sub_derivatives.items():
                    derivatives.setdefault(operator, []).extend(operands)

                return result

        derivatives = {}
        new_children = [invoke_idj(child) for child in expr.children]

        for operator, operands in derivatives.items():
            new_children.insert(
                0, operator(sum(self.rec(operand) for operand in operands)))

        from pymbolic.primitives import flattened_sum
        return flattened_sum(new_children)
Пример #9
0
    def map_derivative_source(self, expr):
        rec_operand = self.rec(expr.operand)

        nablas = []
        for d_or_n in self.derivative_collector(rec_operand):
            if isinstance(d_or_n, prim.NablaComponent):
                nablas.append(d_or_n)
            elif isinstance(d_or_n, prim.DerivativeSource):
                pass
            else:
                raise RuntimeError(
                    "unexpected result from "
                    "DerivativeSourceAndNablaComponentCollector")

        n_axes = max(n.ambient_axis for n in nablas) + 1
        assert n_axes

        from pymbolic.primitives import flattened_sum
        return flattened_sum(
            self.take_derivative(
                axis,
                self.nabla_component_to_unit_vector(expr.nabla_id, axis)(
                    rec_operand)) for axis in range(n_axes))
Пример #10
0
 def map_sum(self, expr, *args, **kwargs):
     from pymbolic.primitives import flattened_sum
     return flattened_sum(tuple(
         self.rec(child, *args, **kwargs) for child in expr.children))
Пример #11
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = dict(
                    (name, idx)
                    for idx, name in enumerate(var_descr.dim_names))
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (
                vary_by_axes
                and
                (min(vary_by_axes) < 0
                or
                max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero,
            flattened_sum, flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
            UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                    lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return any(
                lhs == var_name
                for lhs, sbscript in insn.assignees_and_indices())

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            for ch in expr.children:
                yield ch
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-expression instruction"
                    % var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                    make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = set(
                        cf for cf in my_common_factors
                        if unif_subst_map(cf) in product_parts)

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        (_, index_key), = insn.assignees_and_indices()

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        mapped_my_common_factors = set(
                unif_subst_map(cf)
                for cf in my_common_factors)

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                    flattened_product([
                        part
                        for part in iterate_as(Product, term)
                        if part not in mapped_my_common_factors
                        ]))

        new_insns.append(
                insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Пример #12
0
    def map_product(self, expr):
        # {{{ gather NablaComponents and DerivativeSources

        d_source_nabla_ids_per_child = []

        # id to set((child index, axis), ...)
        nabla_finder = {}

        for child_idx, rec_child in enumerate(expr.children):
            nabla_component_ids = set()
            derivative_source_ids = set()

            nablas = []
            for d_or_n in self.derivative_collector(rec_child):
                if isinstance(d_or_n, prim.NablaComponent):
                    nabla_component_ids.add(d_or_n.nabla_id)
                    nablas.append(d_or_n)
                elif isinstance(d_or_n, prim.DerivativeSource):
                    derivative_source_ids.add(d_or_n.nabla_id)
                else:
                    raise RuntimeError(
                        "unexpected result from "
                        "DerivativeSourceAndNablaComponentCollector")

            d_source_nabla_ids_per_child.append(derivative_source_ids)

            for ncomp in nablas:
                nabla_finder.setdefault(ncomp.nabla_id, set()).add(
                    (child_idx, ncomp.ambient_axis))

        # }}}

        if nabla_finder and not any(d_source_nabla_ids_per_child):
            raise ValueError(
                "no derivative source found to resolve in '%s'"
                "--did you forget to wrap the term that should have its "
                "derivative taken in 'Derivative()(term)'?" % str(expr))

        # a list of lists, the outer level presenting a sum, the inner a product
        result = [list(expr.children)]

        for child_idx, (d_source_nabla_ids, child) in enumerate(
                zip(d_source_nabla_ids_per_child, expr.children)):
            if not d_source_nabla_ids:
                continue

            if len(d_source_nabla_ids) > 1:
                raise NotImplementedError("more than one DerivativeSource per "
                                          "child in a product")

            nabla_id, = d_source_nabla_ids
            try:
                nablas = nabla_finder[nabla_id]
            except KeyError:
                continue

            if self.restrict_to_id is not None and nabla_id != self.restrict_to_id:
                continue

            n_axes = max(axis for _, axis in nablas) + 1

            new_result = []
            for prod_term_list in result:
                for axis in range(n_axes):
                    new_ptl = prod_term_list[:]
                    dsfinder = self.derivative_source_finder(
                        nabla_id, self, axis)

                    new_ptl[child_idx] = dsfinder(new_ptl[child_idx])
                    for nabla_child_index, _ in nablas:
                        new_ptl[nabla_child_index] = \
                                self.nabla_component_to_unit_vector(nabla_id, axis)(
                                        new_ptl[nabla_child_index])

                    new_result.append(new_ptl)

            result = new_result

        from pymbolic.primitives import flattened_sum
        return flattened_sum(
            type(expr)(tuple(
                self.rec(prod_term) for prod_term in prod_term_list))
            for prod_term_list in result)
Пример #13
0
    def emit_assignment(self, codegen_state, insn):
        kernel = codegen_state.kernel
        ecm = codegen_state.expression_to_code_mapper

        assignee_var_name, = insn.assignee_var_names()

        lhs_var = codegen_state.kernel.get_var_descriptor(assignee_var_name)
        lhs_dtype = lhs_var.dtype

        if insn.atomicity:
            raise NotImplementedError("atomic ops in ISPC")

        from loopy.expression import dtype_to_type_context
        from pymbolic.mapper.stringifier import PREC_NONE

        rhs_type_context = dtype_to_type_context(kernel.target, lhs_dtype)
        rhs_code = ecm(insn.expression,
                       prec=PREC_NONE,
                       type_context=rhs_type_context,
                       needed_dtype=lhs_dtype)

        lhs = insn.assignee

        # {{{ handle streaming stores

        if "!streaming_store" in insn.tags:
            ary = ecm.find_array(lhs)

            from loopy.kernel.array import get_access_info
            from pymbolic import evaluate

            from loopy.symbolic import simplify_using_aff
            index_tuple = tuple(
                simplify_using_aff(kernel, idx) for idx in lhs.index_tuple)

            access_info = get_access_info(
                kernel.target, ary, index_tuple,
                lambda expr: evaluate(expr, self.codegen_state.var_subst_map),
                codegen_state.vectorization_info)

            from loopy.kernel.data import GlobalArg, TemporaryVariable

            if not isinstance(ary, (GlobalArg, TemporaryVariable)):
                raise LoopyError("array type not supported in ISPC: %s" %
                                 type(ary).__name)

            if len(access_info.subscripts) != 1:
                raise LoopyError("streaming stores must have a subscript")
            subscript, = access_info.subscripts

            from pymbolic.primitives import Sum, flattened_sum, Variable
            if isinstance(subscript, Sum):
                terms = subscript.children
            else:
                terms = (subscript.children, )

            new_terms = []

            from loopy.kernel.data import LocalIndexTag
            from loopy.symbolic import get_dependencies

            saw_l0 = False
            for term in terms:
                if (isinstance(term, Variable) and isinstance(
                        kernel.iname_to_tag.get(term.name), LocalIndexTag)
                        and kernel.iname_to_tag.get(term.name).axis == 0):
                    if saw_l0:
                        raise LoopyError("streaming store must have stride 1 "
                                         "in local index, got: %s" % subscript)
                    saw_l0 = True
                    continue
                else:
                    for dep in get_dependencies(term):
                        if (isinstance(kernel.iname_to_tag.get(dep),
                                       LocalIndexTag)
                                and kernel.iname_to_tag.get(dep).axis == 0):
                            raise LoopyError(
                                "streaming store must have stride 1 "
                                "in local index, got: %s" % subscript)

                    new_terms.append(term)

            if not saw_l0:
                raise LoopyError("streaming store must have stride 1 in "
                                 "local index, got: %s" % subscript)

            if access_info.vector_index is not None:
                raise LoopyError("streaming store may not use a short-vector "
                                 "data type")

            rhs_has_programindex = any(
                isinstance(kernel.iname_to_tag.get(dep), LocalIndexTag)
                and kernel.iname_to_tag.get(dep).axis == 0
                for dep in get_dependencies(insn.expression))

            if not rhs_has_programindex:
                rhs_code = "broadcast(%s, 0)" % rhs_code

            from cgen import Statement
            return Statement(
                "streaming_store(%s + %s, %s)" %
                (access_info.array_name,
                 ecm(flattened_sum(new_terms), PREC_NONE, 'i'), rhs_code))

        # }}}

        from cgen import Assign
        return Assign(ecm(lhs, prec=PREC_NONE, type_context=None), rhs_code)
Пример #14
0
    def map_sum(self, expr):
        from pymbolic.primitives import flattened_sum

        return flattened_sum(self.rec(ch) for ch in expr.children)
Пример #15
0
 def map_sum(self, expr):
     from pymbolic.primitives import flattened_sum
     return flattened_sum(tuple(self.rec(child) for child in expr.children))
Пример #16
0
 def map_polynomial(self, expr, enclosing_prec, *args, **kwargs):
     from pymbolic.primitives import flattened_sum
     return self.rec(flattened_sum(
         [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]),
         enclosing_prec, *args, **kwargs)
Пример #17
0
    def map_product(self, expr):
        # {{{ gather NablaComponents and DerivativeSources

        d_source_nabla_ids_per_child = []

        # id to set((child index, axis), ...)
        nabla_finder = {}

        for child_idx, rec_child in enumerate(expr.children):
            nabla_component_ids = set()
            derivative_source_ids = set()

            nablas = []
            for d_or_n in self.derivative_collector(rec_child):
                if isinstance(d_or_n, prim.NablaComponent):
                    nabla_component_ids.add(d_or_n.nabla_id)
                    nablas.append(d_or_n)
                elif isinstance(d_or_n, prim.DerivativeSource):
                    derivative_source_ids.add(d_or_n.nabla_id)
                else:
                    raise RuntimeError("unexpected result from "
                            "DerivativeSourceAndNablaComponentCollector")

            d_source_nabla_ids_per_child.append(derivative_source_ids)

            for ncomp in nablas:
                nabla_finder.setdefault(
                        ncomp.nabla_id, set()).add((child_idx, ncomp.ambient_axis))

        # }}}

        if nabla_finder and not any(d_source_nabla_ids_per_child):
            raise ValueError("no derivative source found to resolve in '%s'"
                    "--did you forget to wrap the term that should have its "
                    "derivative taken in 'Derivative()(term)'?" % str(expr))

        # a list of lists, the outer level presenting a sum, the inner a product
        result = [list(expr.children)]

        for child_idx, (d_source_nabla_ids, child) in enumerate(
                zip(d_source_nabla_ids_per_child, expr.children)):
            if not d_source_nabla_ids:
                continue

            if len(d_source_nabla_ids) > 1:
                raise NotImplementedError("more than one DerivativeSource per "
                        "child in a product")

            nabla_id, = d_source_nabla_ids
            try:
                nablas = nabla_finder[nabla_id]
            except KeyError:
                continue

            if self.restrict_to_id is not None and nabla_id != self.restrict_to_id:
                continue

            n_axes = max(axis for _, axis in nablas) + 1

            new_result = []
            for prod_term_list in result:
                for axis in range(n_axes):
                    new_ptl = prod_term_list[:]
                    dsfinder = self.derivative_source_finder(nabla_id, self, axis)

                    new_ptl[child_idx] = dsfinder(new_ptl[child_idx])
                    for nabla_child_index, _ in nablas:
                        new_ptl[nabla_child_index] = \
                                self.nabla_component_to_unit_vector(nabla_id, axis)(
                                        new_ptl[nabla_child_index])

                    new_result.append(new_ptl)

            result = new_result

        from pymbolic.primitives import flattened_sum
        return flattened_sum(
                    type(expr)(tuple(
                        self.rec(prod_term) for prod_term in prod_term_list))
                    for prod_term_list in result)
Пример #18
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    assert isinstance(kernel, LoopKernel)
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = {
                name: idx
                for idx, name in enumerate(var_descr.dim_names)
            }
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (vary_by_axes
                and (min(vary_by_axes) < 0
                     or max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum,
                                     flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
                                UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, _val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return var_name in insn.assignee_var_names()

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            yield from expr.children
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-single-assignment" %
                             var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = {
                    cf
                    for cf in my_common_factors
                    if unif_subst_map(cf) in product_parts
                }

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    common_factors = [(ik, cf) for ik, cf in common_factors if cf]

    if not common_factors:
        raise LoopyError("no common factors found")

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(insn.assignee)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        mapped_my_common_factors = {
            unif_subst_map(cf)
            for cf in my_common_factors
        }

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                flattened_product([
                    part for part in iterate_as(Product, term)
                    if part not in mapped_my_common_factors
                ]))

        new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                [unif_subst_map(cf) for cf in my_common_factors] + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Пример #19
0
 def map_polynomial(self, expr, other, *args, **kwargs):
     from pymbolic.primitives import flattened_sum
     return type(expr) == type(other) \
             and self.rec(flattened_sum([coeff * expr.base**exp for exp, coeff in expr.data[::-1]]),
                          flattened_sum([coeff * expr.base**exp for exp, coeff in other.data[::-1]]),
                          *args, **kwargs)
Пример #20
0
 def map_polynomial(self, expr, enclosing_prec, *args, **kwargs):
     from pymbolic.primitives import flattened_sum
     return self.rec(
         flattened_sum(
             [coeff * expr.base**exp for exp, coeff in expr.data[::-1]]),
         enclosing_prec, *args, **kwargs)
Пример #21
0
    def emit_assignment(self, codegen_state, insn):
        kernel = codegen_state.kernel
        ecm = codegen_state.expression_to_code_mapper

        assignee_var_name, = insn.assignee_var_names()

        lhs_var = codegen_state.kernel.get_var_descriptor(assignee_var_name)
        lhs_dtype = lhs_var.dtype

        if insn.atomicity:
            raise NotImplementedError("atomic ops in ISPC")

        from loopy.expression import dtype_to_type_context
        from pymbolic.mapper.stringifier import PREC_NONE

        rhs_type_context = dtype_to_type_context(kernel.target, lhs_dtype)
        rhs_code = ecm(insn.expression, prec=PREC_NONE,
                    type_context=rhs_type_context,
                    needed_dtype=lhs_dtype)

        lhs = insn.assignee

        # {{{ handle streaming stores

        if "!streaming_store" in insn.tags:
            ary = ecm.find_array(lhs)

            from loopy.kernel.array import get_access_info
            from pymbolic import evaluate

            from loopy.symbolic import simplify_using_aff
            index_tuple = tuple(
                    simplify_using_aff(kernel, idx) for idx in lhs.index_tuple)

            access_info = get_access_info(kernel.target, ary, index_tuple,
                    lambda expr: evaluate(expr, codegen_state.var_subst_map),
                    codegen_state.vectorization_info)

            from loopy.kernel.data import ArrayArg, TemporaryVariable

            if not isinstance(ary, (ArrayArg, TemporaryVariable)):
                raise LoopyError("array type not supported in ISPC: %s"
                        % type(ary).__name)

            if len(access_info.subscripts) != 1:
                raise LoopyError("streaming stores must have a subscript")
            subscript, = access_info.subscripts

            from pymbolic.primitives import Sum, flattened_sum, Variable
            if isinstance(subscript, Sum):
                terms = subscript.children
            else:
                terms = (subscript.children,)

            new_terms = []

            from loopy.kernel.data import LocalIndexTag, filter_iname_tags_by_type
            from loopy.symbolic import get_dependencies

            saw_l0 = False
            for term in terms:
                if (isinstance(term, Variable)
                            and kernel.iname_tags_of_type(term.name, LocalIndexTag)):
                    tag, = kernel.iname_tags_of_type(
                        term.name, LocalIndexTag, min_num=1, max_num=1)
                    if tag.axis == 0:
                        if saw_l0:
                            raise LoopyError(
                                "streaming store must have stride 1 in "
                                "local index, got: %s" % subscript)
                        saw_l0 = True
                        continue
                else:
                    for dep in get_dependencies(term):
                        if filter_iname_tags_by_type(
                                kernel.iname_to_tags.get(dep, []), LocalIndexTag):
                            tag, = filter_iname_tags_by_type(
                                kernel.iname_to_tags.get(dep, []), LocalIndexTag, 1)
                            if tag.axis == 0:
                                raise LoopyError(
                                    "streaming store must have stride 1 in "
                                    "local index, got: %s" % subscript)

                    new_terms.append(term)

            if not saw_l0:
                raise LoopyError("streaming store must have stride 1 in "
                        "local index, got: %s" % subscript)

            if access_info.vector_index is not None:
                raise LoopyError("streaming store may not use a short-vector "
                        "data type")

            rhs_has_programindex = any(
                isinstance(tag, LocalIndexTag) and tag.axis == 0
                for tag in kernel.iname_tags(dep)
                for dep in get_dependencies(insn.expression))

            if not rhs_has_programindex:
                rhs_code = "broadcast(%s, 0)" % rhs_code

            from cgen import Statement
            return Statement(
                    "streaming_store(%s + %s, %s)"
                    % (
                        access_info.array_name,
                        ecm(flattened_sum(new_terms), PREC_NONE, 'i'),
                        rhs_code))

        # }}}

        from cgen import Assign
        return Assign(ecm(lhs, prec=PREC_NONE, type_context=None), rhs_code)