def apply(self, sdfg, *args, **kwargs): state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[type(self)._match_node]] expansion = type(self).expansion(node, state, sdfg, *args, **kwargs) if isinstance(expansion, SDFG): expansion = state.add_nested_sdfg(expansion, sdfg, node.in_connectors, node.out_connectors, name=node.name, schedule=node.schedule, debuginfo=node.debuginfo) elif isinstance(expansion, nd.CodeNode): expansion.debuginfo = node.debuginfo if isinstance(expansion, nd.NestedSDFG): # Fix parent references nsdfg = expansion.sdfg nsdfg.parent = state nsdfg.parent_sdfg = sdfg nsdfg.update_sdfg_list([]) nsdfg.parent_nsdfg_node = expansion # Update schedule to match library node schedule nsdfg.schedule = node.schedule elif isinstance(expansion, (nd.EntryNode, nd.LibraryNode)): if expansion.schedule is ScheduleType.Default: expansion.schedule = node.schedule else: raise TypeError("Node expansion must be a CodeNode or an SDFG") # Fix nested schedules if isinstance(expansion, nd.NestedSDFG): infer_types._set_default_schedule_types(expansion.sdfg, expansion.schedule, True) expansion.environments = copy.copy( set(map(lambda a: a.full_class_path(), type(self).environments))) sdutil.change_edge_dest(state, node, expansion) sdutil.change_edge_src(state, node, expansion) state.remove_node(node) type(self).postprocessing(sdfg, state, expansion)
def fix_sdfg(sdfg, graph): # fix sdfg as for now the SDFG gets parsed wrongly for node in graph.nodes(): if isinstance(node, dace.sdfg.nodes.NestedSDFG): nested_original = node for edge in itertools.chain(graph.in_edges(node), graph.out_edges(node)): for e in graph.memlet_tree(edge): if 'z' in e.data.subset.free_symbols: new_subset = str(e.data.subset) new_subset = new_subset.replace('z', '0:O') e.data.subset = subsets.Range.from_string(new_subset) # next up replace sdfg inner_sdfg = helper_sdfg.to_sdfg() nnode = graph.add_nested_sdfg(inner_sdfg, sdfg, {'AA', 'BB', 'CC'}, {'CC'}) # redirect edges connectors = [] for e in graph.in_edges(nested_original): connectors.append(e.dst_conn) connectors.sort() for e in graph.in_edges(nested_original): if e.dst_conn == connectors[0]: graph.add_edge(e.src, e.src_conn, e.dst, 'AA', e.data) graph.remove_edge(e) if e.dst_conn == connectors[1]: graph.add_edge(e.src, e.src_conn, e.dst, 'BB', e.data) graph.remove_edge(e) if e.dst_conn == connectors[2]: graph.add_edge(e.src, e.src_conn, e.dst, 'CC', e.data) graph.remove_edge(e) e = graph.out_edges(nested_original)[0] graph.add_edge(e.src, 'CC', e.dst, e.dst_conn, e.data) graph.remove_edge(e) utils.change_edge_dest(graph, nested_original, nnode) utils.change_edge_src(graph, nested_original, nnode) graph.remove_node(nested_original) sdfg.validate()
def apply(self, sdfg, *args, **kwargs): state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[type(self)._match_node]] expansion = type(self).expansion(node, state, sdfg, *args, **kwargs) if isinstance(expansion, dace.SDFG): # Modify internal schedules according to node schedule if node.schedule != ScheduleType.Default: for nstate in expansion.nodes(): topnodes = nstate.scope_dict(node_to_children=True)[None] for topnode in topnodes: if isinstance(topnode, (nd.EntryNode, nd.LibraryNode)): topnode.schedule = node.schedule expansion = state.add_nested_sdfg(expansion, sdfg, node.in_connectors, node.out_connectors, name=node.name, debuginfo=node.debuginfo) elif isinstance(expansion, dace.sdfg.nodes.CodeNode): expansion.debuginfo = node.debuginfo if isinstance(expansion, dace.sdfg.nodes.NestedSDFG): # Fix parent references nsdfg = expansion.sdfg nsdfg.parent = state nsdfg.parent_sdfg = sdfg nsdfg.update_sdfg_list([]) nsdfg.parent_nsdfg_node = expansion else: raise TypeError("Node expansion must be a CodeNode or an SDFG") expansion.environments = copy.copy( set(map(lambda a: a.__name__, type(self).environments))) sdutil.change_edge_dest(state, node, expansion) sdutil.change_edge_src(state, node, expansion) state.remove_node(node) type(self).postprocessing(sdfg, state, expansion)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = sdutil.find_source_nodes(state) output_nodes = sdutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_node = node_trace break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.sdfg.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.sdfg.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def _stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exit = graph.exit_node(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx target_dim = map_entry.map.params[dim_idx] if self.tiling_type == 'ceilrange': new_dim, new_map, td_rng = self._create_ceil_range( sdfg, graph, map_entry) elif self.tiling_type == 'number_of_tiles': new_dim, new_map, td_rng = self._create_from_tile_numbers( sdfg, graph, map_entry) else: new_dim, new_map, td_rng = self._create_strided_range( sdfg, graph, map_entry) new_map_entry = nodes.MapEntry(new_map) new_map_exit = nodes.MapExit(new_map) td_to_new_approx = td_rng[1] if isinstance(td_to_new_approx, dace.symbolic.SymExpr): td_to_new_approx = td_to_new_approx.approx # Special case: If range is 1 and no prefix was specified, skip range if td_rng[0] == td_to_new_approx and target_dim == new_dim: map_entry.map.range = subsets.Range( [r for i, r in enumerate(map_entry.map.range) if i != dim_idx]) map_entry.map.params = [ p for i, p in enumerate(map_entry.map.params) if i != dim_idx ] if len(map_entry.map.params) == 0: raise ValueError('Strip-mining all dimensions of the map with ' 'empty tiles is disallowed') else: map_entry.map.range[dim_idx] = td_rng # Make internal map's schedule to "not parallel" new_map.schedule = map_entry.map.schedule map_entry.map.schedule = dtypes.ScheduleType.Sequential # Redirect edges new_map_entry.in_connectors = dcpy(map_entry.in_connectors) sdutil.change_edge_dest(graph, map_entry, new_map_entry) new_map_exit.out_connectors = dcpy(map_exit.out_connectors) sdutil.change_edge_src(graph, map_exit, new_map_exit) # Create new entry edges new_in_edges = dict() entry_in_conn = {} entry_out_conn = {} for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry): if (src_conn is not None and src_conn[:4] == 'OUT_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = src_conn[4:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_in_edges.keys(): old_subset = new_in_edges[key].subset new_in_edges[key].subset = calc_set_union( old_subset, new_subset) else: entry_in_conn['IN_' + conn] = None entry_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements() new_in_edges[key] = new_memlet else: if src_conn is not None and src_conn[:4] == 'OUT_': conn = src_conn[4:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: entry_in_conn[in_conn] = None if out_conn: entry_out_conn[out_conn] = None new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_entry.out_connectors = entry_out_conn map_entry.in_connectors = entry_in_conn for (_, in_conn, out_conn), memlet in new_in_edges.items(): graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet) # Create new exit edges new_out_edges = dict() exit_in_conn = {} exit_out_conn = {} for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit): if (dst_conn is not None and dst_conn[:3] == 'IN_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = dst_conn[3:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_out_edges.keys(): old_subset = new_out_edges[key].subset new_out_edges[key].subset = calc_set_union( old_subset, new_subset) else: exit_in_conn['IN_' + conn] = None exit_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements() new_out_edges[key] = new_memlet else: if dst_conn is not None and dst_conn[:3] == 'IN_': conn = dst_conn[3:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = dst_conn out_conn = dst_conn if in_conn: exit_in_conn[in_conn] = None if out_conn: exit_out_conn[out_conn] = None new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_exit.in_connectors = exit_in_conn map_exit.out_connectors = exit_out_conn for (_, in_conn, out_conn), memlet in new_out_edges.items(): graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet) # Skew if necessary if self.skew: xfh.offset_map(sdfg, graph, map_entry, dim_idx, td_rng[0]) # Return strip-mined dimension. return target_dim, new_dim, new_map
def _expand_reduce(self, sdfg, state, node): # expands a reduce into two nested maps # taken from legacy expand_reduce.py node.validate(sdfg, state) inedge: graph.MultiConnectorEdge = state.in_edges(node)[0] outedge: graph.MultiConnectorEdge = state.out_edges(node)[0] input_dims = len(inedge.data.subset) output_dims = len(outedge.data.subset) input_data = sdfg.arrays[inedge.data.data] output_data = sdfg.arrays[outedge.data.data] # Standardize axes axes = node.axes if node.axes else [i for i in range(input_dims)] # Create nested SDFG nsdfg = SDFG('reduce') nsdfg.add_array('_in', inedge.data.subset.size(), input_data.dtype, strides=input_data.strides, storage=input_data.storage) nsdfg.add_array('_out', outedge.data.subset.size(), output_data.dtype, strides=output_data.strides, storage=output_data.storage) if node.identity is not None: raise ValueError("Node identity has to be None at this point.") else: nstate = nsdfg.add_state() # END OF INIT # (If axes != all) Add outer map, which corresponds to the output range if len(axes) != input_dims: # Interleave input and output axes to match input memlet ictr, octr = 0, 0 input_subset = [] for i in range(input_dims): if i in axes: input_subset.append('_i%d' % ictr) ictr += 1 else: input_subset.append('_o%d' % octr) octr += 1 output_size = outedge.data.subset.size() ome, omx = nstate.add_map( 'reduce_output', { '_o%d' % i: '0:%s' % symstr(sz) for i, sz in enumerate(outedge.data.subset.size()) }) outm = Memlet.simple('_out', ','.join( ['_o%d' % i for i in range(output_dims)]), wcr_str=node.wcr) inmm = Memlet.simple('_in', ','.join(input_subset)) else: ome, omx = None, None outm = Memlet.simple('_out', '0', wcr_str=node.wcr) inmm = Memlet.simple( '_in', ','.join(['_i%d' % i for i in range(len(axes))])) # Add inner map, which corresponds to the range to reduce, containing # an identity tasklet ime, imx = nstate.add_map( 'reduce_values', { '_i%d' % i: '0:%s' % symstr(inedge.data.subset.size()[axis]) for i, axis in enumerate(sorted(axes)) }) # Add identity tasklet for reduction t = nstate.add_tasklet('identity', {'inp'}, {'out'}, 'out = inp') # Connect everything r = nstate.add_read('_in') w = nstate.add_read('_out') if ome: nstate.add_memlet_path(r, ome, ime, t, dst_conn='inp', memlet=inmm) nstate.add_memlet_path(t, imx, omx, w, src_conn='out', memlet=outm) else: nstate.add_memlet_path(r, ime, t, dst_conn='inp', memlet=inmm) nstate.add_memlet_path(t, imx, w, src_conn='out', memlet=outm) # Rename outer connectors and add to node inedge._dst_conn = '_in' outedge._src_conn = '_out' node.add_in_connector('_in') node.add_out_connector('_out') nsdfg = state.add_nested_sdfg(nsdfg, sdfg, node.in_connectors, node.out_connectors, schedule=node.schedule, name=node.name) utils.change_edge_dest(state, node, nsdfg) utils.change_edge_src(state, node, nsdfg) state.remove_node(node) return nsdfg
def apply(self, _, sdfg): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): if isinstance(self.subgraph[StateFusion.first_state], SDFGState): first_state: SDFGState = self.subgraph[StateFusion.first_state] second_state: SDFGState = self.subgraph[StateFusion.second_state] else: first_state: SDFGState = sdfg.node( self.subgraph[StateFusion.first_state]) second_state: SDFGState = sdfg.node( self.subgraph[StateFusion.second_state]) # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) if sdfg.start_state == first_state: sdfg.start_state = sdfg.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] top2 = top_level_nodes(second_state) # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.data == node.data), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes sdict = first_state.scope_dict() order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and sdict[x] is None ] for node in second_state.nodes(): if isinstance(node, nodes.NestedSDFG): # update parent information node.sdfg.parent = first_state first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) top = top_level_nodes(first_state) # Merge common (data) nodes for node in second_input: # merge only top level nodes, skip everything else if node not in top2: continue if first_state.in_degree(node) == 0: candidates = [ x for x in order if x.data == node.data and x in top ] if len(candidates) == 0: continue elif len(candidates) == 1: n = candidates[0] else: # Choose first candidate that intersects memlets for cand in candidates: if StateFusion.memlets_intersect( first_state, [cand], False, second_state, [node], True): n = cand break else: # No node intersects, use topologically-last node n = candidates[0] sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state)
def _stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exit = graph.exit_node(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided tile_stride = self.tile_stride if tile_stride is None or len(tile_stride) == 0: tile_stride = tile_size # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning map object? new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix, target_dim) nd_from = 0 if symbolic.pystr_to_symbolic(tile_stride) == 1: nd_to = td_to else: nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride)) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) new_map_entry = nodes.MapEntry(new_map) new_map_exit = nodes.MapExit(new_map) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = symbolic.pystr_to_symbolic(tile_size) else: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), str(new_dim), tile_stride)) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) # Special case: If range is 1 and no prefix was specified, skip range if td_from_new == td_to_new_approx and target_dim == new_dim: map_entry.map.range = subsets.Range( [r for i, r in enumerate(map_entry.map.range) if i != dim_idx]) map_entry.map.params = [ p for i, p in enumerate(map_entry.map.params) if i != dim_idx ] if len(map_entry.map.params) == 0: raise ValueError('Strip-mining all dimensions of the map with ' 'empty tiles is disallowed') else: map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Make internal map's schedule to "not parallel" new_map.schedule = map_entry.map.schedule map_entry.map.schedule = dtypes.ScheduleType.Sequential # Redirect edges new_map_entry.in_connectors = dcpy(map_entry.in_connectors) sdutil.change_edge_dest(graph, map_entry, new_map_entry) new_map_exit.out_connectors = dcpy(map_exit.out_connectors) sdutil.change_edge_src(graph, map_exit, new_map_exit) # Create new entry edges new_in_edges = dict() entry_in_conn = {} entry_out_conn = {} for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry): if (src_conn is not None and src_conn[:4] == 'OUT_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = src_conn[4:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_in_edges.keys(): old_subset = new_in_edges[key].subset new_in_edges[key].subset = calc_set_union( old_subset, new_subset) else: entry_in_conn['IN_' + conn] = None entry_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements() new_in_edges[key] = new_memlet else: if src_conn is not None and src_conn[:4] == 'OUT_': conn = src_conn[4:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: entry_in_conn[in_conn] = None if out_conn: entry_out_conn[out_conn] = None new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_entry.out_connectors = entry_out_conn map_entry.in_connectors = entry_in_conn for (_, in_conn, out_conn), memlet in new_in_edges.items(): graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet) # Create new exit edges new_out_edges = dict() exit_in_conn = {} exit_out_conn = {} for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit): if (dst_conn is not None and dst_conn[:3] == 'IN_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = dst_conn[3:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_out_edges.keys(): old_subset = new_out_edges[key].subset new_out_edges[key].subset = calc_set_union( old_subset, new_subset) else: exit_in_conn['IN_' + conn] = None exit_out_conn['OUT_' + conn] = None new_memlet = dcpy(memlet) new_memlet.subset = new_subset if memlet.dynamic: new_memlet.num_accesses = memlet.num_accesses else: new_memlet.num_accesses = new_memlet.num_elements() new_out_edges[key] = new_memlet else: if dst_conn is not None and dst_conn[:3] == 'IN_': conn = dst_conn[3:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: exit_in_conn[in_conn] = None if out_conn: exit_out_conn[out_conn] = None new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_exit.in_connectors = exit_in_conn map_exit.out_connectors = exit_out_conn for (_, in_conn, out_conn), memlet in new_out_edges.items(): graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map