Esempio n. 1
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]]
        edge = sdfg.in_edges(state)[0]
        # Since inter-state assignments that use an assigned value leads to
        # undefined behavior (e.g., {m: n, n: m}), we can replace each
        # assignment separately.
        keys_to_remove = set()
        assignments_to_consider = _assignments_to_consider(sdfg, edge)
        for varname, assignment in assignments_to_consider.items():
            state.replace(varname, assignment)
            keys_to_remove.add(varname)

        repl_dict = {}

        for varname in keys_to_remove:
            # Remove assignments from edge
            del edge.data.assignments[varname]

            for e in sdfg.edges():
                if varname in e.data.free_symbols:
                    break
            else:
                # If removed assignment does not appear in any other edge,
                # replace and remove symbol
                if assignments_to_consider[varname] in sdfg.symbols:
                    repl_dict[varname] = assignments_to_consider[varname]
                if varname in sdfg.symbols:
                    sdfg.remove_symbol(varname)

        def _str_repl(s, d):
            for k, v in d.items():
                s.replace(str(k), str(v))

        if repl_dict:
            symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
Esempio n. 2
0
 def apply(self, sdfg):
     state = sdfg.nodes()[self.subgraph[StartStateElimination.start_state]]
     # Move assignments to the nested SDFG node's symbol mappings
     node = sdfg.parent_nsdfg_node
     edge = sdfg.out_edges(state)[0]
     for k, v in edge.data.assignments.items():
         node.symbol_mapping[k] = v
     sdfg.remove_node(state)
Esempio n. 3
0
 def apply(self, sdfg):
     state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]]
     # Handle orphan symbols (due to the deletion the incoming edge)
     edge = sdfg.in_edges(state)[0]
     sym_assign = edge.data.assignments.keys()
     sdfg.remove_node(state)
     # Remove orphan symbols
     for sym in sym_assign:
         if sym in sdfg.free_symbols:
             sdfg.remove_symbol(sym)
Esempio n. 4
0
 def apply(self, sdfg):
     state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]]
     edge = sdfg.in_edges(state)[0]
     # Since inter-state assignments that use an assigned value leads to
     # undefined behavior (e.g., {m: n, n: m}), we can replace each
     # assignment separately.
     for varname, assignment in edge.data.assignments.items():
         state.replace(varname, assignment)
     # Remove assignments from edge
     edge.data.assignments = {}
Esempio n. 5
0
    def apply(self, sdfg):
        fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]]
        sstate = sdfg.nodes()[self.subgraph[
            SymbolAliasPromotion._second_state]]

        edge = sdfg.edges_between(fstate, sstate)[0].data
        in_edge = sdfg.in_edges(fstate)[0].data

        to_consider = _alias_assignments(sdfg, edge)

        to_not_consider = set()
        for k, v in to_consider.items():
            # Remove symbols that are taking part in the edge's condition
            condsyms = [str(s) for s in edge.condition_sympy().free_symbols]
            if k in condsyms:
                to_not_consider.add(k)
            # Remove symbols that are set in the in_edge
            # with a different assignment
            if k in in_edge.assignments and in_edge.assignments[k] != v:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a symbol
            # and is set in the in_edge.
            if v in sdfg.symbols and v in in_edge.assignments:
                to_not_consider.add(k)
            # Remove symbols whose assignment (RHS) is a scalar
            # and is set in the first state.
            if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar):
                if any(
                        isinstance(n, nodes.AccessNode) and n.data == v
                        for n in fstate.nodes()):
                    to_not_consider.add(k)

        for k in to_not_consider:
            del to_consider[k]

        for k, v in to_consider.items():
            del edge.assignments[k]
            in_edge.assignments[k] = v
Esempio n. 6
0
 def copy_memory(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView,
                 state_id: int, src_node: nodes.Node, dst_node: nodes.Node,
                 edge: graph.MultiConnectorEdge,
                 function_stream: prettycode.CodeIOStream,
                 callsite_stream: prettycode.CodeIOStream):
     """
         Generate input/output memory copies from the array references to local variables (i.e. for the tasklet code).
     """
     if isinstance(edge.src, nodes.AccessNode) and isinstance(
             edge.dst, nodes.Tasklet):  # handle AccessNode->Tasklet
         if isinstance(dst_node.in_connectors[edge.dst_conn],
                       dtypes.pointer):  # pointer accessor
             line: str = "{} {} = &{}[0];".format(
                 dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn,
                 edge.src.data)
         elif isinstance(dst_node.in_connectors[edge.dst_conn],
                         dtypes.vector):  # vector accessor
             line: str = "{} {} = *({} *)(&{}[0]);".format(
                 dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn,
                 dst_node.in_connectors[edge.dst_conn].ctype, edge.src.data)
         else:  # scalar accessor
             arr = sdfg.arrays[edge.data.data]
             if isinstance(arr, data.Array):
                 line: str = "{}* {} = &{}[0];".format(
                     dst_node.in_connectors[edge.dst_conn].ctype,
                     edge.dst_conn, edge.src.data)
             elif isinstance(arr, data.Scalar):
                 line: str = "{} {} = {};".format(
                     dst_node.in_connectors[edge.dst_conn].ctype,
                     edge.dst_conn, edge.src.data)
     elif isinstance(edge.src, nodes.MapEntry) and isinstance(
             edge.dst, nodes.Tasklet):
         rtl_name = self.unique_name(edge.dst, sdfg.nodes()[state_id], sdfg)
         self.n_unrolled[rtl_name] = symbolic.evaluate(
             edge.src.map.range[0][1] + 1, sdfg.constants)
         line: str = f'{dst_node.in_connectors[edge.dst_conn]} {edge.dst_conn} = &{edge.data.data}[{edge.src.map.params[0]}*{edge.data.volume}];'
     else:
         raise RuntimeError(
             "Not handling copy_memory case of type {} -> {}.".format(
                 type(edge.src), type(edge.dst)))
     # write accessor to file
     callsite_stream.write(line)
Esempio n. 7
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        state = graph.nodes()[candidate[StateAssignElimination._end_state]]

        out_edges = graph.out_edges(state)
        in_edges = graph.in_edges(state)

        # We only match end states with one source and at least one assignment
        if len(in_edges) != 1:
            return False
        edge = in_edges[0]

        assignments_to_consider = _assignments_to_consider(sdfg, edge)

        # No assignments to eliminate
        if len(assignments_to_consider) == 0:
            return False

        # If this is an end state, there are no other edges to consider
        if len(out_edges) == 0:
            return True

        # Otherwise, ensure the symbols are never set/used again in edges
        akeys = set(assignments_to_consider.keys())
        for e in sdfg.edges():
            if e is edge:
                continue
            if e.data.free_symbols & akeys:
                return False

        # If used in any state that is not the current one, fail
        for s in sdfg.nodes():
            if s is state:
                continue
            if s.free_symbols & akeys:
                return False

        return True
Esempio n. 8
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            sdutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            sdutil.change_edge_src(sdfg, second_state, first_state)
            sdutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in sdutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in sdutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in sdutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        # First keep a backup of the topological sorted order of the nodes
        order = [
            x for x in reversed(list(nx.topological_sort(first_state._nx)))
            if isinstance(x, nodes.AccessNode)
        ]
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in second_input:
            if first_state.in_degree(node) == 0:
                n = next((x for x in order if x.label == node.label), None)
                if n:
                    sdutil.change_edge_src(first_state, node, n)
                    first_state.remove_node(node)
                    n.access = dtypes.AccessType.ReadWrite

        # Redirect edges and remove second state
        sdutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Esempio n. 9
0
File: rtl.py Progetto: sscholbe/dace
    def unparse_tasklet(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView,
                        state_id: int, node: nodes.Node,
                        function_stream: prettycode.CodeIOStream,
                        callsite_stream: prettycode.CodeIOStream):

        # extract data
        state = sdfg.nodes()[state_id]
        tasklet = node

        # construct variables paths
        unique_name: str = "{}_{}_{}_{}".format(tasklet.name, sdfg.sdfg_id,
                                                sdfg.node_id(state),
                                                state.node_id(tasklet))

        # Collect all of the input and output connectors into buses and scalars
        buses = {}
        scalars = {}
        for edge in state.in_edges(tasklet):
            arr = sdfg.arrays[edge.src.data]
            # catch symbolic (compile time variables)
            check_issymbolic([
                tasklet.in_connectors[edge.dst_conn].veclen,
                tasklet.in_connectors[edge.dst_conn].bytes
            ], sdfg)

            # extract parameters
            vec_len = int(
                symbolic.evaluate(tasklet.in_connectors[edge.dst_conn].veclen,
                                  sdfg.constants))
            total_size = int(
                symbolic.evaluate(tasklet.in_connectors[edge.dst_conn].bytes,
                                  sdfg.constants))
            if isinstance(arr, data.Array):
                if self.hardware_target:
                    raise NotImplementedError(
                        'Array input for hardware* not implemented')
                else:
                    buses[edge.dst_conn] = (False, total_size, vec_len)
            elif isinstance(arr, data.Stream):
                buses[edge.dst_conn] = (False, total_size, vec_len)
            elif isinstance(arr, data.Scalar):
                scalars[edge.dst_conn] = (False, total_size * 8)

        for edge in state.out_edges(tasklet):
            arr = sdfg.arrays[edge.dst.data]
            # catch symbolic (compile time variables)
            check_issymbolic([
                tasklet.out_connectors[edge.src_conn].veclen,
                tasklet.out_connectors[edge.src_conn].bytes
            ], sdfg)

            # extract parameters
            vec_len = int(
                symbolic.evaluate(tasklet.out_connectors[edge.src_conn].veclen,
                                  sdfg.constants))
            total_size = int(
                symbolic.evaluate(tasklet.out_connectors[edge.src_conn].bytes,
                                  sdfg.constants))
            if isinstance(arr, data.Array):
                if self.hardware_target:
                    raise NotImplementedError(
                        'Array input for hardware* not implemented')
                else:
                    buses[edge.src_conn] = (True, total_size, vec_len)
            elif isinstance(arr, data.Stream):
                buses[edge.src_conn] = (True, total_size, vec_len)
            elif isinstance(arr, data.Scalar):
                print('Scalar output not implemented')

        # generate system verilog module components
        parameter_string: str = self.generate_rtl_parameters(sdfg.constants)
        inputs, outputs = self.generate_rtl_inputs_outputs(buses, scalars)

        # create rtl code object (that is later written to file)
        self.code_objects.append(
            codeobject.CodeObject(
                name="{}".format(unique_name),
                code=RTLCodeGen.RTL_HEADER.format(name=unique_name,
                                                  parameters=parameter_string,
                                                  inputs="\n".join(inputs),
                                                  outputs="\n".join(outputs)) +
                tasklet.code.code + RTLCodeGen.RTL_FOOTER,
                language="sv",
                target=RTLCodeGen,
                title="rtl",
                target_type="{}".format(unique_name),
                additional_compiler_kwargs="",
                linkable=True,
                environments=None))

        if self.hardware_target:
            if self.vendor == 'xilinx':
                rtllib_config = {
                    "name": unique_name,
                    "buses": {
                        name: ('m_axis' if is_output else 's_axis', vec_len)
                        for name, (is_output, _, vec_len) in buses.items()
                    },
                    "params": {
                        "scalars": {
                            name: total_size
                            for name, (_, total_size) in scalars.items()
                        },
                        "memory": {}
                    },
                    "ip_cores": tasklet.ip_cores if isinstance(
                        tasklet, nodes.RTLTasklet) else {},
                }

                self.code_objects.append(
                    codeobject.CodeObject(name=f"{unique_name}_control",
                                          code=rtllib_control(rtllib_config),
                                          language="v",
                                          target=RTLCodeGen,
                                          title="rtl",
                                          target_type="{}".format(unique_name),
                                          additional_compiler_kwargs="",
                                          linkable=True,
                                          environments=None))

                self.code_objects.append(
                    codeobject.CodeObject(name=f"{unique_name}_top",
                                          code=rtllib_top(rtllib_config),
                                          language="v",
                                          target=RTLCodeGen,
                                          title="rtl",
                                          target_type="{}".format(unique_name),
                                          additional_compiler_kwargs="",
                                          linkable=True,
                                          environments=None))

                self.code_objects.append(
                    codeobject.CodeObject(name=f"{unique_name}_package",
                                          code=rtllib_package(rtllib_config),
                                          language="tcl",
                                          target=RTLCodeGen,
                                          title="rtl",
                                          target_type="scripts",
                                          additional_compiler_kwargs="",
                                          linkable=True,
                                          environments=None))

                self.code_objects.append(
                    codeobject.CodeObject(name=f"{unique_name}_synth",
                                          code=rtllib_synth(rtllib_config),
                                          language="tcl",
                                          target=RTLCodeGen,
                                          title="rtl",
                                          target_type="scripts",
                                          additional_compiler_kwargs="",
                                          linkable=True,
                                          environments=None))
            else:  # self.vendor != "xilinx"
                raise NotImplementedError(
                    'Only RTL codegen for Xilinx is implemented')
        else:  # not hardware_target
            # generate verilator simulation cpp code components
            inputs, outputs = self.generate_cpp_inputs_outputs(tasklet)
            valid_zeros, ready_zeros = self.generate_cpp_zero_inits(tasklet)
            vector_init = self.generate_cpp_vector_init(tasklet)
            num_elements = self.generate_cpp_num_elements(tasklet)
            internal_state_str, internal_state_var = self.generate_cpp_internal_state(
                tasklet)
            read_input_hs = self.generate_input_hs(tasklet)
            feed_elements = self.generate_feeding(tasklet, inputs)
            in_ptrs, out_ptrs = self.generate_ptrs(tasklet)
            export_elements = self.generate_exporting(tasklet, outputs)
            write_output_hs = self.generate_write_output_hs(tasklet)
            hs_flags = self.generate_hs_flags(tasklet)
            input_hs_toggle = self.generate_input_hs_toggle(tasklet)
            output_hs_toggle = self.generate_output_hs_toggle(tasklet)
            running_condition = self.generate_running_condition(tasklet)

            # add header code to stream
            if not self.cpp_general_header_added:
                sdfg.append_global_code(
                    cpp_code=RTLCodeGen.CPP_GENERAL_HEADER_TEMPLATE.format(
                        debug_include="// generic includes\n#include <iostream>"
                        if self.verilator_debug else ""))
                self.cpp_general_header_added = True
            sdfg.append_global_code(
                cpp_code=RTLCodeGen.CPP_MODEL_HEADER_TEMPLATE.format(
                    name=unique_name))

            # add main cpp code to stream
            callsite_stream.write(contents=RTLCodeGen.CPP_MAIN_TEMPLATE.format(
                name=unique_name,
                inputs=inputs,
                outputs=outputs,
                num_elements=str.join('\n', num_elements),
                vector_init=vector_init,
                valid_zeros=str.join('\n', valid_zeros),
                ready_zeros=str.join('\n', ready_zeros),
                read_input_hs=str.join('\n', read_input_hs),
                feed_elements=str.join('\n', feed_elements),
                in_ptrs=str.join('\n', in_ptrs),
                out_ptrs=str.join('\n', out_ptrs),
                export_elements=str.join('\n', export_elements),
                write_output_hs=str.join('\n', write_output_hs),
                hs_flags=str.join('\n', hs_flags),
                input_hs_toggle=str.join('\n', input_hs_toggle),
                output_hs_toggle=str.join('\n', output_hs_toggle),
                running_condition=str.join(' && ', running_condition),
                internal_state_str=internal_state_str,
                internal_state_var=internal_state_var,
                debug_sim_start="std::cout << \"SIM {name} START\" << std::endl;"
                if self.verilator_debug else "",
                debug_internal_state="""
// report internal state
VL_PRINTF("[t=%lu] ap_aclk=%u ap_areset=%u valid_i=%u ready_i=%u valid_o=%u ready_o=%u \\n",
    main_time, model->ap_aclk, model->ap_areset,
    model->valid_i, model->ready_i, model->valid_o, model->ready_o);
VL_PRINTF("{internal_state_str}\\n", {internal_state_var});
std::cout << std::flush;
""".format(internal_state_str=internal_state_str,
            internal_state_var=internal_state_var)
                if self.verilator_debug else "",
                debug_sim_end="std::cout << \"SIM {name} END\" << std::endl;"
                if self.verilator_debug else ""),
                                  sdfg=sdfg,
                                  state_id=state_id,
                                  node_id=node)
Esempio n. 10
0
 def apply(self, sdfg):
     state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]]
     sdfg.remove_node(state)
Esempio n. 11
0
    def unparse_tasklet(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView,
                        state_id: int, node: nodes.Node,
                        function_stream: prettycode.CodeIOStream,
                        callsite_stream: prettycode.CodeIOStream):

        # extract data
        state = sdfg.nodes()[state_id]
        tasklet = node

        # construct variables paths
        unique_name: str = "top_{}_{}_{}".format(sdfg.sdfg_id,
                                                 sdfg.node_id(state),
                                                 state.node_id(tasklet))

        # generate system verilog module components
        parameter_string: str = self.generate_rtl_parameters(sdfg.constants)
        inputs, outputs = self.generate_rtl_inputs_outputs(sdfg, tasklet)

        # create rtl code object (that is later written to file)
        self.code_objects.append(
            codeobject.CodeObject(
                name="{}".format(unique_name),
                code=RTLCodeGen.RTL_HEADER.format(name=unique_name,
                                                  parameters=parameter_string,
                                                  inputs="\n".join(inputs),
                                                  outputs="\n".join(outputs)) +
                tasklet.code.code + RTLCodeGen.RTL_FOOTER,
                language="sv",
                target=RTLCodeGen,
                title="rtl",
                target_type="",
                additional_compiler_kwargs="",
                linkable=True,
                environments=None))

        # generate verilator simulation cpp code components
        inputs, outputs = self.generate_cpp_inputs_outputs(tasklet)
        vector_init = self.generate_cpp_vector_init(tasklet)
        num_elements = self.generate_cpp_num_elements()
        internal_state_str, internal_state_var = self.generate_cpp_internal_state(
            tasklet)

        # add header code to stream
        if not self.cpp_general_header_added:
            sdfg.append_global_code(
                cpp_code=RTLCodeGen.CPP_GENERAL_HEADER_TEMPLATE.format(
                    debug_include="// generic includes\n#include <iostream>"
                    if self.verilator_debug else ""))
            self.cpp_general_header_added = True
        sdfg.append_global_code(
            cpp_code=RTLCodeGen.CPP_MODEL_HEADER_TEMPLATE.format(
                name=unique_name))

        # add main cpp code to stream
        callsite_stream.write(contents=RTLCodeGen.CPP_MAIN_TEMPLATE.format(
            name=unique_name,
            inputs=inputs,
            outputs=outputs,
            num_elements=num_elements,
            vector_init=vector_init,
            internal_state_str=internal_state_str,
            internal_state_var=internal_state_var,
            debug_sim_start="std::cout << \"SIM {name} START\" << std::endl;"
            if self.verilator_debug else "",
            debug_feed_element="std::cout << \"feed new element\" << std::endl;"
            if self.verilator_debug else "",
            debug_export_element="std::cout << \"export element\" << std::endl;"
            if self.verilator_debug else "",
            debug_internal_state="""
// report internal state 
VL_PRINTF("[t=%lu] clk_i=%u rst_i=%u valid_i=%u ready_i=%u valid_o=%u ready_o=%u \\n", main_time, model->clk_i, model->rst_i, model->valid_i, model->ready_i, model->valid_o, model->ready_o);
VL_PRINTF("{internal_state_str}\\n", {internal_state_var});
std::cout << std::flush;
""".format(internal_state_str=internal_state_str,
           internal_state_var=internal_state_var)
            if self.verilator_debug else "",
            debug_read_input_hs=
            "std::cout << \"remove read_input_hs flag\" << std::endl;"
            if self.verilator_debug else "",
            debug_output_hs=
            "std::cout << \"remove write_output_hs flag\" << std::endl;"
            if self.verilator_debug else "",
            debug_sim_end="std::cout << \"SIM {name} END\" << std::endl;"
            if self.verilator_debug else ""),
                              sdfg=sdfg,
                              state_id=state_id,
                              node_id=node)
Esempio n. 12
0
    def apply(self, sdfg):
        first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]]
        second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]]

        # Remove interstate edge(s)
        edges = sdfg.edges_between(first_state, second_state)
        for edge in edges:
            if edge.data.assignments:
                for src, dst, other_data in sdfg.in_edges(first_state):
                    other_data.assignments.update(edge.data.assignments)
            sdfg.remove_edge(edge)

        # Special case 1: first state is empty
        if first_state.is_empty():
            nxutil.change_edge_dest(sdfg, first_state, second_state)
            sdfg.remove_node(first_state)
            return

        # Special case 2: second state is empty
        if second_state.is_empty():
            nxutil.change_edge_src(sdfg, second_state, first_state)
            nxutil.change_edge_dest(sdfg, second_state, first_state)
            sdfg.remove_node(second_state)
            return

        # Normal case: both states are not empty

        # Find source/sink (data) nodes
        first_input = [
            node for node in nxutil.find_source_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        first_output = [
            node for node in nxutil.find_sink_nodes(first_state)
            if isinstance(node, nodes.AccessNode)
        ]
        second_input = [
            node for node in nxutil.find_source_nodes(second_state)
            if isinstance(node, nodes.AccessNode)
        ]

        # first input = first input - first output
        first_input = [
            node for node in first_input
            if next((x for x in first_output
                     if x.label == node.label), None) is None
        ]

        # Merge second state to first state
        for node in second_state.nodes():
            first_state.add_node(node)
        for src, src_conn, dst, dst_conn, data in second_state.edges():
            first_state.add_edge(src, src_conn, dst, dst_conn, data)

        # Merge common (data) nodes
        for node in first_input:
            try:
                old_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_src(first_state, old_node, node)
            first_state.remove_node(old_node)
            second_input.remove(old_node)
        for node in first_output:
            try:
                new_node = next(x for x in second_input
                                if x.label == node.label)
            except StopIteration:
                continue
            nxutil.change_edge_dest(first_state, node, new_node)
            first_state.remove_node(node)
            second_input.remove(new_node)

        # Redirect edges and remove second state
        nxutil.change_edge_src(sdfg, second_state, first_state)
        sdfg.remove_node(second_state)
        if Config.get_bool("debugprint"):
            StateFusion._states_fused += 1
Esempio n. 13
0
    def apply(self, sdfg):
        """ The method creates two nested maps. The inner map ranges over the
            reduction axes, while the outer map ranges over the rest of the 
            input dimensions. The inner map contains a trivial tasklet, while
            the outgoing edges copy the reduction WCR.
        """
        graph = sdfg.nodes()[self.state_id]
        red_node = graph.nodes()[self.subgraph[ReduceExpansion._reduce]]

        inputs = []
        in_memlets = []
        for src, _, _, _, memlet in graph.in_edges(red_node):
            if src not in inputs:
                inputs.append(src)
                in_memlets.append(memlet)
        if len(inputs) > 1:
            raise NotImplementedError

        outputs = []
        out_memlets = []
        for _, _, dst, _, memlet in graph.out_edges(red_node):
            if dst not in outputs:
                outputs.append(dst)
                out_memlets.append(memlet)
        if len(outputs) > 1:
            raise NotImplementedError

        axes = red_node.axes
        if axes is None:
            axes = tuple(i for i in range(in_memlets[0].subset.dims()))

        outer_map_range = {}
        inner_map_range = {}
        for idx, r in enumerate(in_memlets[0].subset):
            if idx in axes:
                inner_map_range.update({
                    "__dim_{}".format(str(idx)):
                    subsets.Range.dim_to_string(r)
                })
            else:
                outer_map_range.update({
                    "__dim_{}".format(str(idx)):
                    subsets.Range.dim_to_string(r)
                })

        if len(outer_map_range) > 0:
            outer_map_entry, outer_map_exit = graph.add_map(
                'reduce_outer', outer_map_range, schedule=red_node.schedule)

        inner_map_entry, inner_map_exit = graph.add_map(
            'reduce_inner',
            inner_map_range,
            schedule=(dtypes.ScheduleType.Default
                      if len(outer_map_range) > 0 else red_node.schedule))

        tasklet = graph.add_tasklet(name='red_tasklet',
                                    inputs={'in_1'},
                                    outputs={'out_1'},
                                    code='out_1 = in_1')

        inner_map_entry.in_connectors = {'IN_1'}
        inner_map_entry.out_connectors = {'OUT_1'}

        outer_in_memlet = dcpy(in_memlets[0])

        if len(outer_map_range) > 0:
            outer_map_entry.in_connectors = {'IN_1'}
            outer_map_entry.out_connectors = {'OUT_1'}
            graph.add_edge(inputs[0], None, outer_map_entry, 'IN_1',
                           outer_in_memlet)
        else:
            graph.add_edge(inputs[0], None, inner_map_entry, 'IN_1',
                           outer_in_memlet)

        med_in_memlet = dcpy(in_memlets[0])
        med_in_range = []
        for idx, r in enumerate(med_in_memlet.subset):
            if idx in axes:
                med_in_range.append(r)
            else:
                med_in_range.append(("__dim_{}".format(str(idx)),
                                     "__dim_{}".format(str(idx)), 1))
        med_in_memlet.subset = subsets.Range(med_in_range)
        med_in_memlet.num_accesses = med_in_memlet.subset.num_elements()

        if len(outer_map_range) > 0:
            graph.add_edge(outer_map_entry, 'OUT_1', inner_map_entry, 'IN_1',
                           med_in_memlet)

        inner_in_memlet = dcpy(med_in_memlet)
        inner_in_idx = []
        for idx in range(len(inner_in_memlet.subset)):
            inner_in_idx.append("__dim_{}".format(str(idx)))
        inner_in_memlet.subset = subsets.Indices(inner_in_idx)
        inner_in_memlet.num_accesses = inner_in_memlet.subset.num_elements()
        graph.add_edge(inner_map_entry, 'OUT_1', tasklet, 'in_1',
                       inner_in_memlet)
        inner_map_exit.in_connectors = {'IN_1'}
        inner_map_exit.out_connectors = {'OUT_1'}

        inner_out_memlet = dcpy(out_memlets[0])
        inner_out_idx = []
        for idx, r in enumerate(inner_in_memlet.subset):
            if idx not in axes:
                inner_out_idx.append(r)
        if len(inner_out_idx) == 0:
            inner_out_idx = [0]

        inner_out_memlet.subset = subsets.Indices(inner_out_idx)
        inner_out_memlet.wcr = red_node.wcr
        inner_out_memlet.num_accesses = inner_out_memlet.subset.num_elements()
        graph.add_edge(tasklet, 'out_1', inner_map_exit, 'IN_1',
                       inner_out_memlet)

        outer_out_memlet = dcpy(out_memlets[0])
        outer_out_range = []
        for idx, r in enumerate(outer_out_memlet.subset):
            if idx not in axes:
                outer_out_range.append(r)
        if len(outer_out_range) == 0:
            outer_out_range = [(0, 0, 1)]

        outer_out_memlet.subset = subsets.Range(outer_out_range)
        outer_out_memlet.wcr = red_node.wcr

        if len(outer_map_range) > 0:
            outer_map_exit.in_connectors = {'IN_1'}
            outer_map_exit.out_connectors = {'OUT_1'}
            med_out_memlet = dcpy(inner_out_memlet)
            med_out_memlet.num_accesses = med_out_memlet.subset.num_elements()
            graph.add_edge(inner_map_exit, 'OUT_1', outer_map_exit, 'IN_1',
                           med_out_memlet)

            graph.add_edge(outer_map_exit, 'OUT_1', outputs[0], None,
                           outer_out_memlet)
        else:
            graph.add_edge(inner_map_exit, 'OUT_1', outputs[0], None,
                           outer_out_memlet)

        graph.remove_edge(graph.in_edges(red_node)[0])
        graph.remove_edge(graph.out_edges(red_node)[0])
        graph.remove_node(red_node)

        return