Exemplo n.º 1
0
    def apply(self, sdfg: dace.SDFG) -> None:
        state = sdfg.node(self.state_id)
        left = self.left(sdfg)
        right = self.right(sdfg)

        # Merge source locations
        dinfo = self._merge_source_locations(left, right)

        # merge oir nodes
        res = HorizontalExecutionLibraryNode(
            oir_node=oir.HorizontalExecution(
                body=left.as_oir().body + right.as_oir().body,
                declarations=left.as_oir().declarations +
                right.as_oir().declarations,
            ),
            iteration_space=left.iteration_space,
            debuginfo=dinfo,
        )
        state.add_node(res)

        intermediate_accesses = set(
            n for path in nx.all_simple_paths(state.nx, left, right)
            for n in path[1:-1])

        # rewire edges and connectors to left and delete right
        for edge in state.edges_between(left, right):
            state.remove_edge_and_connectors(edge)
        for acc in intermediate_accesses:
            for edge in state.in_edges(acc):
                if edge.src is not left:
                    rewire_edge(state, edge, dst=res)
                else:
                    state.remove_edge_and_connectors(edge)
            for edge in state.out_edges(acc):
                if edge.dst is not right:
                    rewire_edge(state, edge, src=res)
                else:
                    state.remove_edge_and_connectors(edge)
        for edge in state.in_edges(left):
            rewire_edge(state, edge, dst=res)
        for edge in state.out_edges(right):
            rewire_edge(state, edge, src=res)
        for edge in state.out_edges(left):
            rewire_edge(state, edge, src=res)
        for edge in state.in_edges(right):
            rewire_edge(state, edge, dst=res)
        state.remove_node(left)
        state.remove_node(right)
        for acc in intermediate_accesses:
            if not state.in_edges(acc):
                if not state.out_edges(acc):
                    state.remove_node(acc)
                else:
                    assert (len(state.edges_between(acc, res)) == 1
                            and len(state.out_edges(acc))
                            == 1), "Previously written array now read-only."
                    state.remove_node(acc)
                    res.remove_in_connector("IN_" + acc.label)
            elif not state.out_edges:
                acc.access = dace.AccessType.WriteOnly
Exemplo n.º 2
0
def offsets_match(left: HorizontalExecutionLibraryNode,
                  right: HorizontalExecutionLibraryNode) -> bool:
    left_accesses = AccessCollector.apply(left.as_oir())
    right_accesses = AccessCollector.apply(right.as_oir())
    conflicting = read_after_write_conflicts(
        ij_offsets(left_accesses.write_offsets()),
        ij_offsets(
            right_accesses.read_offsets())) | write_after_read_conflicts(
                ij_offsets(left_accesses.read_offsets()),
                ij_offsets(right_accesses.write_offsets()))
    return not conflicting
Exemplo n.º 3
0
 def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                               iteration_spaces, **kwargs):
     return HorizontalExecutionLibraryNode(
         name=f"HorizontalExecution_{id(node)}",
         oir_node=node,
         iteration_space=iteration_spaces[id(node)],
     )
Exemplo n.º 4
0
def get_vertical_loop_section_sdfg(section: "VerticalLoopSection") -> SDFG:
    from gtc.dace.nodes import HorizontalExecutionLibraryNode

    sdfg = SDFG("VerticalLoopSection_" + str(id(section)))
    old_state = sdfg.add_state("start_state", is_start_state=True)
    for he in section.horizontal_executions:
        new_state = sdfg.add_state("HorizontalExecution_" + str(id(he)) + "_state")
        sdfg.add_edge(old_state, new_state, InterstateEdge())
        new_state.add_node(HorizontalExecutionLibraryNode(oir_node=he))

        old_state = new_state
    return sdfg
Exemplo n.º 5
0
 def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, block_extents, **kwargs):
     return HorizontalExecutionLibraryNode(
         name=f"HorizontalExecution_{id(node)}",
         oir_node=node,
         extent=block_extents[id(node)],
     )