Esempio n. 1
0
 def testGetModulePathTar(self):
     FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")
     http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
     path = http_resolver("http://localhost:%d/mock_module.tar" %
                          self.server_port)
     files = os.listdir(path)
     self.assertListEqual(sorted(files), ["file1", "file2", "file3"])
Esempio n. 2
0
    def testAbandondedLockFile(self):
        # Tests that the caching procedure is resilient to an abandonded lock
        # file.
        FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")

        # Create an "abandoned" lock file, i.e. a lock file with no process actively
        # downloading anymore.
        module_dir = compressed_module_resolver._module_dir(self.module_handle)
        task_uid = uuid.uuid4().hex
        lock_filename = resolver._lock_filename(module_dir)
        tf_utils.atomic_write_string_to_file(
            lock_filename,
            resolver._lock_file_contents(task_uid),
            overwrite=False)
        with mock.patch.object(
                compressed_module_resolver.HttpCompressedFileResolver,
                "_lock_file_timeout_sec",
                return_value=10):
            http_resolver = compressed_module_resolver.HttpCompressedFileResolver(
            )
            handle = "http://localhost:%d/mock_module.tar.gz" % self.server_port
            # After seeing the lock file is abandoned, this resolver will download the
            # module and return a path to the extracted contents.
            path = http_resolver(handle)
        files = os.listdir(path)
        self.assertListEqual(sorted(files), ["file1", "file2", "file3"])
        self.assertFalse(tf.compat.v1.gfile.Exists(lock_filename))
Esempio n. 3
0
def _install_default_resolvers():
    for impl in [
            resolver.PathResolver(),
            compressed_module_resolver.GcsCompressedFileResolver(),
            compressed_module_resolver.HttpCompressedFileResolver()
    ]:
        registry.resolver.add_implementation(impl)
Esempio n. 4
0
def _get_default_resolvers():
    return [
        resolver.FailResolver(),
        resolver.PathResolver(),
        compressed_module_resolver.GcsCompressedFileResolver(),
        compressed_module_resolver.HttpCompressedFileResolver(),
    ]
Esempio n. 5
0
 def testAppendFormatQuery(self):
     tests = [
         (
             "https://example.com/module.tar.gz",
             "https://example.com/module.tar.gz?tf-hub-format=compressed",
         ),
         (
             "https://example.com/module",
             "https://example.com/module?tf-hub-format=compressed",
         ),
         (
             "https://example.com/module?extra=abc",
             "https://example.com/module?extra=abc&tf-hub-format=compressed",
         ),
         (
             "https://example.com/module?extra=abc",
             "https://example.com/module?extra=abc&tf-hub-format=compressed",
         ),
         (
             "https://example.com/module?extra=abc&tf-hub-format=test",
             ("https://example.com/module?extra=abc&"
              "tf-hub-format=test&tf-hub-format=compressed"),
         )
     ]
     http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
     for handle, expected in tests:
         self.assertTrue(
             http_resolver._append_compressed_format_query(handle),
             expected)
Esempio n. 6
0
 def testNoCacheDirSet(self):
     FLAGS.tfhub_cache_dir = ""
     http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
     handle = "http://localhost:%d/mock_module.tar.gz" % self.server_port
     path = http_resolver(handle)
     files = os.listdir(path)
     self.assertListEqual(sorted(files), ["file1", "file2", "file3"])
     self.assertStartsWith(path, tempfile.gettempdir())
 def testModuleDescriptor(self):
   FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")
   http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
   path = http_resolver(self.module_handle)
   desc = tf_utils.read_file_to_string(resolver._module_descriptor_file(path))
   self.assertRegexpMatches(desc, "Module: %s\n"
                            "Download Time: .*\n"
                            "Downloader Hostname: %s .PID:%d." %
                            (re.escape(self.module_handle),
                             re.escape(socket.gethostname()), os.getpid()))
Esempio n. 8
0
    def testGetModulePathTarGz_withEnvVariable(self, env_value, expected_mode):
        # Tests whether Certificate Validation when resolving a url is off or on.
        # This Environment variable defaults to "off" but can be turned on by
        # setting it to "true"
        FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")

        with unittest.mock.patch.dict(
                os.environ,
            {resolver._TFHUB_DISABLE_CERT_VALIDATION: env_value}):
            http_resolver = compressed_module_resolver.HttpCompressedFileResolver(
            )
            path = http_resolver(self.module_handle)

        self.assertEqual(http_resolver._context.verify_mode, expected_mode)
        self.assertCountEqual(os.listdir(path), ["file1", "file2", "file3"])
 def testModuleAlreadyDownloaded(self):
   FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")
   http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
   path = http_resolver(self.module_handle)
   files = sorted(os.listdir(path))
   self.assertListEqual(files, ["file1", "file2", "file3"])
   creation_times = [
       tf.gfile.Stat(os.path.join(path, f)).mtime_nsec for f in files
   ]
   # Call resolver again and make sure that the module is not downloaded again
   # by checking the timestamps of the module files.
   path = http_resolver(self.module_handle)
   files = sorted(os.listdir(path))
   self.assertListEqual(files, ["file1", "file2", "file3"])
   self.assertListEqual(
       creation_times,
       [tf.gfile.Stat(os.path.join(path, f)).mtime_nsec for f in files])
Esempio n. 10
0
 def testCorruptedArchive(self):
     with tf.compat.v1.gfile.GFile("bad_archive.tar.gz", mode="w") as f:
         f.write("bad_archive")
     http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
     try:
         http_resolver("http://localhost:%d/bad_archive.tar.gz" %
                       self.server_port)
         self.fail("Corrupted archive should have failed to resolve.")
     except IOError as e:
         self.assertEqual(
             "http://localhost:%d/bad_archive.tar.gz does not appear "
             "to be a valid module." % self.server_port, str(e))
     try:
         http_resolver("http://localhost:%d/bad_archive.tar.gz" %
                       self.redirect_server_port)
         self.fail("Corrupted archive should have failed to resolve.")
     except IOError as e:
         # Check that the error message contain the ultimate (redirected to) URL.
         self.assertEqual(
             "http://localhost:%d/bad_archive.tar.gz does not appear "
             "to be a valid module." % self.redirect_server_port, str(e))
Esempio n. 11
0
 def testGetModulePathTarGz(self):
     FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir")
     http_resolver = compressed_module_resolver.HttpCompressedFileResolver()
     path = http_resolver(self.module_handle)
     files = os.listdir(path)
     self.assertListEqual(sorted(files), ["file1", "file2", "file3"])