コード例 #1
0
ファイル: loader.py プロジェクト: phongphuhanam/TensorRT
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The ONNX model with modified outputs.
        """
        model = self.load()

        if self.outputs == constants.MARK_ALL:
            G_LOGGER.verbose("Marking all ONNX tensors as outputs")
            model = onnx_util.mark_layerwise(model)
        elif self.outputs is not None:
            model = onnx_util.mark_outputs(model, self.outputs)

        if self.exclude_outputs is not None:
            model = onnx_util.unmark_outputs(model, self.exclude_outputs)

        return model
コード例 #2
0
    def __call__(self):
        """
        Modifies an ONNX model.

        Returns:
            onnx.ModelProto: The modified ONNX model.
        """
        model, _ = misc.try_call(self._model)

        if self.do_shape_inference:
            model = onnx_util.infer_shapes(model)

        if self.outputs == constants.MARK_ALL:
            G_LOGGER.verbose("Marking all ONNX tensors as outputs")
            model = onnx_util.mark_layerwise(model)
        elif self.outputs is not None:
            model = onnx_util.mark_outputs(model, self.outputs)

        if self.exclude_outputs is not None:
            model = onnx_util.unmark_outputs(model, self.exclude_outputs)

        return onnx_util.check_model(model)