Пример #1
0
    def testAssets(self):
        original_asset_file = os.path.join(self.get_temp_dir(), "hello.txt")
        _write_string_to_file(original_asset_file, "hello world")

        with tf.Graph().as_default() as graph:
            asset_tensor = tf.constant(original_asset_file, name="file")
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS,
                                    asset_tensor)
            saved_model_lib.add_signature("default", {},
                                          {"default": asset_tensor})

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)

        export_dir = os.path.join(self.get_temp_dir(), "exported")
        handler.export(export_dir)

        # Check that asset file got written to the expected place:
        exported_asset_file = os.path.join(export_dir, "assets", "hello.txt")
        self.assertTrue(tf_v1.gfile.Exists(exported_asset_file))

        loaded_handler = saved_model_lib.load(export_dir)
        with _instantiate_meta_graph(loaded_handler).as_default():
            with tf_v1.Session() as sess:
                self.assertEqual(sess.run("file:0"),
                                 tf.compat.as_bytes(exported_asset_file))
Пример #2
0
    def testWithMultipleAssetsWithSameBasename(self):
        tmp_asset_dir = os.path.join(self.get_temp_dir(), "asset")
        file_a = os.path.join(tmp_asset_dir, "a", "hello.txt")
        file_b = os.path.join(tmp_asset_dir, "b", "hello.txt")
        tf_v1.gfile.MakeDirs(os.path.dirname(file_a))
        tf_v1.gfile.MakeDirs(os.path.dirname(file_b))
        _write_string_to_file(file_a, "hello A")
        _write_string_to_file(file_b, "hello B")
        with tf.Graph().as_default() as graph:
            asset_a = tf.constant(file_a, name="file_a")
            asset_b = tf.constant(file_b, name="file_b")
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_a)
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_b)
            saved_model_lib.add_signature("default", {}, {"default": asset_a})

        export_dir = os.path.join(self.get_temp_dir(), "exported")
        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        handler.export(export_dir)
        tf_v1.gfile.DeleteRecursively(tmp_asset_dir)

        loaded_handler = saved_model_lib.load(export_dir)
        with _instantiate_meta_graph(loaded_handler).as_default():
            with tf_v1.Session() as sess:
                self.assertEqual(_read_file_to_string(sess.run("file_a:0")),
                                 "hello A")
                self.assertEqual(_read_file_to_string(sess.run("file_b:0")),
                                 "hello B")
Пример #3
0
def export_module_spec(spec, export_path):
  """Export module with random initialization."""
  with tf_v1.Graph().as_default():
    m = hub.Module(spec)
    with tf_v1.Session() as session:
      session.run(tf_v1.initializers.global_variables())
      m.export(export_path, session)
Пример #4
0
  def test_http_locations(self):
    with tf.Graph().as_default():
      self._generate_module()

      m = hub.Module("http://localhost:%d/test_module.tgz" % self.server_port)
      out = m(11)
      with tf_v1.Session() as sess:
        self.assertAllClose(sess.run(out), 121)

      # Test caching using custom filesystem (file://) to make sure that the
      # TF Hub library can operate on such paths.
      try:
        root_dir = "file://%s" % self.get_temp_dir()
        cache_dir = "%s_%s" % (root_dir, "cache")
        tf_v1.gfile.MakeDirs(cache_dir)
        os.environ["TFHUB_CACHE_DIR"] = cache_dir
        m = hub.Module("http://localhost:%d/test_module.tgz" % self.server_port)
        out = m(11)
        with tf_v1.train.MonitoredSession() as sess:
          self.assertAllClose(sess.run(out), 121)

        cache_content = sorted(tf_v1.gfile.ListDirectory(cache_dir))
        logging.info("Cache context: %s", str(cache_content))
        self.assertEqual(2, len(cache_content))
        self.assertTrue(cache_content[1].endswith(".descriptor.txt"))
        module_files = sorted(tf_v1.gfile.ListDirectory(
            os.path.join(cache_dir, cache_content[0])))
        self.assertListEqual(
            ["assets", "saved_model.pb", "tfhub_module.pb", "variables"],
            module_files)
      finally:
        os.unsetenv("TFHUB_CACHE_DIR")
Пример #5
0
  def test_module_export_vocab_on_custom_fs(self):
    root_dir = "file://%s" % self.get_temp_dir()
    export_dir = "%s_%s" % (root_dir, "export")
    tf_v1.gfile.MakeDirs(export_dir)
    # Create a module with a vocab file located on a custom filesystem.
    vocab_dir = os.path.join(root_dir, "vocab_location")
    tf_v1.gfile.MakeDirs(vocab_dir)
    vocab_filename = os.path.join(vocab_dir, "tokens.txt")
    tf_utils.atomic_write_string_to_file(vocab_filename, "one", False)

    def create_assets_module_fn():

      def assets_module_fn():
        indices = tf_v1.placeholder(dtype=tf.int64, name="indices")
        table = index_to_string_table_from_file(
            vocabulary_file=vocab_filename, default_value="UNKNOWN")
        outputs = table.lookup(indices)
        hub.add_signature(inputs=indices, outputs=outputs)

      return assets_module_fn

    with tf.Graph().as_default():
      assets_module_fn = create_assets_module_fn()
      spec = hub.create_module_spec(assets_module_fn)
      embedding_module = hub.Module(spec)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.tables_initializer())
        embedding_module.export(export_dir, sess)

    module_files = tf_v1.gfile.ListDirectory(export_dir)
    self.assertListEqual(
        ["assets", "saved_model.pb", "tfhub_module.pb", "variables"],
        sorted(module_files))
    module_files = tf_v1.gfile.ListDirectory(os.path.join(export_dir, "assets"))
    self.assertListEqual(["tokens.txt"], module_files)
Пример #6
0
 def testModuleInNestedScope(self):
     with tf.Graph().as_default():
         with tf_v1.variable_scope("foo"):
             m = module.Module(_ModuleSpec())
             result = m([1, 2])
         with tf_v1.Session() as session:
             self.assertAllEqual(session.run(result), [2, 4])
Пример #7
0
  def _generate_module(self):
    spec = hub.create_module_spec(self._stateless_module_fn)
    m = hub.Module(spec, name="test_module")
    out = m(10)

    export_path = os.path.join(self.get_temp_dir(), "module")
    with tf_v1.Session() as sess:
      sess.run(tf_v1.global_variables_initializer())
      self.assertAllClose(sess.run(out), 100)
      m.export(export_path, sess)

    self._create_tgz(export_path)
Пример #8
0
def export_module_spec(spec, path, checkpoint_path, name_transform_fn):
  """Helper function to ModuleSpec.export()."""
  with tf.Graph().as_default():
    m = Module(spec)
    assign_map = {
        name_transform_fn(name): value for name, value in m.variable_map.items()
    }
    tf_v1.train.init_from_checkpoint(checkpoint_path, assign_map)
    init_op = tf_v1.initializers.global_variables()
    with tf_v1.Session() as session:
      session.run(init_op)
      m.export(path, session)
Пример #9
0
def _make_estimator_serving_session(estimator, serving_input_fn,
                                    checkpoint_path):
    """Returns a session constructed using `estimator` and `serving_input_fn`.

  The Estimator API does not provide an API to construct a graph and session,
  making it necessary for this function to replicate how an estimator builds
  a graph.

  This code is based on `Estimator.export_savedmodel` (another function that
  has to replicate how an estimator builds a graph).

  Args:
    estimator: tf.Estimator to use when constructing the session.
    serving_input_fn: A function that takes no arguments and returns a
      `ServingInputReceiver`. It is used to construct the session.
    checkpoint_path: The checkpoint path to restore in the session. Must not
      be None.
  """
    with tf.Graph().as_default() as g:
        mode = tf_v1.estimator.ModeKeys.PREDICT
        tf_v1.train.create_global_step(g)
        tf_v1.set_random_seed(estimator.config.tf_random_seed)
        serving_input_receiver = serving_input_fn()

        estimator_spec = estimator.model_fn(
            features=serving_input_receiver.features,
            labels=None,
            mode=mode,
            config=estimator.config)

        # pylint: disable=protected-access
        # Note that MonitoredSession(), despite the name is not a Session, and
        # can't be used to export Modules as one can't use them with Savers.
        # As so this must use a raw tf.Session().
        session = tf_v1.Session(config=estimator._session_config)
        # pylint: enable=protected-access

        with session.as_default():
            # TODO(b/71839662): Consider if this needs to support TPUEstimatorSpec
            # which does not have a scaffold member.
            saver_for_restore = estimator_spec.scaffold.saver or tf_v1.train.Saver(
                sharded=True)
            saver_for_restore.restore(session, checkpoint_path)
        return session
    def createSavedModel(self):
        model_dir = os.path.join(self.get_temp_dir(), "saved_model")
        with tf.Graph().as_default():
            x = tf_v1.placeholder(dtype=tf.float32, shape=[None, 3])
            w = tf_v1.get_variable("weights", shape=[])
            y = x * w
            tf_v1.add_to_collection(_EXTRA_COLLECTION, y)

            init_op = tf_v1.assign(w, 2)

            with tf_v1.Session() as session:
                session.run(init_op)
                tf_v1.saved_model.simple_save(
                    session,
                    model_dir,
                    inputs={"x": x},
                    outputs={"y": y},
                )
        return model_dir
Пример #11
0
 def testModuleDictInput(self):
     with tf.Graph().as_default():
         m = module.Module(_ModuleSpec())
         result = m({"x": [1, 2]})
         with tf_v1.Session() as session:
             self.assertAllEqual(session.run(result), [2, 4])