예제 #1
0
def clear_output_dirs(output_dict: Dict[str, List[types.Artifact]]) -> None:
    """Clear dirs of output artifacts' URI."""
    for _, artifact_list in output_dict.items():
        for artifact in artifact_list:
            if fileio.isdir(artifact.uri) and fileio.listdir(artifact.uri):
                fileio.rmtree(artifact.uri)
                fileio.mkdir(artifact.uri)
예제 #2
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        fileio.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with fileio.open(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        fileio.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with fileio.open(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw',
                                              filename='fname',
                                              enable_quantization=True)
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          enable_quantization=True)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with fileio.open(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with fileio.open(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
예제 #3
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, src_model_path, dst_model_path = (
            self.create_temp_model_template())

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        fileio.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with fileio.open(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        fileio.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with fileio.open(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            quantization_optimizations=[tf.lite.Optimize.DEFAULT])
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_supported_types=[],
            representative_dataset=None,
            signature_key=None)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with fileio.open(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with fileio.open(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
예제 #4
0
  def _rewrite(self, original_model: rewriter.ModelDescription,
               rewritten_model: rewriter.ModelDescription):
    """Rewrites the provided model.

    Args:
      original_model: A `ModelDescription` specifying the original model to be
        rewritten.
      rewritten_model: A `ModelDescription` specifying the format and location
        of the rewritten model.

    Raises:
      ValueError: If the model could not be sucessfully rewritten.
    """
    if rewritten_model.model_type not in [
        rewriter.ModelType.TFLITE_MODEL, rewriter.ModelType.ANY_MODEL
    ]:
      raise ValueError('TFLiteConverter can only convert to the TFLite format.')

    # TODO(dzats): We create a temporary directory with a SavedModel that does
    # not contain an assets or assets.extra directory. Remove this when the
    # TFLite converter can convert models having these directories.
    tmp_model_dir = os.path.join(
        _ensure_str(rewritten_model.path),
        'tmp-rewrite-' + str(int(time.time())))
    if fileio.exists(tmp_model_dir):
      raise ValueError('TFLiteConverter is unable to create a unique path '
                       'for the temp rewriting directory.')

    fileio.makedirs(tmp_model_dir)
    _create_tflite_compatible_saved_model(
        _ensure_str(original_model.path), tmp_model_dir)

    converter = self._create_tflite_converter(
        saved_model_path=tmp_model_dir,
        quantization_optimizations=self._quantization_optimizations,
        quantization_supported_types=self._quantization_supported_types,
        representative_dataset=self._representative_dataset,
        signature_key=self._signature_key,
        **self._kwargs)
    tflite_model = converter.convert()

    output_path = os.path.join(
        _ensure_str(rewritten_model.path), self._filename)
    with fileio.open(_ensure_str(output_path), 'wb') as f:
      f.write(_ensure_bytes(tflite_model))
    fileio.rmtree(tmp_model_dir)

    copy_pairs = []
    if self._copy_assets:
      src = os.path.join(
          _ensure_str(original_model.path), tf.saved_model.ASSETS_DIRECTORY)
      dst = os.path.join(
          _ensure_str(rewritten_model.path), tf.saved_model.ASSETS_DIRECTORY)
      if fileio.isdir(src):
        fileio.mkdir(dst)
        copy_pairs.append((src, dst))
    if self._copy_assets_extra:
      src = os.path.join(
          _ensure_str(original_model.path), EXTRA_ASSETS_DIRECTORY)
      dst = os.path.join(
          _ensure_str(rewritten_model.path), EXTRA_ASSETS_DIRECTORY)
      if fileio.isdir(src):
        fileio.mkdir(dst)
        copy_pairs.append((src, dst))
    for src, dst in copy_pairs:
      io_utils.copy_dir(src, dst)