Exemplo n.º 1
0
    def run(self, args):
        script = Script(
            summary=
            "Defines or modifies a TensorRT Network using the Network API.",
            always_create_runners=False)
        script.add_import(imports=["func"], frm="polygraphy")
        script.add_import(imports=["tensorrt as trt"])

        if self.arg_groups[ModelArgs].model_file is not None:
            loader_name = self.arg_groups[
                TrtNetworkLoaderArgs].add_trt_network_loader(script)
            params = safe("builder, network, parser")
        else:
            script.add_import(imports=["CreateNetwork"],
                              frm="polygraphy.backend.trt")
            loader_name = safe("CreateNetwork()")
            params = safe("builder, network")

        script.append_suffix(safe("@func.extend({:})", inline(loader_name)))
        script.append_suffix(safe("def load_network({:}):", inline(params)))
        script.append_suffix(
            safe(
                "\tpass # TODO: Set up the network here. This function should not return anything."
            ))

        script.save(args.output)
Exemplo n.º 2
0
    def run(self, args):
        script = Script(summary="Creates a TensorRT Builder Configuration.",
                        always_create_runners=False)
        script.add_import(imports=["func"], frm="polygraphy")
        script.add_import(imports=["tensorrt as trt"])

        loader_name = self.arg_groups[TrtConfigArgs].add_trt_config_loader(
            script)
        if not loader_name:
            script.add_import(imports=["CreateConfig"],
                              frm="polygraphy.backend.trt")
            loader_name = script.add_loader(safe("CreateConfig()"),
                                            "create_trt_config")
        params = safe("config")

        script.append_suffix(safe("@func.extend({:})", inline(loader_name)))
        script.append_suffix(safe("def load_config({:}):", inline(params)))
        script.append_suffix(
            safe(
                "\tpass # TODO: Set up the builder configuration here. This function should not return anything."
            ))

        script.save(args.output)