Ejemplo n.º 1
0
def add_backward_desc_for_connector(backward_sdfg: dace.SDFG,
                                    forward_node: nd.Node,
                                    context: BackwardContext, connector: str,
                                    input: bool) -> str:
    """ Adds the backward array for the connector of ``forward_node``.

        :param backward_sdfg: the sdfg to add to.
        :param forward_node: the forward node with the connector that we want to add a descriptor for
        :param connector: the connector on the forward node that we want to add the descriptor for
        :param input: ``True`` if the connector is an input, ``False`` otherwise
        :return: the name of the newly added array in ``backward_sdfg``.
    """

    if input:
        edge = utils.in_edge_with_name(forward_node, context.forward_state,
                                       connector)
    else:
        edge = utils.out_edge_with_name(forward_node, context.forward_state,
                                        connector)
    arr_name = edge.data.data

    forward_desc = context.forward_sdfg.arrays[arr_name]

    new_desc = copy.deepcopy(forward_desc)
    new_desc.transient = False
    return backward_sdfg.add_datadesc(arr_name + "_grad",
                                      new_desc,
                                      find_new_name=True)
Ejemplo n.º 2
0
def add_backward_desc(backward_sdfg: dace.SDFG, forward_sdfg: dace.SDFG,
                      forward_desc: dt.Data, forward_name: str) -> str:
    """ Adds the backward array for the given descriptor.

        :param backward_sdfg: the sdfg to add to.
        :param forward_sdfg: the forward sdfg.
        :param forward_desc: the data descriptor of the forward array from ``forward_sdfg``.
        :param forward_name: a name for the forward array (does not have to match it's actual name).
        :return: the name of the newly added array in ``backward_sdfg``.
    """
    backward_name = utils.find_str_not_in_set(forward_sdfg.arrays,
                                              forward_name + "_grad")
    new_desc = copy.deepcopy(forward_desc)
    new_desc.transient = False
    return backward_sdfg.add_datadesc(backward_name, new_desc)
Ejemplo n.º 3
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result
        # this method of execution is slow but simple. A better option would be to call the ORT
        # C API from a python object (like the OpChecker).

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]
        log.debug(f"Applying constant folding: {node} in {state}")

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = torch.from_numpy(
                np.array(shape_desc.shape, np.int64))

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            global UNIQUE_ID
            UNIQUE_ID += 1
            sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID))
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' +
                           edge.dst_conn] = input_value.cpu().numpy()[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.clone()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype())
                else:
                    empty_array = np.empty(tuple(desc.shape),
                                           desc.dtype.as_numpy_dtype())

                empty_array = torch.from_numpy(empty_array)

                if desc.storage is dtypes.StorageType.GPU_Global:
                    empty_array = empty_array.cuda()

                outputs['array_' + edge.src_conn] = empty_array

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                assert type(output_value) is torch.Tensor

                if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore,
                                         desc.storage):
                    cpu_desc = copy.deepcopy(desc)
                    cpu_desc.storage = dtypes.StorageType.CPU_Heap
                    cpu_desc.transient = False
                    desc.transient = True
                    copy_in_name = sdfg.temp_data_name()
                    clean_copy_in_name = clean_onnx_name(copy_in_name)
                    sdfg.add_datadesc(clean_copy_in_name, cpu_desc)

                    access_constant = state.add_access(clean_constant_name)
                    state.add_edge(state.add_read(clean_copy_in_name), None,
                                   access_constant, None,
                                   sdfg.make_array_memlet(clean_copy_in_name))

                    name_to_add = copy_in_name
                else:
                    access_constant = state.add_read(clean_constant_name)
                    name_to_add = constant_name

                if isinstance(desc, dt.Scalar):
                    parent.weights[name_to_add] = output_value.reshape(())
                else:
                    parent.weights[name_to_add] = output_value

                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        remove_node_and_computation(sdfg, state, node)
Ejemplo n.º 4
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = np.array(shape_desc.shape,
                                                     np.int64)

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            sub_sdfg = dace.SDFG("sub_sdfg")
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' + edge.dst_conn] = input_value[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.copy()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    outputs['array_' + edge.src_conn] = np.empty(
                        (1, ), desc.dtype.as_numpy_dtype())
                else:
                    outputs['array_' + edge.src_conn] = np.empty(
                        tuple(desc.shape), desc.dtype.as_numpy_dtype())

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                if isinstance(desc, dt.Scalar):
                    parent.weights[constant_name] = output_value.reshape(())
                else:
                    parent.weights[constant_name] = output_value

                access_constant = state.add_access(clean_constant_name)
                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        queue = deque([node])
        while len(queue) > 0:
            current_node = queue.popleft()

            edges = state.in_edges(current_node)
            state.remove_node(current_node)
            for e in edges:
                next_node = e.src
                if len(state.out_edges(next_node)) == 0:
                    queue.append(next_node)