예제 #1
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     with util.FreeOnException(super().call_impl()) as (builder, network, parser):
         success = parser.parse(util.invoke_if_callable(self._model_bytes)[0])
         trt_util.check_onnx_parser_errors(parser, success)
         return builder, network, parser
예제 #2
0
    def __call__(self):
        from polygraphy.backend.onnx import util as onnx_util

        builder, network, parser = super().__call__()
        onnx_model, _ = misc.try_call(self.onnx_loader)
        dtype, shape = list(
            onnx_util.get_input_metadata(onnx_model.graph).values())[0]

        parser.parse(onnx_model.SerializeToString())
        trt_util.check_onnx_parser_errors(parser)

        return builder, network, parser, shape[0]
예제 #3
0
    def call_impl(self):
        from polygraphy.backend.onnx import util as onnx_util

        with util.FreeOnException(super().call_impl()) as (builder, network,
                                                           parser):
            onnx_model, _ = util.invoke_if_callable(self.onnx_loader)
            _, shape = list(
                onnx_util.get_input_metadata(onnx_model.graph).values())[0]

            success = parser.parse(onnx_model.SerializeToString())
            trt_util.check_onnx_parser_errors(parser, success)

            return builder, network, parser, shape[0]
예제 #4
0
    def __call__(self):
        """
        Parses an ONNX model.

        Returns:
            (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                    A TensorRT network, as well as the builder used to create it, and the parser
                    used to populate it.
        """
        builder, network, parser = super().__call__()
        parser.parse(misc.try_call(self._model_bytes)[0])
        trt_util.check_onnx_parser_errors(parser)
        return builder, network, parser
예제 #5
0
        def __call__(self):
            """
            Parses an ONNX model from a file.

            Returns:
                (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                        A TensorRT network, as well as the builder used to create it, and the parser
                        used to populate it.
            """
            builder, network, parser = super().__call__()
            # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for
            # potentially parsing any external weights.
            parser.parse_from_file(misc.try_call(self.path)[0])
            trt_util.check_onnx_parser_errors(parser)
            return builder, network, parser
예제 #6
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     path = util.invoke_if_callable(self.path)[0]
     if mod.version(trt.__version__) >= mod.version("7.1"):
         with util.FreeOnException(super().call_impl()) as (builder, network, parser):
             # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for
             # potentially parsing any external weights.
             success = parser.parse_from_file(path)
             trt_util.check_onnx_parser_errors(parser, success)
             return builder, network, parser
     else:
         from polygraphy.backend.common import bytes_from_path
         return network_from_onnx_bytes(bytes_from_path(path), self.explicit_precision)