def create_for_inference(cls,
                             model_buffer: bytearray,
                             sample_rate: int,
                             channels: int,
                             label_file_paths: List[str],
                             score_calibration_md: Optional[
                                 metadata_info.ScoreCalibrationMd] = None):
        """Creates mandatory metadata for TFLite Support inference.

    The parameters required in this method are mandatory when using TFLite
    Support features, such as Task library and Codegen tool (Android Studio ML
    Binding). Other metadata fields will be set to default. If other fields need
    to be filled, use the method `create_from_metadata_info` to edit them.

    Args:
      model_buffer: valid buffer of the model file.
      sample_rate: the sample rate in Hz when the audio was captured.
      channels: the channel count of the audio.
      label_file_paths: paths to the label files [1] in the classification
        tensor. Pass in an empty list if the model does not have any label file.
      score_calibration_md: information of the score calibration operation [2]
        in the classification tensor. Optional if the model does not use score
        calibration.
      [1]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
      [2]:
        https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434

    Returns:
      A MetadataWriter object.
    """
        # To make Task Library working properly, sample_rate, channels need to be
        # positive.
        if sample_rate <= 0:
            raise ValueError(
                "sample_rate should be positive, but got {}.".format(
                    sample_rate))

        if channels <= 0:
            raise ValueError(
                "channels should be positive, but got {}.".format(channels))

        input_md = metadata_info.InputAudioTensorMd(_INPUT_NAME,
                                                    _INPUT_DESCRIPTION,
                                                    sample_rate, channels)

        output_md = metadata_info.ClassificationTensorMd(
            name=_OUTPUT_NAME,
            description=_OUTPUT_DESCRIPTION,
            label_files=[
                metadata_info.LabelFileMd(file_path=file_path)
                for file_path in label_file_paths
            ],
            tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0],
            score_calibration_md=score_calibration_md)

        return cls.create_from_metadata_info(model_buffer,
                                             input_md=input_md,
                                             output_md=output_md)
    def test_create_metadata_by_default_should_succeed(self):
        audio_tensor_md = metadata_info.InputAudioTensorMd()

        metadata_json = _metadata.convert_to_json(
            _create_dummy_model_metadata_with_tensor(
                audio_tensor_md.create_metadata()))
        expected_json = test_utils.load_file(
            self._EXPECTED_TENSOR_DEFAULT_JSON, "r")
        self.assertEqual(metadata_json, expected_json)
    def test_create_metadata_should_succeed(self):
        text_tensor_md = metadata_info.InputAudioTensorMd(
            self._NAME, self._DESCRIPTION, self._SAMPLE_RATE, self._CHANNELS)

        metadata_json = _metadata.convert_to_json(
            _create_dummy_model_metadata_with_tensor(
                text_tensor_md.create_metadata()))
        expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
        self.assertEqual(metadata_json, expected_json)
    def test_create_metadata_fail_with_negative_channels(self):
        negative_channels = -1
        with self.assertRaises(ValueError) as error:
            tensor_md = metadata_info.InputAudioTensorMd(
                channels=negative_channels)
            tensor_md.create_metadata()

        self.assertEqual(
            "channels should be non-negative, but got {}.".format(
                negative_channels), str(error.exception))
    def test_create_metadata_fail_with_negative_sample_rate(self):
        negative_sample_rate = -1
        with self.assertRaises(ValueError) as error:
            tensor_md = metadata_info.InputAudioTensorMd(
                sample_rate=negative_sample_rate)
            tensor_md.create_metadata()

        self.assertEqual(
            f"sample_rate should be non-negative, but got {negative_sample_rate}.",
            str(error.exception))
    def create_from_metadata_info_for_multihead(
        cls,
        model_buffer: bytearray,
        general_md: Optional[metadata_info.GeneralMd] = None,
        input_md: Optional[metadata_info.InputAudioTensorMd] = None,
        output_md_list: Optional[List[
            metadata_info.ClassificationTensorMd]] = None):
        """Creates a MetadataWriter instance for multihead models.

    Args:
      model_buffer: valid buffer of the model file.
      general_md: general infromation about the model. If not specified, default
        general metadata will be generated.
      input_md: input audio tensor informaton. If not specified, default input
        metadata will be generated.
      output_md_list: information of each output tensor head. If not specified,
        default metadata will be generated for each output tensor. If
        `tensor_name` in each `ClassificationTensorMd` instance is not
        specified, elements in `output_md_list` need to have one-to-one mapping
        with the output tensors [1] in the TFLite model.
      [1]:
        https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L605-L612

    Returns:
      A MetadataWriter object.
    """

        if general_md is None:
            general_md = metadata_info.GeneralMd(
                name=_MODEL_NAME, description=_MODEL_DESCRIPTION)

        if input_md is None:
            input_md = metadata_info.InputAudioTensorMd(
                name=_INPUT_NAME, description=_INPUT_DESCRIPTION)

        associated_files = []
        for md in output_md_list or []:
            associated_files.extend(
                [file.file_path for file in md.associated_files or []])

        return super().create_from_metadata_info(
            model_buffer=model_buffer,
            general_md=general_md,
            input_md=[input_md],
            output_md=output_md_list,
            associated_files=associated_files)
  def test_create_from_metadata_info_succeeds_for_multihead(self):
    calibration_file1 = test_utils.create_calibration_file(
        self.get_temp_dir(), "score_cali_1.txt")
    calibration_file2 = test_utils.create_calibration_file(
        self.get_temp_dir(), "score_cali_2.txt")

    general_md = metadata_info.GeneralMd(name="AudioClassifier")
    input_md = metadata_info.InputAudioTensorMd(
        name="audio_clip", sample_rate=_SAMPLE_RATE, channels=_CHANNELS)
    # The output tensors in the model are: Identity, Identity_1
    # Create metadata in a different order to test if MetadataWriter can correct
    # it.
    output_head_md_1 = metadata_info.ClassificationTensorMd(
        name="head1",
        label_files=[
            metadata_info.LabelFileMd("labels_en_1.txt"),
            metadata_info.LabelFileMd("labels_cn_1.txt")
        ],
        score_calibration_md=metadata_info.ScoreCalibrationMd(
            _metadata_fb.ScoreTransformationType.LOG,
            _DEFAULT_SCORE_CALIBRATION_VALUE, calibration_file1),
        tensor_name="Identity_1")
    output_head_md_2 = metadata_info.ClassificationTensorMd(
        name="head2",
        label_files=[
            metadata_info.LabelFileMd("labels_en_2.txt"),
            metadata_info.LabelFileMd("labels_cn_2.txt")
        ],
        score_calibration_md=metadata_info.ScoreCalibrationMd(
            _metadata_fb.ScoreTransformationType.LOG,
            _DEFAULT_SCORE_CALIBRATION_VALUE, calibration_file2),
        tensor_name="Identity")

    writer = (
        audio_classifier.MetadataWriter.create_from_metadata_info_for_multihead(
            test_utils.load_file(_MULTIHEAD_MODEL), general_md, input_md,
            [output_head_md_1, output_head_md_2]))

    metadata_json = writer.get_metadata_json()
    expected_json = test_utils.load_file(_JSON_MULTIHEAD, "r")
    self.assertEqual(metadata_json, expected_json)
Beispiel #8
0
    def add_audio_input(self,
                        sample_rate: int,
                        channels: int,
                        name: str = _INPUT_AUDIO_NAME,
                        description: str = _INPUT_AUDIO_DESCRIPTION):
        """Marks the next input tensor as an audio input."""
        # To make Task Library working properly, sample_rate, channels need to be
        # positive.
        if sample_rate <= 0:
            raise ValueError(
                'sample_rate should be positive, but got {}.'.format(
                    sample_rate))
        if channels <= 0:
            raise ValueError(
                'channels should be positive, but got {}.'.format(channels))

        input_md = metadata_info.InputAudioTensorMd(name=name,
                                                    description=description,
                                                    sample_rate=sample_rate,
                                                    channels=channels)
        self._input_mds.append(input_md)
        return self