예제 #1
0
 def __init__(self, tflite_filepath, **kwargs):
     self._model = writer_utils.load_file(tflite_filepath)
     self._general_md = md_info.GeneralMd(**kwargs)
     self._inputs = []
     self._outputs = []
예제 #2
0
    def _export_metadata(self, tflite_filepath, index_to_label,
                         export_metadata_json_file):
        """Export TFLite metadata."""
        with tempfile.TemporaryDirectory() as temp_dir:
            # Prepare metadata
            with open(tflite_filepath, 'rb') as f:
                model_buffer = f.read()

            general_md = md_info.GeneralMd(name=self._MODEL_NAME,
                                           description=self._MODEL_DESCRIPTION,
                                           version=self._MODEL_VERSION,
                                           author=self._MODEL_AUTHOR,
                                           licenses=self._MODEL_LICENSES)
            input_md = md_info.InputAudioTensorMd(self._INPUT_NAME,
                                                  self._INPUT_DESCRIPTION,
                                                  self._SAMPLE_RATE,
                                                  self._CHANNELS)

            # Save label files.
            custom_label_filepath = os.path.join(temp_dir,
                                                 self._CUSTOM_LABEL_FILE)
            self._export_labels(custom_label_filepath, index_to_label)

            custom_output_md = md_info.ClassificationTensorMd(
                name=self._CUSTOM_OUTPUT_NAME,
                description=self._CUSTOM_OUTPUT_DESCRIPTION,
                label_files=[
                    md_info.LabelFileMd(file_path=os.path.join(
                        temp_dir, self._CUSTOM_LABEL_FILE))
                ],
                tensor_type=writer_utils.get_output_tensor_types(
                    model_buffer)[-1],
                score_calibration_md=None)

            if self._keep_yamnet_and_custom_heads:
                yamnet_label_filepath = os.path.join(temp_dir,
                                                     self._YAMNET_LABEL_FILE)
                self._export_labels(yamnet_label_filepath,
                                    self._yamnet_labels())

                yamnet_output_md = md_info.ClassificationTensorMd(
                    name=self._YAMNET_OUTPUT_NAME,
                    description=self._YAMNET_OUTPUT_DESCRIPTION,
                    label_files=[
                        md_info.LabelFileMd(file_path=os.path.join(
                            temp_dir, self._YAMNET_LABEL_FILE))
                    ],
                    tensor_type=writer_utils.get_output_tensor_types(
                        model_buffer)[0],
                    score_calibration_md=None)
                output_md = [yamnet_output_md, custom_output_md]
            else:
                output_md = [custom_output_md]

            # Populate metadata
            writer = md_writer.MetadataWriter.create_from_metadata_info_for_multihead(
                model_buffer=model_buffer,
                general_md=general_md,
                input_md=input_md,
                output_md_list=output_md)

            output_model = writer.populate()

            with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
                f.write(output_model)

            if export_metadata_json_file:
                metadata_json = writer.get_metadata_json()
                export_json_file = os.path.splitext(
                    tflite_filepath)[0] + '.json'
                with open(export_json_file, 'w') as f:
                    f.write(metadata_json)