예제 #1
0
    def apply(self, sdfg: dace.SDFG) -> None:
        state = sdfg.node(self.state_id)
        left = self.left(sdfg)
        right = self.right(sdfg)

        # Merge source locations
        dinfo = self._merge_source_locations(left, right)

        # merge oir nodes
        res = HorizontalExecutionLibraryNode(
            oir_node=oir.HorizontalExecution(
                body=left.as_oir().body + right.as_oir().body,
                declarations=left.as_oir().declarations +
                right.as_oir().declarations,
            ),
            iteration_space=left.iteration_space,
            debuginfo=dinfo,
        )
        state.add_node(res)

        intermediate_accesses = set(
            n for path in nx.all_simple_paths(state.nx, left, right)
            for n in path[1:-1])

        # rewire edges and connectors to left and delete right
        for edge in state.edges_between(left, right):
            state.remove_edge_and_connectors(edge)
        for acc in intermediate_accesses:
            for edge in state.in_edges(acc):
                if edge.src is not left:
                    rewire_edge(state, edge, dst=res)
                else:
                    state.remove_edge_and_connectors(edge)
            for edge in state.out_edges(acc):
                if edge.dst is not right:
                    rewire_edge(state, edge, src=res)
                else:
                    state.remove_edge_and_connectors(edge)
        for edge in state.in_edges(left):
            rewire_edge(state, edge, dst=res)
        for edge in state.out_edges(right):
            rewire_edge(state, edge, src=res)
        for edge in state.out_edges(left):
            rewire_edge(state, edge, src=res)
        for edge in state.in_edges(right):
            rewire_edge(state, edge, dst=res)
        state.remove_node(left)
        state.remove_node(right)
        for acc in intermediate_accesses:
            if not state.in_edges(acc):
                if not state.out_edges(acc):
                    state.remove_node(acc)
                else:
                    assert (len(state.edges_between(acc, res)) == 1
                            and len(state.out_edges(acc))
                            == 1), "Previously written array now read-only."
                    state.remove_node(acc)
                    res.remove_in_connector("IN_" + acc.label)
            elif not state.out_edges:
                acc.access = dace.AccessType.WriteOnly
예제 #2
0
    def visit_HorizontalExecution(
        self,
        node: oir.HorizontalExecution,
        tmps_to_replace: Set[str],
        symtable: Dict[str, Any],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.HorizontalExecution:
        local_tmps_to_replace = (node.iter_tree().if_isinstance(
            oir.FieldAccess).getattr("name").if_in(tmps_to_replace).to_set())
        tmps_name_map = {
            tmp: new_symbol_name(tmp)
            for tmp in local_tmps_to_replace
        }

        return oir.HorizontalExecution(
            body=self.visit(node.body,
                            tmps_name_map=tmps_name_map,
                            symtable=symtable,
                            **kwargs),
            declarations=node.declarations + [
                oir.LocalScalar(name=tmps_name_map[tmp],
                                dtype=symtable[tmp].dtype,
                                loc=symtable[tmp].loc)
                for tmp in local_tmps_to_replace
            ],
        )
예제 #3
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *,
                           ctx: Context) -> oir.VerticalLoop:
        horiz_execs: List[oir.HorizontalExecution] = []
        for stmt in node.body:
            ctx.reset_local_scalars()
            ret = self.visit(stmt, ctx=ctx)
            stmts = utils.flatten_list(
                [ret] if isinstance(ret, oir.Stmt) else ret)
            horiz_execs.append(
                oir.HorizontalExecution(body=stmts,
                                        declarations=ctx.local_scalars))

        ctx.temp_fields += [
            oir.Temporary(name=temp.name,
                          dtype=temp.dtype,
                          dimensions=temp.dimensions)
            for temp in node.temporaries
        ]

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval),
                    horizontal_executions=horiz_execs,
                    loc=node.loc,
                )
            ],
        )
예제 #4
0
 def visit_ParAssignStmt(
     self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any
 ) -> None:
     body = [oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))]
     if mask is not None:
         body = [oir.MaskStmt(body=body, mask=mask)]
     ctx.add_horizontal_execution(
         oir.HorizontalExecution(
             body=body,
             declarations=[],
         ),
     )
예제 #5
0
 def visit_HorizontalExecution(
     self,
     node: oir.HorizontalExecution,
     *,
     name_map: Dict[str, str],
     fills: List[oir.Stmt],
     flushes: List[oir.Stmt],
     **kwargs: Any,
 ) -> oir.HorizontalExecution:
     return oir.HorizontalExecution(
         body=fills + self.visit(node.body, name_map=name_map, **kwargs) + flushes,
         declarations=node.declarations,
     )
예제 #6
0
 def visit_HorizontalExecution(
     self,
     node: oir.HorizontalExecution,
     local_tmps: Set[str],
     symtable: Dict[str, Any],
     **kwargs: Any,
 ) -> oir.HorizontalExecution:
     declarations = node.declarations + [
         oir.LocalScalar(
             name=tmp, dtype=symtable[tmp].dtype, loc=symtable[tmp].loc)
         for tmp in node.iter_tree().if_isinstance(oir.FieldAccess).getattr(
             "name").if_in(local_tmps).to_set()
     ]
     return oir.HorizontalExecution(
         body=self.visit(node.body, local_tmps=local_tmps, **kwargs),
         mask=self.visit(node.mask, local_tmps=local_tmps, **kwargs),
         declarations=declarations,
     )
예제 #7
0
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary:
    mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True))
    ctx.add_decl(mask_field_decl)

    fill_mask_field = oir.HorizontalExecution(
        body=[
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=mask_field_decl.dtype,
                ),
                right=cond,
            )
        ],
        declarations=[],
    )
    ctx.add_horizontal_execution(fill_mask_field)
    return mask_field_decl
예제 #8
0
 def build(self):
     return oir.HorizontalExecution(
         body=[],
         mask=None,
     )
    def _merge(
        self,
        horizontal_executions: List[oir.HorizontalExecution],
        symtable: Dict[str, Any],
        new_symbol_name: Callable[[str], str],
        protected_fields: Set[str],
    ) -> List[oir.HorizontalExecution]:
        """Recursively merge horizontal executions.

        Uses the following algorithm:
        1. Get output fields of the first horizontal execution.
        2. Check in which following h. execs. the outputs are read.
        3. Duplicate the body of the first h. exec. for each read access (with corresponding offset) and prepend it to the depending h. execs.
        4. Recurse on the resulting h. execs.
        """
        if len(horizontal_executions) <= 1:
            return horizontal_executions
        first, *others = horizontal_executions
        first_accesses = AccessCollector.apply(first)
        other_accesses = AccessCollector.apply(others)

        def first_fields_rewritten_later() -> bool:
            return bool(first_accesses.fields()
                        & other_accesses.write_fields())

        def first_has_large_body() -> bool:
            return len(first.body) > self.max_horizontal_execution_body_size

        def first_writes_protected() -> bool:
            return bool(protected_fields & first_accesses.write_fields())

        def first_has_expensive_function_call() -> bool:
            if self.allow_expensive_function_duplication:
                return False
            nf = common.NativeFunction
            expensive_calls = {
                nf.SIN,
                nf.COS,
                nf.TAN,
                nf.ARCSIN,
                nf.ARCCOS,
                nf.ARCTAN,
                nf.SQRT,
                nf.EXP,
                nf.LOG,
            }
            calls = first.iter_tree().if_isinstance(
                oir.NativeFuncCall).getattr("func")
            return any(call in expensive_calls for call in calls)

        if (first_fields_rewritten_later() or first_writes_protected()
                or first_has_large_body()
                or first_has_expensive_function_call()):
            return [first] + self._merge(others, symtable, new_symbol_name,
                                         protected_fields)

        writes = first_accesses.write_fields()
        others_otf = []
        for horizontal_execution in others:
            read_offsets: Set[Tuple[int, int, int]] = set()
            read_offsets = read_offsets.union(
                *(offsets for field, offsets in AccessCollector.apply(
                    horizontal_execution).read_offsets().items()
                  if field in writes))

            if not read_offsets:
                others_otf.append(horizontal_execution)
                continue

            offset_symbol_map = {(name, o): new_symbol_name(name)
                                 for name in writes for o in read_offsets}

            merged = oir.HorizontalExecution(
                body=self.visit(horizontal_execution.body,
                                offset_symbol_map=offset_symbol_map),
                declarations=horizontal_execution.declarations + [
                    oir.LocalScalar(name=new_name,
                                    dtype=symtable[old_name].dtype)
                    for (old_name, _), new_name in offset_symbol_map.items()
                ] + [
                    d for d in first.declarations
                    if d not in horizontal_execution.declarations
                ],
            )
            for offset in read_offsets:
                merged.body = (self.visit(
                    first.body,
                    shift=offset,
                    offset_symbol_map=offset_symbol_map,
                    symtable=symtable,
                ) + merged.body)
            others_otf.append(merged)

        return self._merge(others_otf, symtable, new_symbol_name,
                           protected_fields)
예제 #10
0
 def visit_HorizontalExecution(self, node: oir.HorizontalExecution) -> oir.HorizontalExecution:
     return oir.HorizontalExecution(
         body=self._merge(node.body),
         declarations=node.declarations,
         loc=node.loc,
     )
예제 #11
0
    def visit_VerticalLoopSection(
        self,
        node: oir.VerticalLoopSection,
        *,
        block_extents: Dict[int, Extent],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoopSection:
        horizontal_executions = [node.horizontal_executions[0]]
        new_block_extents = [block_extents[id(horizontal_executions[-1])]]

        for this_hexec in node.horizontal_executions[1:]:
            last_extent = new_block_extents[-1]

            last_writes = AccessCollector.apply(
                horizontal_executions[-1]).write_fields()
            this_offset_reads = {
                name
                for name, offsets in AccessCollector.apply(
                    this_hexec).read_offsets().items()
                if any(off[0] != 0 or off[1] != 0 for off in offsets)
            }

            reads_with_offset_after_write = last_writes & this_offset_reads
            this_extent = block_extents[id(this_hexec)]

            if reads_with_offset_after_write or last_extent != this_extent:
                # Cannot merge: simply append to list
                horizontal_executions.append(this_hexec)
                new_block_extents.append(this_extent)
            else:
                # Merge
                duplicated_locals = {
                    decl.name
                    for decl in horizontal_executions[-1].declarations
                } & {decl.name
                     for decl in this_hexec.declarations}
                # Map from old to new scalar names applied to the second horizontal execution
                scalar_map = {
                    name: new_symbol_name(name)
                    for name in duplicated_locals
                }
                locals_symtable = {
                    decl.name: decl
                    for decl in this_hexec.declarations
                }

                new_body = self.visit(this_hexec.body,
                                      scalar_map=scalar_map,
                                      **kwargs)

                this_not_duplicated = [
                    decl for decl in this_hexec.declarations
                    if decl.name not in duplicated_locals
                ]
                this_mapped = [
                    oir.ScalarDecl(name=scalar_map[name],
                                   dtype=locals_symtable[name].dtype)
                    for name in duplicated_locals
                ]

                horizontal_executions[-1] = oir.HorizontalExecution(
                    body=horizontal_executions[-1].body + new_body,
                    declarations=(horizontal_executions[-1].declarations +
                                  this_not_duplicated + this_mapped),
                )

        return oir.VerticalLoopSection(
            interval=node.interval,
            horizontal_executions=horizontal_executions)
예제 #12
0
 def to_oir(self) -> oir.HorizontalExecution:
     return oir.HorizontalExecution(body=self.body,
                                    declarations=self.declarations,
                                    loc=self.loc)