예제 #1
0
  def _export_tflite(self,
                     tflite_filepath,
                     label_filepath,
                     quantized=False,
                     quantization_steps=None,
                     representative_data=None,
                     inference_input_type=tf.float32,
                     inference_output_type=tf.float32,
                     with_metadata=False,
                     export_metadata_json_file=False):
    """Converts the retrained model to tflite format and saves it.


    Args:
      tflite_filepath: File path to save tflite model.
      label_filepath: File path to save labels.
      quantized: boolean, if True, save quantized model.
      quantization_steps: Number of post-training quantization calibration steps
        to run. Used only if `quantized` is True.
      representative_data: Representative data used for post-training
        quantization. Used only if `quantized` is True.
      inference_input_type: Target data type of real-number input arrays. Allows
        for a different type for input arrays. Defaults to tf.float32. Must be
        be `{tf.float32, tf.uint8, tf.int8}`
      inference_output_type: Target data type of real-number output arrays.
        Allows for a different type for output arrays. Defaults to tf.float32.
         Must be `{tf.float32, tf.uint8, tf.int8}`
      with_metadata: Whether the output tflite model contains metadata.
      export_metadata_json_file: Whether to export metadata in json file. If
        True, export the metadata in the same directory as tflite model.Used
        only if `with_metadata` is True.
    """
    super(ImageClassifier,
          self)._export_tflite(tflite_filepath, quantized, quantization_steps,
                               representative_data, inference_input_type,
                               inference_output_type)
    if with_metadata:
      if not metadata.TFLITE_SUPPORT_TOOLS_INSTALLED:
        tf.compat.v1.logging.warning('Needs to install tflite-support package.')
        return

      if label_filepath is None:
        tf.compat.v1.logging.warning(
            'Label filepath is needed when exporting TFLite with metadata.')
        return

      model_info = metadata.get_model_info(self.model_spec, quantized=quantized)
      populator = metadata.MetadataPopulatorForImageClassifier(
          tflite_filepath, model_info, label_filepath)
      populator.populate()

      if export_metadata_json_file:
        metadata.export_metadata_json_file(tflite_filepath)
    def export(self,
               tflite_filename,
               label_filename,
               quantized=False,
               quantization_steps=None,
               representative_data=None,
               with_metadata=False,
               export_metadata_json_file=False):
        """Converts the retrained model based on `model_export_format`.

    Args:
      tflite_filename: File name to save tflite model.
      label_filename: File name to save labels.
      quantized: boolean, if True, save quantized model.
      quantization_steps: Number of post-training quantization calibration steps
        to run. Used only if `quantized` is True.
      representative_data: Representative data used for post-training
        quantization. Used only if `quantized` is True.
      with_metadata: Whether the output tflite model contains metadata.
      export_metadata_json_file: Whether to export metadata in json file. If
        True, export the metadata in the same directory as tflite model.Used
        only if `with_metadata` is True.
    """
        if self.model_export_format != mef.ModelExportFormat.TFLITE:
            raise ValueError(
                'Model Export Format %s is not supported currently.' %
                self.model_export_format)
        self._export_tflite(tflite_filename, label_filename, quantized,
                            quantization_steps, representative_data)
        if with_metadata:
            if not metadata.TFLITE_SUPPORT_TOOLS_INSTALLED:
                tf.compat.v1.logging.warning(
                    'Needs to install tflite-support package.')
                return

            model_info = metadata.get_model_info(self.model_spec,
                                                 quantized=quantized)
            populator = metadata.MetadataPopulatorForImageClassifier(
                tflite_filename, model_info, label_filename)
            populator.populate()

            if export_metadata_json_file:
                metadata.export_metadata_json_file(tflite_filename)