Beispiel #1
0
 def testModuleRunningOnColab(self):
     module_export_path = os.path.join(self.get_temp_dir(), "module")
     with tf.Graph().as_default():
         test_utils.export_module(module_export_path)
         # Mock the server by returning the path to the local uncompressed module
         with mock.patch.object(
                 uncompressed_module_resolver.HttpUncompressedFileResolver,
                 "_request_gcs_location",
                 return_value=module_export_path) as mocked_urlopen:
             with test_utils.RunningOnColabContext():
                 m = hub.Module("https://tfhub.dev/google/model/1")
             mocked_urlopen.assert_called_once_with(
                 "https://tfhub.dev/google/model/1?tf-hub-format=uncompressed"
             )
         out = m(11)
         with tf.compat.v1.Session() as sess:
             self.assertAllClose(sess.run(out), 121)
Beispiel #2
0
 def test_load_format_uncompressed(self):
     # The uncompressed resolver should be called in both cases
     with test_utils.UncompressedLoadFormatContext():
         self._assert_uncompressed_resolver_called()
         with test_utils.RunningOnColabContext():
             self._assert_uncompressed_resolver_called()
Beispiel #3
0
 def test_load_format_auto(self):
     # ModelLoadFormat is set to AUTO on default
     # On Colab, use the uncompressed resolver
     self._assert_compressed_resolver_called()
     with test_utils.RunningOnColabContext():
         self._assert_uncompressed_resolver_called()
Beispiel #4
0
 def test_on_auto_load_format_on_colab(self):
     with test_utils.AutoLoadFormatContext(
     ), test_utils.RunningOnColabContext():
         for handle in self.handles:
             self.assertTrue(
                 self.uncompressed_resolver.is_supported(handle))