Esempio n. 1
0
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname')
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          enable_quantization=False)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Esempio n. 2
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with tf.io.gfile.GFile(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        tf.io.gfile.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with tf.io.gfile.GFile(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        tf.io.gfile.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with tf.io.gfile.GFile(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            enable_experimental_new_converter=True,
            enable_quantization=True)
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            enable_experimental_new_converter=True,
            enable_quantization=True)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(tf.io.gfile.exists(expected_model))
        with tf.io.gfile.GFile(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with tf.io.gfile.GFile(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with tf.io.gfile.GFile(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
Esempio n. 3
0
    def setUp(self):
        super().setUp()
        src_model_path = '/path/to/src/model'
        dst_model_path = '/path/to/dst/model'
        self._source_model = rewriter.ModelDescription(
            rewriter.ModelType.SAVED_MODEL, src_model_path)

        self._dest_model = rewriter.ModelDescription(
            rewriter.ModelType.SAVED_MODEL, dst_model_path)
Esempio n. 4
0
    def testInvokeTFJSRewriter(self, converter):
        src_model_path = '/path/to/src/model'
        dst_model_path = '/path/to/dst/model'

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFJS_MODEL,
                                              dst_model_path)

        tfrw = tfjs_rewriter.TFJSRewriter(name='myrw')
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(src_model_path, dst_model_path)
Esempio n. 5
0
    def create_temp_model_template(self):
        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        return src_model, dst_model, src_model_path, dst_model_path
Esempio n. 6
0
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with tf.io.gfile.GFile(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True)
        tfrw.perform_rewrite(src_model, dst_model)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(tf.io.gfile.exists(expected_model))
        with tf.io.gfile.GFile(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Esempio n. 7
0
def _invoke_rewriter(src: str, dst: str, rewriter_inst: rewriter.BaseRewriter,
                     src_model_type: rewriter.ModelType,
                     dst_model_type: rewriter.ModelType):
  """Converts the provided model by invoking the specified rewriters.

  Args:
    src: Path to the source model.
    dst: Path where the destination model is to be written.
    rewriter_inst: instance of the rewriter to invoke.
    src_model_type: the `rewriter.ModelType` of the source model.
    dst_model_type: the `rewriter.ModelType` of the destination model.

  Raises:
    ValueError: if the source path is the same as the destination path.
  """

  if src == dst:
    raise ValueError('Source path and destination path cannot match.')

  original_model = rewriter.ModelDescription(src_model_type, src)
  rewritten_model = rewriter.ModelDescription(dst_model_type, dst)

  rewriter_inst.perform_rewrite(original_model, rewritten_model)