def testInvokeTFLiteRewriterQuantizationFullIntegerSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, _, dst_model_path = self.create_temp_model_template() def representative_dataset(): for i in range(2): yield [np.array(i)] tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_enable_full_integer=True, representative_dataset=representative_dataset) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_supported_types=[], representative_dataset=representative_dataset, signature_key=None) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(f.read(), b'model')
def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with fileio.open(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname') tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(saved_model_path=mock.ANY, enable_quantization=False) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with tf.io.gfile.GFile(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) tf.io.gfile.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with tf.io.gfile.GFile(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) tf.io.gfile.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with tf.io.gfile.GFile(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', enable_experimental_new_converter=True, enable_quantization=True) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, enable_experimental_new_converter=True, enable_quantization=True) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(tf.io.gfile.exists(expected_model)) with tf.io.gfile.GFile(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with tf.io.gfile.GFile(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with tf.io.gfile.GFile(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def testInvokeTFLiteRewriterQuantizationFullIntegerFails(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, _, _ = self.create_temp_model_template() tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_enable_full_integer=True) self.assertRaises(NotImplementedError, tfrw.perform_rewrite(src_model, dst_model))
def testInvokeTFLiteRewriterWithSignatureKey(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, _, _ = self.create_temp_model_template() tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname', signature_key='tflite') tfrw.perform_rewrite(src_model, dst_model) _, kwargs = converter.call_args self.assertListEqual(kwargs['signature_keys'], ['tflite'])
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, src_model_path, dst_model_path = ( self.create_temp_model_template()) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) fileio.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with fileio.open(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) fileio.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with fileio.open(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT]) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_supported_types=[], representative_dataset=None, signature_key=None) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with fileio.open(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with fileio.open(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def testInvokeConverterWithKwargs(self, converter): converter.return_value = self.ConverterMock() src_model, dst_model, _, _ = self.create_temp_model_template() tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', output_arrays=['head']) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, quantization_optimizations=[], quantization_supported_types=[], representative_dataset=None, signature_key=None, output_arrays=['head'])
def testInvokeTFLiteRewriterQuantizationFullIntegerFailsNoData( self, converter, model): class ModelMock(object): pass m = ModelMock() model.return_value = m n = self.ConverterMock() converter.return_value = n with self.assertRaises(ValueError): _ = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_enable_full_integer=True)
def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, _, dst_model_path = self.create_temp_model_template( ) tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname') tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(saved_model_path=mock.ANY, quantization_optimizations=[], quantization_supported_types=[], input_data=None) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')
def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with tf.io.gfile.GFile(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True) tfrw.perform_rewrite(src_model, dst_model) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(tf.io.gfile.exists(expected_model)) with tf.io.gfile.GFile(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')
def testInvokeTFLiteRewriterQuantizationFloat16Succeeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model, dst_model, _, dst_model_path = self.create_temp_model_template( ) tfrw = tflite_rewriter.TFLiteRewriter( name='myrw', filename='fname', quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_supported_types=[tf.float16]) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with( saved_model_path=mock.ANY, quantization_optimizations=[tf.lite.Optimize.DEFAULT], quantization_supported_types=[tf.float16], representative_dataset=None, signature_key=None) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model')