def export_onnx_model(model, inputs): """ Trace and export a model to onnx format. Args: model (nn.Module): inputs (torch.Tensor): the model will be called by `model(*inputs)` Returns: an onnx model """ assert isinstance(model, torch.nn.Module) # make sure all modules are in eval mode, onnx may change the training state # of the module if the states are not consistent def _check_eval(module): assert not module.training model.apply(_check_eval) logger.info("Beginning ONNX file converting") # Export the model to ONNX with torch.no_grad(): with io.BytesIO() as f: torch.onnx.export( model, inputs, f, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, # verbose=True, # NOTE: uncomment this for debugging # export_params=True, ) onnx_model = onnx.load_from_string(f.getvalue()) logger.info("Completed convert of ONNX model") # Apply ONNX's Optimization logger.info("Beginning ONNX model path optimization") all_passes = onnxoptimizer.get_available_passes() passes = [ "extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv" ] assert all(p in all_passes for p in passes) onnx_model = onnxoptimizer.optimize(onnx_model, passes) logger.info("Completed ONNX model path optimization") return onnx_model
def optimize( context: mlrun.MLClientCtx, model_path: str, optimizations: List[str] = None, fixed_point: bool = False, optimized_model_name: str = None, ): """ Optimize the given ONNX model. :param context: The MLRun function execution context. :param model_path: Path to the ONNX model object. :param optimizations: List of possible optimizations. To see what optimizations are available, pass "help". If None, all of the optimizations will be used. Defaulted to None. :param fixed_point: Optimize the weights using fixed point. Defaulted to False. :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden. Defaulted to None. """ # Import the model handler: import onnxoptimizer from mlrun.frameworks.onnx import ONNXModelHandler # Check if needed to print the available optimizations ("help" is passed): if optimizations == "help": available_passes = "\n* ".join(onnxoptimizer.get_available_passes()) print(f"The available optimizations are:\n* {available_passes}") return # Create the model handler: model_handler = ONNXModelHandler( model_path=model_path, context=context ) # Load the ONNX model: model_handler.load() # Optimize the model using the given configurations: model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point) # Rename if needed: if optimized_model_name is not None: model_handler.set_model_name(model_name=optimized_model_name) # Log the optimized model: model_handler.log()