def run(self, args): script = Script( summary= "Defines or modifies a TensorRT Network using the Network API.", always_create_runners=False) script.add_import(imports=["func"], frm="polygraphy") script.add_import(imports=["tensorrt as trt"]) if self.arg_groups[ModelArgs].model_file is not None: loader_name = self.arg_groups[ TrtNetworkLoaderArgs].add_trt_network_loader(script) params = safe("builder, network, parser") else: script.add_import(imports=["CreateNetwork"], frm="polygraphy.backend.trt") loader_name = safe("CreateNetwork()") params = safe("builder, network") script.append_suffix(safe("@func.extend({:})", inline(loader_name))) script.append_suffix(safe("def load_network({:}):", inline(params))) script.append_suffix( safe( "\tpass # TODO: Set up the network here. This function should not return anything." )) script.save(args.output)
def run(self, args): script = Script(summary="Creates a TensorRT Builder Configuration.", always_create_runners=False) script.add_import(imports=["func"], frm="polygraphy") script.add_import(imports=["tensorrt as trt"]) loader_name = self.arg_groups[TrtConfigArgs].add_trt_config_loader( script) if not loader_name: script.add_import(imports=["CreateConfig"], frm="polygraphy.backend.trt") loader_name = script.add_loader(safe("CreateConfig()"), "create_trt_config") params = safe("config") script.append_suffix(safe("@func.extend({:})", inline(loader_name))) script.append_suffix(safe("def load_config({:}):", inline(params))) script.append_suffix( safe( "\tpass # TODO: Set up the builder configuration here. This function should not return anything." )) script.save(args.output)