def fusion(sdfg: dace.SDFG, graph: dace.SDFGState, subgraph: Union[SubgraphView, List[SubgraphView]] = None, **kwargs): subgraph = graph if not subgraph else subgraph if not isinstance(subgraph, List): subgraph = [subgraph] map_fusion = SubgraphFusion() for (property, val) in kwargs.items(): setattr(map_fusion, property, val) for sg in subgraph: map_entries = helpers.get_highest_scope_maps(sdfg, graph, sg) # remove map_entries and their corresponding exits from the subgraph # already before applying transformation if isinstance(sg, SubgraphView): for map_entry in map_entries: sg.nodes().remove(map_entry) if graph.exit_node(map_entry) in sg.nodes(): sg.nodes().remove(graph.exit_node(map_entry)) print(f"Subgraph Fusion on map entries {map_entries}") map_fusion.fuse(sdfg, graph, map_entries) if isinstance(sg, SubgraphView): sg.nodes().append(map_fusion._global_map_entry)
def apply(self, sdfg, subgraph, do_not_override=[], **kwargs): self.subgraph = subgraph graph = subgraph.graph map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph) self.fuse(sdfg, graph, map_entries, do_not_override, **kwargs)
def apply(self, sdfg, map_base_variables=None): # get lowest scope map entries and expand subgraph = self.subgraph_view(sdfg) graph = subgraph.graph # next, get all the base maps and expand maps = helpers.get_highest_scope_maps(sdfg, graph, subgraph) self.expand(sdfg, graph, maps, map_base_variables=map_base_variables)
def expand_maps(sdfg: dace.SDFG, graph: dace.SDFGState, subgraph: Union[SubgraphView, List[SubgraphView]] = None, **kwargs): subgraph = graph if not subgraph else subgraph if not isinstance(subgraph, List): subgraph = [subgraph] trafo_expansion = MultiExpansion() for (property, val) in kwargs.items(): setattr(trafo_expansion, property, val) for sg in subgraph: map_entries = helpers.get_highest_scope_maps(sdfg, graph, sg) trafo_expansion.expand(sdfg, graph, map_entries)
def match(sdfg: SDFG, subgraph: SubgraphView) -> bool: ### get lowest scope maps of subgraph # grab first node and see whether all nodes are in the same graph # (or nested sdfgs therein) graph = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False # next, get all the maps maps = helpers.get_highest_scope_maps(sdfg, graph, subgraph) brng = helpers.common_map_base_ranges(maps) # if leq than one map found -> fail if len(maps) <= 1: return False # see whether they have common parameters; if not -> fail if len(brng) == 0: return False return True
def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool: ''' Fusible if 1. Maps have the same access sets and ranges in order 2. Any nodes in between two maps are AccessNodes only, without WCR There is at most one AccessNode only on a path between two maps, no other nodes are allowed 3. The exiting memlets' subsets to an intermediate edge must cover the respective incoming memlets' subset into the next map ''' # get graph graph = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False # next, get all the maps map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph) map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] maps = [map_entry.map for map_entry in map_entries] # 1. check whether all map ranges and indices are the same if len(maps) <= 1: return False base_map = maps[0] for map in maps: if map.get_param_num() != base_map.get_param_num(): return False if not all( [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]): return False if not map.range == base_map.range: return False # 1.1 check whether all map entries have the same schedule schedule = map_entries[0].schedule if not all([entry.schedule == schedule for entry in map_entries]): return False # 2. check intermediate feasiblility # see map_fusion.py for similar checks # we are being more relaxed here # 2.1 do some preparation work first: # calculate all out_nodes and intermediate_nodes # definition see in apply() intermediate_nodes = set() out_nodes = set() for map_entry, map_exit in zip(map_entries, map_exits): for edge in graph.out_edges(map_exit): current_node = edge.dst if len(graph.out_edges(current_node)) == 0: out_nodes.add(current_node) else: for dst_edge in graph.out_edges(current_node): if dst_edge.dst in map_entries: intermediate_nodes.add(current_node) else: out_nodes.add(current_node) # 2.2 topological feasibility: # For each intermediate and out node: must never reach any map # entry if it is not connected to map entry immediately visited = set() # for memoization purposes def visit_descendants(graph, node, visited, map_entries): # if we have already been at this node if node in visited: return True # not necessary to add if there aren't any other in connections if len(graph.in_edges(node)) > 1: visited.add(node) for oedge in graph.out_edges(node): if not visit_descendants(graph, oedge.dst, visited, map_entries): return False return True for node in intermediate_nodes | out_nodes: # these nodes must not lead to a map entry nodes_to_check = set() for oedge in graph.out_edges(node): if oedge.dst not in map_entries: nodes_to_check.add(oedge.dst) for forbidden_node in nodes_to_check: if not visit_descendants(graph, forbidden_node, visited, map_entries): return False # 2.3 memlet feasibility # For each intermediate node, look at whether inner adjacent # memlets of the exiting map cover inner adjacent memlets # of the next entering map. # We also check for any WCRs on the fly. for node in intermediate_nodes: upper_subsets = set() lower_subsets = set() # First, determine which dimensions of the memlet ranges # change with the map, we do not need to care about the other dimensions. total_dims = len(sdfg.data(node.data).shape) dims_to_discard = SubgraphFusion.get_invariant_dimensions( sdfg, graph, map_entries, map_exits, node) # find upper_subsets for in_edge in graph.in_edges(node): # first check for WCRs if in_edge.data.wcr: return False if in_edge.src in map_exits: edge = graph.memlet_path(in_edge)[-2] subset_to_add = dcpy(edge.data.subset\ if edge.data.data == node.data\ else edge.data.other_subset) subset_to_add.pop(dims_to_discard) upper_subsets.add(subset_to_add) else: raise NotImplementedError("Nodes between two maps to be" "fused with *incoming* edges" "from outside the maps are not" "allowed yet.") # find lower_subsets for out_edge in graph.out_edges(node): if out_edge.dst in map_entries: # cannot use memlet tree here as there could be # not just one map succedding. Do it manually for oedge in graph.out_edges(out_edge.dst): if oedge.src_conn[3:] == out_edge.dst_conn[2:]: subset_to_add = dcpy(oedge.data.subset \ if edge.data.data == node.data \ else edge.data.other_subset) subset_to_add.pop(dims_to_discard) lower_subsets.add(subset_to_add) upper_iter = iter(upper_subsets) union_upper = next(upper_iter) # TODO: add this check at a later point # We assume that upper_subsets for each data array # are contiguous # or do the full check if possible (intersection needed) ''' # check whether subsets in upper_subsets are adjacent. # this is a requriement for the current implementation #try: # O(n^2*|dims|) but very small amount of subsets anyway try: for dim in range(total_dims - len(dims_to_discard)): ordered_list = [(-1,-1,-1)] for upper_subset in upper_subsets: lo = upper_subset[dim][0] hi = upper_subset[dim][1] for idx,element in enumerate(ordered_list): if element[0] <= lo and element[1] >= hi: break if element[0] > lo: ordered_list.insert(idx, (lo,hi)) ordered_list.pop(0) highest = ordered_list[0][1] for i in range(len(ordered_list)): if i < len(ordered_list)-1: current_range = ordered_list[i] if current_range[1] > highest: hightest = current_range[1] next_range = ordered_list[i+1] if highest < next_range[0] - 1: return False except TypeError: #return False ''' # FORNOW: just omit warning if unsure for lower_subset in lower_subsets: covers = False for upper_subset in upper_subsets: if upper_subset.covers(lower_subset): covers = True break if not covers: warnings.warn( f"WARNING: For node {node}, please check assure that" "incoming memlets cover outgoing ones. Ambiguous check (WIP)." ) # now take union of upper subsets for subs in upper_iter: union_upper = subsets.union(union_upper, subs) if not union_upper: # something went wrong using union -- we'd rather abort return False # finally check coverage for lower_subset in lower_subsets: if not union_upper.covers(lower_subset): return False return True