Esempio n. 1
0
    def apply(self, sdfg: SDFG):
        nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg)
        state = sdfg.node(self.state_id)

        new_state = sdfg.add_state_before(state)
        isedge = sdfg.edges_between(new_state, state)[0]

        # Find relevant symbol mapping
        mapping: Dict[str, str] = {}
        mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()})
        mapping.update({
            k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data
            for k in nsdfg.in_connectors
        })

        nisedge = nsdfg.sdfg.edges()[0]
        # Safe replacement of edge contents
        for k, v in mapping.items():
            nisedge.data.replace(k, '__dacesym_' + k, replace_keys=False)
        for k, v in mapping.items():
            nisedge.data.replace('__dacesym_' + k, v, replace_keys=False)

        for akey, aval in nisedge.data.assignments.items():
            # Map assignment to outer edge
            if akey not in sdfg.symbols and akey not in sdfg.arrays:
                newname = akey
            else:
                newname = nsdfg.label + '_' + akey

            isedge.data.assignments[newname] = aval

            # Add symbol to outer SDFG
            sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey])

            # Add symbol mapping to nested SDFG
            nsdfg.symbol_mapping[akey] = newname

        isedge.data.condition = nisedge.data.condition

        # Clean nested SDFG
        nsdfg.sdfg.remove_node(nisedge.src)
Esempio n. 2
0
    def apply(self, sdfg):
        fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]]
        sstate = sdfg.nodes()[self.subgraph[
            SymbolAliasPromotion._second_state]]

        edge = sdfg.edges_between(fstate, sstate)[0].data
        in_edge = sdfg.in_edges(fstate)[0].data

        to_consider = _alias_assignments(sdfg, edge)

        to_not_consider = set()
        for k, v in to_consider.items():
            # Remove symbols that are taking part in the edge's condition
            condsyms = [str(s) for s in edge.condition_sympy().free_symbols]
            if k in condsyms:
                to_not_consider.add(k)
            # Remove symbols that are set in the in_edge
            # with a different assignment
            if k in in_edge.assignments and in_edge.assignments[k] != v:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a symbol
            # and is set in the in_edge.
            if v in sdfg.symbols and v in in_edge.assignments:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a scalar
            # and is set in the first state.
            if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar):
                if any(
                        isinstance(n, nodes.AccessNode) and n.data == v
                        for n in fstate.nodes()):
                    to_not_consider.add(k)

        for k in to_not_consider:
            del to_consider[k]

        for k, v in to_consider.items():
            del edge.assignments[k]
            in_edge.assignments[k] = v
Esempio n. 3
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Esempio n. 4
0
    def apply(self, sdfg: SDFG):
        nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg)
        state = sdfg.node(self.state_id)

        new_state = sdfg.add_state_before(state)
        isedge = sdfg.edges_between(new_state, state)[0]

        # Find relevant symbol and data descriptor mapping
        mapping: Dict[str, str] = {}
        mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()})
        mapping.update({
            k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data
            for k in nsdfg.in_connectors
        })
        mapping.update({
            k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data
            for k in nsdfg.out_connectors
        })

        # Get internal state and interstate edge
        source_state = nsdfg.sdfg.start_state
        nisedge = nsdfg.sdfg.out_edges(source_state)[0]

        # Add state contents (nodes)
        new_state.add_nodes_from(source_state.nodes())

        # Replace data descriptors and symbols on state graph
        for node in source_state.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in mapping:
                node.data = mapping[node.data]
        for edge in source_state.edges():
            edge.data.replace(mapping)
            if edge.data.data in mapping:
                edge.data.data = mapping[edge.data.data]

        # Add state contents (edges)
        for edge in source_state.edges():
            new_state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                               edge.data)

        # Safe replacement of edge contents
        def replfunc(m):
            for k, v in mapping.items():
                nisedge.data.replace(k, v, replace_keys=False)
        symbolic.safe_replace(mapping, replfunc)

        # Add interstate edge
        for akey, aval in nisedge.data.assignments.items():
            # Map assignment to outer edge
            if akey not in sdfg.symbols and akey not in sdfg.arrays:
                newname = akey
            else:
                newname = nsdfg.label + '_' + akey

            isedge.data.assignments[newname] = aval

            # Add symbol to outer SDFG
            sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey])

            # Add symbol mapping to nested SDFG
            nsdfg.symbol_mapping[akey] = newname

        isedge.data.condition = nisedge.data.condition

        # Clean nested SDFG
        nsdfg.sdfg.remove_node(source_state)

        # Set new starting state
        nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)
Esempio n. 5
0
 def apply(self, _, sdfg: SDFG):
     a: SDFGState = self.state_a
     b: SDFGState = self.state_b
     edge = sdfg.edges_between(a, b)[0]
     edge.data.condition = CodeBlock("1")
Esempio n. 6
0
 def apply(self, _, sdfg: SDFG):
     a: SDFGState = self.state_a
     b: SDFGState = self.state_b
     edge = sdfg.edges_between(a, b)[0]
     sdfg.remove_edge(edge)
Esempio n. 7
0
    def apply(self, sdfg):
        if isinstance(self.subgraph[StateFusion.first_state], SDFGState):
            first_state: SDFGState = self.subgraph[StateFusion.first_state]
            second_state: SDFGState = self.subgraph[StateFusion.second_state]
        else:
            first_state: SDFGState = sdfg.node(
                self.subgraph[StateFusion.first_state])
            second_state: SDFGState = sdfg.node(
                self.subgraph[StateFusion.second_state])

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            if sdfg.start_state == first_state:
                sdfg.start_state = sdfg.node_id(second_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            if sdfg.start_state == second_state:
                sdfg.start_state = sdfg.node_id(first_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        top2 = top_level_nodes(second_state)

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.data == node.data), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        sdict = first_state.scope_dict()
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode) and sdict[x] is None
        ]
        for node in second_state.nodes():
            if isinstance(node, nodes.NestedSDFG):
                # update parent information
                node.sdfg.parent = first_state
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        top = top_level_nodes(first_state)

        # Merge common (data) nodes
        for node in second_input:

            # merge only top level nodes, skip everything else
            if node not in top2:
                continue

            if first_state.in_degree(node) == 0:
                candidates = [
                    x for x in order if x.data == node.data and x in top
                ]
                if len(candidates) == 0:
                    continue
                elif len(candidates) == 1:
                    n = candidates[0]
                else:
                    # Choose first candidate that intersects memlets
                    for cand in candidates:
                        if StateFusion.memlets_intersect(
                                first_state, [cand], False, second_state,
                            [node], True):
                            n = cand
                            break
                    else:
                        # No node intersects, use topologically-last node
                        n = candidates[0]

                sdutil.change_edge_src(first_state, node, n)
                first_state.remove_node(node)
                n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if sdfg.start_state == second_state:
            sdfg.start_state = sdfg.node_id(first_state)
Esempio n. 8
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1