def get_vertical_loop_section_sdfg(section: "VerticalLoopSection") -> SDFG: from gtc.dace.nodes import HorizontalExecutionLibraryNode sdfg = SDFG("VerticalLoopSection_" + str(id(section))) old_state = sdfg.add_state("start_state", is_start_state=True) for he in section.horizontal_executions: new_state = sdfg.add_state("HorizontalExecution_" + str(id(he)) + "_state") sdfg.add_edge(old_state, new_state, InterstateEdge()) new_state.add_node(HorizontalExecutionLibraryNode(oir_node=he)) old_state = new_state return sdfg
def mkc(sdfg: dace.SDFG, state_before, src_name, dst_name, src_storage=None, dst_storage=None, src_shape=None, dst_shape=None, copy_expr=None, src_loc=None, dst_loc=None): """ Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided arrayname already exists it will use the old array, and ignore all newly passed values """ if copy_expr is None: copy_expr = src_name if (state_before == None): state = sdfg.add_state(is_start_state=True) else: state = sdfg.add_state_after(state_before) def mkarray(name, shape, storage, loc): if (name in sdfg.arrays): return sdfg.arrays[name] is_transient = False if (storage in _FPGA_STORAGE_TYPES): is_transient = True arr = sdfg.add_array(name, shape, dace.int32, storage, transient=is_transient) if loc is not None: arr[1].location["memorytype"] = loc[0] arr[1].location["bank"] = loc[1] return arr a = mkarray(src_name, src_shape, src_storage, src_loc) b = mkarray(dst_name, dst_shape, dst_storage, dst_loc) aAcc = state.add_access(src_name) bAcc = state.add_access(dst_name) edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr)) a_np_arr, b_np_arr = None, None if src_shape is not None: try: a_np_arr = np.zeros(src_shape, dtype=np.int32) except: pass if dst_shape is not None: try: b_np_arr = np.zeros(dst_shape, dtype=np.int32) except: pass return (state, a_np_arr, b_np_arr)
def _expand_and_finalize_sdfg(stencil_ir: gtir.Stencil, sdfg: dace.SDFG, layout_map) -> dace.SDFG: args_data = make_args_data_from_gtir(GtirPipeline(stencil_ir)) # stencils without effect if all(info is None for info in args_data.field_info.values()): sdfg = dace.SDFG(stencil_ir.name) sdfg.add_state(stencil_ir.name) return sdfg for array in sdfg.arrays.values(): if array.transient: array.lifetime = dace.AllocationLifetime.Persistent _pre_expand_trafos(sdfg) sdfg.expand_library_nodes(recursive=True) _specialize_transient_strides(sdfg, layout_map=layout_map) _post_expand_trafos(sdfg) return sdfg
def split_condition_interstate_edges(sdfg: dace.SDFG): edges_to_split = set() for isedge in sdfg.edges(): if (not isedge.data.is_unconditional() and len(isedge.data.assignments) > 0): edges_to_split.add(isedge) for ise in edges_to_split: sdfg.remove_edge(ise) interim = sdfg.add_state() sdfg.add_edge(ise.src, interim, dace.InterstateEdge(ise.data.condition)) sdfg.add_edge(interim, ise.dst, dace.InterstateEdge(assignments=ise.data.assignments))
class BaseOirSDFGBuilder(ABC): has_transients = True def __init__(self, name, stencil: Stencil, nodes): self._stencil = stencil self._sdfg = SDFG(name) self._state = self._sdfg.add_state(name + "_state") self._extents = nodes_extent_calculation(nodes) self._dtypes = { decl.name: decl.dtype for decl in stencil.declarations + stencil.params } self._axes = { decl.name: decl.dimensions for decl in stencil.declarations + stencil.params if isinstance(decl, FieldDecl) } self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict() self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict() self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._access_collection_cache: Dict[ int, AccessCollector.CartesianAccessCollection] = dict() self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._delete_candidates: List[MultiConnectorEdge] = list() def _access_space_to_subset(self, name, access_space): extent = self._extents[name] origin = (extent[0][0], extent[1][0]) subsets = [] if self._axes[name][0]: subsets.append("{start}:__I{end:+d}".format( start=origin[0] + access_space[0][0], end=origin[0] + access_space[0][1])) if self._axes[name][1]: subsets.append("{start}:__J{end:+d}".format( start=origin[1] + access_space[1][0], end=origin[1] + access_space[1][1])) return subsets def _are_nodes_ordered(self, name, node1, node2): assert name in self._access_nodes assert node1.data == name assert node2.data == name return self._access_nodes[name].index( node1) < self._access_nodes[name].index(node2) def _get_source(self, name): if name not in self._source_nodes: self._source_nodes[name] = self._state.add_read(name) if name not in self._access_nodes: self._access_nodes[name] = [] self._access_nodes[name].insert(0, self._source_nodes[name]) return self._source_nodes[name] def _get_new_sink(self, name): res = self._state.add_access(name) if name not in self._access_nodes: self._access_nodes[name] = [] self._access_nodes[name].append(res) return res def _get_current_sink(self, name): if name in self._access_nodes: return self._access_nodes[name][-1] return None def _get_access_collection( self, node: "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]" ) -> AccessCollector.CartesianAccessCollection: if isinstance(node, SDFG): res = AccessCollector.CartesianAccessCollection([]) for node in node.states()[0].nodes(): if isinstance( node, (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)): collection = self._get_access_collection(node) res._ordered_accesses.extend(collection._ordered_accesses) return res elif isinstance(node, HorizontalExecutionLibraryNode): if id(node.oir_node) not in self._access_collection_cache: self._access_collection_cache[id( node.oir_node)] = AccessCollector.apply( node.oir_node).cartesian_accesses() return self._access_collection_cache[id(node.oir_node)] else: assert isinstance(node, VerticalLoopLibraryNode) res = AccessCollector.CartesianAccessCollection([]) for _, sdfg in node.sections: collection = self._get_access_collection(sdfg) res._ordered_accesses.extend(collection._ordered_accesses) return res def _get_recent_reads(self, name, interval): if name not in self._recent_read_acc: self._recent_read_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() return self._recent_read_acc[name][interval] def _get_recent_writes(self, name, interval): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() return self._recent_write_acc[name][interval] def _set_read(self, name, interval, node): if name not in self._recent_read_acc: self._recent_read_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() self._recent_read_acc[name][interval] = node def _set_write(self, name, interval, node): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() self._recent_write_acc[name][interval] = node def _reset_writes(self): self._recent_write_acc = dict() def _add_read_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): read_accesses: Dict[str, dace.nodes.AccessNode] = dict() for interval, access_collection in collections: for name in access_collection.read_fields(): for offset in access_collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) for candidate_access in self._get_recent_writes( name, read_interval): if name not in read_accesses or self._are_nodes_ordered( name, read_accesses[name], candidate_access): # candidate_access is downstream from recent_access, therefore candidate is more recent read_accesses[name] = candidate_access for interval, access_collection in collections: for name in access_collection.read_fields(): for offset in access_collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) if name not in read_accesses: read_accesses[name] = self._get_source(name) self._set_read(name, read_interval, read_accesses[name]) for name, recent_access in read_accesses.items(): node.add_in_connector("IN_" + name) self._state.add_edge(recent_access, None, node, "IN_" + name, dace.Memlet()) def _add_write_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): write_accesses = dict() for interval, access_collection in collections: for name in access_collection.write_fields(): access_node = self._get_current_sink(name) if access_node is None or ( (name not in write_accesses) and (access_node in self._get_recent_reads(name, interval) or access_node in self._get_recent_writes(name, interval) or nx.has_path(self._state.nx, access_node, node))): write_accesses[name] = self._get_new_sink(name) else: write_accesses[name] = access_node self._set_write(name, interval, write_accesses[name]) for name, access_node in write_accesses.items(): node.add_out_connector("OUT_" + name) self._state.add_edge(node, "OUT_" + name, access_node, None, dace.Memlet()) def _add_write_after_write_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): for interval, collection in collections: for name in collection.write_fields(): for src in self._get_recent_writes(name, interval): edge = self._state.add_edge(src, None, node, None, dace.Memlet()) self._delete_candidates.append(edge) def _add_write_after_read_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): for interval, collection in collections: for name in collection.read_fields(): for offset in collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) for dst in self._get_recent_writes(name, read_interval): edge = self._state.add_edge(node, None, dst, None, dace.Memlet()) self._delete_candidates.append(edge) for interval, collection in collections: for name in collection.write_fields(): self._set_write(name, interval, node) def add_node(self, node): self._state.add_node(node) def finalize(self): for edge in self._delete_candidates: assert edge.src_conn is None assert edge.dst_conn is None self._state.remove_edge(edge) if not nx.has_path(self._state.nx, edge.src, edge.dst): self._state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) self.add_subsets() self.add_arrays() for acc in (n for n in self._state.nodes() if isinstance(n, dace.nodes.AccessNode)): is_write = len(self._state.in_edges(acc)) > 0 and all( edge.data.data is not None for edge in self._state.in_edges(acc)) is_read = len(self._state.out_edges(acc)) > 0 and all( edge.data.data is not None for edge in self._state.out_edges(acc)) if is_read and is_write: acc.access = dace.AccessType.ReadWrite elif is_read: acc.access = dace.AccessType.ReadOnly else: assert is_write acc.access = dace.AccessType.WriteOnly def _get_sdfg(self): self.finalize() return self._sdfg def add_arrays(self): shapes = self.get_shapes() for decl in self._stencil.params + self._stencil.declarations: name = decl.name dtype = dace.dtypes.typeclass( np.dtype(data_type_to_typestr(self._dtypes[name])).name) if isinstance(decl, ScalarDecl): self._sdfg.add_symbol(name, stype=dtype) else: if name not in self._get_access_collection( self._sdfg).offsets(): continue assert name in self._dtypes strides = tuple( dace.symbolic.pystr_to_symbolic(f"__{name}_{var}_stride") for is_axis, var in zip(self._axes[name], "IJK") if is_axis ) + tuple( dace.symbolic.pystr_to_symbolic(f"__{name}_d{dim}_stride") for dim, _ in enumerate(decl.data_dims)) self._sdfg.add_array( name, dtype=dtype, shape=shapes[name], strides=strides, transient=isinstance(decl, Temporary) and self.has_transients, lifetime=dace.AllocationLifetime.Persistent, ) def add_subsets(self): decls = { decl.name: decl for decl in self._stencil.params + self._stencil.declarations } for node in self._state.nodes(): if isinstance(node, dace.nodes.LibraryNode): access_spaces_input, access_spaces_output = self.get_access_spaces( node) k_subset_strs_input, k_subset_strs_output = self.get_k_subsets( node) for edge in self._state.in_edges(node) + self._state.out_edges( node): if edge.dst_conn is not None: name = edge.src.data access_space = access_spaces_input[name] subset_str_k = k_subset_strs_input.get(name, None) dynamic = isinstance( node, HorizontalExecutionLibraryNode) and any( isinstance(stmt, oir.MaskStmt) for stmt in node.oir_node.body) elif edge.src_conn is not None: name = edge.dst.data access_space = access_spaces_output[name] subset_str_k = k_subset_strs_output.get(name, None) dynamic = False else: continue subset_strs = self._access_space_to_subset( name, access_space) if subset_str_k is not None: subset_strs.append(subset_str_k) for dim in decls[name].data_dims: subset_strs.append(f"0:{dim}") edge.data = dace.Memlet.simple( data=name, subset_str=",".join(subset_strs), dynamic=dynamic) @abstractmethod def get_k_size(self, name): pass @abstractmethod def add_read_edges(self, node): pass @abstractmethod def add_write_edges(self, node): pass @abstractmethod def add_write_after_read_edges(self, node): pass @abstractmethod def add_write_after_write_edges(self, node): pass @abstractmethod def get_k_subsets(self, node): pass @abstractmethod def get_access_spaces(self, node): pass @abstractmethod def get_shapes(self): pass @classmethod def build(cls, name, stencil: Stencil, nodes: List[dace.nodes.LibraryNode]): builder = cls(name, stencil, nodes) for n in nodes: builder.add_node(n) builder.add_write_after_write_edges(n) builder.add_read_edges(n) builder.add_write_edges(n) builder._reset_writes() for n in reversed(nodes): builder.add_write_after_read_edges(n) res = builder._get_sdfg() res.validate() return res
def make_sdfg(specialize): if specialize: sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format( P.get(), H.get(), W.get())) else: sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get())) copy_to_fpga_state = make_copy_to_fpga_state(sdfg) state = sdfg.add_state("compute") # Compute module nested_sdfg = make_compute_nested_sdfg(state) tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"}, {"hist_pipe_out"}) A_pipes_out = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) A_pipes_in = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_pipes_out = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) unroll_entry, unroll_exit = state.add_map( "unroll_compute", {"p": "0:P"}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet()) state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet()) state.add_memlet_path(A_pipes_in, tasklet, dst_conn="A_pipe_in", memlet=Memlet.simple(A_pipes_in, "p", num_accesses="W*H")) state.add_memlet_path(tasklet, hist_pipes_out, src_conn="hist_pipe_out", memlet=Memlet.simple(hist_pipes_out, "p", num_accesses="num_bins")) # Read module a_device = state.add_array("A_device", (H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) read_entry, read_exit = state.add_map("read_map", { "h": "0:H", "w": "0:W:P" }, schedule=ScheduleType.FPGA_Device) a_val = state.add_array("A_val", (P, ), dtype, transient=True, storage=StorageType.FPGA_Local) read_unroll_entry, read_unroll_exit = state.add_map( "read_unroll", {"p": "0:P"}, schedule=ScheduleType.FPGA_Device, unroll=True) read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"}, "A_pipe = A_in[p]") state.add_memlet_path(a_device, read_entry, a_val, memlet=Memlet(a_val, num_accesses=1, subset=Indices(["0"]), vector_length=P.get(), other_subset=Indices(["h", "w"]))) state.add_memlet_path(a_val, read_unroll_entry, read_tasklet, dst_conn="A_in", memlet=Memlet.simple(a_val, "0", veclen=P.get(), num_accesses=1)) state.add_memlet_path(read_tasklet, read_unroll_exit, read_exit, A_pipes_out, src_conn="A_pipe", memlet=Memlet.simple(A_pipes_out, "p")) # Write module hist_pipes_in = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_device_out = state.add_array( "hist_device", (num_bins, ), itype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"}, schedule=ScheduleType.FPGA_Device) merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ), "0", schedule=ScheduleType.FPGA_Device) state.add_memlet_path(hist_pipes_in, merge_entry, merge_reduce, memlet=Memlet.simple(hist_pipes_in, "0:P", num_accesses=P)) state.add_memlet_path(merge_reduce, merge_exit, hist_device_out, memlet=dace.memlet.Memlet.simple( hist_device_out, "nb")) copy_to_host_state = make_copy_to_host_state(sdfg) sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge()) sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge()) return sdfg
def backward( forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], required_gradients: typing.List[typing.Optional[str]] ) -> typing.Tuple[Node, BackwardResult]: reduction_type = detect_reduction_type(forward_node.wcr) if len(given_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one output edge" .format(forward_node)) if len(required_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one input edge" .format(forward_node)) input_name = next(iter(required_gradients)) in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) output_name = next(iter(given_gradients)) out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) all_axes: typing.List[int] = list(range(len(in_desc.shape))) reduce_axes: typing.List[ int] = all_axes if forward_node.axes is None else forward_node.axes non_reduce_axes: typing.List[int] = [ i for i in all_axes if i not in reduce_axes ] result = BackwardResult.empty() if reduction_type is dtypes.ReductionType.Sum: # in this case, we need to simply scatter the grad across the axes that were reduced sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") state = sdfg.add_state() rev_input_conn_name = "input_gradient" rev_output_conn_name = "output_gradient" result.required_grad_names[output_name] = rev_output_conn_name result.given_grad_names[input_name] = rev_input_conn_name _, rev_input_arr = sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype) _, rev_output_arr = sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype) state.add_mapped_tasklet( "_distribute_grad_" + str(reduction_type).replace(".", "_") + "_", { "i" + str(i): "0:{}".format(shape) for i, shape in enumerate(in_desc.shape) }, { "__in": Memlet.simple( rev_input_conn_name, "0" if forward_node.axes is None else ",".join( "i" + str(i) for i in non_reduce_axes)) }, "__out = __in", { "__out": Memlet.simple(rev_output_conn_name, ",".join("i" + str(i) for i in all_axes), wcr_str="lambda x, y: x + y") }, external_edges=True) return context.backward_state.add_nested_sdfg( sdfg, None, {rev_input_conn_name}, {rev_output_conn_name}), result else: raise AutoDiffException( "Unsupported reduction type '{}'".format(reduction_type))