Exemplo n.º 1
0
def toco_convert_protos(model_flags_str,
                        toco_flags_str,
                        input_data_str,
                        debug_info_str=None,
                        enable_mlir_converter=False):
    """Convert `input_data_str` according to model and toco parameters.

  Unless you know what you are doing consider using
  the more friendly `tf.compat.v1.lite.toco_convert`.

  Args:
    model_flags_str: Serialized proto describing model properties, see
      `toco/model_flags.proto`.
    toco_flags_str: Serialized proto describing conversion properties, see
      `toco/toco_flags.proto`.
    input_data_str: Input data in serialized form (e.g. a graphdef is common, or
      it can be hlo text or proto)
    debug_info_str: Serialized `GraphDebugInfo` proto describing logging
      information. (default None)
    enable_mlir_converter: Enables MLIR-based conversion instead of the default
      TOCO conversion. (default False)

  Returns:
    Converted model in serialized form (e.g. a TFLITE model is common).
  Raises:
    ConverterError: When conversion fails in TFLiteConverter, usually due to
      ops not being supported.
    RuntimeError: When conversion fails, an exception is raised with the error
      message embedded.
  """
    # Historically, TOCO conversion failures would trigger a crash, so we would
    # attempt to run the converter out-of-process. The MLIR conversion pipeline
    # surfaces errors instead, and can be safely run in-process.
    if enable_mlir_converter or not _toco_from_proto_bin:
        try:
            model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
                                                       toco_flags_str,
                                                       input_data_str,
                                                       debug_info_str,
                                                       enable_mlir_converter)
            return model_str
        except Exception as e:
            converter_error = ConverterError(str(e))
            for error_data in _metrics_wrapper.retrieve_collected_errors():
                converter_error.append_error(error_data)
            raise converter_error

    return _run_toco_binary(model_flags_str, toco_flags_str, input_data_str,
                            debug_info_str)
Exemplo n.º 2
0
def convert(model_flags_str,
            conversion_flags_str,
            input_data_str,
            debug_info_str=None,
            enable_mlir_converter=True):
    """Converts `input_data_str` to a TFLite model.

  Args:
    model_flags_str: Serialized proto describing model properties, see
      `model_flags.proto`.
    conversion_flags_str: Serialized proto describing conversion properties, see
      `toco/toco_flags.proto`.
    input_data_str: Input data in serialized form (e.g. a graphdef is common, or
      it can be hlo text or proto)
    debug_info_str: Serialized `GraphDebugInfo` proto describing logging
      information. (default None)
    enable_mlir_converter: Enables MLIR-based conversion. (default True)

  Returns:
    Converted model in serialized form (e.g. a TFLITE model is common).
  Raises:
    ConverterError: When conversion fails in TFLiteConverter, usually due to
      ops not being supported.
    RuntimeError: When conversion fails, an exception is raised with the error
      message embedded.
  """
    # Historically, deprecated conversion failures would trigger a crash, so we
    # attempt to run the converter out-of-process. The current MLIR conversion
    # pipeline surfaces errors instead, and can be safely run in-process.
    if enable_mlir_converter or not _deprecated_conversion_binary:
        try:
            model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
                                                       conversion_flags_str,
                                                       input_data_str,
                                                       debug_info_str,
                                                       enable_mlir_converter)
            return model_str
        except Exception as e:
            converter_error = ConverterError(str(e))
            for error_data in _metrics_wrapper.retrieve_collected_errors():
                converter_error.append_error(error_data)
            raise converter_error

    return _run_deprecated_conversion_binary(model_flags_str,
                                             conversion_flags_str,
                                             input_data_str, debug_info_str)
 def test_basic_retrieve_collected_errors_empty(self):
     errors = metrics_wrapper.retrieve_collected_errors()
     self.assertEmpty(errors)