Exemplo n.º 1
0
    def _export_tflite(self,
                       tflite_filepath,
                       quantization_config=None,
                       with_metadata=True,
                       export_metadata_json_file=False):
        """Converts the retrained model to tflite format and saves it.

    Args:
      tflite_filepath: File path to save tflite model.
      quantization_config: Configuration for post-training quantization.
      with_metadata: Whether the output tflite model contains metadata.
      export_metadata_json_file: Whether to export metadata in json file. If
        True, export the metadata in the same directory as tflite model.Used
        only if `with_metadata` is True.
    """
        self.model_spec.export_tflite(tflite_filepath, quantization_config)

        if with_metadata:
            with tempfile.TemporaryDirectory() as temp_dir:
                tf.compat.v1.logging.info(
                    'Label file is inside the TFLite model with metadata.')
                label_filepath = os.path.join(temp_dir, 'labelmap.txt')
                self._export_labels(label_filepath)
                model_info = _get_model_info(self.model_spec,
                                             quantization_config)
                export_dir = os.path.dirname(tflite_filepath)
                populator = metadata_writer.MetadataPopulatorForObjectDetector(
                    tflite_filepath, export_dir, model_info, label_filepath)
                populator.populate(export_metadata_json_file)
Exemplo n.º 2
0
    def _export_tflite(
            self,
            tflite_filepath: str,
            quantization_type: QuantizationType = QuantizationType.INT8,
            representative_data: Optional[
                object_detector_dataloader.DataLoader] = None,
            quantization_config: Optional[configs.QuantizationConfig] = None,
            with_metadata: bool = True,
            export_metadata_json_file: bool = False) -> None:
        """Converts the retrained model to tflite format and saves it.

    Args:
      tflite_filepath: File path to save tflite model.
      quantization_type: Enum, type of post-training quantization. Accepted
        values are `INT8`, `FP16`, `FP32`, `DYNAMIC`. `FP16` means float16
        quantization with 2x smaller, optimized for GPU. `INT8` means full
        integer quantization with 4x smaller, 3x+ speedup, optimized for Edge
        TPU. 'DYNAMIC' means dynamic range quantization with	4x smaller, 2x-3x
        speedup. `FP32` mean exporting float model without quantization. Please
        refer to
        https://www.tensorflow.org/lite/performance/post_training_quantization
        for more detailed about different techniques for post-training
        quantization.
      representative_data: Representative dataset for full integer
        quantization. Used when `quantization_type=INT8`.
      quantization_config: Configuration for post-training quantization.
      with_metadata: Whether the output tflite model contains metadata.
      export_metadata_json_file: Whether to export metadata in json file. If
        True, export the metadata in the same directory as tflite model.Used
        only if `with_metadata` is True.
    """
        if quantization_type and quantization_config:
            raise ValueError(
                'At most one of the paramaters `quantization_type` and '
                '`quantization_config` can be set.')
        if quantization_type == QuantizationType.INT8 and \
           representative_data is None:
            raise ValueError('`representative_data` must be set when '
                             '`quantization_type=QuantizationType.INT8.')

        ds, _, _ = self._get_dataset_and_steps(representative_data,
                                               batch_size=1,
                                               is_training=False)

        self.model_spec.export_tflite(tflite_filepath, quantization_type, ds,
                                      quantization_config)

        if with_metadata:
            with tempfile.TemporaryDirectory() as temp_dir:
                tf.compat.v1.logging.info(
                    'Label file is inside the TFLite model with metadata.')
                label_filepath = os.path.join(temp_dir, 'labelmap.txt')
                self._export_labels(label_filepath)
                model_info = _get_model_info(self.model_spec,
                                             quantization_type,
                                             quantization_config)
                export_dir = os.path.dirname(tflite_filepath)
                populator = metadata_writer.MetadataPopulatorForObjectDetector(
                    tflite_filepath, export_dir, model_info, label_filepath)
                populator.populate(export_metadata_json_file)
Exemplo n.º 3
0
    def _export_tflite(
            self,
            tflite_filepath: str,
            quantization_config: configs.QuantizationConfigType = 'default',
            with_metadata: bool = True,
            export_metadata_json_file: bool = False) -> None:
        """Converts the retrained model to tflite format and saves it.

    Args:
      tflite_filepath: File path to save tflite model.
      quantization_config: Configuration for post-training quantization. If
        'default', sets the `quantization_config` by default according to
        `self.model_spec`. If None, exports the float tflite model without
        quantization.
      with_metadata: Whether the output tflite model contains metadata.
      export_metadata_json_file: Whether to export metadata in json file. If
        True, export the metadata in the same directory as tflite model.Used
        only if `with_metadata` is True.
    """
        if quantization_config == 'default':
            quantization_config = self.model_spec.get_default_quantization_config(
                self.representative_data)

        self.model_spec.export_tflite(self.model, tflite_filepath,
                                      quantization_config)

        if with_metadata:
            with tempfile.TemporaryDirectory() as temp_dir:
                tf.compat.v1.logging.info(
                    'Label file is inside the TFLite model with metadata.')
                label_filepath = os.path.join(temp_dir, 'labelmap.txt')
                self._export_labels(label_filepath)
                model_info = _get_model_info(self.model_spec,
                                             quantization_config)
                export_dir = os.path.dirname(tflite_filepath)
                populator = metadata_writer.MetadataPopulatorForObjectDetector(
                    tflite_filepath, export_dir, model_info, label_filepath)
                populator.populate(export_metadata_json_file)