def export_onnx_model(model, inputs):
    """
    Trace and export a model to onnx format.
    Args:
        model (nn.Module):
        inputs (torch.Tensor): the model will be called by `model(*inputs)`
    Returns:
        an onnx model
    """
    assert isinstance(model, torch.nn.Module)

    # make sure all modules are in eval mode, onnx may change the training state
    # of the module if the states are not consistent
    def _check_eval(module):
        assert not module.training

    model.apply(_check_eval)

    logger.info("Beginning ONNX file converting")
    # Export the model to ONNX
    with torch.no_grad():
        with io.BytesIO() as f:
            torch.onnx.export(
                model,
                inputs,
                f,
                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                # verbose=True,  # NOTE: uncomment this for debugging
                # export_params=True,
            )
            onnx_model = onnx.load_from_string(f.getvalue())

    logger.info("Completed convert of ONNX model")

    # Apply ONNX's Optimization
    logger.info("Beginning ONNX model path optimization")
    all_passes = onnxoptimizer.get_available_passes()
    passes = [
        "extract_constant_to_initializer", "eliminate_unused_initializer",
        "fuse_bn_into_conv"
    ]
    assert all(p in all_passes for p in passes)
    onnx_model = onnxoptimizer.optimize(onnx_model, passes)
    logger.info("Completed ONNX model path optimization")
    return onnx_model
Example #2
0
def optimize(
    context: mlrun.MLClientCtx,
    model_path: str,
    optimizations: List[str] = None,
    fixed_point: bool = False,
    optimized_model_name: str = None,
):
    """
    Optimize the given ONNX model.

    :param context:              The MLRun function execution context.
    :param model_path:           Path to the ONNX model object.
    :param optimizations:        List of possible optimizations. To see what optimizations are available, pass "help".
                                 If None, all of the optimizations will be used. Defaulted to None.
    :param fixed_point:          Optimize the weights using fixed point. Defaulted to False.
    :param optimized_model_name: The name of the optimized model. If None, the original model will be overridden.
                                 Defaulted to None.
    """
    # Import the model handler:
    import onnxoptimizer
    from mlrun.frameworks.onnx import ONNXModelHandler

    # Check if needed to print the available optimizations ("help" is passed):
    if optimizations == "help":
        available_passes = "\n* ".join(onnxoptimizer.get_available_passes())
        print(f"The available optimizations are:\n* {available_passes}")
        return

    # Create the model handler:
    model_handler = ONNXModelHandler(
        model_path=model_path, context=context
    )

    # Load the ONNX model:
    model_handler.load()

    # Optimize the model using the given configurations:
    model_handler.optimize(optimizations=optimizations, fixed_point=fixed_point)

    # Rename if needed:
    if optimized_model_name is not None:
        model_handler.set_model_name(model_name=optimized_model_name)

    # Log the optimized model:
    model_handler.log()