def __init__(self, tflite_filepath, **kwargs): self._model = writer_utils.load_file(tflite_filepath) self._general_md = md_info.GeneralMd(**kwargs) self._inputs = [] self._outputs = []
def _export_metadata(self, tflite_filepath, index_to_label, export_metadata_json_file): """Export TFLite metadata.""" with tempfile.TemporaryDirectory() as temp_dir: # Prepare metadata with open(tflite_filepath, 'rb') as f: model_buffer = f.read() general_md = md_info.GeneralMd(name=self._MODEL_NAME, description=self._MODEL_DESCRIPTION, version=self._MODEL_VERSION, author=self._MODEL_AUTHOR, licenses=self._MODEL_LICENSES) input_md = md_info.InputAudioTensorMd(self._INPUT_NAME, self._INPUT_DESCRIPTION, self._SAMPLE_RATE, self._CHANNELS) # Save label files. custom_label_filepath = os.path.join(temp_dir, self._CUSTOM_LABEL_FILE) self._export_labels(custom_label_filepath, index_to_label) custom_output_md = md_info.ClassificationTensorMd( name=self._CUSTOM_OUTPUT_NAME, description=self._CUSTOM_OUTPUT_DESCRIPTION, label_files=[ md_info.LabelFileMd(file_path=os.path.join( temp_dir, self._CUSTOM_LABEL_FILE)) ], tensor_type=writer_utils.get_output_tensor_types( model_buffer)[-1], score_calibration_md=None) if self._keep_yamnet_and_custom_heads: yamnet_label_filepath = os.path.join(temp_dir, self._YAMNET_LABEL_FILE) self._export_labels(yamnet_label_filepath, self._yamnet_labels()) yamnet_output_md = md_info.ClassificationTensorMd( name=self._YAMNET_OUTPUT_NAME, description=self._YAMNET_OUTPUT_DESCRIPTION, label_files=[ md_info.LabelFileMd(file_path=os.path.join( temp_dir, self._YAMNET_LABEL_FILE)) ], tensor_type=writer_utils.get_output_tensor_types( model_buffer)[0], score_calibration_md=None) output_md = [yamnet_output_md, custom_output_md] else: output_md = [custom_output_md] # Populate metadata writer = md_writer.MetadataWriter.create_from_metadata_info_for_multihead( model_buffer=model_buffer, general_md=general_md, input_md=input_md, output_md_list=output_md) output_model = writer.populate() with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: f.write(output_model) if export_metadata_json_file: metadata_json = writer.get_metadata_json() export_json_file = os.path.splitext( tflite_filepath)[0] + '.json' with open(export_json_file, 'w') as f: f.write(metadata_json)