def testGetModulePathTar(self): FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") http_resolver = compressed_module_resolver.HttpCompressedFileResolver() path = http_resolver("http://localhost:%d/mock_module.tar" % self.server_port) files = os.listdir(path) self.assertListEqual(sorted(files), ["file1", "file2", "file3"])
def testAbandondedLockFile(self): # Tests that the caching procedure is resilient to an abandonded lock # file. FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") # Create an "abandoned" lock file, i.e. a lock file with no process actively # downloading anymore. module_dir = compressed_module_resolver._module_dir(self.module_handle) task_uid = uuid.uuid4().hex lock_filename = resolver._lock_filename(module_dir) tf_utils.atomic_write_string_to_file( lock_filename, resolver._lock_file_contents(task_uid), overwrite=False) with mock.patch.object( compressed_module_resolver.HttpCompressedFileResolver, "_lock_file_timeout_sec", return_value=10): http_resolver = compressed_module_resolver.HttpCompressedFileResolver( ) handle = "http://localhost:%d/mock_module.tar.gz" % self.server_port # After seeing the lock file is abandoned, this resolver will download the # module and return a path to the extracted contents. path = http_resolver(handle) files = os.listdir(path) self.assertListEqual(sorted(files), ["file1", "file2", "file3"]) self.assertFalse(tf.compat.v1.gfile.Exists(lock_filename))
def _install_default_resolvers(): for impl in [ resolver.PathResolver(), compressed_module_resolver.GcsCompressedFileResolver(), compressed_module_resolver.HttpCompressedFileResolver() ]: registry.resolver.add_implementation(impl)
def _get_default_resolvers(): return [ resolver.FailResolver(), resolver.PathResolver(), compressed_module_resolver.GcsCompressedFileResolver(), compressed_module_resolver.HttpCompressedFileResolver(), ]
def testAppendFormatQuery(self): tests = [ ( "https://example.com/module.tar.gz", "https://example.com/module.tar.gz?tf-hub-format=compressed", ), ( "https://example.com/module", "https://example.com/module?tf-hub-format=compressed", ), ( "https://example.com/module?extra=abc", "https://example.com/module?extra=abc&tf-hub-format=compressed", ), ( "https://example.com/module?extra=abc", "https://example.com/module?extra=abc&tf-hub-format=compressed", ), ( "https://example.com/module?extra=abc&tf-hub-format=test", ("https://example.com/module?extra=abc&" "tf-hub-format=test&tf-hub-format=compressed"), ) ] http_resolver = compressed_module_resolver.HttpCompressedFileResolver() for handle, expected in tests: self.assertTrue( http_resolver._append_compressed_format_query(handle), expected)
def testNoCacheDirSet(self): FLAGS.tfhub_cache_dir = "" http_resolver = compressed_module_resolver.HttpCompressedFileResolver() handle = "http://localhost:%d/mock_module.tar.gz" % self.server_port path = http_resolver(handle) files = os.listdir(path) self.assertListEqual(sorted(files), ["file1", "file2", "file3"]) self.assertStartsWith(path, tempfile.gettempdir())
def testModuleDescriptor(self): FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") http_resolver = compressed_module_resolver.HttpCompressedFileResolver() path = http_resolver(self.module_handle) desc = tf_utils.read_file_to_string(resolver._module_descriptor_file(path)) self.assertRegexpMatches(desc, "Module: %s\n" "Download Time: .*\n" "Downloader Hostname: %s .PID:%d." % (re.escape(self.module_handle), re.escape(socket.gethostname()), os.getpid()))
def testGetModulePathTarGz_withEnvVariable(self, env_value, expected_mode): # Tests whether Certificate Validation when resolving a url is off or on. # This Environment variable defaults to "off" but can be turned on by # setting it to "true" FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") with unittest.mock.patch.dict( os.environ, {resolver._TFHUB_DISABLE_CERT_VALIDATION: env_value}): http_resolver = compressed_module_resolver.HttpCompressedFileResolver( ) path = http_resolver(self.module_handle) self.assertEqual(http_resolver._context.verify_mode, expected_mode) self.assertCountEqual(os.listdir(path), ["file1", "file2", "file3"])
def testModuleAlreadyDownloaded(self): FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") http_resolver = compressed_module_resolver.HttpCompressedFileResolver() path = http_resolver(self.module_handle) files = sorted(os.listdir(path)) self.assertListEqual(files, ["file1", "file2", "file3"]) creation_times = [ tf.gfile.Stat(os.path.join(path, f)).mtime_nsec for f in files ] # Call resolver again and make sure that the module is not downloaded again # by checking the timestamps of the module files. path = http_resolver(self.module_handle) files = sorted(os.listdir(path)) self.assertListEqual(files, ["file1", "file2", "file3"]) self.assertListEqual( creation_times, [tf.gfile.Stat(os.path.join(path, f)).mtime_nsec for f in files])
def testCorruptedArchive(self): with tf.compat.v1.gfile.GFile("bad_archive.tar.gz", mode="w") as f: f.write("bad_archive") http_resolver = compressed_module_resolver.HttpCompressedFileResolver() try: http_resolver("http://localhost:%d/bad_archive.tar.gz" % self.server_port) self.fail("Corrupted archive should have failed to resolve.") except IOError as e: self.assertEqual( "http://localhost:%d/bad_archive.tar.gz does not appear " "to be a valid module." % self.server_port, str(e)) try: http_resolver("http://localhost:%d/bad_archive.tar.gz" % self.redirect_server_port) self.fail("Corrupted archive should have failed to resolve.") except IOError as e: # Check that the error message contain the ultimate (redirected to) URL. self.assertEqual( "http://localhost:%d/bad_archive.tar.gz does not appear " "to be a valid module." % self.redirect_server_port, str(e))
def testGetModulePathTarGz(self): FLAGS.tfhub_cache_dir = os.path.join(self.get_temp_dir(), "cache_dir") http_resolver = compressed_module_resolver.HttpCompressedFileResolver() path = http_resolver(self.module_handle) files = os.listdir(path) self.assertListEqual(sorted(files), ["file1", "file2", "file3"])