Пример #1
0
    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
                continue
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):
                reshapes.add(aname)

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                                                 datadesc,
                                                 find_new_name=True)
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
            else:
                newname = dname
            newname, _ = sdfg.add_view(newname,
                                       desc.shape,
                                       desc.dtype,
                                       storage=desc.storage,
                                       strides=desc.strides,
                                       offset=desc.offset,
                                       debuginfo=desc.debuginfo,
                                       allow_conflicts=desc.allow_conflicts,
                                       total_size=desc.total_size,
                                       alignment=desc.alignment,
                                       may_alias=desc.may_alias,
                                       find_new_name=True)
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.remove_edge(e)
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,
                                        e.data)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
                                  True)
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
                                  False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)
Пример #2
0
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
                inputs.append(edge)
        for edge in state.out_edges(node):
            if edge.dst not in snodes:
                outputs.append(edge)

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
                input_arrays.add(node.data)
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name
    ###################

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge
            continue

        name = input_names[edge]
        if name in reconnected_in:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
        reconnected_in.add(name)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge
            continue

        name = output_names[edge]
        if name in reconnected_out:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)
        reconnected_out.add(name)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
            state.remove_node(edge.src)
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:
            state.remove_node(edge.dst)

    return nested_sdfg
Пример #3
0
def generate_reference(name, chain):
    """Generates a simple, unoptimized SDFG to run on the CPU, for verification
       purposes."""

    sdfg = SDFG(name)

    for k, v in chain.constants.items():
        sdfg.add_constant(k, v["value"], dace.data.Scalar(v["data_type"]))

    (dimensions_to_skip, shape, vector_length, parameters, iterators,
     memcopy_indices, memcopy_accesses) = _generate_init(chain)

    prev_state = sdfg.add_state("init")

    # Throw vectorization in the bin for the reference code
    vector_length = 1

    shape = tuple(map(int, shape))

    input_shapes = {}  # Maps inputs to their shape tuple

    for node in chain.graph.nodes():
        if isinstance(node, Input) or isinstance(node, Output):
            if isinstance(node, Input):
                for output in node.outputs.values():
                    pars = tuple(
                        output["input_dims"]
                    ) if "input_dims" in output and output[
                        "input_dims"] is not None else tuple(parameters)
                    arr_shape = tuple(s for s, p in zip(shape, parameters)
                                      if p in pars)
                    input_shapes[node.name] = arr_shape
                    break
                else:
                    raise ValueError("No outputs found for input node.")
            else:
                arr_shape = shape
            if len(arr_shape) > 0:
                try:
                    sdfg.add_array(node.name, arr_shape, node.data_type)
                except NameError:
                    sdfg.data(
                        node.name).access = dace.dtypes.AccessType.ReadWrite
            else:
                sdfg.add_symbol(node.name, node.data_type)

    for link in chain.graph.edges(data=True):
        name = link[0].name
        if name not in sdfg.arrays and name not in sdfg.symbols:
            sdfg.add_array(name, shape, link[0].data_type, transient=True)
            input_shapes[name] = tuple(shape)

    input_iterators = {
        k: tuple("0:{}".format(s) for s in v)
        for k, v in input_shapes.items()
    }

    # Enforce dependencies via topological sort
    for node in nx.topological_sort(chain.graph):

        if not isinstance(node, Kernel):
            continue

        state = sdfg.add_state(node.name)
        sdfg.add_edge(prev_state, state, dace.InterstateEdge())

        (stencil_node, input_to_connector,
         output_to_connector) = _generate_stencil(node, chain, shape,
                                                  dimensions_to_skip)
        stencil_node.implementation = "CPU"

        for field, connector in input_to_connector.items():

            if len(input_iterators[field]) == 0:
                continue  # Scalar variable

            # Outer memory read
            read_node = state.add_read(field)
            state.add_memlet_path(read_node,
                                  stencil_node,
                                  dst_conn=connector,
                                  memlet=Memlet.simple(
                                      field,
                                      ", ".join(input_iterators[field])))

        for _, connector in output_to_connector.items():

            # Outer write
            write_node = state.add_write(node.name)
            state.add_memlet_path(stencil_node,
                                  write_node,
                                  src_conn=connector,
                                  memlet=Memlet.simple(
                                      node.name, ", ".join("0:{}".format(s)
                                                           for s in shape)))

        prev_state = state

    return sdfg
Пример #4
0
    def apply(self, outer_state: SDFGState, sdfg: SDFG):
        nsdfg_node = self.nested_sdfg
        nsdfg: SDFG = nsdfg_node.sdfg

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Environments
        for nstate in nsdfg.nodes():
            for node in nstate.nodes():
                if isinstance(node, nodes.CodeNode):
                    node.environments |= nsdfg_node.environments

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Symbols
        outer_symbols = {str(k): v for k, v in sdfg.symbols.items()}
        for ise in sdfg.edges():
            outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols))

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in outer_state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in outer_state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict)

        # Access nodes that need to be reshaped
        # reshapes: Set(str) = set()
        # for aname, array in nsdfg.arrays.items():
        #     if array.transient:
        #         continue
        #     edge = None
        #     if aname in inputs:
        #         edge = inputs[aname]
        #         if len(array.shape) > len(edge.data.subset):
        #             reshapes.add(aname)
        #             continue
        #     if aname in outputs:
        #         edge = outputs[aname]
        #         if len(array.shape) > len(edge.data.subset):
        #             reshapes.add(aname)
        #             continue
        #     if edge is not None and not InlineMultistateSDFG._check_strides(
        #             array.strides, sdfg.arrays[edge.data.data].strides,
        #             edge.data, nsdfg_node):
        #         reshapes.add(aname)

        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}

        # All transients become transients of the parent (if data already
        # exists, find new name)
        for nstate in nsdfg.nodes():
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    datadesc = nsdfg.arrays[node.data]
                    if node.data not in transients and datadesc.transient:
                        new_name = node.data
                        if (new_name in sdfg.arrays
                                or new_name in outer_symbols
                                or new_name in sdfg.constants):
                            new_name = f'{nsdfg.label}_{node.data}'

                        name = sdfg.add_datadesc(new_name,
                                                 datadesc,
                                                 find_new_name=True)
                        transients[node.data] = name

            # All transients of edges between code nodes are also added to parent
            for edge in nstate.edges():
                if (isinstance(edge.src, nodes.CodeNode)
                        and isinstance(edge.dst, nodes.CodeNode)):
                    if edge.data.data is not None:
                        datadesc = nsdfg.arrays[edge.data.data]
                        if edge.data.data not in transients and datadesc.transient:
                            new_name = edge.data.data
                            if (new_name in sdfg.arrays
                                    or new_name in outer_symbols
                                    or new_name in sdfg.constants):
                                new_name = f'{nsdfg.label}_{edge.data.data}'

                            name = sdfg.add_datadesc(new_name,
                                                     datadesc,
                                                     find_new_name=True)
                            transients[edge.data.data] = name

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        symbolic.safe_replace(repldict,
                              lambda m: replace_datadesc_names(nsdfg, m),
                              value_as_string=True)

        # Add views whenever reshapes are necessary
        # for dname in reshapes:
        #     desc = nsdfg.arrays[dname]
        #     # To avoid potential confusion, rename protected __return keyword
        #     if dname.startswith('__return'):
        #         newname = f'{nsdfg.name}_ret{dname[8:]}'
        #     else:
        #         newname = dname
        #     newname, _ = sdfg.add_view(newname,
        #                                desc.shape,
        #                                desc.dtype,
        #                                storage=desc.storage,
        #                                strides=desc.strides,
        #                                offset=desc.offset,
        #                                debuginfo=desc.debuginfo,
        #                                allow_conflicts=desc.allow_conflicts,
        #                                total_size=desc.total_size,
        #                                alignment=desc.alignment,
        #                                may_alias=desc.may_alias,
        #                                find_new_name=True)
        #     repldict[dname] = newname

        # Add extra access nodes for out/in view nodes
        # inv_reshapes = {repldict[r]: r for r in reshapes}
        # for nstate in nsdfg.nodes():
        #     for node in nstate.nodes():
        #         if isinstance(node,
        #                       nodes.AccessNode) and node.data in inv_reshapes:
        #             if nstate.in_degree(node) > 0 and nstate.out_degree(
        #                     node) > 0:
        #                 # Such a node has to be in the output set
        #                 edge = outputs[inv_reshapes[node.data]]

        #                 # Redirect outgoing edges through access node
        #                 out_edges = list(nstate.out_edges(node))
        #                 anode = nstate.add_access(edge.data.data)
        #                 vnode = nstate.add_access(node.data)
        #                 nstate.add_nedge(node, anode, edge.data)
        #                 nstate.add_nedge(anode, vnode, edge.data)
        #                 for e in out_edges:
        #                     nstate.remove_edge(e)
        #                     nstate.add_edge(vnode, e.src_conn, e.dst,
        #                                     e.dst_conn, e.data)

        # Make unique names for states
        statenames = set(s.label for s in sdfg.nodes())
        for nstate in nsdfg.nodes():
            if nstate.label in statenames:
                newname = data.find_new_name(nstate.label, statenames)
                statenames.add(newname)
                nstate.set_label(newname)

        #######################################################
        # Collect and modify interstate edges as necessary

        outer_assignments = set()
        for e in sdfg.edges():
            outer_assignments |= e.data.assignments.keys()

        inner_assignments = set()
        for e in nsdfg.edges():
            inner_assignments |= e.data.assignments.keys()

        assignments_to_replace = inner_assignments & outer_assignments
        sym_replacements: Dict[str, str] = {}
        allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys())
        for assign in assignments_to_replace:
            newname = data.find_new_name(assign, allnames)
            allnames.add(newname)
            sym_replacements[assign] = newname
        nsdfg.replace_dict(sym_replacements)

        #######################################################
        # Add nested SDFG states into top-level SDFG

        outer_start_state = sdfg.start_state

        sdfg.add_nodes_from(nsdfg.nodes())
        for ise in nsdfg.edges():
            sdfg.add_edge(ise.src, ise.dst, ise.data)

        #######################################################
        # Reconnect inlined SDFG

        source = nsdfg.start_state
        sinks = nsdfg.sink_nodes()

        # Reconnect state machine
        for e in sdfg.in_edges(outer_state):
            sdfg.add_edge(e.src, source, e.data)
        for e in sdfg.out_edges(outer_state):
            for sink in sinks:
                sdfg.add_edge(sink, e.dst, e.data)

        # Modify start state as necessary
        if outer_start_state is outer_state:
            sdfg.start_state = sdfg.node_id(source)

        # TODO: Modify memlets by offsetting
        # If both source and sink nodes are inputs/outputs, reconnect once
        # edges_to_ignore = self._modify_access_to_access(new_incoming_edges,
        #                                                 nsdfg, nstate, state,
        #                                                 orig_data)

        # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()}
        # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()}
        # # If a source/sink node is one of the inputs/outputs, reconnect it,
        # # replacing memlets in outgoing/incoming paths
        # modified_edges = set()
        # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
        #                                            state, sink_to_outer, True,
        #                                            edges_to_ignore)
        # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
        #                                            state, source_to_outer,
        #                                            False, edges_to_ignore)

        # # Reshape: add connections to viewed data
        # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
        #                           True)
        # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
        #                           False)

        # Modify all other internal edges pertaining to input/output nodes
        # for nstate in nsdfg.nodes():
        #     for node in nstate.nodes():
        #         if isinstance(node, nodes.AccessNode):
        #             if node.data in input_set or node.data in output_set:
        #                 if node.data in input_set:
        #                     outer_edge = inputs[input_set[node.data]]
        #                 else:
        #                     outer_edge = outputs[output_set[node.data]]

        #                 for edge in state.all_edges(node):
        #                     if (edge not in modified_edges
        #                             and edge.data.data == node.data):
        #                         for e in state.memlet_tree(edge):
        #                             if e.data.data == node.data:
        #                                 e._data = helpers.unsqueeze_memlet(
        #                                     e.data, outer_edge.data)

        # Replace nested SDFG parents with new SDFG
        for nstate in nsdfg.nodes():
            nstate.parent = sdfg
            for node in nstate.nodes():
                if isinstance(node, nodes.NestedSDFG):
                    node.sdfg.parent_sdfg = sdfg
                    node.sdfg.parent_nsdfg_node = node

        #######################################################
        # Remove nested SDFG and state
        sdfg.remove_node(outer_state)

        return nsdfg.nodes()
Пример #5
0
def generate_sdfg(name,
                  chain,
                  synthetic_reads=False,
                  specialize_scalars=False):
    sdfg = SDFG(name)

    for k, v in chain.constants.items():
        sdfg.add_constant(k, v["value"], dace.data.Scalar(v["data_type"]))

    if specialize_scalars:
        for k, v in chain.inputs.items():
            if len(v["input_dims"]) == 0:
                try:
                    val = stencilflow.load_array(v)
                except FileNotFoundError:
                    continue
                print(f"Specialized constant {k} to {val}.")
                sdfg.add_constant(k, val)

    pre_state = sdfg.add_state("initialize")
    state = sdfg.add_state("compute")
    post_state = sdfg.add_state("finalize")

    sdfg.add_edge(pre_state, state, InterstateEdge())
    sdfg.add_edge(state, post_state, InterstateEdge())

    (dimensions_to_skip, shape, vector_length, parameters, iterators,
     memcopy_indices, memcopy_accesses) = _generate_init(chain)
    vshape = list(shape)  # Copy
    if vector_length > 1:
        vshape[-1] //= vector_length

    def add_input(node, bank):

        # Collapse iterators and shape if input is lower dimensional
        for output in node.outputs.values():
            try:
                input_pars = output["input_dims"][:]
            except (KeyError, TypeError):
                input_pars = list(parameters)  # Copy
            break  # Just needed any output to retrieve the dimensions
        else:
            raise ValueError("Input {} is not connected to anything.".format(
                node.name))
        # If scalar, just add a symbol
        if len(input_pars) == 0:
            sdfg.add_symbol(node.name, node.data_type)
            return  # We're done
        input_shape = [shape[list(parameters).index(i)] for i in input_pars]
        input_accesses = str(functools.reduce(operator.mul, input_shape, 1))
        # Only vectorize the read if the innermost dimensions is read
        input_vector_length = (vector_length
                               if input_pars[-1] == parameters[-1] else 1)
        input_vtype = (dace.dtypes.vector(node.data_type, input_vector_length)
                       if input_vector_length > 1 else node.data_type)
        input_vshape = list(input_shape)
        if input_vector_length > 1:
            input_vshape[-1] //= input_vector_length

        # Sort to get deterministic output
        outputs = sorted([e[1].name for e in chain.graph.out_edges(node)])

        out_memlets = ["_" + o for o in outputs]

        entry, exit = state.add_map("read_" + node.name,
                                    iterators,
                                    schedule=ScheduleType.FPGA_Device)

        if not synthetic_reads:  # Generate synthetic inputs without memory

            # Host-side array, which will be an input argument
            sdfg.add_array(node.name + "_host", input_shape, node.data_type)

            # Device-side copy
            _, array = sdfg.add_array(node.name,
                                      input_vshape,
                                      input_vtype,
                                      storage=StorageType.FPGA_Global,
                                      transient=True)
            array.location["bank"] = bank
            access_node = state.add_read(node.name)

            # Copy data to the FPGA
            copy_host = pre_state.add_read(node.name + "_host")
            copy_fpga = pre_state.add_write(node.name)
            pre_state.add_memlet_path(copy_host,
                                      copy_fpga,
                                      memlet=Memlet.simple(
                                          copy_fpga,
                                          ", ".join("0:{}".format(s)
                                                    for s in input_vshape),
                                          num_accesses=input_accesses))

            tasklet_code = "\n".join(
                ["{} = memory".format(o) for o in out_memlets])

            tasklet = state.add_tasklet("read_" + node.name, {"memory"},
                                        out_memlets, tasklet_code)

            vectorized_pars = input_pars
            # if input_vector_length > 1:
            #     vectorized_pars[-1] = "{}*{}".format(input_vector_length,
            #                                          vectorized_pars[-1])

            # Lower-dimensional arrays should buffer values and send them
            # multiple times
            is_lower_dim = len(input_shape) != len(shape)
            if is_lower_dim:
                buffer_name = node.name + "_buffer"
                sdfg.add_array(buffer_name,
                               input_shape,
                               input_vtype,
                               storage=StorageType.FPGA_Local,
                               transient=True)
                buffer_node = state.add_access(buffer_name)
                buffer_entry, buffer_exit = state.add_map(
                    "buffer_" + node.name, {
                        k: "0:{}".format(v)
                        for k, v in zip(input_pars, input_shape)
                    },
                    schedule=dace.ScheduleType.FPGA_Device)
                buffer_tasklet = state.add_tasklet("buffer_" + node.name,
                                                   {"memory"}, {"buffer"},
                                                   "buffer = memory")
                state.add_memlet_path(access_node,
                                      buffer_entry,
                                      buffer_tasklet,
                                      dst_conn="memory",
                                      memlet=dace.Memlet.simple(
                                          access_node.data,
                                          ", ".join(vectorized_pars),
                                          num_accesses=1))
                state.add_memlet_path(buffer_tasklet,
                                      buffer_exit,
                                      buffer_node,
                                      src_conn="buffer",
                                      memlet=dace.Memlet.simple(
                                          buffer_node.data,
                                          ", ".join(input_pars),
                                          num_accesses=1))
                state.add_memlet_path(buffer_node,
                                      entry,
                                      tasklet,
                                      dst_conn="memory",
                                      memlet=dace.Memlet.simple(
                                          buffer_node.data,
                                          ", ".join(input_pars),
                                          num_accesses=1))
            else:

                state.add_memlet_path(access_node,
                                      entry,
                                      tasklet,
                                      dst_conn="memory",
                                      memlet=Memlet.simple(
                                          node.name,
                                          ", ".join(vectorized_pars),
                                          num_accesses=1))

        else:

            tasklet_code = "\n".join([
                "{} = {}".format(o, float(synthetic_reads))
                for o in out_memlets
            ])

            tasklet = state.add_tasklet("read_" + node.name, {}, out_memlets,
                                        tasklet_code)

            state.add_memlet_path(entry, tasklet, memlet=dace.Memlet())

        # Add memlets to all FIFOs connecting to compute units
        for out_name, out_memlet in zip(outputs, out_memlets):
            stream_name = "read_{}_to_{}".format(node.name, out_name)
            write_node = state.add_write(stream_name)
            state.add_memlet_path(tasklet,
                                  exit,
                                  write_node,
                                  src_conn=out_memlet,
                                  memlet=Memlet.simple(stream_name,
                                                       "0",
                                                       num_accesses=1))

    def add_output(node, bank):
        # Host-side array, which will be an output argument
        try:
            sdfg.add_array(node.name + "_host", shape, node.data_type)
            _, array = sdfg.add_array(node.name,
                                      vshape,
                                      dace.dtypes.vector(
                                          node.data_type, vector_length),
                                      storage=StorageType.FPGA_Global,
                                      transient=True)
            array.location["bank"] = bank
        except NameError:
            # This array is also read
            sdfg.data(node.name + "_host").access = dace.AccessType.ReadWrite
            sdfg.data(node.name).access = dace.AccessType.ReadWrite

        # Device-side copy
        write_node = state.add_write(node.name)

        # Copy data to the host
        copy_fpga = post_state.add_read(node.name)
        copy_host = post_state.add_write(node.name + "_host")
        post_state.add_memlet_path(copy_fpga,
                                   copy_host,
                                   memlet=Memlet.simple(
                                       copy_fpga,
                                       ", ".join(memcopy_indices),
                                       num_accesses=memcopy_accesses))

        entry, exit = state.add_map("write_" + node.name,
                                    iterators,
                                    schedule=ScheduleType.FPGA_Device)

        src = chain.graph.in_edges(node)
        if len(src) > 1:
            raise RuntimeError("Only one writer per output supported")
        src = next(iter(src))[0]

        in_memlet = "_" + src.name

        tasklet_code = "memory = " + in_memlet

        tasklet = state.add_tasklet("write_" + node.name, {in_memlet},
                                    {"memory"}, tasklet_code)

        vectorized_pars = copy.copy(parameters)
        # if vector_length > 1:
        #     vectorized_pars[-1] = "{}*{}".format(vector_length,
        #                                          vectorized_pars[-1])

        stream_name = "{}_to_write_{}".format(src.name, node.name)
        read_node = state.add_read(stream_name)

        state.add_memlet_path(read_node,
                              entry,
                              tasklet,
                              dst_conn=in_memlet,
                              memlet=Memlet.simple(stream_name,
                                                   "0",
                                                   num_accesses=1))

        state.add_memlet_path(tasklet,
                              exit,
                              write_node,
                              src_conn="memory",
                              memlet=Memlet.simple(node.name,
                                                   ", ".join(vectorized_pars),
                                                   num_accesses=1))

    def add_kernel(node):

        (stencil_node, input_to_connector,
         output_to_connector) = _generate_stencil(node, chain, shape,
                                                  dimensions_to_skip)

        if len(stencil_node.output_fields) == 0:
            if len(input_to_connector) == 0:
                warnings.warn("Ignoring orphan stencil: {}".format(node.name))
            else:
                raise ValueError("Orphan stencil with inputs: {}".format(
                    node.name))
            return

        vendor_str = dace.config.Config.get("compiler", "fpga_vendor")
        if vendor_str == "intel_fpga":
            stencil_node.implementation = "Intel FPGA"
        elif vendor_str == "xilinx":
            stencil_node.implementation = "Xilinx"
        else:
            raise ValueError(f"Unsupported FPGA backend: {vendor_str}")
        state.add_node(stencil_node)

        is_from_memory = {
            e[0].name: not isinstance(e[0], stencilflow.kernel.Kernel)
            for e in chain.graph.in_edges(node)
        }
        is_to_memory = {
            e[1].name: not isinstance(e[1], stencilflow.kernel.Kernel)
            for e in chain.graph.out_edges(node)
        }

        # Add read nodes and memlets
        for field_name, connector in input_to_connector.items():

            input_vector_length = vector_length
            try:
                # Scalars are symbols rather than data nodes
                if len(node.inputs[field_name]["input_dims"]) == 0:
                    continue
                else:
                    # If the innermost dimension of this field is not the
                    # vectorized one, read it as scalars
                    if (node.inputs[field_name]["input_dims"][-1] !=
                            parameters[-1]):
                        input_vector_length = 1
            except (KeyError, TypeError):
                pass  # input_dim is not defined or is None

            if is_from_memory[field_name]:
                stream_name = "read_{}_to_{}".format(field_name, node.name)
            else:
                stream_name = "{}_to_{}".format(field_name, node.name)

            # Outer memory read
            read_node = state.add_read(stream_name)
            state.add_memlet_path(read_node,
                                  stencil_node,
                                  dst_conn=connector,
                                  memlet=Memlet.simple(
                                      stream_name,
                                      "0",
                                      num_accesses=memcopy_accesses))

        # Add read nodes and memlets
        for output_name, connector in output_to_connector.items():

            # Add write node and memlet
            if is_to_memory[output_name]:
                stream_name = "{}_to_write_{}".format(node.name, output_name)
            else:
                stream_name = "{}_to_{}".format(node.name, output_name)

            # Outer write
            write_node = state.add_write(stream_name)
            state.add_memlet_path(stencil_node,
                                  write_node,
                                  src_conn=connector,
                                  memlet=Memlet.simple(
                                      stream_name,
                                      "0",
                                      num_accesses=memcopy_accesses))

    # First generate all connections between kernels and memories
    for link in chain.graph.edges(data=True):
        _add_pipe(sdfg, link, parameters, vector_length)

    bank = 0
    # Now generate all memory access functions so arrays are registered
    for node in chain.graph.nodes():
        if isinstance(node, Input):
            add_input(node, bank)
            bank = (bank + 1) % NUM_BANKS
        elif isinstance(node, Output):
            add_output(node, bank)
            bank = (bank + 1) % NUM_BANKS
        elif isinstance(node, Kernel):
            # Generate these separately after
            pass
        else:
            raise RuntimeError("Unexpected node type: {}".format(
                node.node_type))

    # Finally generate the compute kernels
    for node in chain.graph.nodes():
        if isinstance(node, Kernel):
            add_kernel(node)

    return sdfg