예제 #1
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]]
        edge = sdfg.in_edges(state)[0]
        # Since inter-state assignments that use an assigned value leads to
        # undefined behavior (e.g., {m: n, n: m}), we can replace each
        # assignment separately.
        keys_to_remove = set()
        assignments_to_consider = _assignments_to_consider(sdfg, edge)
        for varname, assignment in assignments_to_consider.items():
            state.replace(varname, assignment)
            keys_to_remove.add(varname)

        repl_dict = {}

        for varname in keys_to_remove:
            # Remove assignments from edge
            del edge.data.assignments[varname]

            for e in sdfg.edges():
                if varname in e.data.free_symbols:
                    break
            else:
                # If removed assignment does not appear in any other edge,
                # replace and remove symbol
                if assignments_to_consider[varname] in sdfg.symbols:
                    repl_dict[varname] = assignments_to_consider[varname]
                if varname in sdfg.symbols:
                    sdfg.remove_symbol(varname)

        def _str_repl(s, d):
            for k, v in d.items():
                s.replace(str(k), str(v))

        if repl_dict:
            symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
예제 #2
0
    def _check_strides(inner_strides: List[symbolic.SymbolicType],
                       outer_strides: List[symbolic.SymbolicType],
                       memlet: Memlet, nested_sdfg: nodes.NestedSDFG) -> bool:
        """
        Returns True if the strides of the inner array can be matched
        to the strides of the outer array upon inlining. Takes into
        consideration memlet (un)squeeze and nested SDFG symbol mapping.
        :param inner_strides: The strides of the array inside the nested SDFG.
        :param outer_strides: The strides of the array in the external SDFG.
        :param nested_sdfg: Nested SDFG node with symbol mapping.
        :return: True if all strides match, False otherwise.
        """
        # Replace all inner symbols based on symbol mapping
        istrides = list(inner_strides)

        def replfunc(mapping):
            for i, s in enumerate(istrides):
                if symbolic.issymbolic(s):
                    istrides[i] = s.subs(mapping)

        symbolic.safe_replace(nested_sdfg.symbol_mapping, replfunc)

        if istrides == list(outer_strides):
            return True

        # Take unsqueezing into account
        dims_to_ignore = [
            i for i, s in enumerate(memlet.subset.size()) if s == 1
        ]
        ostrides = [
            os for i, os in enumerate(outer_strides) if i not in dims_to_ignore
        ]

        if len(ostrides) == 0:
            ostrides = [1]

        if len(ostrides) != len(istrides):
            return False

        return all(istr == ostr for istr, ostr in zip(istrides, ostrides))
예제 #3
0
    def apply(self, _, sdfg: SDFG):
        state = self.end_state
        edge = sdfg.in_edges(state)[0]
        # Since inter-state assignments that use an assigned value leads to
        # undefined behavior (e.g., {m: n, n: m}), we can replace each
        # assignment separately.
        assignments_to_consider = _assignments_to_consider(sdfg, edge, True)

        def _str_repl(s, d, **kwargs):
            for k, v in d.items():
                s.replace(str(k), str(v), **kwargs)

        # Replace in state, and all successors
        symbolic.safe_replace(assignments_to_consider,
                              lambda m: _str_repl(state, m))
        visited = {edge}
        for isedge in sdfg.bfs_edges(state):
            if isedge not in visited:
                symbolic.safe_replace(
                    assignments_to_consider,
                    lambda m: _str_repl(isedge.data, m, replace_keys=False))
                visited.add(isedge)
            if isedge.dst not in visited:
                symbolic.safe_replace(assignments_to_consider,
                                      lambda m: _str_repl(isedge.dst, m))
                visited.add(isedge.dst)

        repl_dict = {}

        for varname in assignments_to_consider.keys():
            # Remove assignments from edge
            del edge.data.assignments[varname]

            for e in sdfg.edges():
                if varname in e.data.free_symbols:
                    break
            else:
                # If removed assignment does not appear in any other edge,
                # replace and remove symbol
                if varname in sdfg.symbols:
                    sdfg.remove_symbol(varname)
                # if assignments_to_consider[varname] in sdfg.symbols:
                if varname in sdfg.free_symbols:
                    repl_dict[varname] = assignments_to_consider[varname]

        if repl_dict:
            symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
예제 #4
0
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        first_map_exit = graph.nodes()[candidate[MapFusion.first_map_exit]]
        first_map_entry = graph.entry_node(first_map_exit)
        second_map_entry = graph.nodes()[candidate[MapFusion.second_map_entry]]
        second_map_exit = graph.exit_node(second_map_entry)

        for _in_e in graph.in_edges(first_map_exit):
            if _in_e.data.wcr is not None:
                for _out_e in graph.out_edges(second_map_entry):
                    if _out_e.data.data == _in_e.data.data:
                        # wcr is on a node that is used in the second map, quit
                        return False
        # Check whether there is a pattern map -> access -> map.
        intermediate_nodes = set()
        intermediate_data = set()
        for _, _, dst, _, _ in graph.out_edges(first_map_exit):
            if isinstance(dst, nodes.AccessNode):
                intermediate_nodes.add(dst)
                intermediate_data.add(dst.data)

                # If array is used anywhere else in this state.
                num_occurrences = len([
                    n for s in sdfg.nodes() for n in s.nodes()
                    if isinstance(n, nodes.AccessNode) and n.data == dst.data
                ])
                if num_occurrences > 1:
                    return False
            else:
                return False
        # Check map ranges
        perm = MapFusion.find_permutation(first_map_entry.map,
                                          second_map_entry.map)
        if perm is None:
            return False

        # Check if any intermediate transient is also going to another location
        second_inodes = set(e.src for e in graph.in_edges(second_map_entry)
                            if isinstance(e.src, nodes.AccessNode))
        transients_to_remove = intermediate_nodes & second_inodes
        # if any(e.dst != second_map_entry for n in transients_to_remove
        #        for e in graph.out_edges(n)):
        if any(graph.out_degree(n) > 1 for n in transients_to_remove):
            return False

        # Create a dict that maps parameters of the first map to those of the
        # second map.
        params_dict = {}
        for _index, _param in enumerate(second_map_entry.map.params):
            params_dict[_param] = first_map_entry.map.params[perm[_index]]

        out_memlets = [e.data for e in graph.in_edges(first_map_exit)]

        # Check that input set of second map is provided by the output set
        # of the first map, or other unrelated maps
        for second_edge in graph.out_edges(second_map_entry):
            # Memlets that do not come from one of the intermediate arrays
            if second_edge.data.data not in intermediate_data:
                # however, if intermediate_data eventually leads to
                # second_memlet.data, need to fail.
                for _n in intermediate_nodes:
                    source_node = _n
                    destination_node = graph.memlet_path(second_edge)[0].src
                    # NOTE: Assumes graph has networkx version
                    if destination_node in nx.descendants(
                            graph._nx, source_node):
                        return False
                continue

            provided = False

            # Compute second subset with respect to first subset's symbols
            sbs_permuted = dcpy(second_edge.data.subset)
            if sbs_permuted:
                # Create intermediate dicts to avoid conflicts, such as {i:j, j:i}
                symbolic.safe_replace(params_dict,
                                      lambda m: sbs_permuted.replace(m))

            for first_memlet in out_memlets:
                if first_memlet.data != second_edge.data.data:
                    continue

                # If there is a covered subset, it is provided
                if first_memlet.subset.covers(sbs_permuted):
                    provided = True
                    break

            # If none of the output memlets of the first map provide the info,
            # fail.
            if provided is False:
                return False

        # Checking for stencil pattern and common input/output data
        # (after fusing the maps)
        first_map_inputnodes = {
            e.src: e.src.data
            for e in graph.in_edges(first_map_entry)
            if isinstance(e.src, nodes.AccessNode)
        }
        input_views = set()
        viewed_inputnodes = dict()
        for n in first_map_inputnodes.keys():
            if isinstance(n.desc(sdfg), data.View):
                input_views.add(n)
        for v in input_views:
            del first_map_inputnodes[v]
            e = sdutil.get_view_edge(graph, v)
            if e:
                first_map_inputnodes[e.src] = e.src.data
                viewed_inputnodes[e.src.data] = v
        second_map_outputnodes = {
            e.dst: e.dst.data
            for e in graph.out_edges(second_map_exit)
            if isinstance(e.dst, nodes.AccessNode)
        }
        output_views = set()
        viewed_outputnodes = dict()
        for n in second_map_outputnodes:
            if isinstance(n.desc(sdfg), data.View):
                output_views.add(n)
        for v in output_views:
            del second_map_outputnodes[v]
            e = sdutil.get_view_edge(graph, v)
            if e:
                second_map_outputnodes[e.dst] = e.dst.data
                viewed_outputnodes[e.dst.data] = v
        common_data = set(first_map_inputnodes.values()).intersection(
            set(second_map_outputnodes.values()))
        if common_data:
            input_data = [
                viewed_inputnodes[d].data
                if d in viewed_inputnodes.keys() else d for d in common_data
            ]
            input_accesses = [
                graph.memlet_path(e)[-1].data.src_subset
                for e in graph.out_edges(first_map_entry)
                if e.data.data in input_data
            ]
            if len(input_accesses) > 1:
                for i, a in enumerate(input_accesses[:-1]):
                    for b in input_accesses[i + 1:]:
                        if isinstance(a, subsets.Indices):
                            c = subsets.Range.from_indices(a)
                            c.offset(b, negative=True)
                        else:
                            c = a.offset_new(b, negative=True)
                        for r in c:
                            if r != (0, 0, 1):
                                return False

            output_data = [
                viewed_outputnodes[d].data
                if d in viewed_outputnodes.keys() else d for d in common_data
            ]
            output_accesses = [
                graph.memlet_path(e)[0].data.dst_subset
                for e in graph.in_edges(second_map_exit)
                if e.data.data in output_data
            ]

            # Compute output accesses with respect to first map's symbols
            oacc_permuted = [dcpy(a) for a in output_accesses]
            for a in oacc_permuted:
                # Create intermediate dicts to avoid conflicts, such as {i:j, j:i}
                symbolic.safe_replace(params_dict, lambda m: a.replace(m))

            a = input_accesses[0]
            for b in oacc_permuted:
                if isinstance(a, subsets.Indices):
                    c = subsets.Range.from_indices(a)
                    c.offset(b, negative=True)
                else:
                    c = a.offset_new(b, negative=True)
                for r in c:
                    if r != (0, 0, 1):
                        return False

        # Success
        return True
예제 #5
0
    def apply(self, sdfg: SDFG):
        nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg)
        state = sdfg.node(self.state_id)

        new_state = sdfg.add_state_before(state)
        isedge = sdfg.edges_between(new_state, state)[0]

        # Find relevant symbol and data descriptor mapping
        mapping: Dict[str, str] = {}
        mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()})
        mapping.update({
            k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data
            for k in nsdfg.in_connectors
        })
        mapping.update({
            k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data
            for k in nsdfg.out_connectors
        })

        # Get internal state and interstate edge
        source_state = nsdfg.sdfg.start_state
        nisedge = nsdfg.sdfg.out_edges(source_state)[0]

        # Add state contents (nodes)
        new_state.add_nodes_from(source_state.nodes())

        # Replace data descriptors and symbols on state graph
        for node in source_state.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in mapping:
                node.data = mapping[node.data]
        for edge in source_state.edges():
            edge.data.replace(mapping)
            if edge.data.data in mapping:
                edge.data.data = mapping[edge.data.data]

        # Add state contents (edges)
        for edge in source_state.edges():
            new_state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                               edge.data)

        # Safe replacement of edge contents
        def replfunc(m):
            for k, v in mapping.items():
                nisedge.data.replace(k, v, replace_keys=False)
        symbolic.safe_replace(mapping, replfunc)

        # Add interstate edge
        for akey, aval in nisedge.data.assignments.items():
            # Map assignment to outer edge
            if akey not in sdfg.symbols and akey not in sdfg.arrays:
                newname = akey
            else:
                newname = nsdfg.label + '_' + akey

            isedge.data.assignments[newname] = aval

            # Add symbol to outer SDFG
            sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey])

            # Add symbol mapping to nested SDFG
            nsdfg.symbol_mapping[akey] = newname

        isedge.data.condition = nisedge.data.condition

        # Clean nested SDFG
        nsdfg.sdfg.remove_node(source_state)

        # Set new starting state
        nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)
예제 #6
0
    def apply(self, outer_state: SDFGState, sdfg: SDFG):
        nsdfg_node = self.nested_sdfg
        nsdfg: SDFG = nsdfg_node.sdfg

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Environments
        for nstate in nsdfg.nodes():
            for node in nstate.nodes():
                if isinstance(node, nodes.CodeNode):
                    node.environments |= nsdfg_node.environments

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Symbols
        outer_symbols = {str(k): v for k, v in sdfg.symbols.items()}
        for ise in sdfg.edges():
            outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols))

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in outer_state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in outer_state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict)

        # Access nodes that need to be reshaped
        # reshapes: Set(str) = set()
        # for aname, array in nsdfg.arrays.items():
        #     if array.transient:
        #         continue
        #     edge = None
        #     if aname in inputs:
        #         edge = inputs[aname]
        #         if len(array.shape) > len(edge.data.subset):
        #             reshapes.add(aname)
        #             continue
        #     if aname in outputs:
        #         edge = outputs[aname]
        #         if len(array.shape) > len(edge.data.subset):
        #             reshapes.add(aname)
        #             continue
        #     if edge is not None and not InlineMultistateSDFG._check_strides(
        #             array.strides, sdfg.arrays[edge.data.data].strides,
        #             edge.data, nsdfg_node):
        #         reshapes.add(aname)

        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}

        # All transients become transients of the parent (if data already
        # exists, find new name)
        for nstate in nsdfg.nodes():
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    datadesc = nsdfg.arrays[node.data]
                    if node.data not in transients and datadesc.transient:
                        new_name = node.data
                        if (new_name in sdfg.arrays
                                or new_name in outer_symbols
                                or new_name in sdfg.constants):
                            new_name = f'{nsdfg.label}_{node.data}'

                        name = sdfg.add_datadesc(new_name,
                                                 datadesc,
                                                 find_new_name=True)
                        transients[node.data] = name

            # All transients of edges between code nodes are also added to parent
            for edge in nstate.edges():
                if (isinstance(edge.src, nodes.CodeNode)
                        and isinstance(edge.dst, nodes.CodeNode)):
                    if edge.data.data is not None:
                        datadesc = nsdfg.arrays[edge.data.data]
                        if edge.data.data not in transients and datadesc.transient:
                            new_name = edge.data.data
                            if (new_name in sdfg.arrays
                                    or new_name in outer_symbols
                                    or new_name in sdfg.constants):
                                new_name = f'{nsdfg.label}_{edge.data.data}'

                            name = sdfg.add_datadesc(new_name,
                                                     datadesc,
                                                     find_new_name=True)
                            transients[edge.data.data] = name

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        symbolic.safe_replace(repldict,
                              lambda m: replace_datadesc_names(nsdfg, m),
                              value_as_string=True)

        # Add views whenever reshapes are necessary
        # for dname in reshapes:
        #     desc = nsdfg.arrays[dname]
        #     # To avoid potential confusion, rename protected __return keyword
        #     if dname.startswith('__return'):
        #         newname = f'{nsdfg.name}_ret{dname[8:]}'
        #     else:
        #         newname = dname
        #     newname, _ = sdfg.add_view(newname,
        #                                desc.shape,
        #                                desc.dtype,
        #                                storage=desc.storage,
        #                                strides=desc.strides,
        #                                offset=desc.offset,
        #                                debuginfo=desc.debuginfo,
        #                                allow_conflicts=desc.allow_conflicts,
        #                                total_size=desc.total_size,
        #                                alignment=desc.alignment,
        #                                may_alias=desc.may_alias,
        #                                find_new_name=True)
        #     repldict[dname] = newname

        # Add extra access nodes for out/in view nodes
        # inv_reshapes = {repldict[r]: r for r in reshapes}
        # for nstate in nsdfg.nodes():
        #     for node in nstate.nodes():
        #         if isinstance(node,
        #                       nodes.AccessNode) and node.data in inv_reshapes:
        #             if nstate.in_degree(node) > 0 and nstate.out_degree(
        #                     node) > 0:
        #                 # Such a node has to be in the output set
        #                 edge = outputs[inv_reshapes[node.data]]

        #                 # Redirect outgoing edges through access node
        #                 out_edges = list(nstate.out_edges(node))
        #                 anode = nstate.add_access(edge.data.data)
        #                 vnode = nstate.add_access(node.data)
        #                 nstate.add_nedge(node, anode, edge.data)
        #                 nstate.add_nedge(anode, vnode, edge.data)
        #                 for e in out_edges:
        #                     nstate.remove_edge(e)
        #                     nstate.add_edge(vnode, e.src_conn, e.dst,
        #                                     e.dst_conn, e.data)

        # Make unique names for states
        statenames = set(s.label for s in sdfg.nodes())
        for nstate in nsdfg.nodes():
            if nstate.label in statenames:
                newname = data.find_new_name(nstate.label, statenames)
                statenames.add(newname)
                nstate.set_label(newname)

        #######################################################
        # Collect and modify interstate edges as necessary

        outer_assignments = set()
        for e in sdfg.edges():
            outer_assignments |= e.data.assignments.keys()

        inner_assignments = set()
        for e in nsdfg.edges():
            inner_assignments |= e.data.assignments.keys()

        assignments_to_replace = inner_assignments & outer_assignments
        sym_replacements: Dict[str, str] = {}
        allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys())
        for assign in assignments_to_replace:
            newname = data.find_new_name(assign, allnames)
            allnames.add(newname)
            sym_replacements[assign] = newname
        nsdfg.replace_dict(sym_replacements)

        #######################################################
        # Add nested SDFG states into top-level SDFG

        outer_start_state = sdfg.start_state

        sdfg.add_nodes_from(nsdfg.nodes())
        for ise in nsdfg.edges():
            sdfg.add_edge(ise.src, ise.dst, ise.data)

        #######################################################
        # Reconnect inlined SDFG

        source = nsdfg.start_state
        sinks = nsdfg.sink_nodes()

        # Reconnect state machine
        for e in sdfg.in_edges(outer_state):
            sdfg.add_edge(e.src, source, e.data)
        for e in sdfg.out_edges(outer_state):
            for sink in sinks:
                sdfg.add_edge(sink, e.dst, e.data)

        # Modify start state as necessary
        if outer_start_state is outer_state:
            sdfg.start_state = sdfg.node_id(source)

        # TODO: Modify memlets by offsetting
        # If both source and sink nodes are inputs/outputs, reconnect once
        # edges_to_ignore = self._modify_access_to_access(new_incoming_edges,
        #                                                 nsdfg, nstate, state,
        #                                                 orig_data)

        # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()}
        # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()}
        # # If a source/sink node is one of the inputs/outputs, reconnect it,
        # # replacing memlets in outgoing/incoming paths
        # modified_edges = set()
        # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
        #                                            state, sink_to_outer, True,
        #                                            edges_to_ignore)
        # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
        #                                            state, source_to_outer,
        #                                            False, edges_to_ignore)

        # # Reshape: add connections to viewed data
        # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
        #                           True)
        # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
        #                           False)

        # Modify all other internal edges pertaining to input/output nodes
        # for nstate in nsdfg.nodes():
        #     for node in nstate.nodes():
        #         if isinstance(node, nodes.AccessNode):
        #             if node.data in input_set or node.data in output_set:
        #                 if node.data in input_set:
        #                     outer_edge = inputs[input_set[node.data]]
        #                 else:
        #                     outer_edge = outputs[output_set[node.data]]

        #                 for edge in state.all_edges(node):
        #                     if (edge not in modified_edges
        #                             and edge.data.data == node.data):
        #                         for e in state.memlet_tree(edge):
        #                             if e.data.data == node.data:
        #                                 e._data = helpers.unsqueeze_memlet(
        #                                     e.data, outer_edge.data)

        # Replace nested SDFG parents with new SDFG
        for nstate in nsdfg.nodes():
            nstate.parent = sdfg
            for node in nstate.nodes():
                if isinstance(node, nodes.NestedSDFG):
                    node.sdfg.parent_sdfg = sdfg
                    node.sdfg.parent_nsdfg_node = node

        #######################################################
        # Remove nested SDFG and state
        sdfg.remove_node(outer_state)

        return nsdfg.nodes()