Beispiel #1
0
  def test_load(self):
    if not hasattr(tf_v1.saved_model, "load_v2"):
      try:
        hub.load("@my/tf2_module/2")
        self.fail("Failure expected. hub.module() not support in TF 1.x")
      except NotImplementedError:
        pass
    elif tf_v1.executing_eagerly():

      class AdderModule(tf.train.Checkpoint):

        @tf.function(
            input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
        def add(self, x):
          return x + x + 1.

      to_export = AdderModule()
      save_dir = os.path.join(self.get_temp_dir(), "saved_model_v2")
      tf.saved_model.save(to_export, save_dir)
      module_name = "test_module_v2.tgz"
      self._create_tgz(save_dir, module_name)

      restored_module = hub.load(
          "http://localhost:%d/%s" % (self.server_port, module_name))
      self.assertIsNotNone(restored_module)
      self.assertTrue(hasattr(restored_module, "add"))
Beispiel #2
0
 def test_load_v1(self):
   if (not hasattr(tf_v1.saved_model, "load_v2") or
       not tf_v1.executing_eagerly()):
     return  # The test only applies when running V2 mode.
   full_module_path = self._full_module_path("half_plus_two_v1.tar.gz")
   os.chdir(os.path.dirname(full_module_path))
   server_port = test_utils.start_http_server()
   handle = "http://localhost:%d/half_plus_two_v1.tar.gz" % server_port
   hub.load(handle)
 def test_load_v1(self):
     if (not hasattr(tf_v1.saved_model, "load_v2")
             or not tf_v1.executing_eagerly()):
         return  # The test only applies when running V2 mode.
     full_module_path = self._full_module_path("half_plus_two_v1.tar.gz")
     os.chdir(os.path.dirname(full_module_path))
     server_port = test_utils.start_http_server()
     handle = "http://localhost:%d/half_plus_two_v1.tar.gz" % server_port
     try:
         hub.load(handle)
         self.fail("Loading v1 modules not support. Failure expected.")
     except NotImplementedError as e:
         self.assertEqual(
             str(e), "TF Hub module '%s' is stored using TF 1.x "
             "format. Loading of the module using hub.load() is not "
             "supported." % handle)