예제 #1
0
 def testModuleRunningWithUncompressedContext(self):
   module_export_path = os.path.join(self.get_temp_dir(), "module")
   with tf.Graph().as_default():
     test_utils.export_module(module_export_path)
     # Mock the server by returning the path to the local uncompressed module
     with mock.patch.object(
         uncompressed_module_resolver.HttpUncompressedFileResolver,
         "_request_gcs_location",
         return_value=module_export_path) as mocked_urlopen:
       with test_utils.UncompressedLoadFormatContext():
         m = hub.Module("https://tfhub.dev/google/model/1")
       mocked_urlopen.assert_called_once_with(
           "https://tfhub.dev/google/model/1?tf-hub-format=uncompressed")
     out = m(11)
     with tf.compat.v1.Session() as sess:
       self.assertAllClose(sess.run(out), 121)
예제 #2
0
 def test_load_format_uncompressed(self):
     with test_utils.UncompressedLoadFormatContext():
         self._assert_uncompressed_resolver_called()
예제 #3
0
 def test_on_uncompressed_load_format(self):
     with test_utils.UncompressedLoadFormatContext():
         for handle in self.handles:
             self.assertTrue(
                 self.uncompressed_resolver.is_supported(handle))
예제 #4
0
 def test_load_format_uncompressed(self):
     # The uncompressed resolver should be called in both cases
     with test_utils.UncompressedLoadFormatContext():
         self._assert_uncompressed_resolver_called()
         with test_utils.RunningOnColabContext():
             self._assert_uncompressed_resolver_called()