Esempio n. 1
0
    def create_for_inference(cls,
                             model_buffer: bytearray,
                             input_norm_mean: List[float],
                             input_norm_std: List[float],
                             label_file_paths: List[str],
                             score_calibration_md: Optional[
                                 metadata_info.ScoreCalibrationMd] = None):
        """Creates mandatory metadata for TFLite Support inference.

    The parameters required in this method are mandatory when using TFLite
    Support features, such as Task library and Codegen tool (Android Studio ML
    Binding). Other metadata fields will be set to default. If other fields need
    to be filled, use the method `create_from_metadata_info` to edit them.

    Args:
      model_buffer: valid buffer of the model file.
      input_norm_mean: the mean value used in the input tensor normalization
        [1].
      input_norm_std: the std value used in the input tensor normalizarion [1].
      label_file_paths: paths to the label files [2] in the category tensor.
        Pass in an empty list, If the model does not have any label file.
      score_calibration_md: information of the score calibration operation [3]
        in the classification tensor. Optional if the model does not use score
        calibration.
      [1]:
        https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
      [2]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108
      [3]:
        https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434

    Returns:
      A MetadataWriter object.
    """
        input_md = metadata_info.InputImageTensorMd(
            name=_INPUT_NAME,
            description=_INPUT_DESCRIPTION,
            norm_mean=input_norm_mean,
            norm_std=input_norm_std,
            color_space_type=_metadata_fb.ColorSpaceType.RGB,
            tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

        output_category_md = metadata_info.CategoryTensorMd(
            name=_OUTPUT_CATRGORY_NAME,
            description=_OUTPUT_CATEGORY_DESCRIPTION,
            label_files=[
                metadata_info.LabelFileMd(file_path=file_path)
                for file_path in label_file_paths
            ])

        output_score_md = metadata_info.ClassificationTensorMd(
            name=_OUTPUT_SCORE_NAME,
            description=_OUTPUT_SCORE_DESCRIPTION,
            score_calibration_md=score_calibration_md)

        return cls.create_from_metadata_info(
            model_buffer,
            input_md=input_md,
            output_category_md=output_category_md,
            output_score_md=output_score_md)
Esempio n. 2
0
 def test_get_input_tensor_types(self):
     tensor_types = writer_utils.get_input_tensor_types(
         model_buffer=test_utils.load_file(_MODEL_NAME))
     self.assertEqual(tensor_types, [_EXPECTED_INPUT_TYPES])
 def _input_tensor_type(self, idx):
     return writer_utils.get_input_tensor_types(self._model_buffer)[idx]