Beispiel #1
0
    def as_oir(self):
        loc = SourceLocation(
            self.debuginfo.start_line or 1,
            self.debuginfo.start_column or 1,
            self.debuginfo.filename or "<unknown>",
        )
        sections = []
        for interval, sdfg in self.sections:
            horizontal_executions = []
            for state in sdfg.topological_sort(sdfg.start_state):

                for node in (n for n in nx.topological_sort(state.nx)
                             if isinstance(n, HorizontalExecutionLibraryNode)):
                    horizontal_executions.append(
                        OIRFieldRenamer(get_node_name_mapping(
                            state, node)).visit(node.as_oir()))
            sections.append(
                VerticalLoopSection(
                    interval=interval,
                    horizontal_executions=horizontal_executions,
                    loc=loc))

        return VerticalLoop(sections=sections,
                            loop_order=self.loop_order,
                            caches=self.caches,
                            loc=loc)
Beispiel #2
0
    def _split_entry_level(
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        new_symbol_name: Callable[[str], str],
    ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]:
        """Split the entry level of a loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.

        Returns:
            Two loop sections.
        """
        assert loop_order in (common.LoopOrder.FORWARD,
                              common.LoopOrder.BACKWARD)
        if loop_order == common.LoopOrder.FORWARD:
            bound = common.AxisBound(level=section.interval.start.level,
                                     offset=section.interval.start.offset + 1)
            entry_interval = oir.Interval(start=section.interval.start,
                                          end=bound)
            rest_interval = oir.Interval(start=bound, end=section.interval.end)
        else:
            bound = common.AxisBound(level=section.interval.end.level,
                                     offset=section.interval.end.offset - 1)
            entry_interval = oir.Interval(start=bound,
                                          end=section.interval.end)
            rest_interval = oir.Interval(start=section.interval.start,
                                         end=bound)
        decls = list(section.iter_tree().if_isinstance(oir.Decl))
        decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls}

        class FixSymbolNameClashes(NodeTranslator):
            def visit_ScalarAccess(self,
                                   node: oir.ScalarAccess) -> oir.ScalarAccess:
                if node.name not in decls_map:
                    return node
                return oir.ScalarAccess(name=decls_map[node.name],
                                        dtype=node.dtype)

            def visit_LocalScalar(self,
                                  node: oir.LocalScalar) -> oir.LocalScalar:
                return oir.LocalScalar(name=decls_map[node.name],
                                       dtype=node.dtype)

        return (
            oir.VerticalLoopSection(
                interval=entry_interval,
                horizontal_executions=FixSymbolNameClashes().visit(
                    section.horizontal_executions),
                loc=section.loc,
            ),
            oir.VerticalLoopSection(
                interval=rest_interval,
                horizontal_executions=section.horizontal_executions,
                loc=section.loc,
            ),
        )
Beispiel #3
0
    def as_oir(self):

        sections = []
        for interval, sdfg in self.sections:
            horizontal_executions = []
            for state in sdfg.topological_sort(sdfg.start_state):

                for node in (n for n in nx.topological_sort(state.nx)
                             if isinstance(n, HorizontalExecutionLibraryNode)):
                    horizontal_executions.append(
                        OIRFieldRenamer(get_node_name_mapping(
                            state, node)).visit(node.as_oir()))
            sections.append(
                VerticalLoopSection(
                    interval=interval,
                    horizontal_executions=horizontal_executions))

        return VerticalLoop(
            sections=sections,
            loop_order=self.loop_order,
            caches=self.caches,
        )
Beispiel #4
0
 def build(self) -> VerticalLoopSection:
     return VerticalLoopSection(
         interval=self._interval, horizontal_executions=self._horizontal_executions
     )