def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with fileio.open(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname') tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(saved_model_path=mock.ANY, enable_quantization=False) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with tf.io.gfile.GFile(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) tf.io.gfile.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with tf.io.gfile.GFile(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) tf.io.gfile.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with tf.io.gfile.GFile(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', enable_experimental_new_converter=True, enable_quantization=True) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, enable_experimental_new_converter=True, enable_quantization=True) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(tf.io.gfile.exists(expected_model)) with tf.io.gfile.GFile(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with tf.io.gfile.GFile(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with tf.io.gfile.GFile(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def setUp(self): super().setUp() src_model_path = '/path/to/src/model' dst_model_path = '/path/to/dst/model' self._source_model = rewriter.ModelDescription( rewriter.ModelType.SAVED_MODEL, src_model_path) self._dest_model = rewriter.ModelDescription( rewriter.ModelType.SAVED_MODEL, dst_model_path)
def testInvokeTFJSRewriter(self, converter): src_model_path = '/path/to/src/model' dst_model_path = '/path/to/dst/model' src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFJS_MODEL, dst_model_path) tfrw = tfjs_rewriter.TFJSRewriter(name='myrw') tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(src_model_path, dst_model_path)
def create_temp_model_template(self): src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with fileio.open(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) return src_model, dst_model, src_model_path, dst_model_path
def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with tf.io.gfile.GFile(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True) tfrw.perform_rewrite(src_model, dst_model) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(tf.io.gfile.exists(expected_model)) with tf.io.gfile.GFile(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')
def _invoke_rewriter(src: str, dst: str, rewriter_inst: rewriter.BaseRewriter, src_model_type: rewriter.ModelType, dst_model_type: rewriter.ModelType): """Converts the provided model by invoking the specified rewriters. Args: src: Path to the source model. dst: Path where the destination model is to be written. rewriter_inst: instance of the rewriter to invoke. src_model_type: the `rewriter.ModelType` of the source model. dst_model_type: the `rewriter.ModelType` of the destination model. Raises: ValueError: if the source path is the same as the destination path. """ if src == dst: raise ValueError('Source path and destination path cannot match.') original_model = rewriter.ModelDescription(src_model_type, src) rewritten_model = rewriter.ModelDescription(dst_model_type, dst) rewriter_inst.perform_rewrite(original_model, rewritten_model)