Beispiel #1
0
    def expressions():

        # Case 1: Loop with one state
        sdfg = sd.SDFG('_')
        sdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                      sd.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                      sd.InterstateEdge())
        sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard,
                      sd.InterstateEdge())

        # Case 2: Loop with multiple states (no back-edge from state)
        msdfg = sd.SDFG('_')
        msdfg.add_nodes_from([
            DetectLoop._loop_guard, DetectLoop._loop_begin,
            DetectLoop._exit_state
        ])
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin,
                       sd.InterstateEdge())
        msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state,
                       sd.InterstateEdge())

        return [sdfg, msdfg]
Beispiel #2
0
class InlineSDFG(transformation.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG.

        In particular, the steps taken are:

        1. All transient arrays become transients of the parent
        2. If a source/sink node is one of the inputs/outputs:
          a. Remove it
          b. Reconnect through external edges (map/accessnode)
          c. Replace and reoffset memlets with external data descriptor
        3. If other nodes carry the names of inputs/outputs:
          a. Replace data with external data descriptor
          b. Replace and reoffset memlets with external data descriptor
        4. If source/sink node is not connected to a source/destination, and
           the nested SDFG is in a scope, connect to scope with empty memlets
        5. Remove all unused external inputs/output memlet paths
        6. Remove isolated nodes resulting from previous step

    """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), {}, {})

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def _check_strides(inner_strides: List[symbolic.SymbolicType],
                       outer_strides: List[symbolic.SymbolicType],
                       memlet: Memlet, nested_sdfg: nodes.NestedSDFG) -> bool:
        """
        Returns True if the strides of the inner array can be matched
        to the strides of the outer array upon inlining. Takes into
        consideration memlet (un)squeeze and nested SDFG symbol mapping.
        :param inner_strides: The strides of the array inside the nested SDFG.
        :param outer_strides: The strides of the array in the external SDFG.
        :param nested_sdfg: Nested SDFG node with symbol mapping.
        :return: True if all strides match, False otherwise.
        """
        # Take unsqueezing into account
        dims_to_ignore = [
            i for i, s in enumerate(memlet.subset.size()) if s == 1
        ]
        ostrides = [
            os for i, os in enumerate(outer_strides) if i not in dims_to_ignore
        ]
        if len(ostrides) == 0:
            ostrides = [1]
        if len(ostrides) != len(inner_strides):
            return False

        # Replace all inner symbols based on symbol mapping
        repldict = {
            symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v)
            for k, v in nested_sdfg.symbol_mapping.items()
        }
        istrides = [
            istr.subs(repldict) if symbolic.issymbolic(istr) else istr
            for istr in inner_strides
        ]

        return all(istr == ostr for istr, ostr in zip(istrides, ostrides))

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if nested_sdfg.no_inline:
            return False
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        # Ensure every connector has one incoming/outgoing edge
        in_connectors = set()
        out_connectors = set()
        for edge in graph.in_edges(nested_sdfg):
            if edge.dst_conn in in_connectors:
                return False
            in_connectors.add(edge.dst_conn)
        for edge in graph.out_edges(nested_sdfg):
            if edge.src_conn in out_connectors:
                return False
            out_connectors.add(edge.src_conn)

        # Ensure output connectors have no additional outputs (if in a scope),
        # and ensure no two connectors are directly connected to each other
        if graph.entry_node(nested_sdfg) is not None:
            all_connectors = in_connectors | out_connectors
            nstate = nested_sdfg.sdfg.node(0)
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    if (node.data in out_connectors
                            and nstate.out_degree(node) > 0
                            and (node.data not in in_connectors
                                 or nstate.in_degree(node) > 0)):
                        return False
                    if (node.data in in_connectors
                            and any(e.dst.data in all_connectors
                                    for e in nstate.out_edges(node)
                                    if isinstance(e.dst, nodes.AccessNode))):
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _remove_edge_path(self,
                          state: SDFGState,
                          edge_map: Dict[str, MultiConnectorEdge],
                          unused: Set[str],
                          reverse: bool = False) -> List[MultiConnectorEdge]:
        """ Remove all edges along a path, until memlet tree contains siblings
            that should not be removed. Removes resulting isolated nodes as
            well. Operates in place.
            :param state: The state in which to remove edges.
            :param edge_map: Mapping from identifier to edge, used as a
                             predicate for removal.
            :param unused: Set of edge identifiers to remove.
            :param reverse: If False, removes forward in path, otherwise
                            backward.
            :return: List of edges from removed nodes at the path's end.
        """

        if reverse:
            edge_func = lambda e: state.out_edges(e.src)
            edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn
        else:
            edge_func = lambda e: state.in_edges(e.dst)
            edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn

        result = []

        for identifier, edge in edge_map.items():
            if identifier in unused:
                path = state.memlet_path(edge)
                pedge = None
                for pedge in (reversed(path) if reverse else path):
                    # If there are no other edges, it is safe to remove
                    if len([
                            e for e in edge_func(pedge) if edge_pred(pedge, e)
                    ]) == 1:
                        # Remove connectors as well
                        state.remove_edge_and_connectors(pedge)
                    else:
                        break
                else:  # Reached terminus without breaking, remove external node
                    if pedge is not None:
                        node = pedge.src if reverse else pedge.dst

                        # Keep track of edges on the other end of these nodes,
                        # they will be used to reconnect to first/last
                        # occurrence of access nodes in the inlined subgraph.
                        if reverse:
                            result.extend(state.in_edges(node))
                        else:
                            result.extend(state.out_edges(node))

                        state.remove_node(node)

        return result

    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
                continue
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):
                reshapes.add(aname)

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                                                 datadesc,
                                                 find_new_name=True)
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
            else:
                newname = dname
            newname, _ = sdfg.add_view(newname,
                                       desc.shape,
                                       desc.dtype,
                                       storage=desc.storage,
                                       strides=desc.strides,
                                       offset=desc.offset,
                                       debuginfo=desc.debuginfo,
                                       allow_conflicts=desc.allow_conflicts,
                                       total_size=desc.total_size,
                                       alignment=desc.alignment,
                                       may_alias=desc.may_alias,
                                       find_new_name=True)
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.remove_edge(e)
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,
                                        e.data)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
                                  True)
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
                                  False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)

    def _modify_memlet_path(self, new_edges: Dict[nodes.Node,
                                                  MultiConnectorEdge],
                            nstate: SDFGState, state: SDFGState,
                            inputs: bool) -> Set[MultiConnectorEdge]:
        """ Modifies memlet paths in an inlined SDFG. Returns set of modified
            edges.
        """
        result = set()
        for node, top_edge in new_edges.items():
            inner_edges = (nstate.out_edges(node)
                           if inputs else nstate.in_edges(node))
            for inner_edge in inner_edges:
                new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
                                                      top_edge.data)
                if inputs:
                    new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
                                              inner_edge.dst,
                                              inner_edge.dst_conn, new_memlet)
                    mtree = state.memlet_tree(new_edge)
                else:
                    new_edge = state.add_edge(inner_edge.src,
                                              inner_edge.src_conn,
                                              top_edge.dst, top_edge.dst_conn,
                                              new_memlet)
                    mtree = state.memlet_tree(new_edge)

                # Modify all memlets going forward/backward
                def traverse(mtree_node):
                    result.add(mtree_node.edge)
                    mtree_node.edge._data = helpers.unsqueeze_memlet(
                        mtree_node.edge.data, top_edge.data)
                    for child in mtree_node.children:
                        traverse(child)

                for child in mtree.children:
                    traverse(child)

        return result

    def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str,
                                                                      str],
                             new_edges: Dict[str, MultiConnectorEdge],
                             nstate: SDFGState, state: SDFGState,
                             inputs: bool):
        anodes = nstate.source_nodes() if inputs else nstate.sink_nodes()
        reshp = {repldict[r]: r for r in reshapes}
        for node in anodes:
            if not isinstance(node, nodes.AccessNode):
                continue
            if node.data not in reshp:
                continue
            edge = new_edges[reshp[node.data]]
            if inputs:
                state.add_edge(edge.src, edge.src_conn, node, None, edge.data)
            else:
                state.add_edge(node, None, edge.dst, edge.dst_conn, edge.data)
Beispiel #3
0
class InlineSDFG(pattern_matching.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG.

        In particular, the steps taken are:

        1. All transient arrays become transients of the parent
        2. If a source/sink node is one of the inputs/outputs:
          a. Remove it
          b. Reconnect through external edges (map/accessnode)
          c. Replace and reoffset memlets with external data descriptor
        3. If other nodes carry the names of inputs/outputs:
          a. Replace data with external data descriptor
          b. Replace and reoffset memlets with external data descriptor
        4. If source/sink node is not connected to a source/destination, and
           the nested SDFG is in a scope, connect to scope with empty memlets
        5. Remove all unused external inputs/output memlet paths
        6. Remove isolated nodes resulting from previous step

    """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set())

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        # Matches anything
        return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def _find_edge(state: SDFGState, node: nodes.Node,
                   connector: str) -> Optional[MultiConnectorEdge]:
        for edge in state.in_edges(node):
            if edge.dst_conn == connector:
                return edge
        for edge in state.out_edges(node):
            if edge.src_conn == connector:
                return edge
        raise NameError('Edge with connector %s not found on node %s' %
                        (connector, node))

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        # Ensure every connector has one incoming/outgoing edge
        in_connectors = set()
        out_connectors = set()
        for edge in graph.in_edges(nested_sdfg):
            if edge.dst_conn in in_connectors:
                return False
            in_connectors.add(edge.dst_conn)
        for edge in graph.out_edges(nested_sdfg):
            if edge.src_conn in out_connectors:
                return False
            out_connectors.add(edge.src_conn)

        # Ensure output connectors have no additional outputs (if in a scope),
        # and ensure no two connectors are directly connected to each other
        if graph.entry_node(nested_sdfg) is not None:
            all_connectors = in_connectors | out_connectors
            nstate = nested_sdfg.sdfg.node(0)
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    if (node.data in out_connectors
                            and nstate.out_degree(node) > 0
                            and (node.data not in in_connectors
                                 or nstate.in_degree(node) > 0)):
                        return False
                    if (node.data in in_connectors
                            and any(e.dst.data in all_connectors
                                    for e in nstate.out_edges(node)
                                    if isinstance(e.dst, nodes.AccessNode))):
                        return False

        # If some reshaping that cannot be inlined / unsqueezed is happening,
        # do not match transformation in strict mode.
        if strict:
            for aname, array in nested_sdfg.sdfg.arrays.items():
                if array.transient:
                    continue
                edge = InlineSDFG._find_edge(graph, nested_sdfg, aname)
                if len(array.shape) > len(edge.data.subset):
                    return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _remove_edge_path(self,
                          state: SDFGState,
                          edge_map: Dict[str, MultiConnectorEdge],
                          unused: Set[str],
                          reverse: bool = False) -> List[MultiConnectorEdge]:
        """ Remove all edges along a path, until memlet tree contains siblings
            that should not be removed. Removes resulting isolated nodes as
            well. Operates in place.
            :param state: The state in which to remove edges.
            :param edge_map: Mapping from identifier to edge, used as a
                             predicate for removal.
            :param unused: Set of edge identifiers to remove.
            :param reverse: If False, removes forward in path, otherwise
                            backward.
            :return: List of edges from removed nodes at the path's end.
        """

        if reverse:
            edge_func = lambda e: state.out_edges(e.src)
            edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn
        else:
            edge_func = lambda e: state.in_edges(e.dst)
            edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn

        result = []

        for identifier, edge in edge_map.items():
            if identifier in unused:
                path = state.memlet_path(edge)
                pedge = None
                for pedge in (reversed(path) if reverse else path):
                    # If there are no other edges, it is safe to remove
                    if len([
                            e for e in edge_func(pedge) if edge_pred(pedge, e)
                    ]) == 1:
                        # Remove connectors as well
                        state.remove_edge_and_connectors(pedge)
                    else:
                        break
                else:  # Reached terminus without breaking, remove external node
                    if pedge is not None:
                        node = pedge.src if reverse else pedge.dst

                        # Keep track of edges on the other end of these nodes,
                        # they will be used to reconnect to first/last
                        # occurrence of access nodes in the inlined subgraph.
                        if reverse:
                            result.extend(state.in_edges(node))
                        else:
                            result.extend(state.out_edges(node))

                        state.remove_node(node)

        return result

    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)

    def _modify_memlet_path(self, new_edges: Dict[nodes.Node,
                                                  MultiConnectorEdge],
                            nstate: SDFGState, state: SDFGState,
                            inputs: bool) -> Set[MultiConnectorEdge]:
        """ Modifies memlet paths in an inlined SDFG. Returns set of modified
            edges.
        """
        result = set()
        for node, top_edge in new_edges.items():
            inner_edges = (nstate.out_edges(node)
                           if inputs else nstate.in_edges(node))
            for inner_edge in inner_edges:
                new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
                                                      top_edge.data)
                if inputs:
                    new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
                                              inner_edge.dst,
                                              inner_edge.dst_conn, new_memlet)
                    mtree = state.memlet_tree(new_edge)
                else:
                    new_edge = state.add_edge(inner_edge.src,
                                              inner_edge.src_conn,
                                              top_edge.dst, top_edge.dst_conn,
                                              new_memlet)
                    mtree = state.memlet_tree(new_edge)

                # Modify all memlets going forward/backward
                def traverse(mtree_node):
                    result.add(mtree_node.edge)
                    mtree_node.edge._data = helpers.unsqueeze_memlet(
                        mtree_node.edge.data, top_edge.data)
                    for child in mtree_node.children:
                        traverse(child)

                for child in mtree.children:
                    traverse(child)

        return result
Beispiel #4
0
 def expressions():
     # Matches anything
     return [sd.SDFG('_')]
Beispiel #5
0
 def expressions():
     return [sd.SDFG('_')]
Beispiel #6
0
class InlineSDFG(pattern_matching.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set())

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        # Matches anything
        return [nxutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _modify_memlet(self, internal_memlet: Memlet, external_memlet: Memlet):
        """ Unsqueezes and offsets a memlet, as per the semantics of nested
            SDFGs.
            :param internal_memlet: The internal memlet (inside nested SDFG)
                                    before modification.
            :param internal_memlet: The external memlet before modification.
            :return: Offset Memlet to set on the resulting graph.
        """
        result = dc(internal_memlet)
        result.data = external_memlet.data

        shape = external_memlet.subset.size()
        if len(internal_memlet.subset) < len(external_memlet.subset):
            ones = [i for i, d in enumerate(shape) if d == 1]

            # Special case: If internal memlet is a range of size 1 with (0,0,1),
            #               ignore it when unsqueezing
            if (len(internal_memlet.subset) == 1
                    and (internal_memlet.subset[0] == (0, 0, 1)
                         or internal_memlet.subset[0] == 0)):
                to_unsqueeze = ones[1:]
            else:
                to_unsqueeze = ones

            result.subset.unsqueeze(to_unsqueeze)
        elif len(internal_memlet.subset) > len(external_memlet.subset):
            raise ValueError(
                'Unexpected extra dimensions in internal memlet '
                'while inlining SDFG.\nExternal memlet: %s\n'
                'Internal memlet: %s' % (external_memlet, internal_memlet))

        result.subset.offset(external_memlet.subset, False)

        # TODO: Offset rest of memlet according to other_subset
        if external_memlet.other_subset is not None:
            raise NotImplementedError

        return result

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        nsdfg_node = graph.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg = nsdfg_node.sdfg

        # Find original source/destination nodes
        inputs = {}
        outputs = {}
        for e in graph.in_edges(nsdfg_node):
            inputs[e.dst_conn] = (e.src, e.src_conn, e.data)
        for e in graph.out_edges(nsdfg_node):
            outputs[e.src_conn] = (e.dst, e.dst_conn, e.data)

        torename = {}
        torename.update({k: v[2].data for k, v in inputs.items()})
        torename.update({k: v[2].data for k, v in outputs.items()})
        entry_connectors = set()

        # Add SDFG nodes to top-level SDFG
        state = nsdfg.nodes()[0]
        for node in state.nodes():
            # Data access nodes
            if isinstance(node, nodes.AccessNode):
                # External node
                if node.data in inputs or node.data in outputs:
                    for _, _, dst, dst_conn, _ in state.out_edges(node):
                        # Custom entry connector case
                        if (isinstance(dst, nodes.EntryNode)
                                and dst_conn[0:3] != 'IN_'):
                            entry_connectors.add(node.data)
                            sdfg.arrays[node.data] = nsdfg.arrays[node.data]
                            sdfg.arrays[node.data].transient = True
                            graph.add_node(node)
                            torename.pop(node.data)
                            break
                    continue
                # Internal node (e.g., transient)
                if node.data not in torename:
                    name = node.data
                    # Name already exists
                    if name in sdfg.arrays:
                        name = '%s_%s' % (nsdfg.label, node.data)
                        i = 0
                        while name in sdfg.arrays:
                            name = '%s_%s_%d' % (nsdfg.label, node.data, i)
                            i += 1
                    # Add transient
                    sdfg.arrays[name] = nsdfg.arrays[node.data]
                    # Rename all internal uses
                    torename[node.data] = name
            # Set all parents of nested SDFG nodes in the inlined SDFG to their
            # new parent
            elif isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = graph
                node.sdfg.parent_sdfg = sdfg

            graph.add_node(node)

        # TODO: Confirm that the following is always correct
        # Add Scalars of the nested SDFG to the parent
        for name, arr in nsdfg.arrays.items():
            if isinstance(arr, dt.Scalar) and name not in sdfg.arrays:
                sdfg.arrays[name] = arr

        # Reconnect edges to their original source
        for e in state.edges():
            if isinstance(e.src, nodes.AccessNode) and e.src.data in inputs:
                cnode, cconn, cmemlet = inputs[e.src.data]
                if e.src.data in entry_connectors:
                    graph.add_edge(cnode, cconn, e.src, None, cmemlet)
                    graph.add_edge(e.src, None, e.dst, e.dst_conn, e.data)
                else:
                    # Connect to source node instead
                    newmemlet = self._modify_memlet(e.data, cmemlet)
                    graph.add_edge(cnode, cconn, e.dst, e.dst_conn, newmemlet)
            elif isinstance(e.dst, nodes.AccessNode) and e.dst.data in outputs:
                cnode, cconn, cmemlet = outputs[e.dst.data]
                newmemlet = self._modify_memlet(e.data, cmemlet)
                if state.out_edges(e.dst):
                    graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn,
                                   newmemlet)
                    e._src = e._dst
                    e._src_conn = e._dst_conn
                    # Remove wcr
                    newmemlet = dc(newmemlet)
                    newmemlet.wcr = None
                    newmemlet.other_subset = dc(newmemlet.subset)
                    for _, _, dst, _, memlet in graph.out_edges(cnode):
                        if isinstance(dst, nodes.AccessNode
                                      ) and memlet.data == cmemlet.data:
                            memlet.wcr = None
                    # # Remove output node
                    # out_conn = 'OUT_{}'.format(cconn[3:])
                    # for _, conn, dst, _, _ in graph.out_edges(cnode):
                    #     if conn == out_conn:
                    #         graph.remove_node(dst)
                    # # Remove connectors
                    # in_connectors = dc(cnode.in_connectors)
                    # in_connectors.remove(cconn)
                    # cnode.in_connectors = in_connectors
                    # out_connectors = dc(cnode.out_connectors)
                    # out_connectors.remove(out_conn)
                    # cnode.out_connectors = out_connectors
                # else:
                # Connect to destination node instead
                graph.add_edge(e.src, e.src_conn, cnode, cconn, newmemlet)
            elif e.data.data in torename:
                if e.data.data in inputs:
                    newmemlet = self._modify_memlet(e.data,
                                                    inputs[e.data.data][2])
                elif e.data.data in outputs:
                    newmemlet = self._modify_memlet(e.data,
                                                    outputs[e.data.data][2])
                else:
                    # Rename data
                    cdata = torename[e.data.data]
                    newmemlet = dc(e.data)
                    newmemlet.data = cdata

                graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, newmemlet)
            else:
                # Do nothing
                graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        # Rename all access nodes
        for node in state.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in torename:
                node.data = torename[node.data]

        # If an empty memlet was connected to the nested SDFG, reconnect
        # all source nodes with empty memlets
        if None in inputs:
            cnode, cconn, cmemlet = inputs[None]
            for node in state.source_nodes():
                graph.add_edge(cnode, cconn, node, None, EmptyMemlet())

        # Remove the nested SDFG node
        graph.remove_node(nsdfg_node)

        # Remove input/output nodes from top-level graph if not connected to
        # any internal node
        for node, _, _ in list(inputs.values()) + list(outputs.values()):
            if len(graph.all_edges(node)) == 0:
                graph.remove_node(node)