Exemplo n.º 1
0
 def __getitem__(self, s):
     """ This is syntactic sugar that allows us to define an array type
         with the following syntax: dace.uint32[N,M] 
         @return: A data.Array data descriptor.
     """
     from dace import data
     if isinstance(s, list) or isinstance(s, tuple):
         return data.Array(self, tuple(s))
     return data.Array(self, (s, ))
Exemplo n.º 2
0
def runtest(
    title,
    einsum_str,
    output,
    dtype=dace.float16,
    sizes={
        'b': 8,
        'h': 16,
        'i': 1024,
        'j': 512,
        'k': 512,
        'p': 64,
        'u': 4096,
        'q': 3,
        'v': 2
    }):
    print('Generating ' + title)
    import sys
    real_stdout = sys.stdout
    sys.stdout = open(f'{output}/{title}-configs.csv', 'w')

    inputs, output = einsum_str.split('->')
    a, b = inputs.split(',')

    allperms = list(
        itertools.product(itertools.permutations(a), itertools.permutations(b),
                          itertools.permutations(output)))
    for in1, in2, out in allperms:
        in1 = ''.join(in1)
        in2 = ''.join(in2)
        out = ''.join(out)
        einsum_perm = (in1 + ',' + in2 + '->' + out)
        adesc = data.Array(dtype, list(map(lambda k: sizes[k], in1)))
        bdesc = data.Array(dtype, list(map(lambda k: sizes[k], in2)))
        try:
            report, implementation = test_configuration(
                einsum_perm, adesc, bdesc)
        except:
            print('ERROR: Failed "', einsum_perm, '". Skipping')
            continue
        if report is None:
            continue
        with open(title + '.csv', 'a') as fp:
            fp.write(','.join([in1, in2, out, implementation] +
                              repval(report)) + '\n')

    sys.stdout = real_stdout
Exemplo n.º 3
0
    def __array_finalize__(self, obj):
        if obj is None:
            return
        from dace import data

        # Create a new descriptor
        self.descriptor = data.Array(
            types.typeclass(obj.dtype.type),
            obj.shape,
            materialize_func=None,
            transient=False,
            allow_conflicts=False)

        self._symlist = {}
Exemplo n.º 4
0
def _create_datadescriptor(obj):
    """ Creates a data descriptor from various types of objects.
        @see: dace.data.Data
    """
    if isinstance(obj, data.Data):
        return obj

    try:
        return obj.descriptor
    except AttributeError:
        if isinstance(obj, numpy.ndarray):
            return data.Array(dtype=types.typeclass(obj.dtype.type),
                              shape=obj.shape)
        if symbolic.issymbolic(obj):
            return data.Scalar(symbolic.symtype(obj))
        if isinstance(obj, types.typeclass):
            return data.Scalar(obj)
        return data.Scalar(types.typeclass(type(obj)))
Exemplo n.º 5
0
    def _generate_copy_to_device(self, node: nodes.AccessNode, desc: dt.Array,
                                 ptr: str) -> Tuple[str, str, str]:
        """ Copies restored data to device and returns (preamble, postamble, name of new host pointer). """
        new_ptr = f'__dinstr_{node.data}'
        new_desc = dt.Array(desc.dtype, [desc.total_size - desc.start_offset])
        csize = cpp.sym2cpp(desc.total_size - desc.start_offset)

        # Emit synchronous memcpy
        preamble = f'''
        {{
        {new_desc.as_arg(name=new_ptr)} = new {desc.dtype.ctype}[{csize}];
        '''

        postamble = f'''
        {self.backend}Memcpy({ptr}, {new_ptr}, sizeof({desc.dtype.ctype}) * ({csize}), {self.backend}MemcpyHostToDevice);
        delete[] {new_ptr};
        }}
        '''

        return preamble, postamble, new_ptr
Exemplo n.º 6
0
    def __new__(cls,
                shape,
                dtype=types.float32,
                materialize_func=None,
                allow_conflicts=False,
                *args,
                **kwargs):
        """ Initializes a DaCe ND-array.
            @param shape: The array shape (may contain symbols).
            @param dtype: The array data type.
            @param materialize_func: An optional string that contains a method
                                     to materialize array contents on demand.
                                     If not None, the array is not allocated 
                                     within the DaCe program.
            @param allow_conflicts: If True, suppresses warnings on conflicting
                                    array writes in DaCe programs without a 
                                    matching conflict resolution memlet.
        """
        # Avoiding import loops
        from dace import data

        tmpshape = shape
        shape = [symbolic.eval(s, 0) for s in shape]

        kwargs.update({'dtype': dtype.type})

        res = numpy.ndarray.__new__(cls, shape, *args, **kwargs)
        res._symlist = symbolic.symlist(tmpshape)
        for _, sym in res._symlist.items():
            sym._arrays_to_update.append(res)

        if not isinstance(dtype, types.typeclass):
            dtype = types.typeclass(dtype.type)

        res.descriptor = data.Array(
            dtype,
            tmpshape,
            materialize_func=materialize_func,
            transient=False,
            allow_conflicts=allow_conflicts)
        return res
Exemplo n.º 7
0
    def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
        map_entry = self.map_entry
        map_exit = graph.exit_node(map_entry)
        nsdfg_node: Optional[nodes.NestedSDFG] = None

        # Obtain subgraph to perform fission to
        if self.expr_index == 0:  # Map with subgraph
            subgraphs = [(graph,
                          graph.scope_subgraph(map_entry,
                                               include_entry=False,
                                               include_exit=False))]
            parent = sdfg
        else:  # Map with nested SDFG
            nsdfg_node = self.nested_sdfg
            subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()]
            parent = nsdfg_node.sdfg
        modified_arrays = set()

        # Get map information
        outer_map: nodes.Map = map_entry.map
        mapsize = outer_map.range.size()

        # Add new symbols from outer map to nested SDFG
        if self.expr_index == 1:
            map_syms = outer_map.range.free_symbols
            for edge in graph.out_edges(map_entry):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for edge in graph.in_edges(map_exit):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for sym in map_syms:
                symname = str(sym)
                if symname in outer_map.params:
                    continue
                if symname not in nsdfg_node.symbol_mapping.keys():
                    nsdfg_node.symbol_mapping[symname] = sym
                    nsdfg_node.sdfg.symbols[
                        symname] = graph.symbols_defined_at(
                            nsdfg_node)[symname]

            # Remove map symbols from nested mapping
            for name in outer_map.params:
                if str(name) in nsdfg_node.symbol_mapping:
                    del nsdfg_node.symbol_mapping[str(name)]
                if str(name) in nsdfg_node.sdfg.symbols:
                    del nsdfg_node.sdfg.symbols[str(name)]

        for state, subgraph in subgraphs:
            components = MapFission._components(subgraph)
            sources = subgraph.source_nodes()
            sinks = subgraph.sink_nodes()

            # Collect external edges
            if self.expr_index == 0:
                external_edges_entry = list(state.out_edges(map_entry))
                external_edges_exit = list(state.in_edges(map_exit))
            else:
                external_edges_entry = [
                    e for e in subgraph.edges()
                    if (isinstance(e.src, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.src.data].transient)
                ]
                external_edges_exit = [
                    e for e in subgraph.edges()
                    if (isinstance(e.dst, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.dst.data].transient)
                ]

            # Map external edges to outer memlets
            edge_to_outer = {}
            for edge in external_edges_entry:
                if self.expr_index == 0:
                    # Subgraphs use the corresponding outer map edges
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex - 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.in_edges(nsdfg_node)
                                      if e.dst_conn == edge.src.data)
                    edge_to_outer[edge] = outer_edge

            for edge in external_edges_exit:
                if self.expr_index == 0:
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex + 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.out_edges(nsdfg_node)
                                      if e.src_conn == edge.dst.data)
                    edge_to_outer[edge] = outer_edge

            # Collect all border arrays and code->code edges
            arrays = MapFission._border_arrays(
                nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state,
                subgraph)
            scalars = defaultdict(list)
            for _, component_out in components:
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.CodeNode):
                        scalars[e.data.data].append(e)

            # Create new arrays for scalars
            for scalar, edges in scalars.items():
                desc = parent.arrays[scalar]
                del parent.arrays[scalar]
                name, newdesc = parent.add_transient(
                    scalar,
                    mapsize,
                    desc.dtype,
                    desc.storage,
                    lifetime=desc.lifetime,
                    debuginfo=desc.debuginfo,
                    allow_conflicts=desc.allow_conflicts,
                    find_new_name=True)

                # Add extra nodes in component boundaries
                for edge in edges:
                    anode = state.add_access(name)
                    sbs = subsets.Range.from_string(','.join(outer_map.params))
                    # Offset memlet by map range begin (to fit the transient)
                    sbs.offset([r[0] for r in outer_map.range], True)
                    state.add_edge(
                        edge.src, edge.src_conn, anode, None,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.add_edge(
                        anode, None, edge.dst, edge.dst_conn,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.remove_edge(edge)

            # Add extra maps around components
            new_map_entries = []
            for component_in, component_out in components:
                me, mx = state.add_map(outer_map.label + '_fission',
                                       [(p, '0:1') for p in outer_map.params],
                                       outer_map.schedule,
                                       unroll=outer_map.unroll,
                                       debuginfo=outer_map.debuginfo)

                # Add dynamic input connectors
                for conn in map_entry.in_connectors:
                    if not conn.startswith('IN_'):
                        me.add_in_connector(conn)

                me.map.range = dcpy(outer_map.range)
                new_map_entries.append(me)

                # Reconnect edges through new map
                for e in state.in_edges(component_in):
                    state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data))
                    # Reconnect inner edges at source directly to external nodes
                    if self.expr_index == 0 and e in external_edges_entry:
                        state.add_edge(edge_to_outer[e].src,
                                       edge_to_outer[e].src_conn, me, None,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(e.src, e.src_conn, me, None,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.in_degree(component_in) == 0:
                    state.add_edge(me, None, component_in, None, mm.Memlet())

                for e in state.out_edges(component_out):
                    state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data))
                    # Reconnect inner edges at sink directly to external nodes
                    if self.expr_index == 0 and e in external_edges_exit:
                        state.add_edge(mx, None, edge_to_outer[e].dst,
                                       edge_to_outer[e].dst_conn,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(mx, None, e.dst, e.dst_conn,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.out_degree(component_out) == 0:
                    state.add_edge(component_out, None, mx, None, mm.Memlet())
            # Connect other sources/sinks not in components (access nodes)
            # directly to external nodes
            if self.expr_index == 0:
                for node in sources:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.in_edges(node):
                            outer_edge = edge_to_outer[edge]
                            memlet = dcpy(edge.data)
                            memlet.subset = subsets.Range(
                                outer_map.range.ranges + memlet.subset.ranges)
                            state.add_edge(outer_edge.src, outer_edge.src_conn,
                                           edge.dst, edge.dst_conn, memlet)

                for node in sinks:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.out_edges(node):
                            outer_edge = edge_to_outer[edge]
                            state.add_edge(edge.src, edge.src_conn,
                                           outer_edge.dst, outer_edge.dst_conn,
                                           dcpy(outer_edge.data))

            # Augment arrays by prepending map dimensions
            for array in arrays:
                if array in modified_arrays:
                    continue
                desc = parent.arrays[array]
                if isinstance(
                        desc,
                        dt.Scalar):  # Scalar needs to be augmented to an array
                    desc = dt.Array(desc.dtype, desc.shape, desc.transient,
                                    desc.allow_conflicts, desc.storage,
                                    desc.location, desc.strides, desc.offset,
                                    False, desc.lifetime, 0, desc.debuginfo,
                                    desc.total_size, desc.start_offset)
                    parent.arrays[array] = desc
                for sz in reversed(mapsize):
                    desc.strides = [desc.total_size] + list(desc.strides)
                    desc.total_size = desc.total_size * sz

                desc.shape = mapsize + list(desc.shape)
                desc.offset = [0] * len(mapsize) + list(desc.offset)
                modified_arrays.add(array)

            # Fill scope connectors so that memlets can be tracked below
            state.fill_scope_connectors()

            # Correct connectors and memlets in nested SDFGs to account for
            # missing outside map
            if self.expr_index == 1:
                to_correct = ([(e, e.src) for e in external_edges_entry] +
                              [(e, e.dst) for e in external_edges_exit])
                corrected_nodes = set()
                for edge, node in to_correct:
                    if isinstance(node, nodes.AccessNode):
                        if node in corrected_nodes:
                            continue
                        corrected_nodes.add(node)

                        outer_edge = edge_to_outer[edge]
                        desc = parent.arrays[node.data]

                        # Modify shape of internal array to match outer one
                        outer_desc = sdfg.arrays[outer_edge.data.data]
                        if not isinstance(desc, dt.Scalar):
                            desc.shape = outer_desc.shape
                        if isinstance(desc, dt.Array):
                            desc.strides = outer_desc.strides
                            desc.total_size = outer_desc.total_size

                        # Inside the nested SDFG, offset all memlets to include
                        # the offsets from within the map.
                        # NOTE: Relies on propagation to fix outer memlets
                        for internal_edge in state.all_edges(node):
                            for e in state.memlet_tree(internal_edge):
                                e.data.subset.offset(desc.offset, False)
                                e.data.subset = helpers.unsqueeze_memlet(
                                    e.data, outer_edge.data).subset

                        # Only after offsetting memlets we can modify the
                        # overall offset
                        if isinstance(desc, dt.Array):
                            desc.offset = outer_desc.offset

            # Fill in memlet trees for border transients
            # NOTE: Memlet propagation should run to correct the outer edges
            for node in subgraph.nodes():
                if isinstance(node, nodes.AccessNode) and node.data in arrays:
                    for edge in state.all_edges(node):
                        for e in state.memlet_tree(edge):
                            # Prepend map dimensions to memlet
                            e.data.subset = subsets.Range(
                                [(pystr_to_symbolic(d) - r[0],
                                  pystr_to_symbolic(d) - r[0], 1) for d, r in
                                 zip(outer_map.params, outer_map.range)] +
                                e.data.subset.ranges)

        # If nested SDFG, reconnect nodes around map and modify memlets
        if self.expr_index == 1:
            for edge in graph.in_edges(map_entry):
                if not edge.dst_conn or not edge.dst_conn.startswith('IN_'):
                    continue

                # Modify edge coming into nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)
                edge.data.num_accesses = edge.data.subset.num_elements()

                # Find matching edge inside map
                inner_edge = next(
                    e for e in graph.out_edges(map_entry)
                    if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:])
                graph.add_edge(edge.src, edge.src_conn, nsdfg_node,
                               inner_edge.dst_conn, dcpy(edge.data))

            for edge in graph.out_edges(map_exit):
                # Modify edge coming out of nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)

                # Find matching edge inside map
                inner_edge = next(e for e in graph.in_edges(map_exit)
                                  if e.dst_conn[3:] == edge.src_conn[4:])
                graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst,
                               edge.dst_conn, dcpy(edge.data))

        # Remove outer map
        graph.remove_nodes_from([map_entry, map_exit])
Exemplo n.º 8
0
    def apply(self, sdfg):
        state = sdfg.nodes()[self.state_id]
        nested_sdfg = state.nodes()[self.subgraph[CopyToDevice._nested_sdfg]]
        storage = self.storage

        for _, edge in enumerate(state.in_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            memdata = sdfg.arrays[dataname]

            if isinstance(memdata, data.Array):
                new_data = sdfg.add_array(
                    'device_' + dataname + '_in',
                    memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ],
                    transient=True,
                    storage=storage)
            elif isinstance(memdata, data.Scalar):
                new_data = sdfg.add_scalar(
                    'device_' + dataname + '_in',
                    memdata.dtype,
                    transient=True,
                    storage=storage)
            else:
                raise NotImplementedError

            data_node = nodes.AccessNode('device_' + dataname + '_in')

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            from_data_mm.data = 'device_' + dataname + '_in'
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    from_data_mm.subset[ind] = (begin, end, step)
                else:
                    from_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        for _, edge in enumerate(state.out_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            memdata = sdfg.arrays[dataname]

            if isinstance(memdata, data.Array):
                new_data = data.Array(
                    'device_' + dataname + '_out',
                    memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ],
                    transient=True,
                    storage=storage)
            elif isinstance(memdata, data.Scalar):
                new_data = sdfg.add_scalar(
                    'device_' + dataname + '_out',
                    memdata.dtype,
                    transient=True,
                    storage=storage)
            else:
                raise NotImplementedError

            data_node = nodes.AccessNode('device_' + dataname + '_out')

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            to_data_mm.data = 'device_' + dataname + '_out'
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    to_data_mm.subset[ind] = (begin, end, step)
                else:
                    to_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        # Change storage for all data inside nested SDFG to device.
        change_storage(nested_sdfg.sdfg, storage)
Exemplo n.º 9
0
    def expansion(node, state: SDFGState, sdfg: SDFG):
        # Extract input and output array views (as generated by memlets)
        inputs, outputs = _get_inputs_and_outputs(sdfg, state, node)

        unique_id = "{}_{}_{}_{}".format(clean_onnx_name(node.name),
                                         sdfg.sdfg_id, sdfg.node_id(state),
                                         state.node_id(node))
        _add_ort_init_code(sdfg)

        sdfg.append_global_code(
            "OrtExecutableKernel *__ort_kernel_{};\n".format(unique_id))
        sdfg.append_global_code(
            "OrtExecutableKernelContext *__ort_context_{};\n".format(
                unique_id))

        sdfg.append_init_code("""
        {{
        // Setup for {name}
        __ort_check_status(__ort_api->CreateExecutableKernelContext("{name}", "{op_type}", &__ort_context_{name}));
        """.format(name=unique_id, op_type=node.schema.name))

        # check if ORT supports CUDA for this node
        ##########################################

        # Default: all parameters are on CPU if we execute using cpu
        outputs_on_host = [True for _ in range(len(outputs))]
        inputs_on_host = [True for _ in range(len(inputs))]

        actual_node_schedule = node.schedule
        if node.schedule == ScheduleType.CPU_Multicore or node.schedule == ScheduleType.Default:
            provider_index = 0
        elif node.schedule == ScheduleType.GPU_Device:
            provider_index = 1
            try:
                # the ith position indicates whether the ith output is in host memory
                inputs_on_host, outputs_on_host = check_op(sdfg,
                                                           state,
                                                           node,
                                                           cuda=True)

            except ONNXOpValidationError as e:
                # fallback to CPU
                print("Falling back to CPU for node {}. Reason:\n{}".format(
                    node.name, str(e)))
                provider_index = 0
                actual_node_schedule = ScheduleType.Default
        else:
            raise NotImplementedError(
                "ORT expansion for schedule '{}' is not implemented".format(
                    node.schedule))

        # check if we need to insert device copies
        ##########################################

        # maps the connectors for which a copy will be required to the storage type required to be connected to the tasklet
        input_copy_required = defaultdict(dict)
        output_copy_required = defaultdict(dict)

        assert len(
            node.iter_outputs_in_onnx_order(state)) == len(outputs_on_host)
        assert len(
            node.iter_inputs_in_onnx_order(state)) == len(inputs_on_host)

        # check outputs
        for edge, output_on_host in zip(node.iter_outputs_in_onnx_order(state),
                                        outputs_on_host):
            # get the memlet for this output
            array = sdfg.arrays[edge.data.data]

            if output_on_host:
                is_device_mismatch = not can_access(ScheduleType.Default,
                                                    array.storage)
            else:
                is_device_mismatch = not can_access(ScheduleType.GPU_Device,
                                                    array.storage)

            if isinstance(
                    array, dt.Scalar
            ) and actual_node_schedule == ScheduleType.GPU_Device:
                # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this
                is_device_mismatch = True
                output_copy_required[edge.src_conn]['copy_to_array'] = True

            if is_device_mismatch:
                # we need to insert a copy
                output_copy_required[edge.src_conn][
                    'storage'] = StorageType.Default if output_on_host else StorageType.GPU_Global

        # check inputs (same thing again)
        for edge, input_on_host in zip(node.iter_inputs_in_onnx_order(state),
                                       inputs_on_host):
            array = sdfg.arrays[edge.data.data]

            if input_on_host:
                is_device_mismatch = not can_access(ScheduleType.Default,
                                                    array.storage)
            else:
                is_device_mismatch = not can_access(ScheduleType.GPU_Device,
                                                    array.storage)

            if isinstance(
                    array, dt.Scalar
            ) and actual_node_schedule == ScheduleType.GPU_Device:
                # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this
                is_device_mismatch = True
                input_copy_required[edge.dst_conn]['copy_to_array'] = True

            if is_device_mismatch:
                # we need to insert a copy
                input_copy_required[edge.dst_conn][
                    'storage'] = StorageType.Default if input_on_host else StorageType.GPU_Global

        # begin codegen
        ##########################################
        tasklet_setup_code = ""
        tasklet_code = ""
        tasklet_cleanup_code = ""

        reversed_onnx_dtype_map = {
            v: k
            for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items()
        }

        # emit code for inputs and outputs
        ##########################################
        in_connectors = {}
        out_connectors = {}

        for edge, is_input in node.iter_edges(state):

            parameter_name = edge.dst_conn if is_input else edge.src_conn

            if len(output_copy_required) != 0 or len(input_copy_required) != 0:
                edge_connector_name = "_conn_" + parameter_name
            else:
                edge_connector_name = parameter_name

            input_output_string = "input" if is_input else "output"
            connector_dict = in_connectors if is_input else out_connectors
            memlet = edge.data
            desc = sdfg.arrays[memlet.data]
            sdfg.append_init_code("""
            // Add parameter {parameter_name}
            __ort_check_status(__ort_api->ExecutableKernelContext_Add{input_output_string}(__ort_context_{id}, ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_string}));
            """.format(id=unique_id,
                       type_string=reversed_onnx_dtype_map[desc.dtype].upper(),
                       parameter_name=parameter_name,
                       input_output_string=input_output_string.capitalize()))

            ort_value_name = "ort_value_{input_output_string}_{parameter_name}".format(
                input_output_string=input_output_string,
                parameter_name=parameter_name)

            copy_to_array = (
                (parameter_name in output_copy_required
                 and 'copy_to_array' in output_copy_required[parameter_name])
                or
                (parameter_name in input_copy_required
                 and 'copy_to_array' in input_copy_required[parameter_name]))
            if desc.storage == StorageType.Default:
                mem_info = "__ort_cpu_mem_info"
            elif desc.storage == StorageType.GPU_Global:
                mem_info = "__ort_cuda_mem_info"
            elif desc.storage == StorageType.CPU_Pinned:
                mem_info = "__ort_cuda_pinned_mem_info"
            else:
                raise ValueError(
                    "Unsupported storage type {} for input to ONNX node".
                    format(desc.storage))
            if (isinstance(desc, dt.Scalar) and
                    # when copying to array, the ort value is not a scalar but an array
                    not copy_to_array):

                tasklet_setup_code += """
                OrtValue* {ort_value_name};
                __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue(
                    {mem_info},
                    &{edge_connector_name},
                    {data_size} * sizeof({ctype}),
                    nullptr,
                    0,
                    ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str},
                    &{ort_value_name}
                ));
                """.format(
                    input_output_string=input_output_string,
                    mem_info=mem_info,
                    edge_connector_name=edge_connector_name,
                    data_size=reduce(lambda x, y: x * y, desc.shape),
                    ctype=desc.dtype.ctype,
                    type_str=reversed_onnx_dtype_map[desc.dtype].upper(),
                    ort_value_name=ort_value_name)
                connector_dict[parameter_name] = None

            elif isinstance(desc, dt.Array) or copy_to_array:

                # when we copy a scalar to an array, that scalar ofc has shape []
                dims = [] if copy_to_array else desc.shape

                # setup dims array
                tasklet_setup_code += """
                int64_t {input_output_string}_{parameter_name}_dims[{dims_size}] = {{{dims}}};
                """.format(input_output_string=input_output_string,
                           parameter_name=parameter_name,
                           dims_size=len(dims),
                           dims=", ".join(str(s) for s in dims))

                connector_dict[parameter_name] = dace.pointer(desc.dtype)
                data = "const_cast < void * > (reinterpret_cast < const void * > ({}))".format(
                    edge_connector_name)

                tasklet_setup_code += """
                OrtValue* {ort_value_name};
                __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue(
                    {mem_info},
                    {data},
                    {data_size} * sizeof({ctype}),
                    {input_output_string}_{parameter_name}_dims,
                    {dims_size},
                    ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str},
                    &{ort_value_name}
                ));
                """.format(
                    input_output_string=input_output_string,
                    data=data,
                    mem_info=mem_info,
                    parameter_name=parameter_name,
                    data_size=reduce(lambda x, y: x * y, desc.shape),
                    ctype=desc.dtype.ctype,
                    dims_size=len(dims),
                    type_str=reversed_onnx_dtype_map[desc.dtype].upper(),
                    ort_value_name=ort_value_name)
            else:
                raise NotImplementedError(
                    "Data-descriptor type {} not supported for ONNX nodes".
                    format(type(desc)))


            tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Set{input_output_string_capital}(" \
                            "__ort_kernel_{unique_id}, {position}, {ort_value_name}));\n".format(
                input_output_string_capital=input_output_string.
                    capitalize(),
                ort_value_name=ort_value_name,
                unique_id=unique_id,
                position=get_position(node.schema, is_input,
                                      parameter_name))

            tasklet_cleanup_code += "__ort_api->ReleaseValue(ort_value_{input_output_string}_{parameter_name});\n".format(
                input_output_string=input_output_string,
                parameter_name=parameter_name)

        sdfg.append_init_code("// Setup attributes\n")

        for name, attr in node.schema.attributes.items():
            if hasattr(node, name):
                sdfg.append_init_code(
                    _gen_attr_init_code("__ort_context_{}".format(unique_id),
                                        node.schema.attributes[name],
                                        getattr(node, name)))

        sdfg.prepend_exit_code(
            "__ort_api->ReleaseExecutableKernelContext(__ort_context_{});\n".
            format(unique_id))
        sdfg.prepend_exit_code(
            "__ort_api->ReleaseExecutableKernel(__ort_kernel_{});\n".format(
                unique_id))

        tasklet_code += 'fprintf(stderr, "Launching {}\\n");\n'.format(
            unique_id)
        tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Compute(__ort_kernel_{}));\n".format(
            unique_id)

        sdfg.append_init_code(
            "__ort_check_status(__ort_api->CreateExecutableKernel("
            "__ort_session, __ort_context_{id}, /*provider_index=*/{provider_index}, &__ort_kernel_{id}));\n"
            .format(provider_index=provider_index, id=unique_id))
        sdfg.append_init_code(
            "}} // end setup for context_{}".format(unique_id))

        tasklet_code = tasklet_setup_code + tasklet_code + tasklet_cleanup_code
        tasklet = nd.Tasklet('onnx_code',
                             in_connectors,
                             out_connectors,
                             tasklet_code,
                             language=dace.dtypes.Language.CPP)
        tasklet.environments = {"ONNXRuntime"}

        if len(output_copy_required) != 0 or len(input_copy_required) != 0:
            nsdfg = dace.SDFG("nested_{}".format(unique_id))
            nstate = nsdfg.add_state()
            ntasklet = deepcopy(tasklet)

            # add a prefix to connectors to prevent shadowing of array names
            ntasklet.in_connectors = {
                "_conn_" + k: v
                for k, v in tasklet.in_connectors.items()
            }
            ntasklet.out_connectors = {
                "_conn_" + k: v
                for k, v in tasklet.out_connectors.items()
            }

            nstate.add_node(ntasklet)

            for edge, is_input in node.iter_edges(state):
                parameter_name = edge.dst_conn if is_input else edge.src_conn

                memlet = edge.data
                desc = sdfg.arrays[memlet.data]

                # add the original array
                original_desc = deepcopy(desc)
                original_desc.transient = False
                nsdfg.add_datadesc(parameter_name, original_desc)
                if not (isinstance(desc, dt.Array)
                        or isinstance(desc, dt.Scalar)):
                    raise ValueError(
                        "Unsupported data type {} connected to an ONNX tasklet"
                        .format(type(desc)))

                if parameter_name not in (input_copy_required if is_input else
                                          output_copy_required):
                    if is_input:
                        access = nstate.add_read(parameter_name)
                        nstate.add_edge(access, None, ntasklet,
                                        "_conn_" + parameter_name,
                                        nsdfg.get_array_memlet(parameter_name))
                    else:
                        access = nstate.add_write(parameter_name)
                        nstate.add_edge(ntasklet, "_conn_" + parameter_name,
                                        access, None,
                                        nsdfg.get_array_memlet(parameter_name))
                    continue

                copy_options = input_copy_required[
                    parameter_name] if is_input else output_copy_required[
                        parameter_name]

                # add the copy of the descriptor
                if 'copy_to_array' in copy_options:
                    copy_desc = dt.Array(shape=[1], dtype=desc.dtype)
                else:
                    copy_desc = deepcopy(desc)

                copy_desc.transient = True
                copy_desc.storage = copy_options['storage']
                nsdfg.add_datadesc("copy_" + memlet.data, copy_desc)

                nmemlet = deepcopy(memlet)
                nmemlet.data = "copy_" + nmemlet.data
                if is_input:
                    access = nstate.add_read(parameter_name)
                    access_copy = nstate.add_access("copy_" + memlet.data)
                    nstate.add_edge(
                        access, None, access_copy, None,
                        nsdfg.get_array_memlet("copy_" + memlet.data))
                    nstate.add_edge(access_copy, None, ntasklet,
                                    "_conn_" + parameter_name, nmemlet)
                else:
                    access = nstate.add_write(parameter_name)
                    access_copy = nstate.add_access("copy_" + memlet.data)
                    nstate.add_edge(ntasklet, "_conn_" + parameter_name,
                                    access_copy, None, nmemlet)
                    nstate.add_edge(
                        access_copy, None, access, None,
                        nsdfg.get_array_memlet("copy_" + memlet.data))

            return nsdfg

        else:
            return tasklet