Пример #1
0
    def _get_outputs_using_onnxruntime(self, output_nodes_name):
        """Get outputs using onnxruntime."""

        feed_dict = build_feed_dict(self.inferred_model, self.input_nodes)

        outputs_infer = fetch_output_from_onnx_model(self.model,
                                                     self.model_path,
                                                     feed_dict,
                                                     output_nodes_name)
        return outputs_infer
    def _get_outputs_using_onnxruntime(self, output_nodes_name):
        """Get outputs using onnxruntime."""

        onnx_inputs = self.inferred_model.graph.input
        dtype_dict = dict()
        for onnx_input in onnx_inputs:
            dtype_dict[onnx_input.name] = DTYPE_MAP[
                onnx_input.type.tensor_type.elem_type]

        feed_dict = build_feed_dict(self.inferred_model, self.input_nodes)

        outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict,
                                                     output_nodes_name)
        return outputs_infer
 def _for_resize():
     """Do resize nodes."""
     nonlocal self
     output_tensors = []
     if not self.dynamic_resize_node:
         return
     for node in self.dynamic_resize_node:
         shape_ref = self._nodes_dict[node].input_name_list[3]
         output_tensors.append(shape_ref)
     feed_dict = build_feed_dict(self.model, self.input_nodes)
     fetch_dict = fetch_output_from_onnx_model(
         self.model, feed_dict=feed_dict, output_nodes=output_tensors)
     for opt_tensor_name, value in fetch_dict.items():
         self.tensors_dict[opt_tensor_name] = OnnxTensor(
             value, opt_tensor_name)
Пример #4
0
    def _onnx_infer(self, infer_inputs_shape):
        """
        Run onnx inference to get outputs of constant nodes.

        Args:
            infer_inputs_shape (dict): Input shape for running inference.
        """
        feed_dict = build_feed_dict(self._onnx_model, infer_inputs_shape)
        output_nodes_name = list()
        for node in self._constant_nodes:
            output_nodes_name.extend(node.output)
        original_outputs = [nd.name for nd in self._onnx_model.graph.output]
        self._outputs_infer = fetch_output_from_onnx_model(
            self._onnx_model, self.model_path, feed_dict, output_nodes_name)
        idx = 0
        while idx < len(self._onnx_model.graph.output):
            cur_opt = self._onnx_model.graph.output[idx]
            if cur_opt.name not in original_outputs:
                self._onnx_model.graph.output.remove(cur_opt)
                continue
            idx += 1