示例#1
0
  def testModuleDownloadPermissionDenied(self):
    readonly_dir = os.path.join(self.get_temp_dir(), "readonly")
    os.mkdir(readonly_dir, 0o500)
    module_dir = os.path.join(readonly_dir, "module")

    def unused_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      self.fail("This should not be called. Already writing the lockfile "
                "is expected to raise an error.")

    with self.assertRaises(tf.errors.PermissionDeniedError):
      resolver.atomic_download("module", unused_download_fn, module_dir)
示例#2
0
    def testModuleAlreadyDownloaded(self):
        # Simulate the case when a rogue process finishes downloading a module
        # right before the current process can perform a rename of a temp directory
        # to a permanent module directory.
        module_dir = os.path.join(self.get_temp_dir(), "module")

        def fake_download_fn_with_rogue_behavior(handle, tmp_dir):
            del handle, tmp_dir
            # Create module directory
            tf.compat.v1.gfile.MakeDirs(module_dir)
            tf_utils.atomic_write_string_to_file(
                os.path.join(module_dir, "file"), "content", False)

        self.assertEqual(
            module_dir,
            resolver.atomic_download("module",
                                     fake_download_fn_with_rogue_behavior,
                                     module_dir))
        self.assertEqual(tf.compat.v1.gfile.ListDirectory(module_dir),
                         ["file"])
        self.assertFalse(
            tf.compat.v1.gfile.Exists(resolver._lock_filename(module_dir)))
        parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
        self.assertEqual(sorted(tf.compat.v1.gfile.ListDirectory(parent_dir)),
                         ["module", "module.descriptor.txt"])
        self.assertRegexpMatches(
            tf_utils.read_file_to_string(
                resolver._module_descriptor_file(module_dir)),
            "Module: module\n"
            "Download Time: .*\n"
            "Downloader Hostname: %s .PID:%d." %
            (re.escape(socket.gethostname()), os.getpid()))

        # Try downloading the model again. Mock
        # tf_utils.atomic_write_string_to_file() to throw an exception. Since the
        # model is already downloaded, the function will never get called and the
        # download succeeds.
        with mock.patch.object(
                tf_utils,
                "atomic_write_string_to_file",
                side_effect=ValueError("This error should never be raised!")):
            self.assertEqual(
                module_dir,
                resolver.atomic_download("module",
                                         fake_download_fn_with_rogue_behavior,
                                         module_dir))
            self.assertEqual(tf.compat.v1.gfile.ListDirectory(module_dir),
                             ["file"])
            self.assertFalse(
                tf.compat.v1.gfile.Exists(resolver._lock_filename(module_dir)))
    def __call__(self, handle):
        module_dir = _module_dir(handle)

        def download(handle, tmp_dir):
            """Fetch a module via HTTP(S), handling redirect and download headers."""
            cur_url = handle
            request = url.Request(_append_compressed_format_query(handle))

            # Look for and handle a special response header. If present, interpret it
            # as a redirect to the module download location. This allows publishers
            # (if they choose) to provide the same URL for both a module download and
            # its documentation.

            class LoggingHTTPRedirectHandler(url.HTTPRedirectHandler):
                def redirect_request(self, req, fp, code, msg, headers,
                                     newurl):
                    cur_url = newurl  # pylint:disable=unused-variable
                    return url.HTTPRedirectHandler.redirect_request(
                        self, req, fp, code, msg, headers, newurl)

            url_opener = url.build_opener(LoggingHTTPRedirectHandler)
            response = url_opener.open(request)
            return resolver.DownloadManager(cur_url).download_and_uncompress(
                response, tmp_dir)

        return resolver.atomic_download(handle, download, module_dir,
                                        self._lock_file_timeout_sec())
示例#4
0
  def testModuleConcurrentDownload(self):
    module_dir = os.path.join(self.get_temp_dir(), "module")

    # To simulate one downloading starting while the other is still in progress,
    # call resolver.atomic_download() from download_fn(). The second download
    # is set up with download_fn() that fails. That download_fn() is not
    # expected to be called.
    def second_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      self.fail("This should not be called. The module should have been "
                "downloaded already.")

    second_download_thread = threading.Thread(
        target=resolver.atomic_download,
        args=(
            "module",
            second_download_fn,
            module_dir,
        ))

    def first_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      tf_v1.gfile.MakeDirs(module_dir)
      tf_utils.atomic_write_string_to_file(
          os.path.join(module_dir, "file"), "content", False)
      second_download_thread.start()

    self.assertEqual(module_dir,
                     resolver.atomic_download("module", first_download_fn,
                                              module_dir))
    second_download_thread.join(30)
示例#5
0
  def testModuleAlreadyDownloaded(self):
    # Simulate the case when a rogue process finishes downloading a module
    # right before the current process can perform a rename of a temp directory
    # to a permanent module directory.
    module_dir = os.path.join(self.get_temp_dir(), "module")
    def fake_download_fn_with_rogue_behavior(handle, tmp_dir):
      del handle, tmp_dir
      # Create module directory
      tf_v1.gfile.MakeDirs(module_dir)
      tf_utils.atomic_write_string_to_file(
          os.path.join(module_dir, "file"), "content", False)

    self.assertEqual(
        module_dir,
        resolver.atomic_download("module", fake_download_fn_with_rogue_behavior,
                                 module_dir))
    self.assertEqual(tf_v1.gfile.ListDirectory(module_dir), ["file"])
    self.assertFalse(tf_v1.gfile.Exists(resolver._lock_filename(module_dir)))
    parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
    self.assertEqual(
        sorted(tf_v1.gfile.ListDirectory(parent_dir)),
        ["module", "module.descriptor.txt"])
    self.assertRegexpMatches(
        tf_utils.read_file_to_string(
            resolver._module_descriptor_file(module_dir)),
        "Module: module\n"
        "Download Time: .*\n"
        "Downloader Hostname: %s .PID:%d." % (re.escape(socket.gethostname()),
                                              os.getpid()))
示例#6
0
    def testModuleDownloadedWhenEmptyFolderExists(self):
        # Simulate the case when a module is cached in /tmp/module_dir but module
        # files inside the folder are deleted. In this case, the download should
        # still be conducted.
        module_dir = os.path.join(self.get_temp_dir(), "module")

        def fake_download_fn(handle, tmp_dir):
            del handle, tmp_dir
            tf.compat.v1.gfile.MakeDirs(module_dir)
            tf_utils.atomic_write_string_to_file(
                os.path.join(module_dir, "file"), "content", False)

        # Create an empty folder before downloading.
        self.assertFalse(tf.compat.v1.gfile.Exists(module_dir))
        tf.compat.v1.gfile.MakeDirs(module_dir)

        self.assertEqual(
            module_dir,
            resolver.atomic_download("module", fake_download_fn, module_dir))
        self.assertEqual(tf.compat.v1.gfile.ListDirectory(module_dir),
                         ["file"])
        self.assertFalse(
            tf.compat.v1.gfile.Exists(resolver._lock_filename(module_dir)))
        parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
        self.assertEqual(sorted(tf.compat.v1.gfile.ListDirectory(parent_dir)),
                         ["module", "module.descriptor.txt"])
        self.assertRegexpMatches(
            tf_utils.read_file_to_string(
                resolver._module_descriptor_file(module_dir)),
            "Module: module\n"
            "Download Time: .*\n"
            "Downloader Hostname: %s .PID:%d." %
            (re.escape(socket.gethostname()), os.getpid()))
示例#7
0
  def testModuleConcurrentDownload(self):
    module_dir = os.path.join(self.get_temp_dir(), "module")

    # To simulate one downloading starting while the other is still in progress,
    # call resolver.atomic_download() from download_fn(). The second download
    # is set up with download_fn() that fails. That download_fn() is not
    # expected to be called.
    def second_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      self.fail("This should not be called. The module should have been "
                "downloaded already.")

    second_download_thread = threading.Thread(
        target=resolver.atomic_download,
        args=(
            "module",
            second_download_fn,
            module_dir,
        ))

    def first_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      tf.gfile.MakeDirs(module_dir)
      tf_utils.atomic_write_string_to_file(
          os.path.join(module_dir, "file"), "content", False)
      second_download_thread.start()

    self.assertEqual(module_dir,
                     resolver.atomic_download("module", first_download_fn,
                                              module_dir))
    second_download_thread.join(30)
示例#8
0
  def testModuleAlreadyDownloaded(self):
    # Simulate the case when a rogue process finishes downloading a module
    # right before the current process can perform a rename of a temp directory
    # to a permanent module directory.
    module_dir = os.path.join(self.get_temp_dir(), "module")
    def fake_download_fn_with_rogue_behavior(handle, tmp_dir):
      del handle, tmp_dir
      # Create module directory
      tf.gfile.MakeDirs(module_dir)
      tf_utils.atomic_write_string_to_file(
          os.path.join(module_dir, "file"), "content", False)

    self.assertEqual(
        module_dir,
        resolver.atomic_download("module", fake_download_fn_with_rogue_behavior,
                                 module_dir))
    self.assertEqual(tf.gfile.ListDirectory(module_dir), ["file"])
    self.assertFalse(tf.gfile.Exists(resolver._lock_filename(module_dir)))
    parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
    self.assertEqual(
        sorted(tf.gfile.ListDirectory(parent_dir)),
        ["module", "module.descriptor.txt"])
    self.assertRegexpMatches(
        tf_utils.read_file_to_string(
            resolver._module_descriptor_file(module_dir)),
        "Module: module\n"
        "Download Time: .*\n"
        "Downloader Hostname: %s .PID:%d." % (re.escape(socket.gethostname()),
                                              os.getpid()))
    def __call__(self, handle):
        module_dir = _module_dir(handle)

        def download(handle, tmp_dir):
            return resolver.DownloadManager(handle).download_and_uncompress(
                tf.gfile.GFile(handle, "r"), tmp_dir)

        return resolver.atomic_download(handle, download, module_dir,
                                        LOCK_FILE_TIMEOUT_SEC)
示例#10
0
  def testModuleLockLostDownloadKilled(self):
    module_dir = os.path.join(self.get_temp_dir(), "module")
    download_aborted_msg = "Download aborted."
    def kill_download(handle, tmp_dir):
      del handle, tmp_dir
      # Simulate lock loss by removing the lock.
      tf_v1.gfile.Remove(resolver._lock_filename(module_dir))
      # Throw an error to simulate aborted download.
      raise OSError(download_aborted_msg)

    try:
      resolver.atomic_download("module", kill_download, module_dir)
      self.fail("atomic_download() should have thrown an exception.")
    except OSError as _:
      pass
    parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
    # Test that all files got cleaned up.
    self.assertEqual(tf_v1.gfile.ListDirectory(parent_dir), [])
示例#11
0
    def _get_module_path(self, handle):
        module_dir = _module_dir(self._cache_dir, handle)

        def download(handle, tmp_dir):
            return resolver.download_and_uncompress(
                handle, tf.gfile.GFile(handle, "r"), tmp_dir)

        return resolver.atomic_download(handle, download, module_dir,
                                        LOCK_FILE_TIMEOUT_SEC)
示例#12
0
  def testModuleLockLostDownloadKilled(self):
    module_dir = os.path.join(self.get_temp_dir(), "module")
    download_aborted_msg = "Download aborted."
    def kill_download(handle, tmp_dir):
      del handle, tmp_dir
      # Simulate lock loss by removing the lock.
      tf.gfile.Remove(resolver._lock_filename(module_dir))
      # Throw an error to simulate aborted download.
      raise OSError(download_aborted_msg)

    try:
      resolver.atomic_download("module", kill_download, module_dir)
      self.fail("atomic_download() should have thrown an exception.")
    except OSError as _:
      pass
    parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
    # Test that all files got cleaned up.
    self.assertEqual(tf.gfile.ListDirectory(parent_dir), [])
示例#13
0
  def testNotFoundGCSBucket(self):
    # When trying to use not existing GCS bucket, test that
    # tf_util.atomic_write_string_to_file raises tf.error.NotFoundError.
    # Other errors that may arise from bad network connectivity are ignored by
    # resolver.atomic_download and retried infinitely.
    module_dir = ""
    def dummy_download_fn(handle, tmp_dir):
      del handle, tmp_dir
      return

    # Simulate missing GCS bucket by raising NotFoundError in
    # atomic_write_string_to_file.
    with mock.patch(
        "tensorflow_hub.tf_utils.atomic_write_string_to_file") as mock_:
      mock_.side_effect = tf.errors.NotFoundError(None, None, "Test")
      try:
        resolver.atomic_download("module", dummy_download_fn, module_dir)
        assert False
      except tf.errors.NotFoundError as e:
        self.assertEqual("Test", e.message)
    def __call__(self, handle):
        module_dir = _module_dir(handle)

        def download(handle, tmp_dir):
            """Fetch a module via HTTP(S), handling redirect and download headers."""
            request = url.Request(_append_compressed_format_query(handle))
            response = self._call_urlopen(request)
            return resolver.DownloadManager(handle).download_and_uncompress(
                response, tmp_dir)

        return resolver.atomic_download(handle, download, module_dir,
                                        self._lock_file_timeout_sec())
    def __call__(self, handle):
        module_dir = _module_dir(handle)

        def download(handle, tmp_dir):
            """Fetch a module via HTTP(S), handling redirect and download headers."""
            cur_url = handle
            request = url.Request(_append_compressed_format_query(handle))

            # Look for and handle a special response header. If present, interpret it
            # as a redirect to the module download location. This allows publishers
            # (if they choose) to provide the same URL for both a module download and
            # its documentation.

            class LoggingHTTPRedirectHandler(url.HTTPRedirectHandler):
                def redirect_request(self, req, fp, code, msg, headers,
                                     newurl):
                    cur_url = newurl  # pylint:disable=unused-variable
                    return url.HTTPRedirectHandler.redirect_request(
                        self, req, fp, code, msg, headers, newurl)

            # proxy should be able to read http settings without
            # ProxyHandler({'https': 'http://localhost:3128'})
            proxy_handler = url.ProxyHandler()

            # proxy_auth_handler = url.HTTPBasicAuthHandler()
            # proxy_auth_handler.add_password('realm', 'host', '', '')

            http_handler = url.HTTPHandler()
            # Enable more logging. If you don't see any logging,
            # which means the traffic doesn't hit http yet.
            http_handler.set_http_debuglevel(1)

            # Disable SSL as a hack
            ctx = ssl.create_default_context()
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE
            https_handler = url.HTTPSHandler(context=ctx)
            https_handler.set_http_debuglevel(1)

            url_opener = url.build_opener(proxy_handler, https_handler)
            response = url_opener.open(request)
            return resolver.DownloadManager(cur_url).download_and_uncompress(
                response, tmp_dir)

        return resolver.atomic_download(handle, download, module_dir,
                                        self._lock_file_timeout_sec())