示例#1
0
def export_tensorflow(
    tokenizer: PreTrainedTokenizer,
    model: TFPreTrainedModel,
    config: OnnxConfig,
    opset: int,
    output: Path,
) -> Tuple[List[str], List[str]]:
    """
    Export a TensorFlow model to an ONNX Intermediate Representation (IR)

    Args:
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer used for encoding the data.
        model ([`TFPreTrainedModel`]):
            The model to export.
        config ([`~onnx.config.OnnxConfig`]):
            The ONNX configuration associated with the exported model.
        opset (`int`):
            The version of the ONNX operator set to use.
        output (`Path`):
            Directory to store the exported ONNX model.

    Returns:
        `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
        the ONNX configuration.
    """
    import tensorflow as tf

    import onnx
    import tf2onnx

    model.config.return_dict = True

    # Check if we need to override certain configuration item
    if config.values_override is not None:
        logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
        for override_config_key, override_config_value in config.values_override.items():
            logger.info(f"\t- {override_config_key} -> {override_config_value}")
            setattr(model.config, override_config_key, override_config_value)

    # Ensure inputs match
    model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
    inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
    onnx_outputs = list(config.outputs.keys())

    input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
    onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
    onnx.save(onnx_model, output.as_posix())
    config.restore_ops()

    return matched_inputs, onnx_outputs
示例#2
0
def export_pytorch(
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel,
    config: OnnxConfig,
    opset: int,
    output: Path,
) -> Tuple[List[str], List[str]]:
    """
    Export a PyTorch model to an ONNX Intermediate Representation (IR)

    Args:
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer used for encoding the data.
        model ([`PreTrainedModel`]):
            The model to export.
        config ([`~onnx.config.OnnxConfig`]):
            The ONNX configuration associated with the exported model.
        opset (`int`):
            The version of the ONNX operator set to use.
        output (`Path`):
            Directory to store the exported ONNX model.

    Returns:
        `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
        the ONNX configuration.
    """
    if issubclass(type(model), PreTrainedModel):
        import torch
        from torch.onnx import export as onnx_export

        logger.info(f"Using framework PyTorch: {torch.__version__}")
        with torch.no_grad():
            model.config.return_dict = True
            model.eval()

            # Check if we need to override certain configuration item
            if config.values_override is not None:
                logger.info(
                    f"Overriding {len(config.values_override)} configuration item(s)"
                )
                for override_config_key, override_config_value in config.values_override.items(
                ):
                    logger.info(
                        f"\t- {override_config_key} -> {override_config_value}"
                    )
                    setattr(model.config, override_config_key,
                            override_config_value)

            # Ensure inputs match
            # TODO: Check when exporting QA we provide "is_pair=True"
            model_inputs = config.generate_dummy_inputs(
                tokenizer, framework=TensorType.PYTORCH)
            inputs_match, matched_inputs = ensure_model_and_config_inputs_match(
                model, model_inputs.keys())
            onnx_outputs = list(config.outputs.keys())

            if not inputs_match:
                raise ValueError("Model and config inputs doesn't match")

            config.patch_ops()

            # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
            # so we check the torch version for backwards compatibility
            if parse(torch.__version__) <= parse("1.10.99"):
                # export can work with named args but the dict containing named args
                # has to be the last element of the args tuple.
                onnx_export(
                    model,
                    (model_inputs, ),
                    f=output.as_posix(),
                    input_names=list(config.inputs.keys()),
                    output_names=onnx_outputs,
                    dynamic_axes={
                        name: axes
                        for name, axes in chain(config.inputs.items(),
                                                config.outputs.items())
                    },
                    do_constant_folding=True,
                    use_external_data_format=config.use_external_data_format(
                        model.num_parameters()),
                    enable_onnx_checker=True,
                    opset_version=opset,
                )
            else:
                onnx_export(
                    model,
                    (model_inputs, ),
                    f=output.as_posix(),
                    input_names=list(config.inputs.keys()),
                    output_names=onnx_outputs,
                    dynamic_axes={
                        name: axes
                        for name, axes in chain(config.inputs.items(),
                                                config.outputs.items())
                    },
                    do_constant_folding=True,
                    opset_version=opset,
                )

            config.restore_ops()

    return matched_inputs, onnx_outputs
示例#3
0
def validate_model_outputs(
    config: OnnxConfig,
    tokenizer: PreTrainedTokenizer,
    reference_model: Union[PreTrainedModel, TFPreTrainedModel],
    onnx_model: Path,
    onnx_named_outputs: List[str],
    atol: float,
):
    from onnxruntime import InferenceSession, SessionOptions

    logger.info("Validating ONNX model...")

    # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
    # dynamic input shapes.
    if issubclass(type(reference_model), PreTrainedModel):
        reference_model_inputs = config.generate_dummy_inputs(
            tokenizer, framework=TensorType.PYTORCH)
    else:
        reference_model_inputs = config.generate_dummy_inputs(
            tokenizer, framework=TensorType.TENSORFLOW)

    # Create ONNX Runtime session
    options = SessionOptions()
    session = InferenceSession(onnx_model.as_posix(),
                               options,
                               providers=["CPUExecutionProvider"])

    # Compute outputs from the reference model
    ref_outputs = reference_model(**reference_model_inputs)
    ref_outputs_dict = {}

    # We flatten potential collection of outputs (i.e. past_keys) to a flat structure
    for name, value in ref_outputs.items():
        # Overwriting the output name as "present" since it is the name used for the ONNX outputs
        # ("past_key_values" being taken for the ONNX inputs)
        if name == "past_key_values":
            name = "present"
        if isinstance(value, (list, tuple)):
            value = config.flatten_output_collection_property(name, value)
            ref_outputs_dict.update(value)
        else:
            ref_outputs_dict[name] = value

    # We flatten potential collection of inputs (i.e. past_keys)
    onnx_inputs = {}
    for name, value in reference_model_inputs.items():
        if isinstance(value, (list, tuple)):
            value = config.flatten_output_collection_property(name, value)
            onnx_inputs.update({
                tensor_name: pt_tensor.numpy()
                for tensor_name, pt_tensor in value.items()
            })
        else:
            onnx_inputs[name] = value.numpy()

    # Compute outputs from the ONNX model
    onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)

    # Check we have a subset of the keys into onnx_outputs against ref_outputs
    ref_outputs_set, onnx_outputs_set = set(
        ref_outputs_dict.keys()), set(onnx_named_outputs)
    if not onnx_outputs_set.issubset(ref_outputs_set):
        logger.info(
            f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}"
        )

        raise ValueError(
            "Outputs doesn't match between reference model and ONNX exported model: "
            f"{onnx_outputs_set.difference(ref_outputs_set)}")
    else:
        logger.info(
            f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})"
        )

    # Check the shape and values match
    for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
        if issubclass(type(reference_model), PreTrainedModel):
            ref_value = ref_outputs_dict[name].detach().numpy()
        else:
            ref_value = ref_outputs_dict[name].numpy()
        logger.info(f'\t- Validating ONNX Model output "{name}":')

        # Shape
        if not ort_value.shape == ref_value.shape:
            logger.info(
                f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}"
            )
            raise ValueError(
                "Outputs shape doesn't match between reference model and ONNX exported model: "
                f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
            )
        else:
            logger.info(
                f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")

        # Values
        if not np.allclose(ref_value, ort_value, atol=atol):
            logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
            raise ValueError(
                "Outputs values doesn't match between reference model and ONNX exported model: "
                f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
            )
        else:
            logger.info(f"\t\t-[✓] all values close (atol: {atol})")
示例#4
0
def export(
    tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
) -> Tuple[List[str], List[str]]:
    """
    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR

    Args:
        tokenizer:
        model:
        config:
        opset:
        output:

    Returns:

    """
    if not is_torch_available():
        raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")

    import torch
    from torch.onnx import export

    from ..file_utils import torch_version

    if not is_torch_onnx_dict_inputs_support_available():
        raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

    logger.info(f"Using framework PyTorch: {torch.__version__}")
    with torch.no_grad():
        model.config.return_dict = True
        model.eval()

        # Check if we need to override certain configuration item
        if config.values_override is not None:
            logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
            for override_config_key, override_config_value in config.values_override.items():
                logger.info(f"\t- {override_config_key} -> {override_config_value}")
                setattr(model.config, override_config_key, override_config_value)

        # Ensure inputs match
        # TODO: Check when exporting QA we provide "is_pair=True"
        model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
        inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
        onnx_outputs = list(config.outputs.keys())

        if not inputs_match:
            raise ValueError("Model and config inputs doesn't match")

        config.patch_ops()

        # export can works with named args but the dict containing named args as to be last element of the args tuple
        export(
            model,
            (model_inputs,),
            f=output.as_posix(),
            input_names=list(config.inputs.keys()),
            output_names=onnx_outputs,
            dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
            do_constant_folding=True,
            use_external_data_format=config.use_external_data_format(model.num_parameters()),
            enable_onnx_checker=True,
            opset_version=opset,
        )

        config.restore_ops()

    return matched_inputs, onnx_outputs