def __stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[OrthogonalTiling._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] # Map subgraph map_subgraph = graph.scope_subgraph(map_entry) # Retrieve transformation properties. prefix = self.prefix tile_sizes = self.tile_sizes divides_evenly = self.divides_evenly new_param = [] new_range = [] for dim_idx in range(len(map_entry.map.params)): if dim_idx >= len(tile_sizes): tile_size = tile_sizes[-1] else: tile_size = tile_sizes[dim_idx] # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] new_dim = prefix + '_' + target_dim # Basic values if divides_evenly: tile_num = '(%s + 1 - %s) / %s' % (symbolic.symstr(td_to), symbolic.symstr(td_from), str(tile_size)) else: tile_num = 'int_ceil((%s + 1 - %s), %s)' % (symbolic.symstr( td_to), symbolic.symstr(td_from), str(tile_size)) # Outer map values (over all tiles) nd_from = 0 nd_to = symbolic.pystr_to_symbolic(str(tile_num) + ' - 1') nd_step = 1 # Inner map values (over one tile) td_from_new = dace.symbolic.pystr_to_symbolic(td_from) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1 - %s * %s, %s + %s) - 1' % (symbolic.symstr(td_to), str(new_dim), str(tile_size), td_from_new, str(tile_size))) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s - 1' % (td_from_new, str(tile_size))) # Outer map (over all tiles) new_dim_range = (nd_from, nd_to, nd_step) new_param.append(new_dim) new_range.append(new_dim_range) # Inner map (over one tile) if divides_evenly: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Fix subgraph memlets target_dim = dace.symbolic.pystr_to_symbolic(target_dim) offset = dace.symbolic.pystr_to_symbolic( '%s * %s' % (new_dim, str(tile_size))) for _, _, _, _, memlet in map_subgraph.edges(): old_subset = memlet.subset if isinstance(old_subset, dace.subsets.Indices): new_indices = [] for idx in old_subset: new_idx = idx.subs(target_dim, target_dim + offset) new_indices.append(new_idx) memlet.subset = dace.subsets.Indices(new_indices) elif isinstance(old_subset, dace.subsets.Range): new_ranges = [] for i, old_range in enumerate(old_subset): if len(old_range) == 3: b, e, s, = old_range t = old_subset.tile_sizes[i] else: raise ValueError( 'Range %s is invalid.' % old_range) new_b = b.subs(target_dim, target_dim + offset) new_e = e.subs(target_dim, target_dim + offset) new_s = s.subs(target_dim, target_dim + offset) new_t = t.subs(target_dim, target_dim + offset) new_ranges.append((new_b, new_e, new_s, new_t)) memlet.subset = dace.subsets.Range(new_ranges) else: raise NotImplementedError new_map = nodes.Map(prefix + '_' + map_entry.map.label, new_param, subsets.Range(new_range)) new_map_entry = nodes.MapEntry(new_map) new_exit = nodes.MapExit(new_map) # Make internal map's schedule to "not parallel" map_entry.map._schedule = dtypes.ScheduleType.Default # Redirect/create edges. new_in_edges = {} for _src, conn, _dest, _, memlet in graph.out_edges(map_entry): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = copy.deepcopy(memlet.subset) # new_subset = calc_set_image(map_entry.map.params, # map_entry.map.range, memlet.subset, # cont_or_strided) if memlet.data in new_in_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_in_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.data, sdfg.arrays[nnew_memlet.data], new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, int(conn[4:]))) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (new_map_entry, None, map_entry, None, new_memlet, int(conn[4:])) }) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_out_edges = {} for _src, conn, _dest, _, memlet in graph.in_edges(map_exit): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = memlet.subset # new_subset = calc_set_image(map_entry.map.params, # map_entry.map.range, # memlet.subset, cont_or_strided) if memlet.data in new_out_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_out_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.data, sdfg.arrays[nnew_memlet.data], new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, conn[4:])) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (map_exit, None, new_exit, None, new_memlet, conn[4:]) }) nxutil.change_edge_src(graph, map_exit, new_exit) # Connector related work follows # 1. Dictionary 'old_connector_number': 'new_connector_numer' # 2. New node in/out connectors # 3. New edges in_conn_nums = [] for _, e in new_in_edges.items(): _, _, _, _, _, num = e in_conn_nums.append(num) in_conn = {} for i, num in enumerate(in_conn_nums): in_conn.update({num: i + 1}) entry_in_connectors = set() entry_out_connectors = set() for i in range(len(in_conn_nums)): entry_in_connectors.add('IN_' + str(i + 1)) entry_out_connectors.add('OUT_' + str(i + 1)) new_map_entry.in_connectors = entry_in_connectors new_map_entry.out_connectors = entry_out_connectors for _, e in new_in_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst, 'IN_' + str(in_conn[num]), memlet) out_conn_nums = [] for _, e in new_out_edges.items(): _, _, dst, _, _, num = e if dst is not new_exit: continue out_conn_nums.append(num) out_conn = {} for i, num in enumerate(out_conn_nums): out_conn.update({num: i + 1}) exit_in_connectors = set() exit_out_connectors = set() for i in range(len(out_conn_nums)): exit_in_connectors.add('IN_' + str(i + 1)) exit_out_connectors.add('OUT_' + str(i + 1)) new_exit.in_connectors = exit_in_connectors new_exit.out_connectors = exit_out_connectors for _, e in new_out_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst, 'IN_' + str(out_conn[num]), memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
def apply(self, sdfg): def gnode(nname): return graph.nodes()[self.subgraph[nname]] expr_index = self.expr_index graph = sdfg.nodes()[self.state_id] tasklet = gnode(MapReduceFusion._tasklet) tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]] in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]] if expr_index == 0: # Reduce without outer map rmap_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] elif expr_index == 1: # Reduce with outer map rmap_out_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_entry]] rmap_out_exit = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_out_exit]] rmap_in_entry = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_entry]] rmap_tasklet = graph.nodes()[self.subgraph[ MapReduceFusion._rmap_in_tasklet]] if expr_index == 2: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]] else: rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]] out_array = gnode(MapReduceFusion._out_array) # Set nodes to remove according to the expression index nodes_to_remove = [in_array] if expr_index == 0: nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) elif expr_index == 1: nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry)) nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit)) else: nodes_to_remove.append(gnode(MapReduceFusion._reduce)) # If no other edges lead to mapexit, remove it. Otherwise, keep # it and remove reduction incoming/outgoing edges if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1: nodes_to_remove.append(tmap_exit) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') if expr_index == 0: # Reduce without outer map # Index order does not matter, merge as-is pass elif expr_index == 1: # Reduce with outer map tmap = tmap_exit.map perm_outer, perm_inner = MapReduceFusion.find_permutation( tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data) # Split tasklet map into tmap_out -> tmap_in (according to # reduction) omap = nodes.Map( tmap.label + '_nonreduce', [p for i, p in enumerate(tmap.params) if i in perm_outer], [r for i, r in enumerate(tmap.range) if i in perm_outer], tmap.schedule, tmap.unroll, tmap.is_async) tmap.params = [ p for i, p in enumerate(tmap.params) if i in perm_inner ] tmap.range = [ r for i, r in enumerate(tmap.range) if i in perm_inner ] omap_entry = nodes.MapEntry(omap) omap_exit = rmap_out_exit rmap_out_exit.map = omap # Reconnect graph to new map tmap_entry = graph.entry_node(tmap_exit) tmap_in_edges = list(graph.in_edges(tmap_entry)) for e in tmap_in_edges: nxutil.change_edge_dest(graph, tmap_entry, omap_entry) for e in tmap_in_edges: graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn, copy.copy(e.data)) elif expr_index == 2: # Reduce node # Find correspondence between map indices and array outputs tmap = tmap_exit.map perm = MapReduceFusion.find_permutation_reduce( tmap, rmap_cr, graph, memlet_edge.data) output_subset = [tmap.params[d] for d in perm] if len(output_subset) == 0: # Output is a scalar output_subset = [0] array_edge = graph.out_edges(rmap_cr)[0] # Delete relevant edges and nodes graph.remove_edge(memlet_edge) graph.remove_nodes_from(nodes_to_remove) # Add new edges and nodes # From tasklet to map exit graph.add_edge( memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst, memlet_edge.dst_conn, Memlet(out_array.data, memlet_edge.data.num_accesses, subsets.Indices(output_subset), memlet_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) # From map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet(array_edge.data.data, array_edge.data.num_accesses, array_edge.data.subset, array_edge.data.veclen, rmap_cr.wcr, rmap_cr.identity)) return # Remove tmp array node prior to the others, so that a new one # can be created in its stead (see below) graph.remove_node(nodes_to_remove[0]) nodes_to_remove = nodes_to_remove[1:] # Create tasklet -> tmp -> tasklet connection tmp = graph.add_array( 'tmp', memlet_edge.data.subset.bounding_box_size(), sdfg.arrays[memlet_edge.data.data].dtype, transient=True) tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data) tasklet_tmp_memlet.data = tmp.data tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape) # Modify memlet to point to output array memlet_edge.data.data = out_array.data # Recover reduction axes from CR reduce subset reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset reduce_axes = [] for ind, crvar in enumerate(reduce_cr_subset.indices): if '__i' in str(crvar): reduce_axes.append(ind) # Modify memlet access index by filtering out reduction axes if True: # expr_index == 0: newindices = [] for ind, ovar in enumerate(memlet_edge.data.subset.indices): if ind not in reduce_axes: newindices.append(ovar) if len(newindices) == 0: newindices = [0] memlet_edge.data.subset = subsets.Indices(newindices) graph.remove_edge(memlet_edge) graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp, memlet_edge.dst_conn, tasklet_tmp_memlet) red_edges = list(graph.in_edges(rmap_tasklet)) if len(red_edges) != 1: raise RuntimeError('CR edge must be unique') tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet) graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn, tmp_tasklet_memlet) for e in graph.edges_between(rmap_tasklet, rmap_cr): e.data.subset = memlet_edge.data.subset # Move output edges to point directly to CR node if expr_index == 1: # Set output memlet between CR node and outer reduction map to # contain the same subset as the one pointing to the CR node for e in graph.out_edges(rmap_cr): e.data.subset = memlet_edge.data.subset rmap_out = gnode(MapReduceFusion._rmap_out_exit) nxutil.change_edge_src(graph, rmap_out, omap_exit) # Remove nodes graph.remove_nodes_from(nodes_to_remove) # For unrelated outputs, connect original output to rmap_out if expr_index == 1 and tmap_exit not in nodes_to_remove: other_out_edges = list(graph.out_edges(tmap_exit)) for e in other_out_edges: graph.remove_edge(e) graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data) graph.add_edge(omap_exit, None, e.dst, e.dst_conn, copy.copy(e.data))
def _stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exit = graph.exit_nodes(map_entry)[0] # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided tile_stride = self.tile_stride if tile_stride is None or len(tile_stride) == 0: tile_stride = tile_size # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning??? new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix, target_dim) nd_from = 0 nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride)) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) new_map_entry = nodes.MapEntry(new_map) new_map_exit = nodes.MapExit(new_map) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = symbolic.pystr_to_symbolic(tile_size) else: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), str(new_dim), tile_stride)) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), tile_stride, str(new_dim), tile_size)) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Make internal map's schedule to "not parallel" new_map.schedule = map_entry.map.schedule map_entry.map.schedule = dtypes.ScheduleType.Sequential # Redirect edges new_map_entry.in_connectors = dcpy(map_entry.in_connectors) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_map_exit.out_connectors = dcpy(map_exit.out_connectors) nxutil.change_edge_src(graph, map_exit, new_map_exit) # Create new entry edges new_in_edges = dict() entry_in_conn = set() entry_out_conn = set() for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry): if (src_conn is not None and src_conn[:4] == 'OUT_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = src_conn[4:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_in_edges.keys(): old_subset = new_in_edges[key].subset new_in_edges[key].subset = calc_set_union( old_subset, new_subset) else: entry_in_conn.add('IN_' + conn) entry_out_conn.add('OUT_' + conn) new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges[key] = new_memlet else: if src_conn is not None and src_conn[:4] == 'OUT_': conn = src_conn[4:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: entry_in_conn.add(in_conn) if out_conn: entry_out_conn.add(out_conn) new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_entry.out_connectors = entry_out_conn map_entry.in_connectors = entry_in_conn for (_, in_conn, out_conn), memlet in new_in_edges.items(): graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet) # Create new exit edges new_out_edges = dict() exit_in_conn = set() exit_out_conn = set() for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit): if (dst_conn is not None and dst_conn[:3] == 'IN_' and not isinstance( sdfg.arrays[memlet.data], dace.data.Scalar)): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) conn = dst_conn[3:] key = (memlet.data, 'IN_' + conn, 'OUT_' + conn) if key in new_out_edges.keys(): old_subset = new_out_edges[key].subset new_out_edges[key].subset = calc_set_union( old_subset, new_subset) else: exit_in_conn.add('IN_' + conn) exit_out_conn.add('OUT_' + conn) new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges[key] = new_memlet else: if dst_conn is not None and dst_conn[:3] == 'IN_': conn = dst_conn[3:] in_conn = 'IN_' + conn out_conn = 'OUT_' + conn else: in_conn = src_conn out_conn = src_conn if in_conn: exit_in_conn.add(in_conn) if out_conn: exit_out_conn.add(out_conn) new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet) new_map_exit.in_connectors = exit_in_conn map_exit.out_connectors = exit_out_conn for (_, in_conn, out_conn), memlet in new_out_edges.items(): graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_name = node_trace.data break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_name) wcr_input_nodes.add(outer_name) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) node.access = dtypes.AccessType.ReadWrite for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) new_node.access = dtypes.AccessType.ReadWrite # Check if any input nodes of the second state have to be merged with # non-input/output nodes of the first state. for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: nxutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def __stripmine(self, sdfg, graph, candidate): # Retrieve map entry and exit nodes. map_entry = graph.nodes()[candidate[StripMining._map_entry]] map_exits = graph.exit_nodes(map_entry) # Retrieve transformation properties. dim_idx = self.dim_idx new_dim_prefix = self.new_dim_prefix tile_size = self.tile_size divides_evenly = self.divides_evenly strided = self.strided # Retrieve parameter and range of dimension to be strip-mined. target_dim = map_entry.map.params[dim_idx] td_from, td_to, td_step = map_entry.map.range[dim_idx] # Create new map. Replace by cloning??? new_dim = new_dim_prefix + '_' + target_dim nd_from = 0 nd_to = symbolic.pystr_to_symbolic( 'int_ceil(%s + 1 - %s, %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size)) nd_step = 1 new_dim_range = (nd_from, nd_to, nd_step) new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim], subsets.Range([new_dim_range])) new_map_entry = nodes.MapEntry(new_map) # Change the range of the selected dimension to iterate over a single # tile if strided: td_from_new = symbolic.pystr_to_symbolic(new_dim) td_to_new_approx = td_to td_step = symbolic.pystr_to_symbolic(tile_size) else: td_from_new = symbolic.pystr_to_symbolic( '%s + %s * %s' % (symbolic.symstr(td_from), str(new_dim), tile_size)) td_to_new_exact = symbolic.pystr_to_symbolic( 'min(%s + 1, %s + %s * %s + %s) - 1' % (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_size, str(new_dim), tile_size)) td_to_new_approx = symbolic.pystr_to_symbolic( '%s + %s * %s + %s - 1' % (symbolic.symstr(td_from), tile_size, str(new_dim), tile_size)) if divides_evenly or strided: td_to_new = td_to_new_approx else: td_to_new = dace.symbolic.SymExpr(td_to_new_exact, td_to_new_approx) map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step) # Make internal map's schedule to "not parallel" map_entry.map._schedule = dtypes.ScheduleType.Default # Redirect/create edges. new_in_edges = {} for _src, conn, _dest, _, memlet in graph.out_edges(map_entry): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) if memlet.data in new_in_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_in_edges[memlet.data] new_memlet.subset = calc_set_union(new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, int(conn[4:]))) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_in_edges.update({ memlet.data: (new_map_entry, None, map_entry, None, new_memlet, int(conn[4:])) }) nxutil.change_edge_dest(graph, map_entry, new_map_entry) new_out_edges = {} new_exits = [] for map_exit in map_exits: if isinstance(map_exit, nodes.MapExit): new_exit = nodes.MapExit(new_map) new_exits.append(new_exit) for _src, conn, _dest, _, memlet in graph.in_edges(map_exit): if not isinstance(sdfg.arrays[memlet.data], dace.data.Scalar): new_subset = calc_set_image( map_entry.map.params, map_entry.map.range, memlet.subset, ) if memlet.data in new_out_edges: src, src_conn, dest, dest_conn, new_memlet, num = \ new_out_edges[memlet.data] new_memlet.subset = calc_set_union( new_memlet.subset, new_subset) new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (src, src_conn, dest, dest_conn, new_memlet, min(num, conn[4:])) }) else: new_memlet = dcpy(memlet) new_memlet.subset = new_subset new_memlet.num_accesses = new_memlet.num_elements() new_out_edges.update({ memlet.data: (map_exit, None, new_exit, None, new_memlet, conn[4:]) }) nxutil.change_edge_src(graph, map_exit, new_exit) in_conn_nums = [] for _, e in new_in_edges.items(): _, _, _, _, _, num = e in_conn_nums.append(num) in_conn = {} for i, num in enumerate(in_conn_nums): in_conn.update({num: i + 1}) entry_in_connectors = set() entry_out_connectors = set() for i in range(len(in_conn_nums)): entry_in_connectors.add('IN_' + str(i + 1)) entry_out_connectors.add('OUT_' + str(i + 1)) new_map_entry.in_connectors = entry_in_connectors new_map_entry.out_connectors = entry_out_connectors for _, e in new_in_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(in_conn[num]), dst, 'IN_' + str(in_conn[num]), memlet) for new_exit in new_exits: out_conn_nums = [] for _, e in new_out_edges.items(): _, _, dst, _, _, num = e if dst is not new_exit: continue out_conn_nums.append(num) out_conn = {} for i, num in enumerate(out_conn_nums): out_conn.update({num: i + 1}) exit_in_connectors = set() exit_out_connectors = set() for i in range(len(out_conn_nums)): exit_in_connectors.add('IN_' + str(i + 1)) exit_out_connectors.add('OUT_' + str(i + 1)) new_exit.in_connectors = exit_in_connectors new_exit.out_connectors = exit_out_connectors for _, e in new_out_edges.items(): src, _, dst, _, memlet, num = e graph.add_edge(src, 'OUT_' + str(out_conn[num]), dst, 'IN_' + str(out_conn[num]), memlet) # Return strip-mined dimension. return target_dim, new_dim, new_map
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} if input_nodes: pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if array.name in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[array.name] = fpga_array fpga_node = type(node)(fpga_array) pre_state.add_node(node) pre_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(array, full_range.num_elements(), full_range, 1) pre_state.add_edge(node, None, fpga_node, None, mem) state.add_node(fpga_node) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array fpga_node = type(node)(fpga_array) post_state.add_node(node) post_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(fpga_array, full_range.num_elements(), full_range, 1) post_state.add_edge(fpga_node, None, node, None, mem) state.add_node(fpga_node) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) for src, _, dst, _, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + node.data fpga_update(state, 0)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] for node, graph in state.all_nodes_recursive(): if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: # This is an output node with wcr # find the target in the parent sdfg # following the structure State->SDFG->State-> SDFG # from the current_state we have to go two levels up parent_state = graph.parent.parent if parent_state is not None: for parent_edges in parent_state.edges(): if parent_edges.src_conn == e.dst.data or ( isinstance(parent_edges.dst, dace.graph.nodes.AccessNode) and e.dst.data == parent_edges.dst.data): # This must be copied to device input_nodes.append(parent_edges.dst) wcr_input_nodes.add(parent_edges.dst) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): # Special case: nodes that lead to dynamic map ranges # must stay on host for e in state.out_edges(node): last_edge = state.memlet_path(e)[-1] if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn and not last_edge.dst_conn.startswith('IN_')): break else: input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, nodes.EmptyTasklet): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None and e.data.wcr_identity is None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append( (e.data.data, sdfg.arrays[e.data.data])) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in set(input_nodes): if isinstance(inode, data.Scalar): # Scalars can remain on host continue newdesc = inode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True) cloned_arrays[inodename] = name for onodename, onode in set(output_nodes): if onodename in cloned_arrays: continue newdesc = onode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + onodename, newdesc, find_new_name=True) cloned_arrays[onodename] = name # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: continue src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, ed.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): if nname in excluded_copyout or nname not in cloned_arrays: continue src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) # Special case: nodes that lead to dynamic map ranges must # stay on host if any( isinstance( state.memlet_path(e)[-1].dst, nodes.EntryNode) for e in state.out_edges(node)): continue if sdict[node] is None: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = dtypes.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if (self.toplevel_trans and not isinstance(nodedesc, data.Stream)): nodedesc.toplevel = True else: # Make internal transients registers if self.register_trans: nodedesc.storage = dtypes.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: if gcode.label in self.exclude_tasklets.split(','): continue # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=dtypes.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = set('IN_' + e.dst_conn for e in in_edges) me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges) mx.in_connectors = set('IN_' + e.src_conn for e in out_edges) mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges) # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.EmptyMemlet()) ####################################################### # Step 6: Change all top-level maps and Reduce nodes to GPU schedule for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, (nodes.EntryNode, nodes.Reduce)): if sdict[node] is None: node.schedule = dtypes.ScheduleType.GPU_Device elif (isinstance(node, nodes.EntryNode) and self.sequential_innermaps): node.schedule = dtypes.ScheduleType.Sequential ####################################################### # Step 7: Introduce copy-out if data used in outgoing interstate edges for state in list(sdfg.nodes()): arrays_used = set() for e in sdfg.out_edges(state): # Used arrays = intersection between symbols and cloned arrays arrays_used.update( set(e.data.condition_symbols()) & set(cloned_arrays.keys())) # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = sdfg.add_state(state.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state for e in sdfg.out_edges(state): nxutil.change_edge_src(sdfg, state, co_state) # Add unconditional edge to interim state sdfg.add_edge(state, co_state, ed.InterstateEdge()) # Add copy-out nodes for nname in arrays_used: desc = sdfg.arrays[nname] src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) co_state.add_node(src_array) co_state.add_node(dst_array) co_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 8: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. sdfg.apply_strict_transformations()