Exemplo n.º 1
def cutout_state(state: SDFGState, *nodes: nd.Node, make_copy: bool = True) -> SDFG:
    Cut out a subgraph of a state from an SDFG to run separately for localized testing or optimization.
    The subgraph defined by the list of nodes will be extended to include access nodes of data containers necessary
    to run the graph separately. In addition, all transient data containers created outside the cut out graph will
    become global.
    :param state: The SDFG state in which the subgraph resides.
    :param nodes: The nodes in the subgraph to cut out.
    :param make_copy: If True, deep-copies every SDFG element in the copy. Otherwise, original references are kept.
    create_element = copy.deepcopy if make_copy else (lambda x: x)
    sdfg = state.parent
    subgraph: StateSubgraphView = StateSubgraphView(state, nodes)
    subgraph = _extend_subgraph_with_access_nodes(state, subgraph)
    other_arrays = _containers_defined_outside(sdfg, state, subgraph)

    # Make a new SDFG with the included constants, used symbols, and data containers
    new_sdfg = SDFG(f'{state.parent.name}_cutout', sdfg.constants_prop)
    defined_syms = subgraph.defined_symbols()
    freesyms = subgraph.free_symbols
    for sym in freesyms:
        new_sdfg.add_symbol(sym, defined_syms[sym])

    for dnode in subgraph.data_nodes():
        if dnode.data in new_sdfg.arrays:
        new_desc = sdfg.arrays[dnode.data].clone()
        # If transient is defined outside, it becomes a global
        if dnode.data in other_arrays:
            new_desc.transient = False
        new_sdfg.add_datadesc(dnode.data, new_desc)

    # Add a single state with the extended subgraph
    new_state = new_sdfg.add_state(state.label, is_start_state=True)
    inserted_nodes: Dict[nd.Node, nd.Node] = {}
    for e in subgraph.edges():
        if e.src not in inserted_nodes:
            inserted_nodes[e.src] = create_element(e.src)
        if e.dst not in inserted_nodes:
            inserted_nodes[e.dst] = create_element(e.dst)
        new_state.add_edge(inserted_nodes[e.src], e.src_conn, inserted_nodes[e.dst], e.dst_conn, create_element(e.data))

    # Insert remaining isolated nodes
    for n in subgraph.nodes():
        if n not in inserted_nodes:
            inserted_nodes[n] = create_element(n)

    # Remove remaining dangling connectors from scope nodes
    for node in inserted_nodes.values():
        used_connectors = set(e.dst_conn for e in new_state.in_edges(node))
        for conn in (node.in_connectors.keys() - used_connectors):
        used_connectors = set(e.src_conn for e in new_state.out_edges(node))
        for conn in (node.out_connectors.keys() - used_connectors):

    return new_sdfg
Exemplo n.º 2
def promote_scalars_to_symbols(sdfg: sd.SDFG,
                               ignore: Optional[Set[str]] = None,
                               transients_only: bool = True,
                               integers_only: bool = True) -> Set[str]:
    Promotes all matching transient scalars to SDFG symbols, changing all
    tasklets to inter-state assignments. This enables the transformed symbols
    to be used within states as part of memlets, and allows further
    transformations (such as loop detection) to use the information for

    :param sdfg: The SDFG to run the pass on.
    :param ignore: An optional set of strings of scalars to ignore.
    :param transients_only: If False, also considers global data descriptors (e.g., arguments).
    :param integers_only: If False, also considers non-integral descriptors for promotion.
    :return: Set of promoted scalars.
    :note: Operates in-place.
    # Process:
    # 1. Find scalars to promote
    # 2. For every assignment tasklet/access:
    #    2.1. Fission state to isolate assignment
    #    2.2. Replace assignment with inter-state edge assignment
    # 3. For every read of the scalar:
    #    3.1. If destination is tasklet, remove node, edges, and connectors
    #    3.2. If used in tasklet as subscript or connector, modify tasklet code
    #    3.3. If destination is array, change to tasklet that copies symbol data
    # 4. Remove newly-isolated access nodes
    # 5. Remove data descriptors and add symbols to SDFG
    # 6. Replace subscripts in all interstate conditions and assignments
    # 7. Make indirections with symbols a single memlet
    to_promote = find_promotable_scalars(sdfg,
    if ignore:
        to_promote -= ignore
    if len(to_promote) == 0:
        return to_promote

    for state in sdfg.nodes():
        scalar_nodes = [
            n for n in state.nodes()
            if isinstance(n, nodes.AccessNode) and n.data in to_promote
        # Step 2: Assignment tasklets
        for node in scalar_nodes:
            if state.in_degree(node) == 0:
            in_edge = state.in_edges(node)[0]
            input = in_edge.src

            # There is only zero or one incoming edges by definition
            tasklet_inputs = [e.src for e in state.in_edges(input)]
            # Step 2.1
            new_state = xfh.state_fission(
                gr.SubgraphView(state, set([input, node] + tasklet_inputs)))
            new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0]
            # Step 2.2
            node: nodes.AccessNode = new_state.sink_nodes()[0]
            input = new_state.in_edges(node)[0].src
            if isinstance(input, nodes.Tasklet):
                # Convert tasklet to interstate edge
                newcode: str = ''
                if input.language is dtypes.Language.Python:
                    newcode = astutils.unparse(input.code.code[0].value)
                elif input.language is dtypes.Language.CPP:
                    newcode = translate_cpp_tasklet_to_python(

                # Replace tasklet inputs with incoming edges
                for e in new_state.in_edges(input):
                    memlet_str: str = e.data.data
                    if (e.data.subset is not None and not isinstance(
                            sdfg.arrays[memlet_str], dt.Scalar)):
                        memlet_str += '[%s]' % e.data.subset
                    newcode = re.sub(r'\b%s\b' % re.escape(e.dst_conn),
                                     memlet_str, newcode)
                # Add interstate edge assignment
                new_isedge.data.assignments[node.data] = newcode
            elif isinstance(input, nodes.AccessNode):
                memlet: mm.Memlet = in_edge.data
                if (memlet.src_subset and
                        not isinstance(sdfg.arrays[memlet.data], dt.Scalar)):
                        node.data] = '%s[%s]' % (input.data, memlet.src_subset)
                    new_isedge.data.assignments[node.data] = input.data

            # Clean up all nodes after assignment was transferred

    # Step 3: Scalar reads
    remove_scalar_reads(sdfg, {k: k for k in to_promote})

    # Step 4: Isolated nodes
    for state in sdfg.nodes():
        scalar_nodes = [
            n for n in state.nodes()
            if isinstance(n, nodes.AccessNode) and n.data in to_promote
            [n for n in scalar_nodes if len(state.all_edges(n)) == 0])

    # Step 5: Data descriptor management
    for scalar in to_promote:
        desc = sdfg.arrays[scalar]
        sdfg.remove_data(scalar, validate=False)
        # If the scalar is already a symbol (e.g., as part of an array size),
        # do not re-add the symbol
        if scalar not in sdfg.symbols:
            sdfg.add_symbol(scalar, desc.dtype)

    # Step 6: Inter-state edge cleanup
    cleanup_re = {
        s: re.compile(fr'\b{re.escape(s)}\[.*?\]')
        for s in to_promote
    promo = TaskletPromoterDict({k: k for k in to_promote})
    for edge in sdfg.edges():
        ise: InterstateEdge = edge.data
        # Condition
        if not edge.data.is_unconditional():
            if ise.condition.language is dtypes.Language.Python:
                for stmt in ise.condition.code:
            elif ise.condition.language is dtypes.Language.CPP:
                for scalar in to_promote:
                    ise.condition = cleanup_re[scalar].sub(
                        scalar, ise.condition.as_string)

        # Assignments
        for aname, assignment in ise.assignments.items():
            for scalar in to_promote:
                if scalar in assignment:
                    ise.assignments[aname] = cleanup_re[scalar].sub(
                        scalar, assignment.strip())

    # Step 7: Indirection

    return to_promote
Exemplo n.º 3
class ONNXModel:
    """Loads an ONNX model into an SDFG."""
    def __init__(self, name, model: onnx.ModelProto, cuda=False):
        Constructs a new ONNXImporter.
        :param name: the name for the SDFG.
        :param model: the model to import.
        :param cuda: if `True`, weights will be passed as cuda arrays.

        graph: onnx.GraphProto = model.graph

        self.sdfg = SDFG(name)
        self.cuda = cuda
        self.state = self.sdfg.add_state()

        # Add all values to the SDFG, check for unsupported ops

        self.value_infos = {}

        self.inputs = []
        self.outputs = []

        for value, is_input in chain(zip(graph.input, repeat(True)),
                                     zip(graph.output, repeat(False))):
            if not value.HasField("name"):
                raise ValueError("Got input or output without name")
            if is_input:

            self.value_infos[value.name] = value

        for value in graph.value_info:
            if not value.HasField("name"):
                raise ValueError("Got input or output without name")
            if value.name not in self.value_infos:
                self.value_infos[value.name] = value

        # add weights
        self.weights = {}
        for init in graph.initializer:

        access_nodes = {}
        self._idx_to_node = []
        for i, node in enumerate(graph.node):
            if not has_onnx_node(node.op_type):
                raise ValueError("Unsupported ONNX operator: '{}'".format(

            # extract the op attributes

            op_attributes = {
                attribute_proto.name: convert_attribute_proto(attribute_proto)
                for attribute_proto in node.attribute

            if node.HasField("name"):
                node_name = clean_onnx_name(node.name)
                node_name = node.op_type + "_" + str(i)

            # construct the dace node
            op_node = get_onnx_node(node.op_type)(node_name, **op_attributes)

            for param_idx, (name, is_input) in chain(
                    enumerate(zip(node.input, repeat(True))),
                    enumerate(zip(node.output, repeat(False)))):
                if clean_onnx_name(name) not in self.sdfg.arrays:
                    if name not in self.value_infos:
                        raise ValueError(
                            "Could not find array with name '{}'".format(name))

                # get the access node
                if name in access_nodes:
                    access = access_nodes[name]
                    self._update_access_type(access, is_input)
                    access = nd.AccessNode(
                        clean_onnx_name(name), AccessType.ReadOnly
                        if is_input else AccessType.WriteOnly)
                    access_nodes[name] = access

                # get the connector name
                params = op_node.schema.inputs if is_input else op_node.schema.outputs
                params_len = len(params)
                if param_idx >= params_len:
                    # this is a variadic parameter. Then the last parameter of the parameter must be variadic.
                    if params[-1].param_type != ONNXParameterType.Variadic:
                        raise ValueError(
                            "Expected the last {i_or_o} parameter to be variadic,"
                            " since the {i_or_o} with idx {param_idx} has more parameters than the schema ({params_len})"
                            .format(i_or_o="input" if is_input else "output",
                    conn_name = params[-1].name + "__" + str(param_idx -
                                                             params_len + 1)
                elif params[
                        param_idx].param_type == ONNXParameterType.Variadic:
                    # this is a variadic parameter, and it is within the range of params, so it must be the first
                    # instance of a variadic parameter
                    conn_name = params[param_idx].name + "__0"
                    conn_name = params[param_idx].name

                data_desc = self.sdfg.arrays[clean_onnx_name(name)]

                # add the connector if required, and add an edge
                if is_input:
                    if conn_name not in op_node.in_connectors:
                        access, None, op_node, conn_name,
                    if conn_name not in op_node.out_connectors:

                        op_node, conn_name, access, None,

        if self.cuda:

            # set all gpu transients to be persistent
            for _, _, arr in self.sdfg.arrays_recursive():
                if arr.transient and arr.storage == StorageType.GPU_Global:
                    arr.lifetime = AllocationLifetime.Persistent

    def _update_access_type(node: dace.nodes.AccessNode, is_input: bool):
        if node.access == AccessType.ReadOnly and not is_input:
            node.access = AccessType.ReadWrite
        elif node.access == AccessType.WriteOnly and is_input:
            node.access = AccessType.ReadWrite

    def _add_constant_tensor(self, tensor: onnx.TensorProto):
        if not tensor.HasField("name"):
            raise ValueError("Got tensor without name")

        if not tensor.HasField("data_type"):
            raise ValueError("Initializer tensor '{}' has no type".format(

        name = clean_onnx_name(tensor.name)

        dtype = onnx_tensor_type_to_typeclass(tensor.data_type)

        if len(tensor.dims) == 0:
            # this is a scalar
            self.sdfg.add_scalar(name, dtype)
            dims = [d for d in tensor.dims]
            if name not in self.sdfg.arrays:
                self.sdfg.add_array(name, dims, dtype)
                existing_arr = self.sdfg.arrays[name]
                if existing_arr.dtype != dtype:
                    raise ValueError(
                        "Invalid ONNX model; found two values with name '{}', but different dtypes ({} and {})"
                        .format(name, existing_arr.dtype, dtype))
                if tuple(existing_arr.shape) != tuple(dims):
                    raise ValueError(
                        "Invalid ONNX model; found two values with name '{}', but different dimensions ({} and {})"
                        .format(name, existing_arr.shape, dims))

        self.weights[tensor.name] = numpy_helper.to_array(tensor)

    def _add_value_info(self, value_info: onnx.ValueInfoProto):
        if not value_info.HasField("name"):
            raise ValueError("Got value without name")

        name = value_info.name

        if not _nested_HasField(value_info, "type.tensor_type.shape"):
            raise ValueError(
                "Value '{}' does not have a shape in this graph."
                " Please run shape inference before importing.".format(name))

        tensor_type = value_info.type.tensor_type

        if not tensor_type.HasField("elem_type"):
            raise ValueError(
                "Value '{}' does not have a type in this graph."
                " Please run type inference before importing.".format(name))

        shape = []
        for d in tensor_type.shape.dim:
            if d.HasField("dim_value"):
            elif d.HasField("dim_param"):
                parsed = pystr_to_symbolic(d.dim_param)

                for sym in parsed.free_symbols:
                    if clean_onnx_name(str(sym)) not in self.sdfg.symbols:
                    parsed = parsed.subs(
                        sym, dace.symbol(clean_onnx_name(str(sym))))

                raise ValueError(
                    "Value '{}' does not have a shape in this graph."
                    " Please run shape inference before importing.".format(
        transient = name not in self.inputs and name not in self.outputs
        if len(shape) == 0:

    def __call__(self, *args, **inputs):
        sdfg = deepcopy(self.sdfg)

        # convert the positional args to kwargs
        if len(args) > len(self.inputs):
            raise ValueError("Expected {} arguments, got {}".format(
                len(self.inputs), len(args)))

        inputs.update(dict(zip(self.inputs, args)))

        # check that there are no missing inputs
        if len(set(self.inputs).difference(inputs)) != 0:
            raise ValueError("Missing inputs {}".format(", ".join(

        # check that there are no unknown inputs
        # NOTE symbols can only be passed as kwargs
        if len(
                    sdfg.free_symbols)) != 0:
            raise ValueError("Unknown inputs {}".format(", ".join(

        clean_inputs = {}
        for input, arr in inputs.items():
            if input in sdfg.free_symbols:
                clean_inputs[input] = arr
                clean_inputs[clean_onnx_name(input)] = arr

        # add the weights
        params = {}
        for name, arr in self.weights.items():
            if len(arr.shape) == 0:
                params[clean_onnx_name(name)] = arr[()]
                if self.cuda:
                    clean_name = clean_onnx_name(name)
                    sdfg.arrays[clean_name].storage = StorageType.GPU_Global
                    params[clean_name] = numba.cuda.to_device(arr)
                    params[clean_onnx_name(name)] = arr.copy()

        inferred_symbols = infer_symbols_from_shapes(sdfg, {
        # TODO @orausch if this is removed the SDFG complains
        # TypeError: Type mismatch for argument ONNX_unk__493: expected scalar type, got <class 'sympy.core.numbers.Integer'>
        # fix this better
        inferred_symbols = {k: int(v) for k, v in inferred_symbols.items()}

        def eval_dim(dim):
            for sym in dim.free_symbols:
                dim = dim.subs(sym, inferred_symbols[sym.name])
            return dim

        outputs = OrderedDict()
        # create numpy arrays for the outputs
        for output in self.outputs:
            clean_name = clean_onnx_name(output)
            arr = sdfg.arrays[clean_name]

            # TODO @orausch add error handling for evalf
            shape = [
                eval_dim(d) if type(d) is dace.symbol else d for d in arr.shape
            outputs[clean_name] = np.empty(shape,


        sdfg(**clean_inputs, **params, **outputs, **inferred_symbols)

        if len(outputs) == 1:
            return next(iter(outputs.values()))

        return tuple(outputs.values())
Exemplo n.º 4
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
        :raise ValueError: The subgraph is contained in more than one scope.
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
        for edge in state.out_edges(node):
            if edge.dst not in snodes:

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge

        name = input_names[edge]
        if name in reconnected_in:
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge

        name = output_names[edge]
        if name in reconnected_out:
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:

    return nested_sdfg
Exemplo n.º 5
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
        :raise ValueError: The subgraph is contained in more than one scope.
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph.graph != state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_dict(True)
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError(
                    'Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in subgraph.source_nodes():
    for node in subgraph.sink_nodes():

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes()
                           if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(
        n.data for s in sdfg.nodes() for n in s.nodes()
        if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode)
                and isinstance(edge.dst, nodes.CodeNode)):

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = set()
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode)
                and node.data not in subgraph_transients):
            if state.out_degree(node) > 0:
            if state.in_degree(node) > 0:

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = []
    output_names = []
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = '__in_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = '__out_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))

    # Add scope symbols to the nested SDFG
    for v in scope.defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for name, edge in zip(input_names, inputs):
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
                                nstate.add_edge(node, None, edge.dst,
                                                edge.dst_conn, new_edge)))
    for name, edge in zip(output_names, outputs):
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
                                nstate.add_edge(edge.src, edge.src_conn, node,
                                                None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(original_edge.data.subset, True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names) | input_arrays,
                                        set(output_names) | output_arrays)

    # Reconnect memlets to nested SDFG
    for name, edge in zip(input_names, inputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
            data = edge.data
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
    for name, edge in zip(output_names, outputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
            data = edge.data
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, EmptyMemlet())
        state.add_edge(node, None, nested_sdfg, name,
                       Memlet.from_array(name, sdfg.arrays[name]))
    for name in output_arrays:
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, EmptyMemlet())
        state.add_edge(nested_sdfg, name, node, None,
                       Memlet.from_array(name, sdfg.arrays[name]))

    # Remove subgraph nodes from graph

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    return nested_sdfg
Exemplo n.º 6
    def apply(self, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard])
        body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin])
        after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state])

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), (_, body_end) = find_for_loop(
            sdfg, guard, body, itervar=self.itervar)

        # Find all loop-body states
        states = set([body_end])
        to_visit = [body]
        while to_visit:
            state = to_visit.pop(0)
            if state is body_end:
            for _, dst, _ in sdfg.out_edges(state):
                if dst not in states:

        # Nest loop-body states
        if len(states) > 1:

            # Find read/write sets
            read_set, write_set = set(), set()
            for state in states:
                rset, wset = state.read_and_write_sets()
                read_set |= rset
                write_set |= wset
                # Add data from edges
                for src in states:
                    for dst in states:
                        for edge in sdfg.edges_between(src, dst):
                            for s in edge.data.free_symbols:
                                if s in sdfg.arrays:

            # Find NestedSDFG's unique data
            rw_set = read_set | write_set
            unique_set = set()
            for name in rw_set:
                if not sdfg.arrays[name].transient:
                found = False
                for state in sdfg.states():
                    if state in states:
                    for node in state.nodes():
                        if (isinstance(node, nodes.AccessNode) and
                                node.data == name):
                            found = True
                if not found:

            # Find NestedSDFG's connectors
            read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient}
            write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient}

            # Create NestedSDFG and add all loop-body states and edges
            # Also, find defined symbols in NestedSDFG
            fsymbols = set(sdfg.free_symbols)
            new_body = sdfg.add_state('single_state_body')
            nsdfg = SDFG("loop_body", constants=sdfg.constants, parent=new_body)
            nsdfg.add_node(body, is_start_state=True)
            body.parent = nsdfg
            exit_state = nsdfg.add_state('exit')
            nsymbols = dict()
            for state in states:
                if state is body:
                state.parent = nsdfg
            for state in states:
                if state is body:
                for src, dst, data in sdfg.in_edges(state):
                    nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols})
                    nsdfg.add_edge(src, dst, data)
            nsdfg.add_edge(body_end, exit_state, InterstateEdge())

            # Move guard -> body edge to guard -> new_body
            for src, dst, data, in sdfg.edges_between(guard, body):
                sdfg.add_edge(src, new_body, data)
            # Move body_end -> guard edge to new_body -> guard
            for src, dst, data in sdfg.edges_between(body_end, guard):
                sdfg.add_edge(new_body, dst, data)
            # Delete loop-body states and edges from parent SDFG
            for state in states:
                for e in sdfg.all_edges(state):
            # Add NestedSDFG arrays
            for name in read_set | write_set:
                nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name])
                nsdfg.arrays[name].transient = False
            for name in unique_set:
                nsdfg.arrays[name] = sdfg.arrays[name]
                del sdfg.arrays[name]
            # Add NestedSDFG node
            cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set)
            if sdfg.parent:
                for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items():
                    if s not in cnode.symbol_mapping:
                        cnode.symbol_mapping[s] = m
                        nsdfg.add_symbol(s, sdfg.symbols[s])
            for name in read_set:
                r = new_body.add_read(name)
                    r, None, cnode, name,
                    memlet.Memlet.from_array(name, sdfg.arrays[name]))
            for name in write_set:
                w = new_body.add_write(name)
                    cnode, name, w, None,
                    memlet.Memlet.from_array(name, sdfg.arrays[name]))

            # Fix SDFG symbols
            for sym in sdfg.free_symbols - fsymbols:
                del sdfg.symbols[sym]
            for sym, dtype in nsymbols.items():
                nsdfg.symbols[sym] = dtype

            # Change body state reference
            body = new_body

        if (step < 0) == True:
            # If step is negative, we have to flip start and end to produce a
            # correct map with a positive increment
            start, end, step = end, start, -step

        # If necessary, make a nested SDFG with assignments
        isedge = sdfg.edges_between(guard, body)[0]
        symbols_to_remove = set()
        if len(isedge.data.assignments) > 0:
            nsdfg = helpers.nest_state_subgraph(
                sdfg, body, gr.SubgraphView(body, body.nodes()))
            for sym in isedge.data.free_symbols:
                if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors:
                if sym in sdfg.symbols:
                    nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym)
                    nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym])
                elif sym in sdfg.arrays:
                    if sym in nsdfg.sdfg.arrays:
                        raise NotImplementedError
                    rnode = body.add_read(sym)
                    desc = copy.deepcopy(sdfg.arrays[sym])
                    desc.transient = False
                    nsdfg.sdfg.add_datadesc(sym, desc)
                    body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym))

            nstate = nsdfg.sdfg.node(0)
            init_state = nsdfg.sdfg.add_state_before(nstate)
            nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0]
            nisedge.data.assignments = isedge.data.assignments
            symbols_to_remove = set(nisedge.data.assignments.keys())
            for k in nisedge.data.assignments.keys():
                if k in nsdfg.symbol_mapping:
                    del nsdfg.symbol_mapping[k]
            isedge.data.assignments = {}

        source_nodes = body.source_nodes()
        sink_nodes = body.sink_nodes()

        map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)])
        entry = nodes.MapEntry(map)
        exit = nodes.MapExit(map)

        # If the map uses symbols from data containers, instantiate reads
        containers_to_read = entry.free_symbols & sdfg.arrays.keys()
        for rd in containers_to_read:
            # We are guaranteed that this is always a scalar, because
            # can_be_applied makes sure there are no sympy functions in each of
            # the loop expresions
            access_node = body.add_read(rd)

        # Reroute all memlets through the entry and exit nodes
        for n in source_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.out_edges(n):
                body.add_nedge(entry, n, memlet.Memlet())
        for n in sink_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.in_edges(n):
                body.add_nedge(n, exit, memlet.Memlet())

        # Get rid of the loop exit condition edge
        after_edge = sdfg.edges_between(guard, after)[0]

        # Remove the assignment on the edge to the guard
        for e in sdfg.in_edges(guard):
            if itervar in e.data.assignments:
                del e.data.assignments[itervar]

        # Remove the condition on the entry edge
        condition_edge = sdfg.edges_between(guard, body)[0]
        condition_edge.data.condition = CodeBlock("1")

        # Get rid of backedge to guard
        sdfg.remove_edge(sdfg.edges_between(body, guard)[0])

        # Route body directly to after state, maintaining any other assignments
        # it might have had
            body, after,

        # If this had made the iteration variable a free symbol, we can remove
        # it from the SDFG symbols
        if itervar in sdfg.free_symbols:
        for sym in symbols_to_remove:
            if helpers.is_symbol_unused(sdfg, sym):
Exemplo n.º 7
def generate_reference(name, chain):
    """Generates a simple, unoptimized SDFG to run on the CPU, for verification

    sdfg = SDFG(name)

    for k, v in chain.constants.items():
        sdfg.add_constant(k, v["value"], dace.data.Scalar(v["data_type"]))

    (dimensions_to_skip, shape, vector_length, parameters, iterators,
     memcopy_indices, memcopy_accesses) = _generate_init(chain)

    prev_state = sdfg.add_state("init")

    # Throw vectorization in the bin for the reference code
    vector_length = 1

    shape = tuple(map(int, shape))

    input_shapes = {}  # Maps inputs to their shape tuple

    for node in chain.graph.nodes():
        if isinstance(node, Input) or isinstance(node, Output):
            if isinstance(node, Input):
                for output in node.outputs.values():
                    pars = tuple(
                    ) if "input_dims" in output and output[
                        "input_dims"] is not None else tuple(parameters)
                    arr_shape = tuple(s for s, p in zip(shape, parameters)
                                      if p in pars)
                    input_shapes[node.name] = arr_shape
                    raise ValueError("No outputs found for input node.")
                arr_shape = shape
            if len(arr_shape) > 0:
                    sdfg.add_array(node.name, arr_shape, node.data_type)
                except NameError:
                        node.name).access = dace.dtypes.AccessType.ReadWrite
                sdfg.add_symbol(node.name, node.data_type)

    for link in chain.graph.edges(data=True):
        name = link[0].name
        if name not in sdfg.arrays and name not in sdfg.symbols:
            sdfg.add_array(name, shape, link[0].data_type, transient=True)
            input_shapes[name] = tuple(shape)

    input_iterators = {
        k: tuple("0:{}".format(s) for s in v)
        for k, v in input_shapes.items()

    # Enforce dependencies via topological sort
    for node in nx.topological_sort(chain.graph):

        if not isinstance(node, Kernel):

        state = sdfg.add_state(node.name)
        sdfg.add_edge(prev_state, state, dace.InterstateEdge())

        (stencil_node, input_to_connector,
         output_to_connector) = _generate_stencil(node, chain, shape,
        stencil_node.implementation = "CPU"

        for field, connector in input_to_connector.items():

            if len(input_iterators[field]) == 0:
                continue  # Scalar variable

            # Outer memory read
            read_node = state.add_read(field)
                                      ", ".join(input_iterators[field])))

        for _, connector in output_to_connector.items():

            # Outer write
            write_node = state.add_write(node.name)
                                      node.name, ", ".join("0:{}".format(s)
                                                           for s in shape)))

        prev_state = state

    return sdfg