def clear_output_dirs(output_dict: Dict[str, List[types.Artifact]]) -> None: """Clear dirs of output artifacts' URI.""" for _, artifact_list in output_dict.items(): for artifact in artifact_list: if fileio.isdir(artifact.uri) and fileio.listdir(artifact.uri): fileio.rmtree(artifact.uri) fileio.mkdir(artifact.uri)
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with fileio.open(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) fileio.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with fileio.open(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) fileio.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with fileio.open(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname', enable_quantization=True) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(saved_model_path=mock.ANY, enable_quantization=True) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with fileio.open(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with fileio.open(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, src_model_path, dst_model_path = ( self.create_temp_model_template()) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) fileio.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with fileio.open(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) fileio.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with fileio.open(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT]) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_supported_types=[], representative_dataset=None, signature_key=None) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with fileio.open(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with fileio.open(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def _rewrite(self, original_model: rewriter.ModelDescription, rewritten_model: rewriter.ModelDescription): """Rewrites the provided model. Args: original_model: A `ModelDescription` specifying the original model to be rewritten. rewritten_model: A `ModelDescription` specifying the format and location of the rewritten model. Raises: ValueError: If the model could not be sucessfully rewritten. """ if rewritten_model.model_type not in [ rewriter.ModelType.TFLITE_MODEL, rewriter.ModelType.ANY_MODEL ]: raise ValueError('TFLiteConverter can only convert to the TFLite format.') # TODO(dzats): We create a temporary directory with a SavedModel that does # not contain an assets or assets.extra directory. Remove this when the # TFLite converter can convert models having these directories. tmp_model_dir = os.path.join( _ensure_str(rewritten_model.path), 'tmp-rewrite-' + str(int(time.time()))) if fileio.exists(tmp_model_dir): raise ValueError('TFLiteConverter is unable to create a unique path ' 'for the temp rewriting directory.') fileio.makedirs(tmp_model_dir) _create_tflite_compatible_saved_model( _ensure_str(original_model.path), tmp_model_dir) converter = self._create_tflite_converter( saved_model_path=tmp_model_dir, quantization_optimizations=self._quantization_optimizations, quantization_supported_types=self._quantization_supported_types, representative_dataset=self._representative_dataset, signature_key=self._signature_key, **self._kwargs) tflite_model = converter.convert() output_path = os.path.join( _ensure_str(rewritten_model.path), self._filename) with fileio.open(_ensure_str(output_path), 'wb') as f: f.write(_ensure_bytes(tflite_model)) fileio.rmtree(tmp_model_dir) copy_pairs = [] if self._copy_assets: src = os.path.join( _ensure_str(original_model.path), tf.saved_model.ASSETS_DIRECTORY) dst = os.path.join( _ensure_str(rewritten_model.path), tf.saved_model.ASSETS_DIRECTORY) if fileio.isdir(src): fileio.mkdir(dst) copy_pairs.append((src, dst)) if self._copy_assets_extra: src = os.path.join( _ensure_str(original_model.path), EXTRA_ASSETS_DIRECTORY) dst = os.path.join( _ensure_str(rewritten_model.path), EXTRA_ASSETS_DIRECTORY) if fileio.isdir(src): fileio.mkdir(dst) copy_pairs.append((src, dst)) for src, dst in copy_pairs: io_utils.copy_dir(src, dst)