예제 #1
0
파일: expansion.py 프로젝트: fthaler/gt4py
    def add_nodes_and_edges(self):
        in_memlets, out_memlets = self.get_innermost_memlets()
        map_ranges = {
            "i":
            get_interval_range_str(self.node.iteration_space.i_interval,
                                   "__I"),
            "j":
            get_interval_range_str(self.node.iteration_space.j_interval,
                                   "__J"),
        }
        inputs = [name[len("IN_"):] for name in self.node.in_connectors]
        outputs = [name[len("OUT_"):] for name in self.node.out_connectors]
        input_nodes = {name: self.res_state.add_read(name) for name in inputs}
        output_nodes = {
            name: self.res_state.add_write(name)
            for name in outputs
        }

        self.res_state.add_mapped_tasklet(
            self.node.name + "_tasklet",
            map_ranges=map_ranges,
            inputs=in_memlets,
            outputs=out_memlets,
            input_nodes=input_nodes,
            output_nodes=output_nodes,
            code=TaskletCodegen.apply(self.node.oir_node),
            external_edges=True,
        )
예제 #2
0
파일: expansion.py 프로젝트: havogt/gt4py
    def add_nodes_and_edges(self):
        # for each section
        # acc -> map over k -> nsdfg with HE's
        in_accesses = dict()
        out_accesses = dict()

        for interval, section in self.node.sections:

            interval_str = get_interval_range_str(interval, "__K")
            map_entry, map_exit = self.res_state.add_map(
                section.name + "_map", ndrange={"k": interval_str}
            )

            section_inputs = set()
            section_outputs = set()
            for acc in (
                n for n, _ in section.all_nodes_recursive() if isinstance(n, dace.nodes.AccessNode)
            ):
                if acc.access != dace.AccessType.WriteOnly:
                    if acc.data not in in_accesses:
                        in_accesses[acc.data] = self.res_state.add_read(acc.data)
                    section_inputs.add(acc.data)
                if acc.access != dace.AccessType.ReadOnly:
                    if acc.data not in out_accesses:
                        out_accesses[acc.data] = self.res_state.add_write(acc.data)
                    section_outputs.add(acc.data)

            nsdfg = self.res_state.add_nested_sdfg(
                sdfg=section,
                parent=None,
                inputs=section_inputs,
                outputs=section_outputs,
            )
            in_subsets, out_subsets = self.get_mapped_subsets_dicts(interval, section)
            if len(in_subsets) == 0:
                self.res_state.add_edge(map_entry, None, nsdfg, None, memlet=dace.memlet.Memlet())
            if len(out_subsets) == 0:
                self.res_state.add_edge(nsdfg, None, map_exit, None, memlet=dace.memlet.Memlet())
            for name, subset in in_subsets.items():
                self.res_state.add_memlet_path(
                    in_accesses[name],
                    map_entry,
                    nsdfg,
                    src_conn=None,
                    dst_conn=name,
                    memlet=dace.memlet.Memlet.simple(name, subset),
                )
            for name, subset in out_subsets.items():
                self.res_state.add_memlet_path(
                    nsdfg,
                    map_exit,
                    out_accesses[name],
                    src_conn=name,
                    dst_conn=None,
                    memlet=dace.memlet.Memlet.simple(name, subset, dynamic=False),
                )