def add_trt_network_loader(self, script): model_file = self.model_args.model_file model_type = self.model_args.model_type outputs = args_util.get_outputs_for_script(script, self.outputs) if model_type == "trt-network-script": script.add_import(imports=["InvokeFromScript"], frm="polygraphy.backend.common") loader_str = make_invocable("InvokeFromScript", model_file, name=self.trt_network_func_name) loader_name = script.add_loader(loader_str, "load_network") # When loading from ONNX, we need to disable custom outputs since TRT requires dtypes on outputs, which our marking function doesn't guarantee. elif self.onnx_loader_args is not None and self.onnx_loader_args.should_use_onnx_loader(disable_custom_outputs=True): script.add_import(imports=["NetworkFromOnnxBytes"], frm="polygraphy.backend.trt") onnx_loader = self.onnx_loader_args.add_serialized_onnx_loader(script, disable_custom_outputs=True) loader_str = make_invocable("NetworkFromOnnxBytes", self.trt_plugin_args.wrap_if_plugins(script, onnx_loader), explicit_precision=self.explicit_precision) loader_name = script.add_loader(loader_str, "parse_network_from_onnx") else: script.add_import(imports=["NetworkFromOnnxPath"], frm="polygraphy.backend.trt") loader_str = make_invocable("NetworkFromOnnxPath", self.trt_plugin_args.wrap_if_plugins(script, model_file), explicit_precision=self.explicit_precision) loader_name = script.add_loader(loader_str, "parse_network_from_onnx") MODIFY_NETWORK = "ModifyNetworkOutputs" modify_network_str = make_invocable(MODIFY_NETWORK, loader_name, outputs=outputs, exclude_outputs=self.exclude_outputs) if modify_network_str != make_invocable(MODIFY_NETWORK, loader_name): script.add_import(imports=[MODIFY_NETWORK], frm="polygraphy.backend.trt") loader_name = script.add_loader(modify_network_str, "modify_network") return loader_name
def add_trt_serialized_engine_loader(self, script): assert self.model_args is not None, "ModelArgs is required for engine deserialization!" script.add_import(imports=["EngineFromBytes"], frm="polygraphy.backend.trt") script.add_import(imports=["BytesFromPath"], frm="polygraphy.backend.common") load_engine = script.add_loader(make_invocable("BytesFromPath", self.model_args.model_file), "load_engine_bytes") return script.add_loader(make_invocable("EngineFromBytes", self.trt_plugin_args.wrap_if_plugins(script, load_engine)), "deserialize_engine")
def add_to_script(self, script): script.add_import(imports=["GsFromOnnx"], frm="polygraphy.backend.onnx") script.add_import(imports=["PluginRefRunner"], frm="polygraphy.backend.pluginref") onnx_name = self.onnx_loader_args.add_onnx_loader(script) loader_name = script.add_loader( make_invocable("GsFromOnnx", onnx_name), "pluginref") script.add_runner(make_invocable("PluginRefRunner", loader_name))
def add_to_script(self, script, loader_name): if self.do_shape_inference: script.add_import(imports=["InferShapes"], frm="polygraphy.backend.onnx") loader_name = script.add_loader( make_invocable("InferShapes", loader_name), "infer_shapes") return loader_name
def add_to_script(self, script): script.add_import(imports=["TfRunner"], frm="polygraphy.backend.tf") graph_name = self.tf_loader_args.add_to_script(script) config_name = self.tf_config_args.add_to_script(script) script.add_import(imports=["SessionFromGraph"], frm="polygraphy.backend.tf") loader_name = script.add_loader( make_invocable("SessionFromGraph", graph_name, config=config_name), "build_tf_session") script.add_runner( make_invocable("TfRunner", loader_name, timeline_path=self.timeline_path))
def _get_modify_onnx_loader(self, script, loader_name, disable_custom_outputs=None): if disable_custom_outputs: outputs = None exclude_outputs = None else: outputs = args_util.get_outputs_for_script(script, self.outputs) exclude_outputs = self.exclude_outputs if outputs or exclude_outputs: script.add_import(imports=["ModifyOutputs as ModifyOnnxOutputs"], frm="polygraphy.backend.onnx") loader_name = script.add_loader( make_invocable("ModifyOnnxOutputs", loader_name, outputs=outputs, exclude_outputs=exclude_outputs), "modify_outputs") if self.onnx_shape_inference_args is not None: loader_name = self.onnx_shape_inference_args.add_to_script( script, loader_name) return loader_name
def add_onnx_loader(self, script, disable_custom_outputs=None, suffix=None): model_type = self.model_args.model_type if model_type.is_onnx(): loader_name = self.model_args.model_file if self.onnx_shape_inference_args is not None: loader_name = self.onnx_shape_inference_args.add_to_script(script, loader_name) if loader_name == self.model_args.model_file: # Shape inference loader isn't being used, have to load. script.add_import(imports=["OnnxFromPath"], frm="polygraphy.backend.onnx") loader_str = make_invocable( "OnnxFromPath", self.model_args.model_file, external_data_dir=self.load_external_data ) loader_name = script.add_loader(loader_str, "load_onnx", suffix=suffix) elif model_type.is_tf(): if self.tf2onnx_loader_args is None: G_LOGGER.critical("Could not load: {:}. Is it an ONNX model?".format(self.model_args.model_file)) loader_name = self.tf2onnx_loader_args.add_to_script(script) else: G_LOGGER.critical("Model type: {:} cannot be converted to ONNX.".format(model_type)) loader_name = self._get_modify_onnx_loader(script, loader_name, disable_custom_outputs=disable_custom_outputs) if self.onnx_save_args is not None: loader_name = self.onnx_save_args.add_save_onnx(script, loader_name) return loader_name
def add_to_script(self, script): script.add_import(imports=["Comparator"], frm="polygraphy.comparator") RESULTS_VAR_NAME = inline(safe("results")) comparator_run = make_invocable( "Comparator.run", script.get_runners(), warm_up=self.warm_up, data_loader=self.data_loader_args.add_to_script(script), use_subprocess=self.use_subprocess, save_inputs_path=self.save_inputs) script.append_suffix( safe("\n# Runner Execution\n{results} = {:}", comparator_run, results=RESULTS_VAR_NAME)) if self.save_results: G_LOGGER.verbose("Will save runner results to: {:}".format( self.save_results)) script.add_import(imports=["util"], frm="polygraphy") script.append_suffix( safe("\n# Save results\n{results}.save({:})", self.save_results, results=RESULTS_VAR_NAME)) return RESULTS_VAR_NAME
def add_serialized_onnx_loader(self, script, disable_custom_outputs=None): script.add_import(imports=["BytesFromOnnx"], frm="polygraphy.backend.onnx") onnx_loader = self.add_onnx_loader( script, disable_custom_outputs=disable_custom_outputs) return script.add_loader(make_invocable("BytesFromOnnx", onnx_loader), "serialize_onnx")
def add_to_script(self, script): script.add_import(imports=["OnnxrtRunner"], frm="polygraphy.backend.onnxrt") if self.onnx_loader_args.should_use_onnx_loader(): onnx_name = self.onnx_loader_args.add_serialized_onnx_loader( script) else: onnx_name = self.model_args.model_file script.add_import(imports=["SessionFromOnnx"], frm="polygraphy.backend.onnxrt") loader_name = script.add_loader( make_invocable("SessionFromOnnx", onnx_name), "build_onnxrt_session") script.add_runner(make_invocable("OnnxrtRunner", loader_name))
def add_to_script(self, script, loader_name): if self.do_shape_inference: script.add_import(imports=["InferShapes"], frm="polygraphy.backend.onnx") external_data_dir = self.onnx_loader_args.load_external_data if self.onnx_loader_args is not None else None loader_name = script.add_loader( make_invocable("InferShapes", loader_name, external_data_dir=external_data_dir), "infer_shapes" ) return loader_name
def add_to_script(self, script): script.add_import(imports=["TrtRunner"], frm="polygraphy.backend.trt") if self.model_args.model_type == "engine": loader_name = self.trt_engine_loader_args.add_trt_serialized_engine_loader(script) else: loader_name = self.trt_engine_loader_args.add_trt_build_engine_loader(script) script.add_runner(make_invocable("TrtRunner", loader_name))
def add_save_onnx(self, script, loader_name): if self.path is None: return loader_name script.add_import(imports=["SaveOnnx"], frm="polygraphy.backend.onnx") loader_name = script.add_loader( make_invocable("SaveOnnx", loader_name, path=self.path, external_data_path=self.save_external_data), "save_onnx") # Need to run shape inference again after processing the graph since it may have changed. if self.onnx_shape_inference_args is not None: loader_name = self.onnx_shape_inference_args.add_to_script( script, loader_name) return loader_name
def _add_to_script(self, script, user_input_metadata_str=None): needs_invoke = False using_random_data = False if self.data_loader_script: script.add_import(imports=["mod"], frm="polygraphy") data_loader = make_invocable("mod.import_from_script", self.data_loader_script, name=self.data_loader_func_name) needs_invoke = True elif self.load_inputs: script.add_import(imports=["load_json"], frm="polygraphy.json") data_loader = safe( "[]\nfor input_data_path in {load_inputs}:" "\n\t{data_loader}.extend(load_json(input_data_path, description='input data'))", load_inputs=self.load_inputs, data_loader=Script.DATA_LOADER_NAME, ) else: using_random_data = True if user_input_metadata_str is None and self.model_args is not None and self.model_args.input_shapes: user_input_metadata_str = self.model_args.input_shapes if user_input_metadata_str: script.add_import(imports=["TensorMetadata"], frm="polygraphy.common") data_loader = make_invocable_if_nondefault( "DataLoader", seed=self.seed, iterations=self.iterations, input_metadata=user_input_metadata_str, int_range=self.int_range, float_range=self.float_range, val_range=self.val_range, ) if data_loader: script.add_import(imports=["DataLoader"], frm="polygraphy.comparator") if using_random_data != self.is_using_random_data(): G_LOGGER.internal_error( "is_using_random_data() reported a false positive!") return script.set_data_loader(data_loader), needs_invoke
def add_trt_build_engine_loader(self, script, network_name=None): if network_name: network_loader_name = network_name else: assert self.trt_network_loader_args is not None, "TrtNetworkLoaderArgs is required for engine building!" network_loader_name = self.trt_network_loader_args.add_trt_network_loader(script) assert self.trt_config_args is not None, "TrtConfigArgs is required for engine building!" script.add_import(imports=["EngineFromNetwork"], frm="polygraphy.backend.trt") config_loader_name = self.trt_config_args.add_trt_config_loader(script) loader_str = make_invocable("EngineFromNetwork", self.trt_plugin_args.wrap_if_plugins(script, network_loader_name), config=config_loader_name, save_timing_cache=self.trt_config_args.timing_cache) loader_name = script.add_loader(loader_str, "build_engine") if self.trt_engine_save_args is not None: loader_name = self.trt_engine_save_args.add_save_engine(script, loader_name) return loader_name
def add_to_script(self, script, suffix=None): G_LOGGER.verbose( "Attempting to load as a TensorFlow model, using TF2ONNX to convert to ONNX. " "If this is not correct, please specify --model-type", mode=LogMode.ONCE) script.add_import(imports=["OnnxFromTfGraph"], frm="polygraphy.backend.onnx") loader_str = make_invocable("OnnxFromTfGraph", self.tf_loader_args.add_to_script( script, disable_custom_outputs=True, suffix=suffix), opset=self.opset, fold_constant=self.fold_constant) loader_name = script.add_loader(loader_str, "export_onnx_from_tf", suffix=suffix) return loader_name
def add_to_script(self, script, user_input_metadata_str=None): """ Adds a DataLoader to the script. Args: user_input_metadata_str (str(TensorMetadata)): The name of a variable containing TensorMetadata. This will control the shape and data type of the generated data. """ if self.data_loader_script: script.add_import(imports=["invoke_from_script"], frm="polygraphy.backend.common") data_loader = make_invocable("invoke_from_script", self.data_loader_script, name=self.data_loader_func_name) elif self.load_inputs: script.add_import(imports=["load_json"], frm="polygraphy.json") data_loader = safe( "[]\nfor input_data_path in {load_inputs}:" "\n\t{data_loader}.extend(load_json(input_data_path, description='input data'))", load_inputs=self.load_inputs, data_loader=Script.DATA_LOADER_NAME) else: if user_input_metadata_str is None and self.model_args is not None and self.model_args.input_shapes: user_input_metadata_str = self.model_args.input_shapes if user_input_metadata_str: script.add_import(imports=["TensorMetadata"], frm="polygraphy.common") data_loader = make_invocable_if_nondefault( "DataLoader", seed=self.seed, iterations=self.iterations, input_metadata=user_input_metadata_str, int_range=self.int_range, float_range=self.float_range, val_range=self.val_range) if data_loader: script.add_import(imports=["DataLoader"], frm="polygraphy.comparator") return script.set_data_loader(data_loader)
def add_data_loader(self, script, *args, **kwargs): """ Adds a DataLoader to the script. Args: user_input_metadata_str (str(TensorMetadata)): The name of a variable containing TensorMetadata. This will control the shape and data type of the generated data. Returns: str: The data loader, as a string. This may either be the variable name, or an invocation of the data loader function. """ data_loader, needs_invoke = self._add_to_script( script, *args, **kwargs) if needs_invoke: data_loader = make_invocable(data_loader) return data_loader
def test_invoke_none_args(self): assert make_invocable("Dummy", None).unwrap() == "Dummy(None)" assert make_invocable("Dummy", x=None).unwrap() == "Dummy()"
def add_to_script(self, script): script.add_import(imports=["TrtLegacyRunner"], frm="polygraphy.backend.trt_legacy") G_LOGGER.warning("Legacy TensorRT runner only supports implicit batch TensorFlow/UFF, ONNX, and Caffe models") load_engine = self.model_args.model_file if self.model_args.model_type == "engine" else None loader_name = None if self.model_args.model_type == "onnx": script.add_import(imports=["ParseNetworkFromOnnxLegacy"], frm="polygraphy.backend.trt_legacy") onnx_loader = self.onnx_loader_args.add_onnx_loader(script, disable_custom_outputs=True) loader_name = script.add_loader( make_invocable("ParseNetworkFromOnnxLegacy", onnx_loader), "parse_network_from_onnx_legacy" ) elif self.model_args.model_type == "caffe": script.add_import(imports=["LoadNetworkFromCaffe"], frm="polygraphy.backend.trt_legacy") loader_name = script.add_loader( make_invocable( "LoadNetworkFromCaffe", self.model_args.model_file, self.caffe_model, self.trt_outputs, self.batch_size, ), "parse_network_from_caffe", ) elif load_engine is None: script.add_import(imports=["LoadNetworkFromUff"], frm="polygraphy.backend.trt_legacy") if self.model_args.model_type == "uff": script.add_import(imports=["LoadUffFile"], frm="polygraphy.backend.trt_legacy") shapes = {name: shape for name, (_, shape) in self.model_args.input_shapes.items()} loader_name = script.add_loader( make_invocable( "LoadUffFile", self.model_args.model_file, util.default(shapes, {}), self.trt_outputs ), "load_uff_file", ) else: script.add_import(imports=["ConvertToUff"], frm="polygraphy.backend.trt_legacy") loader_name = script.add_loader( make_invocable( "ConvertToUff", self.tf_loader_args.add_to_script(script), save_uff=self.save_uff, preprocessor=self.preprocessor, ), "convert_to_uff", ) loader_name = script.add_loader( make_invocable("LoadNetworkFromUff", loader_name, uff_order=self.uff_order), "uff_network_loader" ) runner_str = make_invocable( "TrtLegacyRunner", network_loader=loader_name, max_workspace_size=self.trt_config_args.workspace, max_batch_size=self.batch_size, fp16=self.trt_config_args.fp16, tf32=self.trt_config_args.tf32, load_engine=load_engine, save_engine=self.trt_engine_save_args.path, layerwise=self.trt_outputs == constants.MARK_ALL, plugins=self.trt_engine_loader_args.plugins, ) script.add_runner(runner_str)
def add_to_script(self, script, results_name): script.add_import(imports=["Comparator"], frm="polygraphy.comparator") if self.load_results: script.add_import(imports=["util"], frm="polygraphy") script.add_import(imports=["RunResults"], frm="polygraphy.comparator") script.append_suffix( safe( "\n# Load results\nfor load_output in {:}:\n\t{results}.extend(RunResults.load(load_output))", self.load_results, results=results_name)) if self.top_k is not None: script.add_import(imports=["PostprocessFunc"], frm="polygraphy.comparator") script.append_suffix( safe( "\n# Postprocessing - Apply Top-{top_k}\n" "{results} = Comparator.postprocess({results}, PostprocessFunc.topk_func(k={top_k}))", top_k=self.top_k, results=results_name)) SUCCESS_VAR_NAME = inline(safe("success")) script.append_suffix( safe("\n{success} = True", success=SUCCESS_VAR_NAME)) if len( self.runners ) > 1 or self.load_results: # Only do comparisons if there's actually something to compare. script.append_suffix(safe("# Accuracy Comparison")) compare_func_str = make_invocable_if_nondefault( "CompareFunc.basic_compare_func", rtol=self.rtol, atol=self.atol, check_shapes=False if self.no_shape_check else None, fail_fast=self.fail_fast, check_error_stat=self.check_error_stat) compare_func = None if compare_func_str: script.add_import(imports=["CompareFunc"], frm="polygraphy.comparator") compare_func = inline(safe("compare_func")) script.append_suffix( safe("{:} = {:}", compare_func, compare_func_str)) compare_accuracy = make_invocable("Comparator.compare_accuracy", results_name, compare_func=compare_func, fail_fast=self.fail_fast) script.append_suffix( safe("{success} &= bool({:})\n", compare_accuracy, success=SUCCESS_VAR_NAME)) if self.validate: script.append_suffix( safe( "# Validation\n{success} &= Comparator.validate({results}, check_inf=True, check_nan=True)\n", success=SUCCESS_VAR_NAME, results=results_name)) return SUCCESS_VAR_NAME
def add_trt_config_loader(self, script): profiles = [] for (min_shape, opt_shape, max_shape) in self.profile_dicts: profile_str = "Profile()" for name in min_shape.keys(): profile_str += safe(".add({:}, min={:}, opt={:}, max={:})", name, min_shape[name], opt_shape[name], max_shape[name]).unwrap() profiles.append(profile_str) if profiles: script.add_import(imports=["Profile"], frm="polygraphy.backend.trt") profiles = safe("[\n\t{:}\n]", inline(safe(",\n\t".join(profiles)))) profile_name = script.add_loader(profiles, "profiles") else: profile_name = None calibrator = None if any(arg is not None for arg in [ self.calibration_cache, self.calibration_base_class ]) and not self.int8: G_LOGGER.warning( "Some int8 calibrator options were set, but int8 precision is not enabled. " "Calibration options will be ignored. Please set --int8 to enable calibration. " ) if self.int8 and self.data_loader_args is not None: # We cannot do calibration if there is no data loader. script.add_import(imports=["Calibrator"], frm="polygraphy.backend.trt") script.add_import(imports=["DataLoader"], frm="polygraphy.comparator") data_loader_name = self.data_loader_args.add_data_loader(script) if self.calibration_base_class: script.add_import(imports=["tensorrt as trt"]) calibrator = make_invocable( "Calibrator", data_loader=data_loader_name if data_loader_name else inline( safe("DataLoader()")), cache=self.calibration_cache, BaseClass=self.calibration_base_class, quantile=self.quantile, regression_cutoff=self.regression_cutoff, ) algo_selector = None if self.load_tactics is not None: script.add_import(imports=["TacticReplayer"], frm="polygraphy.backend.trt") algo_selector = make_invocable("TacticReplayer", replay=self.load_tactics) elif self.save_tactics is not None: script.add_import(imports=["TacticRecorder"], frm="polygraphy.backend.trt") algo_selector = make_invocable("TacticRecorder", record=self.save_tactics) if self.tactic_sources is not None: script.add_import(imports=["tensorrt as trt"]) if self.trt_config_script is not None: script.add_import(imports=["InvokeFromScript"], frm="polygraphy.backend.common") config_loader_str = make_invocable("InvokeFromScript", self.trt_config_script, name=self.trt_config_func_name) else: config_loader_str = make_invocable_if_nondefault( "CreateTrtConfig", max_workspace_size=self.workspace, tf32=self.tf32, fp16=self.fp16, int8=self.int8, strict_types=self.strict_types, restricted=self.restricted, profiles=profile_name, calibrator=calibrator, load_timing_cache=(self.timing_cache if self.timing_cache and os.path.exists(self.timing_cache) else None), algorithm_selector=algo_selector, sparse_weights=self.sparse_weights, tactic_sources=self.tactic_sources, ) if config_loader_str is not None: script.add_import(imports=["CreateConfig as CreateTrtConfig"], frm="polygraphy.backend.trt") if config_loader_str is not None: config_loader_name = script.add_loader(config_loader_str, "create_trt_config") else: config_loader_name = None return config_loader_name
def wrap_if_plugins(self, script, loader_name): if self.plugins: script.add_import(imports=["LoadPlugins"], frm="polygraphy.backend.trt") loader_str = make_invocable("LoadPlugins", plugins=self.plugins, obj=loader_name) loader_name = script.add_loader(loader_str, "load_plugins") return loader_name
def add_save_engine(self, script, loader_name): if self.path is None: return loader_name script.add_import(imports=["SaveEngine"], frm="polygraphy.backend.trt") return script.add_loader(make_invocable("SaveEngine", loader_name, path=self.path), "save_engine")
def add_to_script(self, script): script.add_import(imports=["TrtLegacyRunner"], frm="polygraphy.backend.trt_legacy") G_LOGGER.warning("Legacy TensorRT runner only supports implicit batch TensorFlow/UFF, ONNX, and Caffe models") load_engine = self.model_args.model_file if self.model_args.model_type == "engine" else None loader_name = None if self.model_args.model_type == "onnx": script.add_import(imports=["ParseNetworkFromOnnxLegacy"], frm="polygraphy.backend.trt_legacy") onnx_loader = self.onnx_loader_args.add_onnx_loader(script, disable_custom_outputs=True) loader_name = script.add_loader( make_invocable("ParseNetworkFromOnnxLegacy", onnx_loader), "parse_network_from_onnx_legacy" ) elif self.model_args.model_type == "caffe": script.add_import(imports=["LoadNetworkFromCaffe"], frm="polygraphy.backend.trt_legacy") loader_name = script.add_loader( make_invocable( "LoadNetworkFromCaffe", self.model_args.model_file, self.caffe_model, self.trt_outputs, self.batch_size, ), "parse_network_from_caffe", ) elif load_engine is None: script.add_import(imports=["LoadNetworkFromUff"], frm="polygraphy.backend.trt_legacy") if self.model_args.model_type == "uff": script.add_import(imports=["LoadUffFile"], frm="polygraphy.backend.trt_legacy") shapes = {name: shape for name, (_, shape) in self.model_args.input_shapes.items()} loader_name = script.add_loader( make_invocable( "LoadUffFile", self.model_args.model_file, util.default(shapes, {}), self.trt_outputs ), "load_uff_file", ) else: script.add_import(imports=["ConvertToUff"], frm="polygraphy.backend.trt_legacy") loader_name = script.add_loader( make_invocable( "ConvertToUff", self.tf_loader_args.add_to_script(script), save_uff=self.save_uff, preprocessor=self.preprocessor, ), "convert_to_uff", ) loader_name = script.add_loader( make_invocable("LoadNetworkFromUff", loader_name, uff_order=self.uff_order), "uff_network_loader" ) calibrator = None if ( self.trt_config_args.int8 and self.data_loader_args is not None ): # We cannot do calibration if there is no data loader. script.add_import(imports=["Calibrator"], frm="polygraphy.backend.trt") script.add_import(imports=["DataLoader"], frm="polygraphy.comparator") data_loader_name = self.data_loader_args.add_data_loader(script) if self.calibration_base_class: script.add_import(imports=["tensorrt as trt"]) calibrator = make_invocable( "Calibrator", data_loader=data_loader_name if data_loader_name else inline(safe("DataLoader()")), cache=self.calibration_cache, BaseClass=self.calibration_base_class, quantile=self.quantile, regression_cutoff=self.regression_cutoff, ) runner_str = make_invocable( "TrtLegacyRunner", network_loader=loader_name, max_workspace_size=self.trt_config_args.workspace, max_batch_size=self.batch_size, fp16=self.trt_config_args.fp16, tf32=self.trt_config_args.tf32, load_engine=load_engine, save_engine=self.trt_engine_save_args.path, layerwise=self.trt_outputs == constants.MARK_ALL, plugins=self.trt_engine_loader_args.plugins, int8=self.trt_config_args.int8, calibrator=calibrator, use_dla=self.use_dla, allow_gpu_fallback=self.allow_gpu_fallback, ) script.add_runner(runner_str)
def add_to_script(self, script, disable_custom_outputs=None, suffix=None): if disable_custom_outputs: outputs = None else: outputs = args_util.get_outputs_for_script(script, self.outputs) model_file = self.model_args.model_file model_type = self.model_args.model_type if model_type == "ckpt": G_LOGGER.verbose( "Loading a TensorFlow checkpoint. Please ensure you are not using the --use-subprocess flag" .format(model_file), mode=LogMode.ONCE, ) script.add_import(imports=["GraphFromCkpt"], frm="polygraphy.backend.tf") loader_id = "load_ckpt" loader_str = make_invocable("GraphFromCkpt", model_file, self.ckpt) elif model_type == "keras": script.add_import(imports=["GraphFromKeras"], frm="polygraphy.backend.tf") loader_id = "load_keras" loader_str = make_invocable("GraphFromKeras", model_file) elif model_type == "frozen": script.add_import(imports=["GraphFromFrozen"], frm="polygraphy.backend.tf") G_LOGGER.verbose( "Attempting to load as a frozen graph. If this is not correct, please specify --model-type", mode=LogMode.ONCE, ) loader_id = "load_frozen" loader_str = make_invocable("GraphFromFrozen", model_file) else: G_LOGGER.critical( "Model type: {:} cannot be imported with TensorFlow.".format( model_type)) loader_name = script.add_loader(loader_str, loader_id, suffix=suffix) if self.freeze_graph: script.add_import(imports=["OptimizeGraph"], frm="polygraphy.backend.tf") loader_name = script.add_loader(make_invocable( "OptimizeGraph", loader_name), "optimize_graph", suffix=suffix) if self.tftrt: script.add_import(imports=["UseTfTrt"], frm="polygraphy.backend.tf") loader_str = make_invocable( "UseTfTrt", loader_name, max_workspace_size=self.trt_config_args.workspace, fp16=self.trt_config_args.fp16, int8=self.trt_config_args.int8, max_batch_size=self.trt_legacy_args.batch_size, is_dynamic_op=self.dynamic_op, minimum_segment_size=self.minimum_segment_size, ) loader_name = script.add_loader(loader_str, "use_tftrt", suffix=suffix) MODIFY_TF = "ModifyGraphOutputs" modify_tf_str = make_invocable(MODIFY_TF, loader_name, outputs=outputs) if modify_tf_str != make_invocable(MODIFY_TF, loader_name): script.add_import(imports=[MODIFY_TF], frm="polygraphy.backend.tf") loader_name = script.add_loader(modify_tf_str, "modify_tf") engine_dir = None if self.tftrt: engine_dir = self.trt_engine_save_args.path WRITE_TF = "SaveGraph" write_tf_str = make_invocable(WRITE_TF, loader_name, path=self.save_pb, tensorboard_dir=self.save_tensorboard, engine_dir=engine_dir) if write_tf_str != make_invocable(WRITE_TF, loader_name): script.add_import(imports=[WRITE_TF], frm="polygraphy.backend.tf") loader_name = script.add_loader(write_tf_str, "save_tf") return loader_name