def offset_map(state, map_entry): offsets = [] subgraph = state.scope_subgraph(map_entry) for i, (p, r) in enumerate( zip(map_entry.map.params, map_entry.map.range.min_element())): if r != 0: offsets.append(r) replace(subgraph, str(p), f'{p}+{r}') else: offsets.append(0) map_entry.map.range.offset(offsets, negative=True)
def apply(self, sdfg): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ graph = sdfg.nodes()[self.state_id] first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]] first_entry = graph.entry_node(first_exit) second_entry = graph.nodes()[self.subgraph[ MapFusion._second_map_entry]] second_exit = graph.exit_nodes(second_entry)[0] intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes perm = MapFusion.find_permutation(first_entry.map, second_entry.map) params_dict = {} for index, param in enumerate(first_entry.map.params): params_dict[param] = second_entry.map.params[perm[index]] # Replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices. # This works in two passes to avoid problems when e.g., exchanging two # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then # i,i). second_scope = graph.scope_subgraph(second_entry) for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, secondp, '__' + secondp + '_fused') for firstp, secondp in params_dict.items(): if firstp != secondp: replace(second_scope, '__' + secondp + '_fused', firstp) # Isolate First exit node ############################ edges_to_remove = set() nodes_to_remove = set() for edge in graph.in_edges(first_exit): memlet_path = graph.memlet_path(edge) edge_index = next(i for i, e in enumerate(memlet_path) if e == edge) access_node = memlet_path[-1].dst if access_node not in do_not_erase: out_edges = [ e for e in graph.out_edges(access_node) if e.dst == second_entry ] # In this transformation, there can only be one edge to the # second map assert len(out_edges) == 1 # Get source connector to the second map connector = out_edges[0].dst_conn[3:] new_dst = None new_dst_conn = None # Look at the second map entry out-edges to get the new # destination for _e in graph.out_edges(second_entry): if _e.src_conn[4:] == connector: new_dst = _e.dst new_dst_conn = _e.dst_conn break if new_dst is None: # Access node is not used in the second map nodes_to_remove.add(access_node) continue # If the source is an access node, modify the memlet to point # to it if (isinstance(edge.src, nodes.AccessNode) and edge.data.data != edge.src.data): edge.data.data = edge.src.data edge.data.subset = ("0" if edge.data.other_subset is None else edge.data.other_subset) edge.data.other_subset = None else: # Add a transient scalar/array self.fuse_nodes(sdfg, graph, edge, new_dst, new_dst_conn) edges_to_remove.add(edge) # Remove transient node between the two maps nodes_to_remove.add(access_node) else: # The case where intermediate array node cannot be removed # Node will become an output of the second map exit out_e = memlet_path[edge_index + 1] conn = second_exit.next_connector() graph.add_edge( second_exit, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) second_exit.add_out_connector('OUT_' + conn) graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) second_exit.add_in_connector('IN_' + conn) edges_to_remove.add(out_e) # If the second map needs this node, link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for out_e in graph.out_edges(second_entry): second_memlet_path = graph.memlet_path(out_e) source_node = second_memlet_path[0].src if source_node == access_node: self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) edges_to_remove.add(edge) ### # First scope exit is isolated and can now be safely removed for e in edges_to_remove: graph.remove_edge(e) graph.remove_nodes_from(nodes_to_remove) graph.remove_node(first_exit) # Isolate second_entry node ########################### for edge in graph.in_edges(second_entry): memlet_path = graph.memlet_path(edge) edge_index = next(i for i, e in enumerate(memlet_path) if e == edge) access_node = memlet_path[0].src if access_node in intermediate_nodes: # Already handled above, can be safely removed graph.remove_edge(edge) continue # This is an external input to the second map which will now go # through the first map. conn = first_entry.next_connector() graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) first_entry.add_in_connector('IN_' + conn) graph.remove_edge(edge) out_e = memlet_path[edge_index + 1] graph.add_edge( first_entry, 'OUT_' + conn, out_e.dst, out_e.dst_conn, dcpy(out_e.data), ) first_entry.add_out_connector('OUT_' + conn) graph.remove_edge(out_e) ### # Second node is isolated and can now be safely removed graph.remove_node(second_entry) # Fix scope exit to point to the right map second_exit.map = first_entry.map
def expand(self, sdfg, graph, map_entries, map_base_variables=None): """ Expansion into outer and inner maps for each map in a specified set. The resulting outer maps all have same range and indices, corresponding variables and memlets get changed accordingly. The inner map contains the leftover dimensions :param sdfg: Underlying SDFG :param graph: Graph in which we expand :param map_entries: List of Map Entries(Type MapEntry) that we want to expand :param map_base_variables: Optional parameter. List of strings If None, then expand() searches for the maximal amount of equal map ranges and pushes those and their corresponding loop variables into the outer loop. If specified, then expand() pushes the ranges belonging to the loop iteration variables specified into the outer loop (For instance map_base_variables = ['i','j'] assumes that all maps have common iteration indices i and j with corresponding correct ranges) """ maps = [entry.map for entry in map_entries] if not map_base_variables: # find the maximal subset of variables to expand # greedy if there exist multiple ranges that are equal in a map map_base_ranges = helpers.common_map_base_ranges(maps) reassignments = helpers.find_reassignment(maps, map_base_ranges) ##### first, regroup and reassign # create params_dict for every map # first, let us define the outer iteration variable names, # just take the first map and their indices at common ranges map_base_variables = [] for rng in map_base_ranges: for i in range(len(maps[0].params)): if maps[0].range[i] == rng and maps[0].params[ i] not in map_base_variables: map_base_variables.append(maps[0].params[i]) break params_dict = {} if self.debug: print("MultiExpansion::Map_base_variables:", map_base_variables) print("MultiExpansion::Map_base_ranges:", map_base_ranges) for map in maps: # for each map create param dict, first assign identity params_dict_map = {param: param for param in map.params} # now look for the correct reassignment # for every element neq -1, need to change param to map_base_variables[] # if param already appears in own dict, do a swap # else we just replace it for i, reassignment in enumerate(reassignments[map]): if reassignment == -1: # nothing to do pass else: current_var = map.params[i] current_assignment = params_dict_map[current_var] target_assignment = map_base_variables[reassignment] if current_assignment != target_assignment: if target_assignment in params_dict_map.values(): # do a swap key1 = current_var for key, value in params_dict_map.items(): if value == target_assignment: key2 = key value1 = params_dict_map[key1] value2 = params_dict_map[key2] params_dict_map[key1] = key2 params_dict_map[key2] = key1 else: # just reassign params_dict_map[current_var] = target_assignment # done, assign params_dict_map to the global one params_dict[map] = params_dict_map for map, map_entry in zip(maps, map_entries): map_scope = graph.scope_subgraph(map_entry) params_dict_map = params_dict[map] for firstp, secondp in params_dict_map.items(): if firstp != secondp: replace(map_scope, firstp, '__' + firstp + '_fused') for firstp, secondp in params_dict_map.items(): if firstp != secondp: replace(map_scope, '__' + firstp + '_fused', secondp) # now also replace the map variables inside maps for i in range(len(map.params)): map.params[i] = params_dict_map[map.params[i]] if self.debug: print("MultiExpansion::Params replaced") else: # just calculate map_base_ranges # do a check whether all maps correct map_base_ranges = [] map0 = maps[0] for var in map_base_variables: index = map0.params.index(var) map_base_ranges.append(map0.range[index]) for map in maps: for var, rng in zip(map_base_variables, map_base_ranges): assert map.range[map.params.index(var)] == rng # then expand all the maps for map, map_entry in zip(maps, map_entries): if map.get_param_num() == len(map_base_variables): # nothing to expand, continue continue map_exit = graph.exit_node(map_entry) # create two new maps, outer and inner params_outer = map_base_variables ranges_outer = map_base_ranges init_params_inner = [] init_ranges_inner = [] for param, rng in zip(map.params, map.range): if param in map_base_variables: continue else: init_params_inner.append(param) init_ranges_inner.append(rng) params_inner = init_params_inner ranges_inner = subsets.Range(init_ranges_inner) inner_map = nodes.Map(label = map.label + '_inner', params = params_inner, ndrange = ranges_inner, schedule = dtypes.ScheduleType.Sequential \ if self.sequential_innermaps \ else dtypes.ScheduleType.Default) map.label = map.label + '_outer' map.params = params_outer map.range = ranges_outer # create new map entries and exits map_entry_inner = nodes.MapEntry(inner_map) map_exit_inner = nodes.MapExit(inner_map) # analogously to Map_Expansion for edge in graph.out_edges(map_entry): graph.remove_edge(edge) graph.add_memlet_path(map_entry, map_entry_inner, edge.dst, src_conn=edge.src_conn, memlet=edge.data, dst_conn=edge.dst_conn) dynamic_edges = dynamic_map_inputs(graph, map_entry) for edge in dynamic_edges: # Remove old edge and connector graph.remove_edge(edge) edge.dst._in_connectors.remove(edge.dst_conn) # Propagate to each range it belongs to path = [] for mapnode in [map_entry, map_entry_inner]: path.append(mapnode) if any(edge.dst_conn in map(str, symbolic.symlist(r)) for r in mapnode.map.range): graph.add_memlet_path(edge.src, *path, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn) for edge in graph.in_edges(map_exit): graph.remove_edge(edge) graph.add_memlet_path(edge.src, map_exit_inner, map_exit, memlet=edge.data, src_conn=edge.src_conn, dst_conn=edge.dst_conn)
def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): map_entry = self.map_entry map_param = map_entry.map.params[0] # Assuming one dimensional ############################## # Change condition of loop to one fewer iteration (so that the # final one reads from the last buffer) map_rstart, map_rend, map_rstride = map_entry.map.range[0] map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' % (map_rend, map_rstride)) map_entry.map.range = subsets.Range([(map_rstart, map_rend, map_rstride)]) ############################## # Gather transients to modify transients_to_modify = set(edge.dst.data for edge in graph.out_edges(map_entry) if isinstance(edge.dst, nodes.AccessNode)) # Add dimension to transients and modify memlets for transient in transients_to_modify: desc: data.Array = sdfg.arrays[transient] # Using non-python syntax to ensure properties change desc.strides = [desc.total_size] + list(desc.strides) desc.shape = [2] + list(desc.shape) desc.offset = [0] + list(desc.offset) desc.total_size = desc.total_size * 2 ############################## # Modify memlets to use map parameter as buffer index modified_subsets = [] # Store modified memlets for final state for edge in graph.scope_subgraph(map_entry).edges(): if edge.data.data in transients_to_modify: edge.data.subset = self._modify_memlet(sdfg, edge.data.subset, edge.data.data) modified_subsets.append(edge.data.subset) else: # Could be other_subset path = graph.memlet_path(edge) src_node = path[0].src dst_node = path[-1].dst # other_subset could be None. In that case, recreate from array dataname = None if (isinstance(src_node, nodes.AccessNode) and src_node.data in transients_to_modify): dataname = src_node.data elif (isinstance(dst_node, nodes.AccessNode) and dst_node.data in transients_to_modify): dataname = dst_node.data if dataname is not None: subset = (edge.data.other_subset or subsets.Range.from_array(sdfg.arrays[dataname])) edge.data.other_subset = self._modify_memlet( sdfg, subset, dataname) modified_subsets.append(edge.data.other_subset) ############################## # Turn map into for loop map_to_for = MapToForLoop() map_to_for.setup_match( sdfg, self.sdfg_id, self.state_id, {MapToForLoop.map_entry: graph.node_id(self.map_entry)}, self.expr_index) nsdfg_node, nstate = map_to_for.apply(graph, sdfg) ############################## # Gather node copies and remove memlets edges_to_replace = [] for node in nstate.source_nodes(): for edge in nstate.out_edges(node): if (isinstance(edge.dst, nodes.AccessNode) and edge.dst.data in transients_to_modify): edges_to_replace.append(edge) nstate.remove_edge(edge) if nstate.out_degree(node) == 0: nstate.remove_node(node) ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state initial_state.set_label('%s_init' % map_entry.map.label) for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src wnode = initial_state.add_write(edge.dst.data) initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, copy.deepcopy(edge.data)) # All instances of the map parameter in this state become the loop start sd.replace(initial_state, map_param, map_rstart) # Initial writes go to the appropriate buffer init_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' % (map_rstart, map_rstride)) sd.replace(initial_state, '__dace_db_param', init_expr) ############################## # Modify main state's memlets # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) ############################## # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] final_state.set_label('%s_final_computation' % map_entry.map.label) dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # If there is a WCR output with transient, only output in last state nstate: sd.SDFGState for node in nstate.sink_nodes(): for e in list(nstate.in_edges(node)): if e.data.wcr is not None: path = nstate.memlet_path(e) if isinstance(path[0].src, nodes.AccessNode): nstate.remove_memlet_path(e) ############################## # Add reads into next buffers to main state for edge in edges_to_replace: rnode = copy.deepcopy(edge.src) nstate.add_node(rnode) wnode = nstate.add_write(edge.dst.data) new_memlet = copy.deepcopy(edge.data) if new_memlet.data in transients_to_modify: new_memlet.other_subset = self._replace_in_subset( new_memlet.other_subset, map_param, '(%s + %s)' % (map_param, map_rstride)) else: new_memlet.subset = self._replace_in_subset( new_memlet.subset, map_param, '(%s + %s)' % (map_param, map_rstride)) nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) nstate.set_label('%s_double_buffered' % map_entry.map.label) # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) # Remove symbol once done del nsdfg_node.sdfg.symbols['__dace_db_param'] del nsdfg_node.symbol_mapping['__dace_db_param'] return nsdfg_node
def apply(self, sdfg): """ This method applies the mapfusion transformation. Other than the removal of the second map entry node (SME), and the first map exit (FME) node, it has the following side effects: 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. The tasklets that use/produce it shall be connected directly with a scalar/new transient (if the dataflow is more than a single scalar) 2. If this transient is adjacent to FME and SME and has other uses, it will be adjacent to the new map exit post fusion. Tasklet-> Tasklet edges will ALSO be added as mentioned above. 3. If an access node is adjacent to FME but not SME, it will be adjacent to new map exit post fusion. 4. If an access node is adjacent to SME but not FME, it will be adjacent to the new map entry node post fusion. """ graph = sdfg.nodes()[self.state_id] first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]] first_entry = graph.entry_node(first_exit) second_entry = graph.nodes()[self.subgraph[ MapFusion._second_map_entry]] second_exit = graph.exit_nodes(second_entry)[0] intermediate_nodes = set() for _, _, dst, _, _ in graph.out_edges(first_exit): intermediate_nodes.add(dst) assert isinstance(dst, nodes.AccessNode) # Check if an access node refers to non transient memory, or transient # is used at another location (cannot erase) do_not_erase = set() for node in intermediate_nodes: if sdfg.arrays[node.data].transient is False: do_not_erase.add(node) else: # If array is used anywhere else in this state. num_occurrences = len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == node.data ]) if num_occurrences > 1: return False for edge in graph.in_edges(node): if edge.src != first_exit: do_not_erase.add(node) break else: for edge in graph.out_edges(node): if edge.dst != second_entry: do_not_erase.add(node) break # Find permutation between first and second scopes if first_entry.map.params != second_entry.map.params: perm = MapFusion.find_permutation(first_entry.map, second_entry.map) params_dict = {} for _index, _param in enumerate(first_entry.map.params): params_dict[_param] = second_entry.map.params[perm[_index]] # Hopefully replaces (in memlets and tasklet) the second scope map # indices with the permuted first map indices second_scope = graph.scope_subgraph(second_entry) for _firstp, _secondp in params_dict.items(): replace(second_scope, _secondp, _firstp) ########Isolate First MapExit node########### for _edge in graph.in_edges(first_exit): __some_str = _edge.data.data _access_node = graph.find_node(__some_str) # all outputs of first_exit are in intermediate_nodes set, so all inputs to # first_exit should also be! if _access_node not in do_not_erase: _new_dst = None _new_dst_conn = None # look at the second map entry out-edges to get the new destination for _e in graph.out_edges(second_entry): if _e.data.data == _access_node.data: _new_dst = _e.dst _new_dst_conn = _e.dst_conn break if _new_dst is None: # Access node is not even used in the second map graph.remove_node(_access_node) continue if _edge.data.data == _access_node.data and isinstance( _edge._src, nodes.AccessNode): _edge.data.data = _edge._src.data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _new_dst, _new_dst_conn, dcpy(_edge.data), ) else: if _edge.data.subset.num_elements() == 1: # We will add a scalar local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_scalar( local_name, dtype=_access_node.desc(graph).dtype, toplevel=False, transient=True, storage=dtypes.StorageType.Register, ) _edge.data.data = ( local_name) # graph.add_access(local_name).data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _new_dst, _new_dst_conn, dcpy(_edge.data), ) else: # We will add a transient of size = memlet subset # size local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = graph.add_transient( local_name, _edge.data.subset.size(), dtype=_access_node.desc(graph).dtype, toplevel=False, ) _edge.data.data = ( local_name) # graph.add_access(local_name).data _edge.data.subset = ",".join([ "0:" + str(_s) for _s in _edge.data.subset.size() ]) graph.add_edge( _edge._src, _edge.src_conn, local_node, None, dcpy(_edge.data), ) graph.add_edge(local_node, None, _new_dst, _new_dst_conn, dcpy(_edge.data)) graph.remove_edge(_edge) ####Isolate this node##### for _in_e in graph.in_edges(_access_node): graph.remove_edge(_in_e) for _out_e in graph.out_edges(_access_node): graph.remove_edge(_out_e) graph.remove_node(_access_node) else: # _access_node will become an output of the second map exit for _out_e in graph.out_edges(first_exit): if _out_e.data.data == _access_node.data: graph.add_edge( second_exit, None, _out_e._dst, _out_e.dst_conn, dcpy(_out_e.data), ) graph.remove_edge(_out_e) break else: raise AssertionError( "No out-edge was found that leads to {}".format( _access_node)) graph.add_edge(_edge._src, _edge.src_conn, second_exit, None, dcpy(_edge.data)) ### If the second map needs this node then link the connector # that generated this to the place where it is needed, with a # temp transient/scalar for memlet to be generated for _out_e in graph.out_edges(second_entry): if _out_e.data.data == _access_node.data: if _edge.data.subset.num_elements() == 1: # We will add a scalar local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_scalar( local_name, dtype=_access_node.desc(graph).dtype, storage=dtypes.StorageType.Register, toplevel=False, transient=True, ) _edge.data.data = ( local_name ) # graph.add_access(local_name).data _edge.data.subset = "0" graph.add_edge( _edge._src, _edge.src_conn, _out_e._dst, _out_e.dst_conn, dcpy(_edge.data), ) else: # We will add a transient of size = memlet subset # size local_name = "__s%d_n%d%s_n%d%s" % ( self.state_id, graph.node_id(_edge._src), _edge.src_conn, graph.node_id(_edge._dst), _edge.dst_conn, ) local_node = sdfg.add_transient( local_name, _edge.data.subset.size(), dtype=_access_node.desc(graph).dtype, toplevel=False, ) _edge.data.data = ( local_name ) # graph.add_access(local_name).data _edge.data.subset = ",".join([ "0:" + str(_s) for _s in _edge.data.subset.size() ]) graph.add_edge( _edge._src, _edge.src_conn, local_node, None, dcpy(_edge.data), ) graph.add_edge( local_node, None, _out_e._dst, _out_e.dst_conn, dcpy(_edge.data), ) break graph.remove_edge(_edge) graph.remove_node(first_exit) # Take a leap of faith #############Isolate second_entry node################ for _edge in graph.in_edges(second_entry): _access_node = graph.find_node(_edge.data.data) if _access_node in intermediate_nodes: # Already handled above, just remove this graph.remove_edge(_edge) continue else: # This is an external input to the second map which will now go through the first # map. graph.add_edge(_edge._src, _edge.src_conn, first_entry, None, dcpy(_edge.data)) graph.remove_edge(_edge) for _out_e in graph.out_edges(second_entry): if _out_e.data.data == _access_node.data: graph.add_edge( first_entry, None, _out_e._dst, _out_e.dst_conn, dcpy(_out_e.data), ) graph.remove_edge(_out_e) break else: raise AssertionError( "No out-edge was found that leads to {}".format( _access_node)) graph.remove_node(second_entry) # Fix scope exit second_exit.map = first_entry.map graph.fill_scope_connectors()