Example #1
0
    def op_repo_replacement(sdfg: SDFG, state: SDFGState, **kwargs):
        attrs = {
            name: value
            for name, value in kwargs.items() if name in dace_schema.attributes
        }
        onnx_node = cls(name=cls_name, **attrs)
        state.add_node(onnx_node)

        input_names = {p.name for p in dace_schema.inputs}
        output_names = {p.name for p in dace_schema.outputs}
        inputs = {
            name: arr_name
            for name, arr_name in kwargs.items() if name in input_names
        }
        outputs = {
            name: arr_name
            for name, arr_name in kwargs.items() if name in output_names
        }

        for inp, arr_name in inputs.items():
            read = state.add_read(arr_name)
            state.add_edge(read, None, onnx_node, inp,
                           sdfg.make_array_memlet(arr_name))

        for outp, arr_name in outputs.items():
            write = state.add_read(arr_name)
            state.add_edge(onnx_node, outp, write, None,
                           sdfg.make_array_memlet(arr_name))
        return []
Example #2
0
    def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState,
                            **kwargs):
        attrs = {
            name: value
            for name, value in kwargs.items() if name in dace_schema.attributes
        }
        # remove used attrs
        kwargs = {k: v for k, v in kwargs.items() if k not in attrs}

        onnx_node = cls(name=cls_name, **attrs)
        state.add_node(onnx_node)

        input_names = dace_schema.non_variadic_inputs()
        variadic_inputs = dace_schema.variadic_inputs()

        output_names = dace_schema.non_variadic_outputs()
        variadic_outputs = dace_schema.variadic_outputs()

        inputs = {
            name: arr_name
            for name, arr_name in kwargs.items()
            if (name in input_names or
                # variadic params
                ("__" in name
                 and parse_variadic_param(name)[0] in variadic_inputs))
        }

        kwargs = {k: v for k, v in kwargs.items() if k not in inputs}

        outputs = {
            name: arr_name
            for name, arr_name in kwargs.items()
            if (name in output_names or
                # variadic params
                ("__" in name
                 and parse_variadic_param(name)[0] in variadic_outputs))
        }

        kwargs = {k: v for k, v in kwargs.items() if k not in outputs}

        if len(kwargs) > 0:
            raise TypeError(f"Unknown arguments {', '.join(kwargs)}")

        for inp, arr_name in inputs.items():
            read = state.add_read(arr_name)
            state.add_edge(read, None, onnx_node, inp,
                           sdfg.make_array_memlet(arr_name))
            onnx_node.add_in_connector(inp)

        for outp, arr_name in outputs.items():
            write = state.add_read(arr_name)
            state.add_edge(onnx_node, outp, write, None,
                           sdfg.make_array_memlet(arr_name))
            onnx_node.add_out_connector(outp)
        return []
Example #3
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result
        # this method of execution is slow but simple. A better option would be to call the ORT
        # C API from a python object (like the OpChecker).

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]
        log.debug(f"Applying constant folding: {node} in {state}")

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = torch.from_numpy(
                np.array(shape_desc.shape, np.int64))

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            global UNIQUE_ID
            UNIQUE_ID += 1
            sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID))
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' +
                           edge.dst_conn] = input_value.cpu().numpy()[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.clone()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype())
                else:
                    empty_array = np.empty(tuple(desc.shape),
                                           desc.dtype.as_numpy_dtype())

                empty_array = torch.from_numpy(empty_array)

                if desc.storage is dtypes.StorageType.GPU_Global:
                    empty_array = empty_array.cuda()

                outputs['array_' + edge.src_conn] = empty_array

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                assert type(output_value) is torch.Tensor

                if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore,
                                         desc.storage):
                    cpu_desc = copy.deepcopy(desc)
                    cpu_desc.storage = dtypes.StorageType.CPU_Heap
                    cpu_desc.transient = False
                    desc.transient = True
                    copy_in_name = sdfg.temp_data_name()
                    clean_copy_in_name = clean_onnx_name(copy_in_name)
                    sdfg.add_datadesc(clean_copy_in_name, cpu_desc)

                    access_constant = state.add_access(clean_constant_name)
                    state.add_edge(state.add_read(clean_copy_in_name), None,
                                   access_constant, None,
                                   sdfg.make_array_memlet(clean_copy_in_name))

                    name_to_add = copy_in_name
                else:
                    access_constant = state.add_read(clean_constant_name)
                    name_to_add = constant_name

                if isinstance(desc, dt.Scalar):
                    parent.weights[name_to_add] = output_value.reshape(())
                else:
                    parent.weights[name_to_add] = output_value

                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        remove_node_and_computation(sdfg, state, node)
Example #4
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = np.array(shape_desc.shape,
                                                     np.int64)

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            sub_sdfg = dace.SDFG("sub_sdfg")
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' + edge.dst_conn] = input_value[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.copy()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    outputs['array_' + edge.src_conn] = np.empty(
                        (1, ), desc.dtype.as_numpy_dtype())
                else:
                    outputs['array_' + edge.src_conn] = np.empty(
                        tuple(desc.shape), desc.dtype.as_numpy_dtype())

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                if isinstance(desc, dt.Scalar):
                    parent.weights[constant_name] = output_value.reshape(())
                else:
                    parent.weights[constant_name] = output_value

                access_constant = state.add_access(clean_constant_name)
                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        queue = deque([node])
        while len(queue) > 0:
            current_node = queue.popleft()

            edges = state.in_edges(current_node)
            state.remove_node(current_node)
            for e in edges:
                next_node = e.src
                if len(state.out_edges(next_node)) == 0:
                    queue.append(next_node)