def gemv_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, alpha, beta, trans=None): # Get properties if trans is None: trans = (sdfg.arrays[x].shape[0] == sdfg.arrays[A].shape[0]) # Add nodes A_in, x_in = (state.add_read(name) for name in (A, x)) y_out = state.add_write(y) libnode = Gemv('gemv', transA=trans, alpha=alpha, beta=beta) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(libnode, '_y', y_out, None, mm.Memlet(y)) if beta != 0: y_in = state.add_read(y) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) return []
def apply(self, sdfg: SDFG) -> nodes.AccessNode: state = sdfg.node(self.state_id) access: nodes.AccessNode = self.access(sdfg) # Get memlet paths first_edge = state.in_edges(access)[0] second_edge = state.out_edges(access)[0] first_mpath = state.memlet_path(first_edge) second_mpath = state.memlet_path(second_edge) # Create new stream of shape 1 desc = sdfg.arrays[access.data] name, newdesc = sdfg.add_stream(access.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) # Remove transient array if possible for ostate in sdfg.nodes(): if ostate is state: continue if any(n.data == access.data for n in ostate.data_nodes()): break else: del sdfg.arrays[access.data] # Replace memlets in path with stream access for e in first_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) for e in second_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace array access node with two stream access nodes wnode = state.add_write(name) rnode = state.add_read(name) state.remove_edge(first_edge) state.add_edge(first_edge.src, first_edge.src_conn, wnode, first_edge.dst_conn, first_edge.data) state.remove_edge(second_edge) state.add_edge(rnode, second_edge.src_conn, second_edge.dst, second_edge.dst_conn, second_edge.data) # Remove original access node state.remove_node(access) return wnode, rnode
def bmmnode(pv, sdfg: dace.SDFG, state: dace.SDFGState, A, B, C, alpha=1, beta=0, trans_a=False, trans_b=False): # Add nodes A_in, B_in = (state.add_read(name) for name in (A, B)) C_out = state.add_write(C) libnode = BatchedMatMul('bmm') libnode.alpha = alpha libnode.beta = beta libnode.transA = trans_a libnode.transB = trans_b state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_a', mm.Memlet(A)) state.add_edge(B_in, None, libnode, '_b', mm.Memlet(B)) state.add_edge(libnode, '_c', C_out, None, mm.Memlet(C)) return []
def _handle_connectors(state: sd.SDFGState, node: nodes.Tasklet, mapping: Dict[str, Tuple[str, subsets.Range]], ignore: Set[str], in_edges: bool) -> bool: """ Adds new connectors and removes unused connectors after indirection promotion. """ if in_edges: orig_edges = {e.dst_conn: e for e in state.in_edges(node)} else: orig_edges = {e.src_conn: e for e in state.out_edges(node)} for cname, (orig, subset) in mapping.items(): if in_edges: node.add_in_connector(cname) else: node.add_out_connector(cname) # Add new edge orig_edge = orig_edges[orig] if in_edges: state.add_edge(orig_edge.src, orig_edge.src_conn, orig_edge.dst, cname, mm.Memlet(data=orig_edge.data.data, subset=subset)) else: state.add_edge(orig_edge.src, cname, orig_edge.dst, orig_edge.dst_conn, mm.Memlet(data=orig_edge.data.data, subset=subset)) # Remove connectors and edges conns_to_remove = set(v[0] for v in mapping.values()) - ignore for conn in conns_to_remove: state.remove_edge(orig_edges[conn]) if in_edges: node.remove_in_connector(conn) else: node.remove_out_connector(conn)
def gemv_libnode(sdfg: SDFG, state: SDFGState, A, B, C, alpha, beta, trans_a=False, trans_b=False): # Add nodes A_in, B_in = (state.add_read(name) for name in (A, B)) C_out = state.add_write(C) libnode = Gemm('gemm', transA=trans_a, transB=trans_b, alpha=alpha, beta=beta) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_a', mm.Memlet(A)) state.add_edge(B_in, None, libnode, '_b', mm.Memlet(B)) state.add_edge(libnode, '_c', C_out, None, mm.Memlet(C)) if beta != 0: C_in = state.add_read(C) state.add_edge(C_in, None, libnode, '_cin', mm.Memlet(C)) return []
def create_deeply_nested_sdfg(): sdfg = dace.SDFG("deepnest_test") state: dace.SDFGState = sdfg.add_state("init") xarr = state.add_array("x", [4, 10], dace.float32) sdfg.arrays["x"].location["memorytype"] = "hbm" sdfg.arrays["x"].location["bank"] = "0:4" yarr = state.add_array("y", [4, 10], dace.float32) sdfg.arrays["y"].location["memorytype"] = "hbm" sdfg.arrays["y"].location["bank"] = "4:8" top_map_entry, top_map_exit = state.add_map("topmap", dict(k="0:2")) top_map_entry.schedule = dtypes.ScheduleType.Unrolled nsdfg = dace.SDFG("nest") nstate = nsdfg.add_state("nested_state") x_read = nstate.add_array("xin", [4, 10], dace.float32, dtypes.StorageType.FPGA_Global) x_write = nstate.add_array("xout", [4, 10], dace.float32, dtypes.StorageType.FPGA_Global) nsdfg.arrays["xin"].location["memorytype"] = "hbm" nsdfg.arrays["xin"].location["bank"] = "0:4" nsdfg.arrays["xout"].location["memorytype"] = "hbm" nsdfg.arrays["xout"].location["bank"] = "4:8" map_entry, map_exit = nstate.add_map("map1", dict(w="0:2")) map_entry.schedule = dtypes.ScheduleType.Unrolled imap_entry, imap_exit = nstate.add_map("map2", dict(i="0:10")) nope = nstate.add_tasklet("nop", dict(_in=None), dict(_out=None), "_out = _in") input_mem = mem.Memlet("xin[2*k+w, i]") output_mem = mem.Memlet("xout[2*k+w, i]") nstate.add_memlet_path(x_read, map_entry, imap_entry, nope, memlet=input_mem, dst_conn="_in") nstate.add_memlet_path(nope, imap_exit, map_exit, x_write, memlet=output_mem, src_conn="_out") nsdfg_node = state.add_nested_sdfg(nsdfg, state, set(["xin"]), set(['xout'])) state.add_memlet_path(xarr, top_map_entry, nsdfg_node, memlet=mem.Memlet.from_array("x", sdfg.arrays["x"]), dst_conn="xin") state.add_memlet_path(nsdfg_node, top_map_exit, yarr, memlet=mem.Memlet.from_array("y", sdfg.arrays["y"]), src_conn="xout") sdfg.apply_fpga_transformations() return sdfg
def test_3_interface_to_2_banks(): sdfg = SDFG("test_4_interface_to_2_banks") state = sdfg.add_state() _, desc_a = sdfg.add_array("a", [2, 2], dace.int32) desc_a.location["memorytype"] = "HBM" desc_a.location["bank"] = "0:2" acc_read1 = state.add_read("a") acc_write1 = state.add_write("a") t1 = state.add_tasklet("r1", set(["_x1", "_x2"]), set(["_y1"]), "_y1 = _x1 + _x2") m1_in, m1_out = state.add_map("m", {"k": "0:2"}, dtypes.ScheduleType.Unrolled) state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[0, 0]"), dst_conn="_x1") state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[1, 0]"), dst_conn="_x2") state.add_memlet_path(t1, m1_out, acc_write1, memlet=memlet.Memlet("a[0, 1]"), src_conn="_y1") sdfg.apply_fpga_transformations() assert sdfg.apply_transformations(InlineSDFG) == 1 assert sdfg.apply_transformations(MapUnroll) == 1 for node in sdfg.states()[0].nodes(): if isinstance(node, dace.sdfg.nodes.Tasklet): sdfg.states()[0].out_edges( node)[0].data.subset = subsets.Range.from_string("1, 1") break bank_assignment = sdfg.generate_code()[3].clean_code assert bank_assignment.count("sp") == 6 assert bank_assignment.count("HBM[0]") == 3 assert bank_assignment.count("HBM[1]") == 3 a = np.zeros([2, 2], np.int32) a[0, 0] = 2 a[1, 0] = 3 sdfg(a=a) assert a[0, 1] == 5 return sdfg
def expansion(node, parent_state, parent_sdfg, **kwargs): node.validate(parent_sdfg, parent_state) inputs = ('_A', '_x', '_y') outputs = ('_res', ) in_edges = [next(parent_state.in_edges_by_connector(node, conn)) for conn in inputs] out_edges = [next(parent_state.out_edges_by_connector(node, conn)) for conn in outputs] arrays = {} arrays.update({inp: parent_sdfg.arrays[e.data.data] for inp, e in zip(inputs, in_edges)}) arrays.update({out: parent_sdfg.arrays[e.data.data] for out, e in zip(outputs, out_edges)}) # TODO: Support memlet subsets if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(inputs, in_edges)): raise NotImplementedError if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(outputs, out_edges)): raise NotImplementedError sdfg = dace.SDFG(f'{node.label}_sdfg') sdfg.add_symbol('M', int) sdfg.add_symbol('N', int) sdfg.add_symbol('alpha', arrays['_A'].dtype) for name, desc in arrays.items(): newdesc = copy.deepcopy(desc) newdesc.transient = False sdfg.add_datadesc(name, newdesc) state = sdfg.add_state() state.add_mapped_tasklet( 'ger', { '_i': f'0:M', '_j': f'0:N' }, { 'a': mm.Memlet('_A[_i, _j]'), 'xin': mm.Memlet('_x[_i]'), 'yin': mm.Memlet(f'_y[_j]') }, f'aout = alpha * xin * yin + a', {'aout': mm.Memlet('_res[_i, _j]')}, external_edges=True, ) outshape = arrays['_res'].shape nsdfg_node = nodes.NestedSDFG(node.label, sdfg, set(inputs), set(outputs), { 'M': outshape[0], 'N': outshape[1], 'alpha': node.alpha }) return nsdfg_node
def axpy_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a, x, y, result): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Axpy('axpy', a=a) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_res', res, None, mm.Memlet(result)) return []
def dot_libnode(sdfg: SDFG, state: SDFGState, x, y, result): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Dot('dot', n=sdfg.arrays[x].shape[0]) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_result', res, None, mm.Memlet(result)) return []
def dot_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, x, y, result, acctype=None): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Dot('dot', n=sdfg.arrays[x].shape[0], accumulator_type=acctype) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_result', res, None, mm.Memlet(result)) return []
def ger_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, output, alpha): # Add nodes A_in, x_in, y_in = (state.add_read(name) for name in (A, x, y)) out = state.add_write(output) libnode = Ger('ger', alpha=alpha) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_res', out, None, mm.Memlet(output)) return []
def expressions(): # Matching # \======/ # | | # o o g = SDFGState() g.add_node(OutMergeArrays._array1) g.add_node(OutMergeArrays._array2) g.add_node(OutMergeArrays._map_exit) g.add_edge(OutMergeArrays._map_exit, None, OutMergeArrays._array1, None, memlet.Memlet()) g.add_edge(OutMergeArrays._map_exit, None, OutMergeArrays._array2, None, memlet.Memlet()) return [g]
def expressions(): # Matching # o o # | | # /======\ g = SDFGState() g.add_node(InMergeArrays._array1) g.add_node(InMergeArrays._array2) g.add_node(InMergeArrays._map_entry) g.add_edge(InMergeArrays._array1, None, InMergeArrays._map_entry, None, memlet.Memlet()) g.add_edge(InMergeArrays._array2, None, InMergeArrays._map_entry, None, memlet.Memlet()) return [g]
def _streamify_recursive(node: nodes.NestedSDFG, to_replace: str, desc: data.Stream): """ Helper function that changes an array in a nested SDFG to a stream. """ nsdfg: SDFG = node.sdfg newdesc = copy.deepcopy(desc) newdesc.transient = False nsdfg.arrays[to_replace] = newdesc # Replace memlets in path with stream access for state in nsdfg.nodes(): for dnode in state.data_nodes(): if dnode.data != to_replace: continue for edge in state.all_edges(dnode): mpath = state.memlet_path(edge) for e in mpath: e.data = mm.Memlet(data=to_replace, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc)
def mkc(sdfg: dace.SDFG, state_before, src_name, dst_name, src_storage=None, dst_storage=None, src_shape=None, dst_shape=None, copy_expr=None, src_loc=None, dst_loc=None): """ Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided arrayname already exists it will use the old array, and ignore all newly passed values """ if copy_expr is None: copy_expr = src_name if (state_before == None): state = sdfg.add_state(is_start_state=True) else: state = sdfg.add_state_after(state_before) def mkarray(name, shape, storage, loc): if (name in sdfg.arrays): return sdfg.arrays[name] is_transient = False if (storage in _FPGA_STORAGE_TYPES): is_transient = True arr = sdfg.add_array(name, shape, dace.int32, storage, transient=is_transient) if loc is not None: arr[1].location["memorytype"] = loc[0] arr[1].location["bank"] = loc[1] return arr a = mkarray(src_name, src_shape, src_storage, src_loc) b = mkarray(dst_name, dst_shape, dst_storage, dst_loc) aAcc = state.add_access(src_name) bAcc = state.add_access(dst_name) edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr)) a_np_arr, b_np_arr = None, None if src_shape is not None: try: a_np_arr = np.zeros(src_shape, dtype=np.int32) except: pass if dst_shape is not None: try: b_np_arr = np.zeros(dst_shape, dtype=np.int32) except: pass return (state, a_np_arr, b_np_arr)
def _Subscript(self, t: ast.Subscript): from dace.frontend.python.astutils import subscript_to_slice target, rng = subscript_to_slice(t, self.sdfg.arrays) rng = subsets.Range(rng) if rng.num_elements() != 1: raise SyntaxError( 'Range subscripts disallowed in interstate edges') memlet = mmlt.Memlet(data=target, subset=rng) self.write(cpp_array_expr(self.sdfg, memlet))
def four_interface_to_2_banks(mem_type, decouple_interfaces): sdfg = SDFG("test_4_interface_to_2_banks_" + mem_type) state = sdfg.add_state() _, desc_a = sdfg.add_array("a", [2, 2], dace.int32) desc_a.location["memorytype"] = mem_type desc_a.location["bank"] = "0:2" acc_read1 = state.add_read("a") acc_write1 = state.add_write("a") t1 = state.add_tasklet("r1", set(["_x1", "_x2"]), set(["_y1"]), "_y1 = _x1 + _x2") m1_in, m1_out = state.add_map("m", {"k": "0:2"}, dtypes.ScheduleType.Unrolled) state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[0, 0]"), dst_conn="_x1") state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[1, 0]"), dst_conn="_x2") state.add_memlet_path(t1, m1_out, acc_write1, memlet=memlet.Memlet("a[0, 1]"), src_conn="_y1") sdfg.apply_fpga_transformations() assert sdfg.apply_transformations(InlineSDFG) == 1 assert sdfg.apply_transformations(MapUnroll) == 1 for node in sdfg.states()[0].nodes(): if isinstance(node, dace.sdfg.nodes.Tasklet): sdfg.states()[0].out_edges(node)[0].data.subset = subsets.Range.from_string("1, 1") break with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=decouple_interfaces): bank_assignment = sdfg.generate_code()[3].clean_code # if we are not decoupling array interfaces we will use less mem interfaces assert bank_assignment.count("sp") == 6 if decouple_interfaces else 4 assert bank_assignment.count(mem_type + "[0]") == 3 if decouple_interfaces else 2 assert bank_assignment.count(mem_type + "[1]") == 3 if decouple_interfaces else 2 a = np.zeros([2, 2], np.int32) a[0, 0] = 2 a[1, 0] = 3 sdfg(a=a) assert a[0, 1] == 5 return sdfg
def create_dynamic_memlet_sdfg(): sdfg = dace.SDFG("dyn_memlet") state: dace.SDFGState = sdfg.add_state("dyn_memlet") xarr = state.add_array("x", [4, 10], dace.int32) sdfg.arrays["x"].location["memorytype"] = "hbm" sdfg.arrays["x"].location["bank"] = "0:4" yarr = state.add_array("y", [4, 10], dace.int32) sdfg.arrays["y"].location["memorytype"] = "hbm" sdfg.arrays["y"].location["bank"] = "4:8" hbm_map_enter, hbm_map_exit = state.add_map("hbmmap", dict(k="0:4"), dtypes.ScheduleType.Unrolled) arr_map_enter, arr_map_exit = state.add_map("map", dict(i="0:_dynbound")) tasklet = state.add_tasklet("dyn", set(["_in"]), set(["_out"]), ("if(i == 2):\n" " _out = 2\n" "elif (_in != 2):\n" " _out = _in\n")) state.add_memlet_path(xarr, hbm_map_enter, arr_map_enter, tasklet, memlet=mem.Memlet("x[k, i]", dynamic=True), dst_conn="_in") state.add_memlet_path(tasklet, arr_map_exit, hbm_map_exit, yarr, memlet=mem.Memlet("y[k, i]", dynamic=True), src_conn="_out") state.add_memlet_path(xarr, hbm_map_enter, arr_map_enter, memlet=mem.Memlet("x[1, 0]"), dst_conn="_dynbound") sdfg.apply_fpga_transformations() return sdfg
def _make_view(self, sdfg: SDFG, graph: SDFGState, in_array: nodes.AccessNode, out_array: nodes.AccessNode, e1: graph.MultiConnectorEdge[mm.Memlet], b_subset: subsets.Subset, b_dims_to_pop: typing.List[int]): in_desc = sdfg.arrays[in_array.data] out_desc = sdfg.arrays[out_array.data] # NOTE: We do not want to create another view, if the immediate # ancestors of in_array are views as well. We just remove it. in_ancestors_desc = [ e.src.desc(sdfg) if isinstance(e.src, nodes.AccessNode) else None for e in graph.in_edges(in_array) ] if all([ desc and isinstance(desc, data.View) for desc in in_ancestors_desc ]): for e in graph.in_edges(in_array): a_subset, _ = _validate_subsets(e, sdfg.arrays) graph.add_edge( e.src, e.src_conn, out_array, None, mm.Memlet(out_array.data, subset=b_subset, other_subset=a_subset, wcr=e1.data.wcr, wcr_nonatomic=e1.data.wcr_nonatomic)) graph.remove_edge(e) graph.remove_edge(e1) graph.remove_node(in_array) if in_array.data in sdfg.arrays: del sdfg.arrays[in_array.data] return view_strides = in_desc.strides if (b_dims_to_pop and len(b_dims_to_pop) == len(out_desc.shape) - len(in_desc.shape)): view_strides = [ s for i, s in enumerate(out_desc.strides) if i not in b_dims_to_pop ] sdfg.arrays[in_array.data] = data.View( in_desc.dtype, in_desc.shape, True, in_desc.allow_conflicts, out_desc.storage, out_desc.location, view_strides, in_desc.offset, out_desc.may_alias, dtypes.AllocationLifetime.Scope, in_desc.alignment, in_desc.debuginfo, in_desc.total_size)
def apply(self, sdfg): def gnode(nname): return graph.nodes()[self.subgraph[nname]] graph = sdfg.nodes()[self.state_id] in_array = gnode(RedundantSecondArray._in_array) out_array = gnode(RedundantSecondArray._out_array) in_desc = sdfg.arrays[in_array.data] out_desc = sdfg.arrays[out_array.data] # We assume the following pattern: A -- e1 --> B -- e2 --> others # 1. Get edge e1 and extract subsets for arrays A and B e1 = graph.edges_between(in_array, out_array)[0] a_subset, b1_subset = _validate_subsets(e1, sdfg.arrays) # Find extraneous A or B subset dimensions a_dims_to_pop = [] b_dims_to_pop = [] aset = a_subset popped = [] if a_subset and b1_subset and a_subset.dims() != b1_subset.dims(): a_size = a_subset.size_exact() b_size = b1_subset.size_exact() if a_subset.dims() > b1_subset.dims(): a_dims_to_pop = find_dims_to_pop(a_size, b_size) aset, popped = pop_dims(a_subset, a_dims_to_pop) else: b_dims_to_pop = find_dims_to_pop(b_size, a_size) # If the src subset does not cover the removed array, create a view. if a_subset and any(m != a for m, a in zip(a_subset.size(), out_desc.shape)): # NOTE: We do not want to create another view, if the immediate # successors of out_array are views as well. We just remove it. out_successors_desc = [ e.dst.desc(sdfg) if isinstance(e.dst, nodes.AccessNode) else None for e in graph.out_edges(out_array) ] if all([ desc and isinstance(desc, data.View) for desc in out_successors_desc ]): for e in graph.out_edges(out_array): _, b_subset = _validate_subsets(e, sdfg.arrays) graph.add_edge( in_array, None, e.dst, e.dst_conn, mm.Memlet(in_array.data, subset=a_subset, other_subset=b_subset, wcr=e1.data.wcr, wcr_nonatomic=e1.data.wcr_nonatomic)) graph.remove_edge(e) graph.remove_edge(e1) graph.remove_node(out_array) if out_array.data in sdfg.arrays: del sdfg.arrays[out_array.data] return view_strides = out_desc.strides if (a_dims_to_pop and len(a_dims_to_pop) == len(in_desc.shape) - len(out_desc.shape)): view_strides = [ s for i, s in enumerate(in_desc.strides) if i not in a_dims_to_pop ] sdfg.arrays[out_array.data] = data.View( out_desc.dtype, out_desc.shape, True, out_desc.allow_conflicts, in_desc.storage, in_desc.location, view_strides, out_desc.offset, in_desc.may_alias, dtypes.AllocationLifetime.Scope, out_desc.alignment, out_desc.debuginfo, out_desc.total_size) return # 2. Iterate over the e2 edges and traverse the memlet tree for e2 in graph.out_edges(out_array): path = graph.memlet_tree(e2) wcr = e1.data.wcr wcr_nonatomic = e1.data.wcr_nonatomic for e3 in path: # 2-a. Extract subsets for array B and others b3_subset, other_subset = _validate_subsets( e3, sdfg.arrays, src_name=out_array.data) # 2-b. Modify memlet to match array A. Example: # A -- (0, a:b)/(c:c+b) --> B -- (c+d)/None --> others # A -- (0, a+d)/None --> others e3.data.data = in_array.data # (c+d) - (c:c+b) = (d) b3_subset.offset(b1_subset, negative=True) # (0, a:b)(d) = (0, a+d) (or offset for indices) if b3_subset and b_dims_to_pop: bset, _ = pop_dims(b3_subset, b_dims_to_pop) else: bset = b3_subset e3.data.subset = compose_and_push_back(aset, bset, a_dims_to_pop, popped) # NOTE: This fixes the following case: # A ----> A[subset] ----> ... -----> Tasklet # Tasklet is not data, so it doesn't have an other subset. if isinstance(e3.dst, nodes.AccessNode): e3.data.other_subset = other_subset else: e3.data.other_subset = None wcr = wcr or e3.data.wcr wcr_nonatomic = wcr_nonatomic or e3.data.wcr_nonatomic e3.data.wcr = wcr e3.data.wcr_nonatomic = wcr_nonatomic # 2-c. Remove edge and add new one graph.remove_edge(e2) e2.data.wcr = wcr e2.data.wcr_nonatomic = wcr_nonatomic graph.add_edge(in_array, e2.src_conn, e2.dst, e2.dst_conn, e2.data) # Finally, remove out_array node graph.remove_node(out_array) if out_array.data in sdfg.arrays: try: sdfg.remove_data(out_array.data) except ValueError: # Already in use (e.g., with Views) pass
def apply(self, sdfg: SDFG) -> nodes.AccessNode: state = sdfg.node(self.state_id) dnode: nodes.AccessNode = self.access(sdfg) if self.expr_index == 0: edges = state.out_edges(dnode) else: edges = state.in_edges(dnode) # To understand how many components we need to create, all map ranges # throughout memlet paths must match exactly. We thus create a # dictionary of unique ranges mapping: Dict[Tuple[subsets.Range], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict( list) ranges = {} for edge in edges: mpath = state.memlet_path(edge) ranges[edge] = _collect_map_ranges(state, mpath) mapping[tuple(r[1] for r in ranges[edge])].append(edge) # Collect all edges with the same memory access pattern components_to_create: Dict[ Tuple[symbolic.SymbolicType], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list) for edges_with_same_range in mapping.values(): for edge in edges_with_same_range: # Get memlet path and innermost edge mpath = state.memlet_path(edge) innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index == 0 else mpath[0]) # Store memlets of the same access in the same component expr = _canonicalize_memlet(innermost_edge.data, ranges[edge]) components_to_create[expr].append((innermost_edge, edge)) components = list(components_to_create.values()) # Split out components that have dependencies between them to avoid # deadlocks if self.expr_index == 0: ccs_to_add = [] for i, component in enumerate(components): edges_to_remove = set() for cedge in component: if any( nx.has_path(state.nx, o[1].dst, cedge[1].dst) for o in component if o is not cedge): ccs_to_add.append([cedge]) edges_to_remove.add(cedge) if edges_to_remove: components[i] = [ c for c in component if c not in edges_to_remove ] components.extend(ccs_to_add) # End of split desc = sdfg.arrays[dnode.data] # Create new streams of shape 1 streams = {} mpaths = {} for edge in edges: name, newdesc = sdfg.add_stream(dnode.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) streams[edge] = name mpath = state.memlet_path(edge) mpaths[edge] = mpath # Replace memlets in path with stream access for e in mpath: e.data = mm.Memlet(data=name, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace access node and memlet tree with one access if self.expr_index == 0: replacement = state.add_read(name) state.remove_edge(edge) state.add_edge(replacement, edge.src_conn, edge.dst, edge.dst_conn, edge.data) else: replacement = state.add_write(name) state.remove_edge(edge) state.add_edge(edge.src, edge.src_conn, replacement, edge.dst_conn, edge.data) # Make read/write components ionodes = [] for component in components: # Pick the first edge as the edge to make the component from innermost_edge, outermost_edge = component[0] mpath = mpaths[outermost_edge] mapname = streams[outermost_edge] innermost_edge.data.other_subset = None # Get edge data and streams if self.expr_index == 0: opname = 'read' path = [e.dst for e in mpath[:-1]] rmemlets = [(dnode, '__inp', innermost_edge.data)] wmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_write(name) ionodes.append(ionode) wmemlets.append( (ionode, '__out%d' % i, mm.Memlet(data=name, subset='0'))) code = '\n'.join('__out%d = __inp' % i for i in range(len(component))) else: # More than one input stream might mean a data race, so we only # address the first one in the tasklet code if len(component) > 1: warnings.warn( f'More than one input found for the same index for {dnode.data}' ) opname = 'write' path = [state.entry_node(e.src) for e in reversed(mpath[1:])] wmemlets = [(dnode, '__out', innermost_edge.data)] rmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_read(name) ionodes.append(ionode) rmemlets.append( (ionode, '__inp%d' % i, mm.Memlet(data=name, subset='0'))) code = '__out = __inp0' # Create map structure for read/write component maps = [] for entry in path: map: nodes.Map = entry.map maps.append( state.add_map(f'__s{opname}_{mapname}', [(p, r) for p, r in zip(map.params, map.range)], map.schedule)) tasklet = state.add_tasklet( f'{opname}_{mapname}', {m[1] for m in rmemlets}, {m[1] for m in wmemlets}, code, ) for node, cname, memlet in rmemlets: state.add_memlet_path(node, *(me for me, _ in maps), tasklet, dst_conn=cname, memlet=memlet) for node, cname, memlet in wmemlets: state.add_memlet_path(tasklet, *(mx for _, mx in reversed(maps)), node, src_conn=cname, memlet=memlet) return ionodes
def create_deeply_nested_sdfg(): sdfg = dace.SDFG("deepnest_test") state: dace.SDFGState = sdfg.add_state("init") xarr = state.add_array("x", [4, 100], dace.float32) yarr = state.add_array("y", [4, 100], dace.float32) topMapEntry, topMapExit = state.add_map("topmap", dict(k="0:2")) topMapEntry.schedule = dtypes.ScheduleType.Unrolled nsdfg = dace.SDFG("nest") nstate = nsdfg.add_state("nested_state", True) xRead = nstate.add_array("xin", [4, 100], dace.float32) xWrite = nstate.add_array("xout", [4, 100], dace.float32) mapEntry, mapExit = nstate.add_map("map1", dict(w="0:2")) mapEntry.schedule = dtypes.ScheduleType.Unrolled noUnrollEntry, noUnrollExit = nstate.add_map("map2", dict(i="0:100")) nope = nstate.add_tasklet("nop", dict(_in=None), dict(_out=None), "_out = _in") inputMem = mem.Memlet("xin[2*k+w, i]") outputMem = mem.Memlet("xout[2*k+w, i]") nstate.add_memlet_path( xRead, mapEntry, noUnrollEntry, nope, memlet=inputMem, dst_conn="_in", ) nstate.add_memlet_path( nope, noUnrollExit, mapExit, xWrite, memlet=outputMem, src_conn="_out", ) nstate2 = nsdfg.add_state("second_nest") tasklet = nstate2.add_tasklet("overwrite", set(), set(["_out"]), "_out = 15.0") xWrite2 = nstate2.add_write("xout") nstate2.add_memlet_path( tasklet, xWrite2, memlet=mem.Memlet("xout[mpt, 0]"), src_conn="_out", ) nsdfg.add_edge(nstate, nstate2, InterstateEdge(None, dict(mpt="k"))) nsdfg_node = state.add_nested_sdfg(nsdfg, state, set(["xin"]), set(['xout'])) nsdfg_node.unique_name = "SomeUniqueName" state.add_memlet_path( xarr, topMapEntry, nsdfg_node, memlet=mem.Memlet.from_array("x", sdfg.arrays["x"]), dst_conn="xin", ) state.add_memlet_path( nsdfg_node, topMapExit, yarr, memlet=mem.Memlet.from_array("y", sdfg.arrays["y"]), src_conn="xout", ) return sdfg
def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), (_, body_end) = find_for_loop( sdfg, guard, body, itervar=self.itervar) # Find all loop-body states states = set([body_end]) to_visit = [body] while to_visit: state = to_visit.pop(0) if state is body_end: continue for _, dst, _ in sdfg.out_edges(state): if dst not in states: to_visit.append(dst) states.add(state) # Nest loop-body states if len(states) > 1: # Find read/write sets read_set, write_set = set(), set() for state in states: rset, wset = state.read_and_write_sets() read_set |= rset write_set |= wset # Add data from edges for src in states: for dst in states: for edge in sdfg.edges_between(src, dst): for s in edge.data.free_symbols: if s in sdfg.arrays: read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set unique_set = set() for name in rw_set: if not sdfg.arrays[name].transient: continue found = False for state in sdfg.states(): if state in states: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): found = True break if not found: unique_set.add(name) # Find NestedSDFG's connectors read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} # Create NestedSDFG and add all loop-body states and edges # Also, find defined symbols in NestedSDFG fsymbols = set(sdfg.free_symbols) new_body = sdfg.add_state('single_state_body') nsdfg = SDFG("loop_body", constants=sdfg.constants, parent=new_body) nsdfg.add_node(body, is_start_state=True) body.parent = nsdfg exit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: if state is body: continue nsdfg.add_node(state) state.parent = nsdfg for state in states: if state is body: continue for src, dst, data in sdfg.in_edges(state): nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) nsdfg.add_edge(src, dst, data) nsdfg.add_edge(body_end, exit_state, InterstateEdge()) # Move guard -> body edge to guard -> new_body for src, dst, data, in sdfg.edges_between(guard, body): sdfg.add_edge(src, new_body, data) # Move body_end -> guard edge to new_body -> guard for src, dst, data in sdfg.edges_between(body_end, guard): sdfg.add_edge(new_body, dst, data) # Delete loop-body states and edges from parent SDFG for state in states: for e in sdfg.all_edges(state): sdfg.remove_edge(e) sdfg.remove_node(state) # Add NestedSDFG arrays for name in read_set | write_set: nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) nsdfg.arrays[name].transient = False for name in unique_set: nsdfg.arrays[name] = sdfg.arrays[name] del sdfg.arrays[name] # Add NestedSDFG node cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set) if sdfg.parent: for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): if s not in cnode.symbol_mapping: cnode.symbol_mapping[s] = m nsdfg.add_symbol(s, sdfg.symbols[s]) for name in read_set: r = new_body.add_read(name) new_body.add_edge( r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) for name in write_set: w = new_body.add_write(name) new_body.add_edge( cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) # Fix SDFG symbols for sym in sdfg.free_symbols - fsymbols: del sdfg.symbols[sym] for sym, dtype in nsymbols.items(): nsdfg.symbols[sym] = dtype # Change body state reference body = new_body if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = sdutil.find_source_nodes(state) output_nodes = sdutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_node = node_trace break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.sdfg.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.sdfg.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): # Special case: nodes that lead to top-level dynamic # map ranges must stay on host for e in state.out_edges(node): last_edge = state.memlet_path(e)[-1] if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None): break else: input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append( (e.data.data, sdfg.arrays[e.data.data])) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in set(input_nodes): if isinstance(inode, data.Scalar): # Scalars can remain on host continue if inode.storage == dtypes.StorageType.GPU_Global: continue newdesc = inode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True) cloned_arrays[inodename] = name for onodename, onode in set(output_nodes): if onodename in cloned_arrays: continue if onode.storage == dtypes.StorageType.GPU_Global: continue newdesc = onode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + onodename, newdesc, find_new_name=True) cloned_arrays[onodename] = name # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: continue src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): if nname in excluded_copyout or nname not in cloned_arrays: continue src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) # Special case: nodes that lead to dynamic map ranges must # stay on host if any( isinstance( state.memlet_path(e)[-1].dst, nodes.EntryNode) for e in state.out_edges(node)): continue gpu_storage = [ dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared, dtypes.StorageType.CPU_Pinned ] if sdict[ node] is None and nodedesc.storage not in gpu_storage: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = dtypes.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if (self.toplevel_trans and not isinstance(nodedesc, data.Stream)): nodedesc.lifetime = dtypes.AllocationLifetime.SDFG elif nodedesc.storage not in gpu_storage: # Make internal transients registers if self.register_trans: nodedesc.storage = dtypes.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: if gcode.label in self.exclude_tasklets.split(','): continue # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=dtypes.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = {('IN_' + e.dst_conn): None for e in in_edges} me.out_connectors = {('OUT_' + e.dst_conn): None for e in in_edges} mx.in_connectors = {('IN_' + e.src_conn): None for e in out_edges} mx.out_connectors = {('OUT_' + e.src_conn): None for e in out_edges} # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.Memlet()) ####################################################### # Step 6: Change all top-level maps and library nodes to GPU schedule for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, (nodes.EntryNode, nodes.LibraryNode)): if sdict[node] is None: node.schedule = dtypes.ScheduleType.GPU_Device elif (isinstance(node, (nodes.EntryNode, nodes.LibraryNode)) and self.sequential_innermaps): node.schedule = dtypes.ScheduleType.Sequential ####################################################### # Step 7: Introduce copy-out if data used in outgoing interstate edges for state in list(sdfg.nodes()): arrays_used = set() for e in sdfg.out_edges(state): # Used arrays = intersection between symbols and cloned arrays arrays_used.update( set(e.data.free_symbols) & set(cloned_arrays.keys())) # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = sdfg.add_state(state.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state for e in sdfg.out_edges(state): sdutil.change_edge_src(sdfg, state, co_state) # Add unconditional edge to interim state sdfg.add_edge(state, co_state, sd.InterstateEdge()) # Add copy-out nodes for nname in arrays_used: desc = sdfg.arrays[nname] src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) co_state.add_node(src_array) co_state.add_node(dst_array) co_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 8: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. sdfg.apply_strict_transformations()
def apply(self, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) if (step < 0) == True: # If step is negative, we have to flip start and end to produce a # correct map with a positive increment start, end, step = end, start, -step # If necessary, make a nested SDFG with assignments isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() if len(isedge.data.assignments) > 0: nsdfg = helpers.nest_state_subgraph( sdfg, body, gr.SubgraphView(body, body.nodes())) for sym in isedge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) elif sym in sdfg.arrays: if sym in nsdfg.sdfg.arrays: raise NotImplementedError rnode = body.add_read(sym) nsdfg.add_in_connector(sym) desc = copy.deepcopy(sdfg.arrays[sym]) desc.transient = False nsdfg.sdfg.add_datadesc(sym, desc) body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] nisedge.data.assignments = isedge.data.assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] isedge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)]) entry = nodes.MapEntry(map) exit = nodes.MapExit(map) body.add_node(entry) body.add_node(exit) # If the map uses symbols from data containers, instantiate reads containers_to_read = entry.free_symbols & sdfg.arrays.keys() for rd in containers_to_read: # We are guaranteed that this is always a scalar, because # can_be_applied makes sure there are no sympy functions in each of # the loop expresions access_node = body.add_read(rd) body.add_memlet_path(access_node, entry, dst_conn=rd, memlet=memlet.Memlet(rd)) # Reroute all memlets through the entry and exit nodes for n in source_nodes: if isinstance(n, nodes.AccessNode): for e in body.out_edges(n): body.remove_edge(e) body.add_edge_pair(entry, e.dst, n, e.data, internal_connector=e.dst_conn) else: body.add_nedge(entry, n, memlet.Memlet()) for n in sink_nodes: if isinstance(n, nodes.AccessNode): for e in body.in_edges(n): body.remove_edge(e) body.add_edge_pair(exit, e.src, n, e.data, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) # Get rid of the loop exit condition edge after_edge = sdfg.edges_between(guard, after)[0] sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard for e in sdfg.in_edges(guard): if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) # Route body directly to after state, maintaining any other assignments # it might have had sdfg.add_edge( body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) for sym in symbols_to_remove: if helpers.is_symbol_unused(sdfg, sym): sdfg.remove_symbol(sym)
def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) nsdfg_node: Optional[nodes.NestedSDFG] = None # Obtain subgraph to perform fission to if self.expr_index == 0: # Map with subgraph subgraphs = [(graph, graph.scope_subgraph(map_entry, include_entry=False, include_exit=False))] parent = sdfg else: # Map with nested SDFG nsdfg_node = self.nested_sdfg subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()] parent = nsdfg_node.sdfg modified_arrays = set() # Get map information outer_map: nodes.Map = map_entry.map mapsize = outer_map.range.size() # Add new symbols from outer map to nested SDFG if self.expr_index == 1: map_syms = outer_map.range.free_symbols for edge in graph.out_edges(map_entry): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for edge in graph.in_edges(map_exit): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for sym in map_syms: symname = str(sym) if symname in outer_map.params: continue if symname not in nsdfg_node.symbol_mapping.keys(): nsdfg_node.symbol_mapping[symname] = sym nsdfg_node.sdfg.symbols[ symname] = graph.symbols_defined_at( nsdfg_node)[symname] # Remove map symbols from nested mapping for name in outer_map.params: if str(name) in nsdfg_node.symbol_mapping: del nsdfg_node.symbol_mapping[str(name)] if str(name) in nsdfg_node.sdfg.symbols: del nsdfg_node.sdfg.symbols[str(name)] for state, subgraph in subgraphs: components = MapFission._components(subgraph) sources = subgraph.source_nodes() sinks = subgraph.sink_nodes() # Collect external edges if self.expr_index == 0: external_edges_entry = list(state.out_edges(map_entry)) external_edges_exit = list(state.in_edges(map_exit)) else: external_edges_entry = [ e for e in subgraph.edges() if (isinstance(e.src, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.src.data].transient) ] external_edges_exit = [ e for e in subgraph.edges() if (isinstance(e.dst, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.dst.data].transient) ] # Map external edges to outer memlets edge_to_outer = {} for edge in external_edges_entry: if self.expr_index == 0: # Subgraphs use the corresponding outer map edges path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex - 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.in_edges(nsdfg_node) if e.dst_conn == edge.src.data) edge_to_outer[edge] = outer_edge for edge in external_edges_exit: if self.expr_index == 0: path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex + 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.out_edges(nsdfg_node) if e.src_conn == edge.dst.data) edge_to_outer[edge] = outer_edge # Collect all border arrays and code->code edges arrays = MapFission._border_arrays( nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state, subgraph) scalars = defaultdict(list) for _, component_out in components: for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.CodeNode): scalars[e.data.data].append(e) # Create new arrays for scalars for scalar, edges in scalars.items(): desc = parent.arrays[scalar] del parent.arrays[scalar] name, newdesc = parent.add_transient( scalar, mapsize, desc.dtype, desc.storage, lifetime=desc.lifetime, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, find_new_name=True) # Add extra nodes in component boundaries for edge in edges: anode = state.add_access(name) sbs = subsets.Range.from_string(','.join(outer_map.params)) # Offset memlet by map range begin (to fit the transient) sbs.offset([r[0] for r in outer_map.range], True) state.add_edge( edge.src, edge.src_conn, anode, None, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.add_edge( anode, None, edge.dst, edge.dst_conn, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.remove_edge(edge) # Add extra maps around components new_map_entries = [] for component_in, component_out in components: me, mx = state.add_map(outer_map.label + '_fission', [(p, '0:1') for p in outer_map.params], outer_map.schedule, unroll=outer_map.unroll, debuginfo=outer_map.debuginfo) # Add dynamic input connectors for conn in map_entry.in_connectors: if not conn.startswith('IN_'): me.add_in_connector(conn) me.map.range = dcpy(outer_map.range) new_map_entries.append(me) # Reconnect edges through new map for e in state.in_edges(component_in): state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data)) # Reconnect inner edges at source directly to external nodes if self.expr_index == 0 and e in external_edges_entry: state.add_edge(edge_to_outer[e].src, edge_to_outer[e].src_conn, me, None, dcpy(edge_to_outer[e].data)) else: state.add_edge(e.src, e.src_conn, me, None, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.in_degree(component_in) == 0: state.add_edge(me, None, component_in, None, mm.Memlet()) for e in state.out_edges(component_out): state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data)) # Reconnect inner edges at sink directly to external nodes if self.expr_index == 0 and e in external_edges_exit: state.add_edge(mx, None, edge_to_outer[e].dst, edge_to_outer[e].dst_conn, dcpy(edge_to_outer[e].data)) else: state.add_edge(mx, None, e.dst, e.dst_conn, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.out_degree(component_out) == 0: state.add_edge(component_out, None, mx, None, mm.Memlet()) # Connect other sources/sinks not in components (access nodes) # directly to external nodes if self.expr_index == 0: for node in sources: if isinstance(node, nodes.AccessNode): for edge in state.in_edges(node): outer_edge = edge_to_outer[edge] memlet = dcpy(edge.data) memlet.subset = subsets.Range( outer_map.range.ranges + memlet.subset.ranges) state.add_edge(outer_edge.src, outer_edge.src_conn, edge.dst, edge.dst_conn, memlet) for node in sinks: if isinstance(node, nodes.AccessNode): for edge in state.out_edges(node): outer_edge = edge_to_outer[edge] state.add_edge(edge.src, edge.src_conn, outer_edge.dst, outer_edge.dst_conn, dcpy(outer_edge.data)) # Augment arrays by prepending map dimensions for array in arrays: if array in modified_arrays: continue desc = parent.arrays[array] if isinstance( desc, dt.Scalar): # Scalar needs to be augmented to an array desc = dt.Array(desc.dtype, desc.shape, desc.transient, desc.allow_conflicts, desc.storage, desc.location, desc.strides, desc.offset, False, desc.lifetime, 0, desc.debuginfo, desc.total_size, desc.start_offset) parent.arrays[array] = desc for sz in reversed(mapsize): desc.strides = [desc.total_size] + list(desc.strides) desc.total_size = desc.total_size * sz desc.shape = mapsize + list(desc.shape) desc.offset = [0] * len(mapsize) + list(desc.offset) modified_arrays.add(array) # Fill scope connectors so that memlets can be tracked below state.fill_scope_connectors() # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: if isinstance(node, nodes.AccessNode): if node in corrected_nodes: continue corrected_nodes.add(node) outer_edge = edge_to_outer[edge] desc = parent.arrays[node.data] # Modify shape of internal array to match outer one outer_desc = sdfg.arrays[outer_edge.data.data] if not isinstance(desc, dt.Scalar): desc.shape = outer_desc.shape if isinstance(desc, dt.Array): desc.strides = outer_desc.strides desc.total_size = outer_desc.total_size # Inside the nested SDFG, offset all memlets to include # the offsets from within the map. # NOTE: Relies on propagation to fix outer memlets for internal_edge in state.all_edges(node): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet( e.data, outer_edge.data).subset # Only after offsetting memlets we can modify the # overall offset if isinstance(desc, dt.Array): desc.offset = outer_desc.offset # Fill in memlet trees for border transients # NOTE: Memlet propagation should run to correct the outer edges for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in arrays: for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet e.data.subset = subsets.Range( [(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) for d, r in zip(outer_map.params, outer_map.range)] + e.data.subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: for edge in graph.in_edges(map_entry): if not edge.dst_conn or not edge.dst_conn.startswith('IN_'): continue # Modify edge coming into nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) edge.data.num_accesses = edge.data.subset.num_elements() # Find matching edge inside map inner_edge = next( e for e in graph.out_edges(map_entry) if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:]) graph.add_edge(edge.src, edge.src_conn, nsdfg_node, inner_edge.dst_conn, dcpy(edge.data)) for edge in graph.out_edges(map_exit): # Modify edge coming out of nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) # Find matching edge inside map inner_edge = next(e for e in graph.in_edges(map_exit) if e.dst_conn[3:] == edge.src_conn[4:]) graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst, edge.dst_conn, dcpy(edge.data)) # Remove outer map graph.remove_nodes_from([map_entry, map_exit])
def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): """ Removes all instances of a promoted symbol's read accesses in an SDFG. This removes each read-only access node as well as all of its descendant edges (in memlet trees) and connectors. Descends recursively to nested SDFGs and modifies tasklets (Python and C++). :param sdfg: The SDFG to operate on. :param array_names: Mapping between scalar names to replace and their replacement symbol name. :note: Operates in-place on the SDFG. """ for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names ] for node in scalar_nodes: symname = array_names[node.data] for out_edge in state.out_edges(node): for e in state.memlet_tree(out_edge): # Step 3.1 dst = e.dst state.remove_edge_and_connectors(e) if isinstance(dst, nodes.Tasklet): # Step 3.2 if dst.language is dtypes.Language.Python: promo = TaskletPromoter(e.dst_conn, symname) for stmt in dst.code.code: promo.visit(stmt) elif dst.language is dtypes.Language.CPP: # Replace whole-word matches (identifiers) in code dst.code.code = re.sub( r'\b%s\b' % re.escape(e.dst_conn), symname, dst.code.as_string) elif isinstance(dst, nodes.AccessNode): # Step 3.3 t = state.add_tasklet('symassign', {}, {'__out'}, '__out = %s' % symname) state.add_edge( t, '__out', dst, e.dst_conn, mm.Memlet(data=dst.data, subset=e.data.dst_subset, volume=1)) # Reassign destination for check below dst = t elif isinstance(dst, nodes.NestedSDFG): tmp_symname = symname val = 1 while (tmp_symname in dst.sdfg.symbols or tmp_symname in dst.sdfg.arrays): # Find new symbol name tmp_symname = f'{symname}_{val}' val += 1 # Descend recursively to remove scalar remove_scalar_reads(dst.sdfg, {e.dst_conn: tmp_symname}) for ise in dst.sdfg.edges(): ise.data.replace(e.dst_conn, tmp_symname) # Remove subscript occurrences as well for aname, aval in ise.data.assignments.items(): vast = ast.parse(aval) vast = astutils.RemoveSubscripts( {tmp_symname}).visit(vast) ise.data.assignments[aname] = astutils.unparse( vast) ise.data.replace(tmp_symname + '[0]', tmp_symname) # Set symbol mapping dst.sdfg.remove_data(e.dst_conn, validate=False) dst.remove_in_connector(e.dst_conn) dst.sdfg.symbols[tmp_symname] = sdfg.arrays[ node.data].dtype dst.symbol_mapping[tmp_symname] = symname elif isinstance(dst, (nodes.EntryNode, nodes.ExitNode)): # Skip continue else: raise ValueError( 'Node type "%s" not supported for promotion' % type(dst).__name__) # If nodes were disconnected, reconnect with empty memlet if (isinstance(e.src, nodes.EntryNode) and len(state.edges_between(e.src, dst)) == 0): state.add_nedge(e.src, dst, mm.Memlet()) # Remove newly-isolated nodes state.remove_nodes_from( [n for n in scalar_nodes if len(state.all_edges(n)) == 0])
def apply(self, _, sdfg): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)