Example #1
0
  def testInvokeTFLiteRewriterQuantizationFullIntegerSucceeds(self, converter):
    m = self.ConverterMock()
    converter.return_value = m

    src_model, dst_model, _, dst_model_path = self.create_temp_model_template()

    def representative_dataset():
      for i in range(2):
        yield [np.array(i)]

    tfrw = tflite_rewriter.TFLiteRewriter(
        name='myrw',
        filename='fname',
        quantization_optimizations=[tf.lite.Optimize.DEFAULT],
        quantization_enable_full_integer=True,
        representative_dataset=representative_dataset)
    tfrw.perform_rewrite(src_model, dst_model)

    converter.assert_called_once_with(
        saved_model_path=mock.ANY,
        quantization_optimizations=[tf.lite.Optimize.DEFAULT],
        quantization_supported_types=[],
        representative_dataset=representative_dataset,
        signature_key=None)
    expected_model = os.path.join(dst_model_path, 'fname')
    self.assertTrue(fileio.exists(expected_model))
    with fileio.open(expected_model, 'rb') as f:
      self.assertEqual(f.read(), b'model')
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname')
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          enable_quantization=False)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Example #3
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with tf.io.gfile.GFile(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        tf.io.gfile.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with tf.io.gfile.GFile(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        tf.io.gfile.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with tf.io.gfile.GFile(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            enable_experimental_new_converter=True,
            enable_quantization=True)
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            enable_experimental_new_converter=True,
            enable_quantization=True)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(tf.io.gfile.exists(expected_model))
        with tf.io.gfile.GFile(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with tf.io.gfile.GFile(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with tf.io.gfile.GFile(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
Example #4
0
    def testInvokeTFLiteRewriterQuantizationFullIntegerFails(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, _, _ = self.create_temp_model_template()

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_enable_full_integer=True)
        self.assertRaises(NotImplementedError,
                          tfrw.perform_rewrite(src_model, dst_model))
Example #5
0
    def testInvokeTFLiteRewriterWithSignatureKey(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, _, _ = self.create_temp_model_template()

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw',
                                              filename='fname',
                                              signature_key='tflite')
        tfrw.perform_rewrite(src_model, dst_model)

        _, kwargs = converter.call_args
        self.assertListEqual(kwargs['signature_keys'], ['tflite'])
Example #6
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, src_model_path, dst_model_path = (
            self.create_temp_model_template())

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        fileio.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with fileio.open(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        fileio.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with fileio.open(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            quantization_optimizations=[tf.lite.Optimize.DEFAULT])
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_supported_types=[],
            representative_dataset=None,
            signature_key=None)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with fileio.open(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with fileio.open(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
Example #7
0
  def testInvokeConverterWithKwargs(self, converter):
    converter.return_value = self.ConverterMock()

    src_model, dst_model, _, _ = self.create_temp_model_template()

    tfrw = tflite_rewriter.TFLiteRewriter(
        name='myrw', filename='fname', output_arrays=['head'])
    tfrw.perform_rewrite(src_model, dst_model)

    converter.assert_called_once_with(
        saved_model_path=mock.ANY,
        quantization_optimizations=[],
        quantization_supported_types=[],
        representative_dataset=None,
        signature_key=None,
        output_arrays=['head'])
Example #8
0
    def testInvokeTFLiteRewriterQuantizationFullIntegerFailsNoData(
            self, converter, model):
        class ModelMock(object):
            pass

        m = ModelMock()
        model.return_value = m
        n = self.ConverterMock()
        converter.return_value = n

        with self.assertRaises(ValueError):
            _ = tflite_rewriter.TFLiteRewriter(
                name='myrw',
                filename='fname',
                quantization_optimizations=[tf.lite.Optimize.DEFAULT],
                quantization_enable_full_integer=True)
Example #9
0
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, _, dst_model_path = self.create_temp_model_template(
        )

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname')
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          quantization_optimizations=[],
                                          quantization_supported_types=[],
                                          input_data=None)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Example #10
0
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with tf.io.gfile.GFile(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True)
        tfrw.perform_rewrite(src_model, dst_model)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(tf.io.gfile.exists(expected_model))
        with tf.io.gfile.GFile(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Example #11
0
    def testInvokeTFLiteRewriterQuantizationFloat16Succeeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, _, dst_model_path = self.create_temp_model_template(
        )

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_supported_types=[tf.float16])
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_supported_types=[tf.float16],
            representative_dataset=None,
            signature_key=None)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')