コード例 #1
0
def main():
    # We can compose multiple lazy loaders together to get the desired conversion.
    # In this case, we want ONNX -> TensorRT Network -> TensorRT engine (w/ fp16).
    #
    # NOTE: `build_engine` is a *callable* that returns an engine, not the engine itself.
    #   To get the engine directly, you can use the immediately evaluated functional API.
    #   See examples/api/06_immediate_eval_api for details.
    build_engine = EngineFromNetwork(
        NetworkFromOnnxPath("identity.onnx"), config=CreateConfig(
            fp16=True))  # Note that config is an optional argument.

    # To reuse the engine elsewhere, we can serialize and save it to a file.
    # The `SaveEngine` lazy loader will return the TensorRT engine when called,
    # which allows us to chain it together with other loaders.
    build_engine = SaveEngine(build_engine, path="identity.engine")

    # Once our loader is ready, inference is simply a matter of constructing a runner,
    # activating it with a context manager (i.e. `with TrtRunner(...)`) and calling `infer()`.
    #
    # NOTE: You can use the activate() function instead of a context manager, but you will need to make sure to
    # deactivate() to avoid a memory leak. For that reason, a context manager is the safer option.
    with TrtRunner(build_engine) as runner:
        inp_data = np.ones(shape=(1, 1, 2, 2), dtype=np.float32)

        # NOTE: The runner owns the output buffers and is free to reuse them between `infer()` calls.
        # Thus, if you want to store results from multiple inferences, you should use `copy.deepcopy()`.
        outputs = runner.infer(feed_dict={"x": inp_data})

        assert np.array_equal(outputs["y"],
                              inp_data)  # It's an identity model!

        print("Inference succeeded!")
コード例 #2
0
    def check_network(self, suffix):
        """
        Checks whether the provided network is accurate compared to golden values.

        Returns:
            OrderedDict[str, OutputCompareResult]:
                    A mapping of output names to an object describing whether they matched, and what the
                    required tolerances were.
        """
        from polygraphy.comparator import Comparator, CompareFunc, DataLoader
        from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, ModifyNetwork, SaveEngine

        with G_LOGGER.verbosity(severity=G_LOGGER.severity if self.args.
                                show_output else G_LOGGER.CRITICAL):
            data_loader = tool_util.get_data_loader(self.args)

            self.args.strict_types = True  # HACK: Override strict types so things actually run in the right precision.
            config = tool_util.get_trt_config_loader(self.args,
                                                     data_loader)(self.builder,
                                                                  self.network)

            suffix = "-{:}-{:}".format(suffix, self.precision)
            engine_path = misc.insert_suffix(self.args.save_engine, suffix)

            self.builder, self.network, self.parser = ModifyNetwork(
                (self.builder, self.network, self.parser),
                outputs=self.args.trt_outputs)()

            engine_loader = SaveEngine(EngineFromNetwork(
                (self.builder, self.network, self.parser), config),
                                       path=engine_path)

            runners = [TrtRunner(engine_loader)]

            results = Comparator.run(runners, data_loader=data_loader)
            if self.args.validate:
                Comparator.validate(results)
            results.update(self.golden)

            compare_func = CompareFunc.basic_compare_func(
                atol=self.args.atol,
                rtol=self.args.rtol,
                check_shapes=not self.args.no_shape_check)
            accuracy_result = Comparator.compare_accuracy(
                results, compare_func=compare_func)

        tolerances = list(accuracy_result.values())[0][
            0]  # First iteration of first runner pair
        for name, req_tol in tolerances.items():
            if bool(req_tol):
                G_LOGGER.success(
                    "PASSED | Output: {:} | Required Tolerances: {:}".format(
                        name, req_tol))
            else:
                G_LOGGER.error(
                    "FAILED | Output: {:} | Required Tolerances: {:}".format(
                        name, req_tol))
        return accuracy_result
コード例 #3
0
ファイル: test_loader.py プロジェクト: stjordanis/TensorRT
 def test_save_engine(self, identity_network):
     with tempfile.NamedTemporaryFile() as outpath:
         engine_loader = SaveEngine(EngineFromNetwork(identity_network),
                                    path=outpath.name)
         with engine_loader():
             assert is_file_non_empty(outpath.name)
コード例 #4
0
ファイル: test_loader.py プロジェクト: leo-XUKANG/TensorRT-1
 def test_save_engine(self, load_identity):
     with tempfile.NamedTemporaryFile() as outpath:
         engine_loader = SaveEngine(EngineFromNetwork(load_identity), path=outpath.name)
         with engine_loader() as engine:
             check_file_non_empty(outpath.name)
コード例 #5
0
    '/work/gitlab/tensorrt-cookbook-in-chinese/08-Tool/Polygraphy/runExample/model.onnx'
)
parse_network_from_onnx = NetworkFromOnnxPath(
    '/work/gitlab/tensorrt-cookbook-in-chinese/08-Tool/Polygraphy/runExample/model.onnx'
)
profiles = [
    Profile().add('tensor-0',
                  min=[1, 1, 28, 28],
                  opt=[4, 1, 28, 28],
                  max=[16, 1, 28, 28])
]
create_trt_config = CreateTrtConfig(max_workspace_size=1000000000,
                                    profiles=profiles)
build_engine = EngineFromNetwork(parse_network_from_onnx,
                                 config=create_trt_config)
save_engine = SaveEngine(build_engine, path='model-FP32.plan')

# Runners
runners = [
    OnnxrtRunner(build_onnxrt_session),
    TrtRunner(save_engine),
]

# Runner Execution
results = Comparator.run(runners, data_loader=data_loader)

success = True
# Accuracy Comparison
compare_func = CompareFunc.simple(rtol={'': 0.001}, atol={'': 0.001})
success &= bool(Comparator.compare_accuracy(results,
                                            compare_func=compare_func))