def _create_metadata_file(self): associated_file1 = _metadata_fb.AssociatedFileT() associated_file1.name = b"file1" associated_file2 = _metadata_fb.AssociatedFileT() associated_file2.name = b"file2" self.expected_recorded_files = [ six.ensure_str(associated_file1.name), six.ensure_str(associated_file2.name) ] input_meta = _metadata_fb.TensorMetadataT() output_meta = _metadata_fb.TensorMetadataT() output_meta.associatedFiles = [associated_file2] subgraph = _metadata_fb.SubGraphMetadataT() # Create a model with two inputs and one output. subgraph.inputTensorMetadata = [input_meta, input_meta] subgraph.outputTensorMetadata = [output_meta] model_meta = _metadata_fb.ModelMetadataT() model_meta.name = "Mobilenet_quantized" model_meta.associatedFiles = [associated_file1] model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) b.Finish( model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) metadata_file = self.create_tempfile().full_path with open(metadata_file, "wb") as f: f.write(b.Output()) return metadata_file
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): # First, creates a dummy metadata different from self._metadata_file. It # needs to have the same input/output tensor numbers as self._model_file. # Populates it and the associated files into the model. input_meta = _metadata_fb.TensorMetadataT() output_meta = _metadata_fb.TensorMetadataT() subgraph = _metadata_fb.SubGraphMetadataT() # Create a model with two inputs and one output. subgraph.inputTensorMetadata = [input_meta, input_meta] subgraph.outputTensorMetadata = [output_meta] model_meta = _metadata_fb.ModelMetadataT() model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) b.Finish( model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) metadata_buf = b.Output() # Populate the metadata. populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) populator1.load_metadata_buffer(metadata_buf) populator1.load_associated_files([self._file1, self._file2]) populator1.populate() # Then, populate the metadata again. populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) populator2.load_metadata_file(self._metadata_file) populator2.populate() # Test if the metadata is populated correctly. self._assert_golden_metadata(self._model_file)
def _create_dummy_metadata(self): # Create dummy input metadata input_metadata = _metadata_fb.TensorMetadataT() input_metadata.name = _INPUT_NAME # Create dummy output metadata output_metadata = _metadata_fb.TensorMetadataT() output_metadata.name = _OUTPUT_NAME # Create dummy model_metadata model_metadata = _metadata_fb.ModelMetadataT() model_metadata.name = _MODEL_NAME return model_metadata, input_metadata, output_metadata
def create_from_metadata( cls, model_buffer: bytearray, model_metadata: Optional[_metadata_fb.ModelMetadataT] = None, input_metadata: Optional[List[ _metadata_fb.TensorMetadataT]] = None, output_metadata: Optional[List[ _metadata_fb.TensorMetadataT]] = None, associated_files: Optional[List[str]] = None): """Creates MetadataWriter based on the metadata Flatbuffers Python Objects. Args: model_buffer: valid buffer of the model file. model_metadata: general model metadata [1]. The subgraph_metadata will be refreshed with input_metadata and output_metadata. input_metadata: a list of metadata of the input tensors [2]. output_metadata: a list of metadata of the output tensors [3]. associated_files: path to the associated files to be populated. [1]: https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L640-L681 [2]: https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L590 [3]: https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L599 Returns: A MetadataWriter Object. """ # Create empty tensor metadata when input_metadata/output_metadata are None # to bypass MetadataPopulator verification. if not input_metadata: model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) num_input_tensors = model.Subgraphs(0).InputsLength() input_metadata = [_metadata_fb.TensorMetadataT() ] * num_input_tensors if not output_metadata: model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) num_output_tensors = model.Subgraphs(0).OutputsLength() output_metadata = [_metadata_fb.TensorMetadataT() ] * num_output_tensors subgraph_metadata = _metadata_fb.SubGraphMetadataT() subgraph_metadata.inputTensorMetadata = input_metadata subgraph_metadata.outputTensorMetadata = output_metadata if model_metadata is None: model_metadata = _metadata_fb.ModelMetadataT() model_metadata.subgraphMetadata = [subgraph_metadata] b = flatbuffers.Builder(0) b.Finish(model_metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) return cls(model_buffer, b.Output(), associated_files)
def create_metadata(self) -> _metadata_fb.TensorMetadataT: """Creates the input tensor metadata based on the information. Returns: A Flatbuffers Python object of the input metadata. """ tensor_metadata = _metadata_fb.TensorMetadataT() tensor_metadata.name = self.name tensor_metadata.description = self.description # Create min and max values stats = _metadata_fb.StatsT() stats.max = self.max_values stats.min = self.min_values tensor_metadata.stats = stats # Create content properties content = _metadata_fb.ContentT() if self.content_type is _metadata_fb.ContentProperties.FeatureProperties: content.contentProperties = _metadata_fb.FeaturePropertiesT() elif self.content_type is _metadata_fb.ContentProperties.ImageProperties: content.contentProperties = _metadata_fb.ImagePropertiesT() elif self.content_type is ( _metadata_fb.ContentProperties.BoundingBoxProperties): content.contentProperties = _metadata_fb.BoundingBoxPropertiesT() content.contentPropertiesType = self.content_type tensor_metadata.content = content # Create associated files if self.associated_files: tensor_metadata.associatedFiles = [ file.create_metadata() for file in self.associated_files ] return tensor_metadata
def testGetRecordedAssociatedFileListWithSubgraphProcessUnits( self, tensor_type, tokenizer_type): # Creates a metadata with the tokenizer in the subgraph process units. tokenizer, expected_files = self._create_tokenizer(tokenizer_type) # Create the subgraph with process units. subgraph = _metadata_fb.SubGraphMetadataT() if tensor_type is TensorType.INPUT: subgraph.inputProcessUnits = [tokenizer] elif tensor_type is TensorType.OUTPUT: subgraph.outputProcessUnits = [tokenizer] else: raise ValueError( "The tensor type, {0}, is unsupported.".format(tensor_type)) # Creates the input and output tensor meta to match self._model_file. dummy_tensor_meta = _metadata_fb.TensorMetadataT() subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] subgraph.outputTensorMetadata = [dummy_tensor_meta] # Create a model metadata with the subgraph metadata meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) # Creates the tempfiles. tempfiles = self._create_tempfiles(expected_files) # Creates the MetadataPopulator object. populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_metadata_buffer(meta_buffer) populator.load_associated_files(tempfiles) populator.populate() recorded_files = populator.get_recorded_associated_file_list() self.assertEqual(set(recorded_files), set(expected_files))
def testPopulatedFullPathAssociatedFileShouldSucceed(self): # Create AssociatedFileT using the full path file name. associated_file = _metadata_fb.AssociatedFileT() associated_file.name = self._file1 # Create model metadata with the associated file. subgraph = _metadata_fb.SubGraphMetadataT() subgraph.associatedFiles = [associated_file] # Creates the input and output tensor metadata to match self._model_file. dummy_tensor = _metadata_fb.TensorMetadataT() subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor] subgraph.outputTensorMetadata = [dummy_tensor] md_buffer = self._create_model_meta_with_subgraph_meta(subgraph) # Populate the metadata to a model. populator = _metadata.MetadataPopulator.with_model_file( self._model_file) populator.load_metadata_buffer(md_buffer) populator.load_associated_files([self._file1]) populator.populate() # The recorded file name in metadata should only contain file basename; file # directory should not be included. recorded_files = populator.get_recorded_associated_file_list() self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)]))
def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self): # Create a dummy metadata with no output tensor metadata, while the expected # number is 1. input_meta = _metadata_fb.TensorMetadataT() subgprah_meta = _metadata_fb.SubGraphMetadataT() subgprah_meta.inputTensorMetadata = [input_meta, input_meta] model_meta = _metadata_fb.ModelMetadataT() model_meta.subgraphMetadata = [subgprah_meta] builder = flatbuffers.Builder(0) builder.Finish( model_meta.Pack(builder), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) meta_buf = builder.Output() populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) with self.assertRaises(ValueError) as error: populator.load_metadata_buffer(meta_buf) self.assertEqual( ("The number of output tensors (1) should match the number of " "output tensor metadata (0)"), str(error.exception))