def add_nodes_and_edges(self): in_memlets, out_memlets = self.get_innermost_memlets() map_ranges = { "i": get_interval_range_str(self.node.iteration_space.i_interval, "__I"), "j": get_interval_range_str(self.node.iteration_space.j_interval, "__J"), } inputs = [name[len("IN_"):] for name in self.node.in_connectors] outputs = [name[len("OUT_"):] for name in self.node.out_connectors] input_nodes = {name: self.res_state.add_read(name) for name in inputs} output_nodes = { name: self.res_state.add_write(name) for name in outputs } self.res_state.add_mapped_tasklet( self.node.name + "_tasklet", map_ranges=map_ranges, inputs=in_memlets, outputs=out_memlets, input_nodes=input_nodes, output_nodes=output_nodes, code=TaskletCodegen.apply(self.node.oir_node), external_edges=True, )
def add_nodes_and_edges(self): # for each section # acc -> map over k -> nsdfg with HE's in_accesses = dict() out_accesses = dict() for interval, section in self.node.sections: interval_str = get_interval_range_str(interval, "__K") map_entry, map_exit = self.res_state.add_map( section.name + "_map", ndrange={"k": interval_str} ) section_inputs = set() section_outputs = set() for acc in ( n for n, _ in section.all_nodes_recursive() if isinstance(n, dace.nodes.AccessNode) ): if acc.access != dace.AccessType.WriteOnly: if acc.data not in in_accesses: in_accesses[acc.data] = self.res_state.add_read(acc.data) section_inputs.add(acc.data) if acc.access != dace.AccessType.ReadOnly: if acc.data not in out_accesses: out_accesses[acc.data] = self.res_state.add_write(acc.data) section_outputs.add(acc.data) nsdfg = self.res_state.add_nested_sdfg( sdfg=section, parent=None, inputs=section_inputs, outputs=section_outputs, ) in_subsets, out_subsets = self.get_mapped_subsets_dicts(interval, section) if len(in_subsets) == 0: self.res_state.add_edge(map_entry, None, nsdfg, None, memlet=dace.memlet.Memlet()) if len(out_subsets) == 0: self.res_state.add_edge(nsdfg, None, map_exit, None, memlet=dace.memlet.Memlet()) for name, subset in in_subsets.items(): self.res_state.add_memlet_path( in_accesses[name], map_entry, nsdfg, src_conn=None, dst_conn=name, memlet=dace.memlet.Memlet.simple(name, subset), ) for name, subset in out_subsets.items(): self.res_state.add_memlet_path( nsdfg, map_exit, out_accesses[name], src_conn=name, dst_conn=None, memlet=dace.memlet.Memlet.simple(name, subset, dynamic=False), )