Exemplo n.º 1
0
class FetcherTest(mox.MoxTestBase):
    def setUp(self):
        super(FetcherTest, self).setUp()

        self.requests = self.mox.CreateMockAnything()
        self.response = self.mox.CreateMock(requests.Response)
        self.fetcher = Fetcher(requests_api=self.requests)
        self.listener = self.mox.CreateMock(Fetcher.Listener)

    def expect_get(self, url, chunk_size_bytes, timeout_secs, listener=True):
        self.requests.get(url, stream=True, timeout=timeout_secs).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {"content-length": "11"}
        if listener:
            self.listener.status(200, content_length=11)

        chunks = ["0123456789", "a"]
        self.response.iter_content(chunk_size=chunk_size_bytes).AndReturn(chunks)
        return chunks

    def test_get(self):
        for chunk in self.expect_get("http://bar", chunk_size_bytes=1024, timeout_secs=60):
            self.listener.recv_chunk(chunk)
        self.listener.finished()
        self.response.close()

        self.mox.ReplayAll()

        self.fetcher.fetch("http://bar", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES))

    def test_checksum_listener(self):
        digest = self.mox.CreateMockAnything()
        for chunk in self.expect_get("http://baz", chunk_size_bytes=1, timeout_secs=37):
            self.listener.recv_chunk(chunk)
            digest.update(chunk)

        self.listener.finished()
        digest.hexdigest().AndReturn("42")

        self.response.close()

        self.mox.ReplayAll()

        checksum_listener = Fetcher.ChecksumListener(digest=digest)
        self.fetcher.fetch(
            "http://baz",
            checksum_listener.wrap(self.listener),
            chunk_size=Amount(1, Data.BYTES),
            timeout=Amount(37, Time.SECONDS),
        )
        self.assertEqual("42", checksum_listener.checksum)

    def test_download_listener(self):
        downloaded = ""
        for chunk in self.expect_get("http://foo", chunk_size_bytes=1048576, timeout_secs=3600):
            self.listener.recv_chunk(chunk)
            downloaded += chunk

        self.listener.finished()
        self.response.close()

        self.mox.ReplayAll()

        with closing(Compatibility.StringIO()) as fp:
            self.fetcher.fetch(
                "http://foo",
                Fetcher.DownloadListener(fp).wrap(self.listener),
                chunk_size=Amount(1, Data.MB),
                timeout=Amount(1, Time.HOURS),
            )
            self.assertEqual(downloaded, fp.getvalue())

    def test_size_mismatch(self):
        self.requests.get("http://foo", stream=True, timeout=60).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {"content-length": "11"}
        self.listener.status(200, content_length=11)

        self.response.iter_content(chunk_size=1024).AndReturn(["a", "b"])
        self.listener.recv_chunk("a")
        self.listener.recv_chunk("b")

        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.Error):
            self.fetcher.fetch(
                "http://foo", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES)
            )

    def test_get_error_transient(self):
        self.requests.get("http://foo", stream=True, timeout=60).AndRaise(requests.ConnectionError)

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.TransientError):
            self.fetcher.fetch(
                "http://foo", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES)
            )

    def test_get_error_permanent(self):
        self.requests.get("http://foo", stream=True, timeout=60).AndRaise(requests.TooManyRedirects)

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch(
                "http://foo", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES)
            )
        self.assertTrue(e.value.response_code is None)

    def test_http_error(self):
        self.requests.get("http://foo", stream=True, timeout=60).AndReturn(self.response)
        self.response.status_code = 404
        self.listener.status(404)

        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch(
                "http://foo", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES)
            )
        self.assertEqual(404, e.value.response_code)

    def test_iter_content_error(self):
        self.requests.get("http://foo", stream=True, timeout=60).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {}
        self.listener.status(200, content_length=None)

        self.response.iter_content(chunk_size=1024).AndRaise(requests.Timeout)
        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.TransientError):
            self.fetcher.fetch(
                "http://foo", self.listener, chunk_size=Amount(1, Data.KB), timeout=Amount(1, Time.MINUTES)
            )

    def expect_download(self, path_or_fd=None):
        downloaded = ""
        for chunk in self.expect_get("http://1", chunk_size_bytes=13, timeout_secs=13, listener=False):
            downloaded += chunk
        self.response.close()

        self.mox.ReplayAll()

        path = self.fetcher.download(
            "http://1", path_or_fd=path_or_fd, chunk_size=Amount(13, Data.BYTES), timeout=Amount(13, Time.SECONDS)
        )
        return downloaded, path

    def test_download(self):
        downloaded, path = self.expect_download()
        try:
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())
        finally:
            os.unlink(path)

    def test_download_fd(self):
        with temporary_file() as fd:
            downloaded, path = self.expect_download(path_or_fd=fd)
            self.assertEqual(path, fd.name)
            fd.close()
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())

    def test_download_path(self):
        with temporary_file() as fd:
            fd.close()
            downloaded, path = self.expect_download(path_or_fd=fd.name)
            self.assertEqual(path, fd.name)
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())
Exemplo n.º 2
0
class FetcherTest(mox.MoxTestBase):
    def setUp(self):
        super(FetcherTest, self).setUp()

        self.requests = self.mox.CreateMockAnything()
        self.response = self.mox.CreateMock(requests.Response)
        self.fetcher = Fetcher(requests_api=self.requests)
        self.listener = self.mox.CreateMock(Fetcher.Listener)

    def expect_get(self, url, chunk_size_bytes, timeout_secs, listener=True):
        self.requests.get(url, stream=True,
                          timeout=timeout_secs).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {'content-length': '11'}
        if listener:
            self.listener.status(200, content_length=11)

        chunks = ['0123456789', 'a']
        self.response.iter_content(
            chunk_size=chunk_size_bytes).AndReturn(chunks)
        return chunks

    def test_get(self):
        for chunk in self.expect_get('http://bar',
                                     chunk_size_bytes=1024,
                                     timeout_secs=60):
            self.listener.recv_chunk(chunk)
        self.listener.finished()
        self.response.close()

        self.mox.ReplayAll()

        self.fetcher.fetch('http://bar',
                           self.listener,
                           chunk_size=Amount(1, Data.KB),
                           timeout=Amount(1, Time.MINUTES))

    def test_checksum_listener(self):
        digest = self.mox.CreateMockAnything()
        for chunk in self.expect_get('http://baz',
                                     chunk_size_bytes=1,
                                     timeout_secs=37):
            self.listener.recv_chunk(chunk)
            digest.update(chunk)

        self.listener.finished()
        digest.hexdigest().AndReturn('42')

        self.response.close()

        self.mox.ReplayAll()

        checksum_listener = Fetcher.ChecksumListener(digest=digest)
        self.fetcher.fetch('http://baz',
                           checksum_listener.wrap(self.listener),
                           chunk_size=Amount(1, Data.BYTES),
                           timeout=Amount(37, Time.SECONDS))
        self.assertEqual('42', checksum_listener.checksum)

    def test_download_listener(self):
        downloaded = ''
        for chunk in self.expect_get('http://foo',
                                     chunk_size_bytes=1048576,
                                     timeout_secs=3600):
            self.listener.recv_chunk(chunk)
            downloaded += chunk

        self.listener.finished()
        self.response.close()

        self.mox.ReplayAll()

        with closing(Compatibility.StringIO()) as fp:
            self.fetcher.fetch('http://foo',
                               Fetcher.DownloadListener(fp).wrap(
                                   self.listener),
                               chunk_size=Amount(1, Data.MB),
                               timeout=Amount(1, Time.HOURS))
            self.assertEqual(downloaded, fp.getvalue())

    def test_size_mismatch(self):
        self.requests.get('http://foo', stream=True,
                          timeout=60).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {'content-length': '11'}
        self.listener.status(200, content_length=11)

        self.response.iter_content(chunk_size=1024).AndReturn(['a', 'b'])
        self.listener.recv_chunk('a')
        self.listener.recv_chunk('b')

        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.Error):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size=Amount(1, Data.KB),
                               timeout=Amount(1, Time.MINUTES))

    def test_get_error_transient(self):
        self.requests.get('http://foo', stream=True,
                          timeout=60).AndRaise(requests.ConnectionError)

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.TransientError):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size=Amount(1, Data.KB),
                               timeout=Amount(1, Time.MINUTES))

    def test_get_error_permanent(self):
        self.requests.get('http://foo', stream=True,
                          timeout=60).AndRaise(requests.TooManyRedirects)

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size=Amount(1, Data.KB),
                               timeout=Amount(1, Time.MINUTES))
        self.assertTrue(e.value.response_code is None)

    def test_http_error(self):
        self.requests.get('http://foo', stream=True,
                          timeout=60).AndReturn(self.response)
        self.response.status_code = 404
        self.listener.status(404)

        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size=Amount(1, Data.KB),
                               timeout=Amount(1, Time.MINUTES))
        self.assertEqual(404, e.value.response_code)

    def test_iter_content_error(self):
        self.requests.get('http://foo', stream=True,
                          timeout=60).AndReturn(self.response)
        self.response.status_code = 200
        self.response.headers = {}
        self.listener.status(200, content_length=None)

        self.response.iter_content(chunk_size=1024).AndRaise(requests.Timeout)
        self.response.close()

        self.mox.ReplayAll()

        with pytest.raises(self.fetcher.TransientError):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size=Amount(1, Data.KB),
                               timeout=Amount(1, Time.MINUTES))

    def expect_download(self, path_or_fd=None):
        downloaded = ''
        for chunk in self.expect_get('http://1',
                                     chunk_size_bytes=13,
                                     timeout_secs=13,
                                     listener=False):
            downloaded += chunk
        self.response.close()

        self.mox.ReplayAll()

        path = self.fetcher.download('http://1',
                                     path_or_fd=path_or_fd,
                                     chunk_size=Amount(13, Data.BYTES),
                                     timeout=Amount(13, Time.SECONDS))
        return downloaded, path

    def test_download(self):
        downloaded, path = self.expect_download()
        try:
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())
        finally:
            os.unlink(path)

    def test_download_fd(self):
        with temporary_file() as fd:
            downloaded, path = self.expect_download(path_or_fd=fd)
            self.assertEqual(path, fd.name)
            fd.close()
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())

    def test_download_path(self):
        with temporary_file() as fd:
            fd.close()
            downloaded, path = self.expect_download(path_or_fd=fd.name)
            self.assertEqual(path, fd.name)
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())
Exemplo n.º 3
0
class FetcherTest(mox.MoxTestBase):
  def setUp(self):
    super(FetcherTest, self).setUp()

    self.requests = self.mox.CreateMockAnything()
    self.response = self.mox.CreateMock(requests.Response)
    self.fetcher = Fetcher(requests_api=self.requests)
    self.listener = self.mox.CreateMock(Fetcher.Listener)

  def expect_get(self, url, chunk_size_bytes, timeout_secs, listener=True):
    self.requests.get(url, allow_redirects=True, stream=True,
                      timeout=timeout_secs).AndReturn(self.response)
    self.response.status_code = 200
    self.response.headers = {'content-length': '11'}
    if listener:
      self.listener.status(200, content_length=11)

    chunks = ['0123456789', 'a']
    self.response.iter_content(chunk_size=chunk_size_bytes).AndReturn(chunks)
    return chunks

  def test_get(self):
    for chunk in self.expect_get('http://bar', chunk_size_bytes=1024, timeout_secs=60):
      self.listener.recv_chunk(chunk)
    self.listener.finished()
    self.response.close()

    self.mox.ReplayAll()

    self.fetcher.fetch('http://bar',
                       self.listener,
                       chunk_size_bytes=1024,
                       timeout_secs=60)

  def test_checksum_listener(self):
    digest = self.mox.CreateMockAnything()
    for chunk in self.expect_get('http://baz', chunk_size_bytes=1, timeout_secs=37):
      self.listener.recv_chunk(chunk)
      digest.update(chunk)

    self.listener.finished()
    digest.hexdigest().AndReturn('42')

    self.response.close()

    self.mox.ReplayAll()

    checksum_listener = Fetcher.ChecksumListener(digest=digest)
    self.fetcher.fetch('http://baz',
                       checksum_listener.wrap(self.listener),
                       chunk_size_bytes=1,
                       timeout_secs=37)
    self.assertEqual('42', checksum_listener.checksum)

  def test_download_listener(self):
    downloaded = ''
    for chunk in self.expect_get('http://foo', chunk_size_bytes=1048576, timeout_secs=3600):
      self.listener.recv_chunk(chunk)
      downloaded += chunk

    self.listener.finished()
    self.response.close()

    self.mox.ReplayAll()

    with closing(StringIO()) as fp:
      self.fetcher.fetch('http://foo',
                         Fetcher.DownloadListener(fp).wrap(self.listener),
                         chunk_size_bytes=1024 * 1024,
                         timeout_secs=60 * 60)
      self.assertEqual(downloaded, fp.getvalue())

  def test_size_mismatch(self):
    self.requests.get('http://foo', allow_redirects=True, stream=True,
                      timeout=60).AndReturn(self.response)
    self.response.status_code = 200
    self.response.headers = {'content-length': '11'}
    self.listener.status(200, content_length=11)

    self.response.iter_content(chunk_size=1024).AndReturn(['a', 'b'])
    self.listener.recv_chunk('a')
    self.listener.recv_chunk('b')

    self.response.close()

    self.mox.ReplayAll()

    with self.assertRaises(self.fetcher.Error):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

  def test_get_error_transient(self):
    self.requests.get('http://foo', allow_redirects=True, stream=True,
                      timeout=60).AndRaise(requests.ConnectionError)

    self.mox.ReplayAll()

    with self.assertRaises(self.fetcher.TransientError):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

  def test_get_error_permanent(self):
    self.requests.get('http://foo', allow_redirects=True, stream=True,
                      timeout=60).AndRaise(requests.TooManyRedirects)

    self.mox.ReplayAll()

    with self.assertRaises(self.fetcher.PermanentError) as e:
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)
    self.assertTrue(e.exception.response_code is None)

  def test_http_error(self):
    self.requests.get('http://foo', allow_redirects=True, stream=True,
                      timeout=60).AndReturn(self.response)
    self.response.status_code = 404
    self.listener.status(404)

    self.response.close()

    self.mox.ReplayAll()

    with self.assertRaises(self.fetcher.PermanentError) as e:
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)
    self.assertEqual(404, e.exception.response_code)

  def test_iter_content_error(self):
    self.requests.get('http://foo', allow_redirects=True, stream=True,
                      timeout=60).AndReturn(self.response)
    self.response.status_code = 200
    self.response.headers = {}
    self.listener.status(200, content_length=None)

    self.response.iter_content(chunk_size=1024).AndRaise(requests.Timeout)
    self.response.close()

    self.mox.ReplayAll()

    with self.assertRaises(self.fetcher.TransientError):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

  def expect_download(self, path_or_fd=None):
    downloaded = ''
    for chunk in self.expect_get('http://1', chunk_size_bytes=13, timeout_secs=13, listener=False):
      downloaded += chunk
    self.response.close()

    self.mox.ReplayAll()

    path = self.fetcher.download('http://1',
                                 path_or_fd=path_or_fd,
                                 chunk_size_bytes=13,
                                 timeout_secs=13)
    return downloaded, path

  def test_download(self):
    downloaded, path = self.expect_download()
    try:
      with open(path) as fp:
        self.assertEqual(downloaded, fp.read())
    finally:
      os.unlink(path)

  def test_download_fd(self):
    with temporary_file() as fd:
      downloaded, path = self.expect_download(path_or_fd=fd)
      self.assertEqual(path, fd.name)
      fd.close()
      with open(path) as fp:
        self.assertEqual(downloaded, fp.read())

  def test_download_path(self):
    with temporary_file() as fd:
      fd.close()
      downloaded, path = self.expect_download(path_or_fd=fd.name)
      self.assertEqual(path, fd.name)
      with open(path) as fp:
        self.assertEqual(downloaded, fp.read())
Exemplo n.º 4
0
class FetcherTest(unittest.TestCase):
    def setUp(self):
        self.requests = mock.Mock(spec=requests.Session)
        self.response = mock.Mock(spec=requests.Response)
        self.fetcher = Fetcher('/unused/root/dir', requests_api=self.requests)
        self.listener = mock.create_autospec(Fetcher.Listener, spec_set=True)

    def status_call(self, status_code, content_length=None):
        return mock.call.status(status_code, content_length=content_length)

    def ok_call(self, chunks):
        return self.status_call(200,
                                content_length=sum(len(c) for c in chunks))

    def assert_listener_calls(self,
                              expected_listener_calls,
                              chunks,
                              expect_finished=True):
        expected_listener_calls.extend(
            mock.call.recv_chunk(chunk) for chunk in chunks)
        if expect_finished:
            expected_listener_calls.append(mock.call.finished())
        self.assertEqual(expected_listener_calls, self.listener.method_calls)

    def assert_local_file_fetch(self, url_prefix=''):
        chunks = ['0123456789', 'a']
        with temporary_file() as fp:
            for chunk in chunks:
                fp.write(chunk)
            fp.close()

            self.fetcher.fetch(url_prefix + fp.name,
                               self.listener,
                               chunk_size_bytes=10)

            self.assert_listener_calls([self.ok_call(chunks)], chunks)
            self.requests.assert_not_called()

    def test_file_path(self):
        self.assert_local_file_fetch()

    def test_file_scheme(self):
        self.assert_local_file_fetch('file:')

    def assert_local_file_fetch_relative(self, url, *rel_path):
        expected_contents = b'proof'
        with temporary_dir() as root_dir:
            with safe_open(os.path.join(root_dir, *rel_path), 'wb') as fp:
                fp.write(expected_contents)
            with temporary_file() as download_fp:
                Fetcher(root_dir).download(url, path_or_fd=download_fp)
                download_fp.close()
                with open(download_fp.name, 'rb') as fp:
                    self.assertEqual(expected_contents, fp.read())

    def test_file_scheme_double_slash_relative(self):
        self.assert_local_file_fetch_relative('file://relative/path',
                                              'relative', 'path')

    def test_file_scheme_embedded_double_slash(self):
        self.assert_local_file_fetch_relative('file://a//strange//path', 'a',
                                              'strange', 'path')

    def test_file_scheme_triple_slash(self):
        self.assert_local_file_fetch('file://')

    def test_file_dne(self):
        with temporary_dir() as base:
            with self.assertRaises(self.fetcher.PermanentError):
                self.fetcher.fetch(os.path.join(base, 'dne'), self.listener)

    def test_file_no_perms(self):
        with temporary_dir() as base:
            no_perms = os.path.join(base, 'dne')
            touch(no_perms)
            os.chmod(no_perms, 0)
            self.assertTrue(os.path.exists(no_perms))
            with self.assertRaises(self.fetcher.PermanentError):
                self.fetcher.fetch(no_perms, self.listener)

    @contextmanager
    def expect_get(self,
                   url,
                   chunk_size_bytes,
                   timeout_secs,
                   chunks=None,
                   listener=True):
        chunks = chunks or ['0123456789', 'a']
        size = sum(len(c) for c in chunks)

        self.requests.get.return_value = self.response
        self.response.status_code = 200
        self.response.headers = {'content-length': str(size)}
        self.response.iter_content.return_value = chunks

        yield chunks, [self.ok_call(chunks)] if listener else []

        self.requests.get.expect_called_once_with(url,
                                                  allow_redirects=True,
                                                  stream=True,
                                                  timeout=timeout_secs)
        self.response.iter_content.expect_called_once_with(
            chunk_size=chunk_size_bytes)

    def test_get(self):
        with self.expect_get('http://bar',
                             chunk_size_bytes=1024,
                             timeout_secs=60) as (chunks,
                                                  expected_listener_calls):

            self.fetcher.fetch('http://bar',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

            self.assert_listener_calls(expected_listener_calls, chunks)
            self.response.close.expect_called_once_with()

    def test_checksum_listener(self):
        digest = mock.Mock(spec=hashlib.md5())
        digest.hexdigest.return_value = '42'
        checksum_listener = Fetcher.ChecksumListener(digest=digest)

        with self.expect_get('http://baz', chunk_size_bytes=1,
                             timeout_secs=37) as (chunks,
                                                  expected_listener_calls):

            self.fetcher.fetch('http://baz',
                               checksum_listener.wrap(self.listener),
                               chunk_size_bytes=1,
                               timeout_secs=37)

        self.assertEqual('42', checksum_listener.checksum)

        def expected_digest_calls():
            for chunk in chunks:
                yield mock.call.update(chunk)
            yield mock.call.hexdigest()

        self.assertEqual(list(expected_digest_calls()), digest.method_calls)

        self.assert_listener_calls(expected_listener_calls, chunks)
        self.response.close.assert_called_once_with()

    def concat_chunks(self, chunks):
        return reduce(lambda acc, c: acc + c, chunks, '')

    def test_download_listener(self):
        with self.expect_get('http://foo',
                             chunk_size_bytes=1048576,
                             timeout_secs=3600) as (chunks,
                                                    expected_listener_calls):

            with closing(StringIO()) as fp:
                self.fetcher.fetch('http://foo',
                                   Fetcher.DownloadListener(fp).wrap(
                                       self.listener),
                                   chunk_size_bytes=1024 * 1024,
                                   timeout_secs=60 * 60)

                downloaded = self.concat_chunks(chunks)
                self.assertEqual(downloaded, fp.getvalue())

        self.assert_listener_calls(expected_listener_calls, chunks)
        self.response.close.assert_called_once_with()

    def test_size_mismatch(self):
        self.requests.get.return_value = self.response
        self.response.status_code = 200
        self.response.headers = {'content-length': '11'}
        chunks = ['a', 'b']
        self.response.iter_content.return_value = chunks

        with self.assertRaises(self.fetcher.Error):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

        self.requests.get.assert_called_once_with('http://foo',
                                                  allow_redirects=True,
                                                  stream=True,
                                                  timeout=60)
        self.response.iter_content.assert_called_once_with(chunk_size=1024)
        self.assert_listener_calls([self.status_call(200, content_length=11)],
                                   chunks,
                                   expect_finished=False)
        self.response.close.assert_called_once_with()

    def test_get_error_transient(self):
        self.requests.get.side_effect = requests.ConnectionError

        with self.assertRaises(self.fetcher.TransientError):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

        self.requests.get.assert_called_once_with('http://foo',
                                                  allow_redirects=True,
                                                  stream=True,
                                                  timeout=60)

    def test_get_error_permanent(self):
        self.requests.get.side_effect = requests.TooManyRedirects

        with self.assertRaises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

        self.assertTrue(e.exception.response_code is None)
        self.requests.get.assert_called_once_with('http://foo',
                                                  allow_redirects=True,
                                                  stream=True,
                                                  timeout=60)

    def test_http_error(self):
        self.requests.get.return_value = self.response
        self.response.status_code = 404

        with self.assertRaises(self.fetcher.PermanentError) as e:
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

            self.assertEqual(404, e.exception.response_code)
            self.requests.get.expect_called_once_with('http://foo',
                                                      allow_redirects=True,
                                                      stream=True,
                                                      timeout=60)
            self.listener.status.expect_called_once_with(404)
            self.response.close.expect_called_once_with()

    def test_iter_content_error(self):
        self.requests.get.return_value = self.response
        self.response.status_code = 200
        self.response.headers = {}
        self.response.iter_content.side_effect = requests.Timeout

        with self.assertRaises(self.fetcher.TransientError):
            self.fetcher.fetch('http://foo',
                               self.listener,
                               chunk_size_bytes=1024,
                               timeout_secs=60)

            self.requests.get.expect_called_once_with('http://foo',
                                                      allow_redirects=True,
                                                      stream=True,
                                                      timeout=60)
            self.response.iter_content.expect_called_once_with(chunk_size=1024)
            self.listener.status.expect_called_once_with(200,
                                                         content_length=None)
            self.response.close.expect_called_once_with()

    def expect_download(self, path_or_fd=None):
        with self.expect_get('http://1',
                             chunk_size_bytes=13,
                             timeout_secs=13,
                             listener=False) as (chunks,
                                                 expected_listener_calls):

            path = self.fetcher.download('http://1',
                                         path_or_fd=path_or_fd,
                                         chunk_size_bytes=13,
                                         timeout_secs=13)

            self.response.close.expect_called_once_with()
            downloaded = self.concat_chunks(chunks)
            return downloaded, path

    def test_download(self):
        downloaded, path = self.expect_download()
        try:
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())
        finally:
            os.unlink(path)

    def test_download_fd(self):
        with temporary_file() as fd:
            downloaded, path = self.expect_download(path_or_fd=fd)
            self.assertEqual(path, fd.name)
            fd.close()
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())

    def test_download_path(self):
        with temporary_file() as fd:
            fd.close()
            downloaded, path = self.expect_download(path_or_fd=fd.name)
            self.assertEqual(path, fd.name)
            with open(path) as fp:
                self.assertEqual(downloaded, fp.read())

    @mock.patch('time.time')
    def test_progress_listener(self, timer):
        timer.side_effect = [0, 1.137]

        stream = StringIO()
        progress_listener = Fetcher.ProgressListener(width=5,
                                                     chunk_size_bytes=1,
                                                     stream=stream)

        with self.expect_get('http://baz',
                             chunk_size_bytes=1,
                             timeout_secs=37,
                             chunks=[[1]] * 1024) as (chunks,
                                                      expected_listener_calls):

            self.fetcher.fetch('http://baz',
                               progress_listener.wrap(self.listener),
                               chunk_size_bytes=1,
                               timeout_secs=37)

        self.assert_listener_calls(expected_listener_calls, chunks)

        # We just test the last progress line which should indicate a 100% complete download.
        # We control progress bar width (5 dots), size (1KB) and total time downloading (fake 1.137s).
        self.assertEqual('100% ..... 1 KB 1.137s\n',
                         stream.getvalue().split('\r')[-1])
Exemplo n.º 5
0
class FetcherTest(unittest.TestCase):
  def setUp(self):
    self.requests = mock.Mock(spec=requests.Session)
    self.response = mock.Mock(spec=requests.Response)
    self.fetcher = Fetcher('/unused/root/dir', requests_api=self.requests)
    self.listener = mock.create_autospec(Fetcher.Listener, spec_set=True)

  def status_call(self, status_code, content_length=None):
    return mock.call.status(status_code, content_length=content_length)

  def ok_call(self, chunks):
    return self.status_call(200, content_length=sum(len(c) for c in chunks))

  def assert_listener_calls(self, expected_listener_calls, chunks, expect_finished=True):
    expected_listener_calls.extend(mock.call.recv_chunk(chunk) for chunk in chunks)
    if expect_finished:
      expected_listener_calls.append(mock.call.finished())
    self.assertEqual(expected_listener_calls, self.listener.method_calls)

  def assert_local_file_fetch(self, url_prefix=''):
    chunks = [b'0123456789', b'a']
    with temporary_file() as fp:
      for chunk in chunks:
        fp.write(chunk)
      fp.close()

      self.fetcher.fetch(url_prefix + fp.name, self.listener, chunk_size_bytes=10)

      self.assert_listener_calls([self.ok_call(chunks)], chunks)
      self.requests.assert_not_called()

  def test_file_path(self):
    self.assert_local_file_fetch()

  def test_file_scheme(self):
    self.assert_local_file_fetch('file:')

  def assert_local_file_fetch_relative(self, url, *rel_path):
    expected_contents = b'proof'
    with temporary_dir() as root_dir:
      with safe_open(os.path.join(root_dir, *rel_path), 'wb') as fp:
        fp.write(expected_contents)
      with temporary_file() as download_fp:
        Fetcher(root_dir).download(url, path_or_fd=download_fp)
        download_fp.close()
        with open(download_fp.name, 'rb') as fp:
          self.assertEqual(expected_contents, fp.read())

  def test_file_scheme_double_slash_relative(self):
    self.assert_local_file_fetch_relative('file://relative/path', 'relative', 'path')

  def test_file_scheme_embedded_double_slash(self):
    self.assert_local_file_fetch_relative('file://a//strange//path', 'a', 'strange', 'path')

  def test_file_scheme_triple_slash(self):
    self.assert_local_file_fetch('file://')

  def test_file_dne(self):
    with temporary_dir() as base:
      with self.assertRaises(self.fetcher.PermanentError):
        self.fetcher.fetch(os.path.join(base, 'dne'), self.listener)

  def test_file_no_perms(self):
    with temporary_dir() as base:
      no_perms = os.path.join(base, 'dne')
      touch(no_perms)
      os.chmod(no_perms, 0)
      self.assertTrue(os.path.exists(no_perms))
      with self.assertRaises(self.fetcher.PermanentError):
        self.fetcher.fetch(no_perms, self.listener)

  @contextmanager
  def expect_get(self, url, chunk_size_bytes, timeout_secs, chunks=None, listener=True):
    chunks = chunks or [b'0123456789', b'a']
    size = sum(len(c) for c in chunks)

    self.requests.get.return_value = self.response
    self.response.status_code = 200
    self.response.headers = {'content-length': str(size)}
    self.response.iter_content.return_value = chunks

    yield chunks, [self.ok_call(chunks)] if listener else []

    self.requests.get.expect_called_once_with(url, allow_redirects=True, stream=True,
                                              timeout=timeout_secs)
    self.response.iter_content.expect_called_once_with(chunk_size=chunk_size_bytes)

  def test_get(self):
    with self.expect_get('http://bar',
                         chunk_size_bytes=1024,
                         timeout_secs=60) as (chunks, expected_listener_calls):

      self.fetcher.fetch('http://bar',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

      self.assert_listener_calls(expected_listener_calls, chunks)
      self.response.close.expect_called_once_with()

  def test_checksum_listener(self):
    digest = mock.Mock(spec=hashlib.md5())
    digest.hexdigest.return_value = '42'
    checksum_listener = Fetcher.ChecksumListener(digest=digest)

    with self.expect_get('http://baz',
                         chunk_size_bytes=1,
                         timeout_secs=37) as (chunks, expected_listener_calls):

      self.fetcher.fetch('http://baz',
                         checksum_listener.wrap(self.listener),
                         chunk_size_bytes=1,
                         timeout_secs=37)

    self.assertEqual('42', checksum_listener.checksum)

    def expected_digest_calls():
      for chunk in chunks:
        yield mock.call.update(chunk)
      yield mock.call.hexdigest()

    self.assertEqual(list(expected_digest_calls()), digest.method_calls)

    self.assert_listener_calls(expected_listener_calls, chunks)
    self.response.close.assert_called_once_with()

  def concat_chunks(self, chunks):
    return reduce(lambda acc, c: acc + c, chunks, b'')

  def test_download_listener(self):
    with self.expect_get('http://foo',
                         chunk_size_bytes=1048576,
                         timeout_secs=3600) as (chunks, expected_listener_calls):

      with closing(BytesIO()) as fp:
        self.fetcher.fetch('http://foo',
                           Fetcher.DownloadListener(fp).wrap(self.listener),
                           chunk_size_bytes=1024 * 1024,
                           timeout_secs=60 * 60)

        downloaded = self.concat_chunks(chunks)
        self.assertEqual(downloaded, fp.getvalue())

    self.assert_listener_calls(expected_listener_calls, chunks)
    self.response.close.assert_called_once_with()

  def test_size_mismatch(self):
    self.requests.get.return_value = self.response
    self.response.status_code = 200
    self.response.headers = {'content-length': '11'}
    chunks = ['a', 'b']
    self.response.iter_content.return_value = chunks

    with self.assertRaises(self.fetcher.Error):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

    self.requests.get.assert_called_once_with('http://foo', allow_redirects=True, stream=True,
                                              timeout=60)
    self.response.iter_content.assert_called_once_with(chunk_size=1024)
    self.assert_listener_calls([self.status_call(200, content_length=11)], chunks,
                               expect_finished=False)
    self.response.close.assert_called_once_with()

  def test_get_error_transient(self):
    self.requests.get.side_effect = requests.ConnectionError

    with self.assertRaises(self.fetcher.TransientError):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

    self.requests.get.assert_called_once_with('http://foo', allow_redirects=True, stream=True,
                                              timeout=60)

  def test_get_error_permanent(self):
    self.requests.get.side_effect = requests.TooManyRedirects

    with self.assertRaises(self.fetcher.PermanentError) as e:
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

    self.assertTrue(e.exception.response_code is None)
    self.requests.get.assert_called_once_with('http://foo', allow_redirects=True, stream=True,
                                              timeout=60)

  def test_http_error(self):
    self.requests.get.return_value = self.response
    self.response.status_code = 404

    with self.assertRaises(self.fetcher.PermanentError) as e:
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

      self.assertEqual(404, e.exception.response_code)
      self.requests.get.expect_called_once_with('http://foo', allow_redirects=True, stream=True,
                                                timeout=60)
      self.listener.status.expect_called_once_with(404)
      self.response.close.expect_called_once_with()

  def test_iter_content_error(self):
    self.requests.get.return_value = self.response
    self.response.status_code = 200
    self.response.headers = {}
    self.response.iter_content.side_effect = requests.Timeout

    with self.assertRaises(self.fetcher.TransientError):
      self.fetcher.fetch('http://foo',
                         self.listener,
                         chunk_size_bytes=1024,
                         timeout_secs=60)

      self.requests.get.expect_called_once_with('http://foo', allow_redirects=True, stream=True,
                                                timeout=60)
      self.response.iter_content.expect_called_once_with(chunk_size=1024)
      self.listener.status.expect_called_once_with(200, content_length=None)
      self.response.close.expect_called_once_with()

  def expect_download(self, path_or_fd=None):
    with self.expect_get('http://1',
                         chunk_size_bytes=13,
                         timeout_secs=13,
                         listener=False) as (chunks, expected_listener_calls):

      path = self.fetcher.download('http://1',
                                   path_or_fd=path_or_fd,
                                   chunk_size_bytes=13,
                                   timeout_secs=13)

      self.response.close.expect_called_once_with()
      downloaded = self.concat_chunks(chunks)
      return downloaded, path

  def test_download(self):
    downloaded, path = self.expect_download()
    try:
      with open(path, 'rb') as fp:
        self.assertEqual(downloaded, fp.read())
    finally:
      os.unlink(path)

  def test_download_fd(self):
    with temporary_file() as fd:
      downloaded, path = self.expect_download(path_or_fd=fd)
      self.assertEqual(path, fd.name)
      fd.close()
      with open(path, 'rb') as fp:
        self.assertEqual(downloaded, fp.read())

  def test_download_path(self):
    with temporary_file() as fd:
      fd.close()
      downloaded, path = self.expect_download(path_or_fd=fd.name)
      self.assertEqual(path, fd.name)
      with open(path, 'rb') as fp:
        self.assertEqual(downloaded, fp.read())

  @mock.patch('time.time')
  def test_progress_listener(self, timer):
    timer.side_effect = [0, 1.137]

    stream = BytesIO()
    progress_listener = Fetcher.ProgressListener(width=5, chunk_size_bytes=1, stream=stream)

    with self.expect_get('http://baz',
                         chunk_size_bytes=1,
                         timeout_secs=37,
                         chunks=[[1]] * 1024) as (chunks, expected_listener_calls):

      self.fetcher.fetch('http://baz',
                         progress_listener.wrap(self.listener),
                         chunk_size_bytes=1,
                         timeout_secs=37)

    self.assert_listener_calls(expected_listener_calls, chunks)

    # We just test the last progress line which should indicate a 100% complete download.
    # We control progress bar width (5 dots), size (1KB) and total time downloading (fake 1.137s).
    self.assertEqual('100% ..... 1 KB 1.137s\n', stream.getvalue().decode('utf-8').split('\r')[-1])