예제 #1
0
def fusion(sdfg: dace.SDFG,
           graph: dace.SDFGState,
           subgraph: Union[SubgraphView, List[SubgraphView]] = None,
           **kwargs):

    subgraph = graph if not subgraph else subgraph
    if not isinstance(subgraph, List):
        subgraph = [subgraph]

    map_fusion = SubgraphFusion()
    for (property, val) in kwargs.items():
        setattr(map_fusion, property, val)

    for sg in subgraph:
        map_entries = helpers.get_highest_scope_maps(sdfg, graph, sg)
        # remove map_entries and their corresponding exits from the subgraph
        # already before applying transformation
        if isinstance(sg, SubgraphView):
            for map_entry in map_entries:
                sg.nodes().remove(map_entry)
                if graph.exit_node(map_entry) in sg.nodes():
                    sg.nodes().remove(graph.exit_node(map_entry))
        print(f"Subgraph Fusion on map entries {map_entries}")
        map_fusion.fuse(sdfg, graph, map_entries)
        if isinstance(sg, SubgraphView):
            sg.nodes().append(map_fusion._global_map_entry)
예제 #2
0
    def apply(self, sdfg, subgraph, do_not_override=[], **kwargs):
        self.subgraph = subgraph

        graph = subgraph.graph

        map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        self.fuse(sdfg, graph, map_entries, do_not_override, **kwargs)
예제 #3
0
    def apply(self, sdfg, map_base_variables=None):
        # get lowest scope map entries and expand
        subgraph = self.subgraph_view(sdfg)
        graph = subgraph.graph

        # next, get all the base maps and expand
        maps = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        self.expand(sdfg, graph, maps, map_base_variables=map_base_variables)
예제 #4
0
파일: smax_test.py 프로젝트: 1C4nfaN/dace
def expand_maps(sdfg: dace.SDFG,
                graph: dace.SDFGState,
                subgraph: Union[SubgraphView, List[SubgraphView]] = None,
                **kwargs):

    subgraph = graph if not subgraph else subgraph
    if not isinstance(subgraph, List):
        subgraph = [subgraph]

    trafo_expansion = MultiExpansion()
    for (property, val) in kwargs.items():
        setattr(trafo_expansion, property, val)

    for sg in subgraph:
        map_entries = helpers.get_highest_scope_maps(sdfg, graph, sg)
        trafo_expansion.expand(sdfg, graph, map_entries)
예제 #5
0
    def match(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        ### get lowest scope maps of subgraph
        # grab first node and see whether all nodes are in the same graph
        # (or nested sdfgs therein)

        graph = subgraph.graph

        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        maps = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        brng = helpers.common_map_base_ranges(maps)

        # if leq than one map found -> fail
        if len(maps) <= 1:
            return False

        # see whether they have common parameters; if not -> fail
        if len(brng) == 0:
            return False

        return True
예제 #6
0
    def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        '''
        Fusible if
        1. Maps have the same access sets and ranges in order
        2. Any nodes in between two maps are AccessNodes only, without WCR
           There is at most one AccessNode only on a path between two maps,
           no other nodes are allowed
        3. The exiting memlets' subsets to an intermediate edge must cover
           the respective incoming memlets' subset into the next map
        '''
        # get graph
        graph = subgraph.graph
        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]
        maps = [map_entry.map for map_entry in map_entries]

        # 1. check whether all map ranges and indices are the same
        if len(maps) <= 1:
            return False
        base_map = maps[0]
        for map in maps:
            if map.get_param_num() != base_map.get_param_num():
                return False
            if not all(
                [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]):
                return False
            if not map.range == base_map.range:
                return False

        # 1.1 check whether all map entries have the same schedule
        schedule = map_entries[0].schedule
        if not all([entry.schedule == schedule for entry in map_entries]):
            return False

        # 2. check intermediate feasiblility
        # see map_fusion.py for similar checks
        # we are being more relaxed here

        # 2.1 do some preparation work first:
        # calculate all out_nodes and intermediate_nodes
        # definition see in apply()
        intermediate_nodes = set()
        out_nodes = set()
        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            intermediate_nodes.add(current_node)
                        else:
                            out_nodes.add(current_node)

        # 2.2 topological feasibility:
        # For each intermediate and out node: must never reach any map
        # entry if it is not connected to map entry immediately
        visited = set()

        # for memoization purposes
        def visit_descendants(graph, node, visited, map_entries):
            # if we have already been at this node
            if node in visited:
                return True
            # not necessary to add if there aren't any other in connections
            if len(graph.in_edges(node)) > 1:
                visited.add(node)
            for oedge in graph.out_edges(node):
                if not visit_descendants(graph, oedge.dst, visited,
                                         map_entries):
                    return False
            return True

        for node in intermediate_nodes | out_nodes:
            # these nodes must not lead to a map entry
            nodes_to_check = set()
            for oedge in graph.out_edges(node):
                if oedge.dst not in map_entries:
                    nodes_to_check.add(oedge.dst)

            for forbidden_node in nodes_to_check:
                if not visit_descendants(graph, forbidden_node, visited,
                                         map_entries):
                    return False

        # 2.3 memlet feasibility
        # For each intermediate node, look at whether inner adjacent
        # memlets of the exiting map cover inner adjacent memlets
        # of the next entering map.
        # We also check for any WCRs on the fly.

        for node in intermediate_nodes:
            upper_subsets = set()
            lower_subsets = set()
            # First, determine which dimensions of the memlet ranges
            # change with the map, we do not need to care about the other dimensions.
            total_dims = len(sdfg.data(node.data).shape)
            dims_to_discard = SubgraphFusion.get_invariant_dimensions(
                sdfg, graph, map_entries, map_exits, node)

            # find upper_subsets
            for in_edge in graph.in_edges(node):
                # first check for WCRs
                if in_edge.data.wcr:
                    return False
                if in_edge.src in map_exits:
                    edge = graph.memlet_path(in_edge)[-2]
                    subset_to_add = dcpy(edge.data.subset\
                                         if edge.data.data == node.data\
                                         else edge.data.other_subset)
                    subset_to_add.pop(dims_to_discard)
                    upper_subsets.add(subset_to_add)
                else:
                    raise NotImplementedError("Nodes between two maps to be"
                                              "fused with *incoming* edges"
                                              "from outside the maps are not"
                                              "allowed yet.")

            # find lower_subsets
            for out_edge in graph.out_edges(node):
                if out_edge.dst in map_entries:
                    # cannot use memlet tree here as there could be
                    # not just one map succedding. Do it manually
                    for oedge in graph.out_edges(out_edge.dst):
                        if oedge.src_conn[3:] == out_edge.dst_conn[2:]:
                            subset_to_add = dcpy(oedge.data.subset \
                                                 if edge.data.data == node.data \
                                                 else edge.data.other_subset)
                            subset_to_add.pop(dims_to_discard)
                            lower_subsets.add(subset_to_add)

            upper_iter = iter(upper_subsets)
            union_upper = next(upper_iter)

            # TODO: add this check at a later point
            # We assume that upper_subsets for each data array
            # are contiguous
            # or do the full check if possible (intersection needed)
            '''
            # check whether subsets in upper_subsets are adjacent.
            # this is a requriement for the current implementation
            #try:
            # O(n^2*|dims|) but very small amount of subsets anyway
            try:
                for dim in range(total_dims - len(dims_to_discard)):
                    ordered_list = [(-1,-1,-1)]
                    for upper_subset in upper_subsets:
                        lo = upper_subset[dim][0]
                        hi = upper_subset[dim][1]
                        for idx,element in enumerate(ordered_list):
                            if element[0] <= lo and element[1] >= hi:
                                break
                            if element[0] > lo:
                                ordered_list.insert(idx, (lo,hi))
                    ordered_list.pop(0)


                    highest = ordered_list[0][1]
                    for i in range(len(ordered_list)):
                        if i < len(ordered_list)-1:
                            current_range = ordered_list[i]
                            if current_range[1] > highest:
                                hightest = current_range[1]
                            next_range = ordered_list[i+1]
                            if highest < next_range[0] - 1:
                                return False
            except TypeError:
                #return False
            '''
            # FORNOW: just omit warning if unsure
            for lower_subset in lower_subsets:
                covers = False
                for upper_subset in upper_subsets:
                    if upper_subset.covers(lower_subset):
                        covers = True
                        break
                if not covers:
                    warnings.warn(
                        f"WARNING: For node {node}, please check assure that"
                        "incoming memlets cover outgoing ones. Ambiguous check (WIP)."
                    )

            # now take union of upper subsets
            for subs in upper_iter:
                union_upper = subsets.union(union_upper, subs)
                if not union_upper:
                    # something went wrong using union -- we'd rather abort
                    return False

            # finally check coverage
            for lower_subset in lower_subsets:
                if not union_upper.covers(lower_subset):
                    return False

        return True