def apply(self, graph: SDFGState, sdfg: SDFG) -> Tuple[nodes.MapEntry, nodes.MapExit]: """ Collapses two maps into one. :param sdfg: The SDFG to apply the transformation to. :return: A 2-tuple of the new map entry and exit nodes. """ # Extract the parameters and ranges of the inner/outer maps. outer_map_entry = self.outer_map_entry inner_map_entry = self.inner_map_entry inner_map_exit = graph.exit_node(inner_map_entry) outer_map_exit = graph.exit_node(outer_map_entry) return sdutil.merge_maps(graph, outer_map_entry, outer_map_exit, inner_map_entry, inner_map_exit)
def can_be_applied(cls, state: SDFGState, candidate, expr_index, sdfg: SDFG, strict=False) -> bool: map_entry = state.node(candidate[cls.map_entry]) map_exit = state.exit_node(map_entry) current_map = map_entry.map subgraph = state.scope_subgraph(map_entry) subgraph_contents = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) # Prevent infinite repeats if current_map.schedule == dace.dtypes.ScheduleType.SVE_Map: return False # Infer all connector types for later checks (without modifying the graph) inferred = infer_types.infer_connector_types(sdfg, state, subgraph) ######################## # Ensure only Tasklets and AccessNodes are within the map for node, _ in subgraph_contents.all_nodes_recursive(): if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): return False ######################## # Check for unsupported datatypes on the connectors (including on the Map itself) bit_widths = set() for node, _ in subgraph.all_nodes_recursive(): for conn in node.in_connectors: t = inferred[(node, conn, True)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False for conn in node.out_connectors: t = inferred[(node, conn, False)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False # Multiple different bit widths occuring (messes up the predicates) if len(bit_widths) > 1: return False ######################## # Check for unsupported memlets param_name = current_map.params[-1] for e, _ in subgraph.all_edges_recursive(): # Check for unsupported strides # The only unsupported strides are the ones containing the innermost # loop param because they are not constant during a vector step param_sym = symbolic.symbol(current_map.params[-1]) if param_sym in e.data.get_stride(sdfg, map_entry.map).free_symbols: return False # Check for unsupported WCR if e.data.wcr is not None: # Unsupported reduction type reduction_type = dace.frontend.operations.detect_reduction_type( e.data.wcr) if reduction_type not in sve.util.REDUCTION_TYPE_TO_SVE: return False # Param in memlet during WCR is not supported if param_name in e.data.subset.free_symbols and e.data.wcr_nonatomic: return False # vreduce is not supported dst_node = state.memlet_path(e)[-1] if isinstance(dst_node, nodes.Tasklet): if isinstance(dst_node.in_connectors[e.dst_conn], dtypes.vector): return False elif isinstance(dst_node, nodes.AccessNode): desc = dst_node.desc(sdfg) if isinstance(desc, data.Scalar) and isinstance( desc.dtype, dtypes.vector): return False ######################## # Check for invalid copies in the subgraph for node, _ in subgraph.all_nodes_recursive(): if not isinstance(node, nodes.Tasklet): continue for e in state.in_edges(node): # Check for valid copies from other tasklets and/or streams if e.data.data is not None: src_node = state.memlet_path(e)[0].src if not isinstance(src_node, (nodes.Tasklet, nodes.AccessNode)): # Make sure we only have Code->Code copies and from arrays return False if isinstance(src_node, nodes.AccessNode): src_desc = src_node.desc(sdfg) if isinstance(src_desc, dace.data.Stream): # Stream pops are not implemented return False # Run the vector inference algorithm to check if vectorization is feasible try: inf_graph = vector_inference.infer_vectors( sdfg, state, map_entry, util.SVE_LEN, flags=vector_inference.VectorInferenceFlags.Allow_Stride, apply=False) except vector_inference.VectorInferenceException as ex: print(f'UserWarning: Vector inference failed! {ex}') return False return True
def apply(self, graph: SDFGState, sdfg: SDFG): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ first_exit = self.first_map_exit first_entry = graph.entry_node(first_exit) second_entry = self.second_map_entry second_exit = graph.exit_node(second_entry) intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes perm = self.find_permutation(first_entry.map, second_entry.map) params_dict = {} for index, param in enumerate(first_entry.map.params): params_dict[param] = second_entry.map.params[perm[index]] # Replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices. # This works in two passes to avoid problems when e.g., exchanging two # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then # i,i). second_scope = graph.scope_subgraph(second_entry) for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, secondp, '__' + secondp + '_fused') for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, '__' + secondp + '_fused', firstp) # Isolate First exit node ############################ edges_to_remove = set() nodes_to_remove = set() for edge in graph.in_edges(first_exit): tree = graph.memlet_tree(edge) access_node = tree.root().edge.dst if access_node not in do_not_erase: out_edges = [ e for e in graph.out_edges(access_node) if e.dst == second_entry ] # In this transformation, there can only be one edge to the # second map assert len(out_edges) == 1 # Get source connector to the second map connector = out_edges[0].dst_conn[3:] new_dsts = [] # Look at the second map entry out-edges to get the new # destinations for e in graph.out_edges(second_entry): if e.src_conn[4:] == connector: new_dsts.append(e) if not new_dsts: # Access node is not used in the second map nodes_to_remove.add(access_node) continue # Add a transient scalar/array self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, new_dsts[0].dst_conn, new_dsts[1:]) edges_to_remove.add(edge) # Remove transient node between the two maps nodes_to_remove.add(access_node) else: # The case where intermediate array node cannot be removed # Node will become an output of the second map exit out_e = tree.parent.edge conn = second_exit.next_connector() graph.add_edge( second_exit, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) second_exit.add_out_connector('OUT_' + conn) graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) second_exit.add_in_connector('IN_' + conn) edges_to_remove.add(out_e) edges_to_remove.add(edge) # If the second map needs this node, link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for out_e in graph.out_edges(second_entry): second_memlet_path = graph.memlet_path(out_e) source_node = second_memlet_path[0].src if source_node == access_node: self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) ### # First scope exit is isolated and can now be safely removed for e in edges_to_remove: graph.remove_edge(e) graph.remove_nodes_from(nodes_to_remove) graph.remove_node(first_exit) # Isolate second_entry node ########################### for edge in graph.in_edges(second_entry): tree = graph.memlet_tree(edge) access_node = tree.root().edge.src if access_node in intermediate_nodes: # Already handled above, can be safely removed graph.remove_edge(edge) continue # This is an external input to the second map which will now go # through the first map. conn = first_entry.next_connector() graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) first_entry.add_in_connector('IN_' + conn) graph.remove_edge(edge) for out_enode in tree.children: out_e = out_enode.edge graph.add_edge( first_entry, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) graph.remove_edge(out_e) first_entry.add_out_connector('OUT_' + conn) ### # Second node is isolated and can now be safely removed graph.remove_node(second_entry) # Fix scope exit to point to the right map second_exit.map = first_entry.map
def apply(self, graph: SDFGState, sdfg: SDFG): map_entry = self.map_entry # Avoiding import loops from dace.transformation.dataflow.strip_mining import StripMining from dace.transformation.dataflow.local_storage import InLocalStorage, OutLocalStorage, LocalStorage rangeexpr = str(map_entry.map.range.num_elements()) stripmine_subgraph = { StripMining.map_entry: self.subgraph[MPITransformMap.map_entry] } sdfg_id = sdfg.sdfg_id stripmine = StripMining(sdfg, sdfg_id, self.state_id, stripmine_subgraph, self.expr_index) stripmine.dim_idx = -1 stripmine.new_dim_prefix = "mpi" stripmine.tile_size = "(" + rangeexpr + "/__dace_comm_size)" stripmine.divides_evenly = True stripmine.apply(graph, sdfg) # Find all in-edges that lead to the map entry outer_map = None edges = [ e for e in graph.in_edges(map_entry) if isinstance(e.src, nodes.EntryNode) ] outer_map = edges[0].src # Add MPI schedule attribute to outer map outer_map.map._schedule = dtypes.ScheduleType.MPI # Now create a transient for each array for e in edges: in_local_storage_subgraph = { LocalStorage.node_a: graph.node_id(outer_map), LocalStorage.node_b: self.subgraph[MPITransformMap.map_entry] } sdfg_id = sdfg.sdfg_id in_local_storage = InLocalStorage(sdfg, sdfg_id, self.state_id, in_local_storage_subgraph, self.expr_index) in_local_storage.array = e.data.data in_local_storage.apply(graph, sdfg) # Transform OutLocalStorage for each output of the MPI map in_map_exit = graph.exit_node(map_entry) out_map_exit = graph.exit_node(outer_map) for e in graph.out_edges(out_map_exit): name = e.data.data outlocalstorage_subgraph = { LocalStorage.node_a: graph.node_id(in_map_exit), LocalStorage.node_b: graph.node_id(out_map_exit) } sdfg_id = sdfg.sdfg_id outlocalstorage = OutLocalStorage(sdfg, sdfg_id, self.state_id, outlocalstorage_subgraph, self.expr_index) outlocalstorage.array = name outlocalstorage.apply(graph, sdfg)