def apply(self, sdfg: SDFG): state = sdfg.node(self.state_id) map_entry = self.map_entry(sdfg) map_exit = state.exit_node(map_entry) current_map = map_entry.map # Expand the innermost map if multidimensional if len(current_map.params) > 1: ext, rem = dace.transformation.helpers.extract_map_dims( sdfg, map_entry, list(range(len(current_map.params) - 1))) map_entry = rem map_exit = state.exit_node(map_entry) current_map = map_entry.map subgraph = state.scope_subgraph(map_entry) subgraph_contents = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) # Set the schedule current_map.schedule = dace.dtypes.ScheduleType.SVE_Map # Infer all connector types and apply them inferred = infer_types.infer_connector_types(sdfg, state, subgraph) infer_types.apply_connector_types(inferred) # Infer vector connectors and AccessNodes and apply them vector_inference.infer_vectors( sdfg, state, map_entry, util.SVE_LEN, flags=vector_inference.VectorInferenceFlags.Allow_Stride, apply=True)
def apply(self, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg) state = sdfg.node(self.state_id) new_state = sdfg.add_state_before(state) isedge = sdfg.edges_between(new_state, state)[0] # Find relevant symbol mapping mapping: Dict[str, str] = {} mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()}) mapping.update({ k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.in_connectors }) nisedge = nsdfg.sdfg.edges()[0] # Safe replacement of edge contents for k, v in mapping.items(): nisedge.data.replace(k, '__dacesym_' + k, replace_keys=False) for k, v in mapping.items(): nisedge.data.replace('__dacesym_' + k, v, replace_keys=False) for akey, aval in nisedge.data.assignments.items(): # Map assignment to outer edge if akey not in sdfg.symbols and akey not in sdfg.arrays: newname = akey else: newname = nsdfg.label + '_' + akey isedge.data.assignments[newname] = aval # Add symbol to outer SDFG sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey]) # Add symbol mapping to nested SDFG nsdfg.symbol_mapping[akey] = newname isedge.data.condition = nisedge.data.condition # Clean nested SDFG nsdfg.sdfg.remove_node(nisedge.src)
def apply(self, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg) state = sdfg.node(self.state_id) new_state = sdfg.add_state_before(state) isedge = sdfg.edges_between(new_state, state)[0] # Find relevant symbol and data descriptor mapping mapping: Dict[str, str] = {} mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()}) mapping.update({ k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.in_connectors }) mapping.update({ k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.out_connectors }) # Get internal state and interstate edge source_state = nsdfg.sdfg.start_state nisedge = nsdfg.sdfg.out_edges(source_state)[0] # Add state contents (nodes) new_state.add_nodes_from(source_state.nodes()) # Replace data descriptors and symbols on state graph for node in source_state.nodes(): if isinstance(node, nodes.AccessNode) and node.data in mapping: node.data = mapping[node.data] for edge in source_state.edges(): edge.data.replace(mapping) if edge.data.data in mapping: edge.data.data = mapping[edge.data.data] # Add state contents (edges) for edge in source_state.edges(): new_state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Safe replacement of edge contents def replfunc(m): for k, v in mapping.items(): nisedge.data.replace(k, v, replace_keys=False) symbolic.safe_replace(mapping, replfunc) # Add interstate edge for akey, aval in nisedge.data.assignments.items(): # Map assignment to outer edge if akey not in sdfg.symbols and akey not in sdfg.arrays: newname = akey else: newname = nsdfg.label + '_' + akey isedge.data.assignments[newname] = aval # Add symbol to outer SDFG sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey]) # Add symbol mapping to nested SDFG nsdfg.symbol_mapping[akey] = newname isedge.data.condition = nisedge.data.condition # Clean nested SDFG nsdfg.sdfg.remove_node(source_state) # Set new starting state nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)
def apply(self, sdfg): if isinstance(self.subgraph[StateFusion.first_state], SDFGState): first_state: SDFGState = self.subgraph[StateFusion.first_state] second_state: SDFGState = self.subgraph[StateFusion.second_state] else: first_state: SDFGState = sdfg.node( self.subgraph[StateFusion.first_state]) second_state: SDFGState = sdfg.node( self.subgraph[StateFusion.second_state]) # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) if sdfg.start_state == first_state: sdfg.start_state = sdfg.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] top2 = top_level_nodes(second_state) # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.data == node.data), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes sdict = first_state.scope_dict() order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and sdict[x] is None ] for node in second_state.nodes(): if isinstance(node, nodes.NestedSDFG): # update parent information node.sdfg.parent = first_state first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) top = top_level_nodes(first_state) # Merge common (data) nodes for node in second_input: # merge only top level nodes, skip everything else if node not in top2: continue if first_state.in_degree(node) == 0: candidates = [ x for x in order if x.data == node.data and x in top ] if len(candidates) == 0: continue elif len(candidates) == 1: n = candidates[0] else: # Choose first candidate that intersects memlets for cand in candidates: if StateFusion.memlets_intersect( first_state, [cand], False, second_state, [node], True): n = cand break else: # No node intersects, use topologically-last node n = candidates[0] sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state)