Esempio n. 1
0
class TestGetObjectWorker(StubbedClientTest):
    def setUp(self):
        super(TestGetObjectWorker, self).setUp()
        self.files = FileCreator()
        self.queue = queue.Queue()
        self.client_factory = mock.Mock(ClientFactory)
        self.client_factory.create_client.return_value = self.client
        self.transfer_monitor = TransferMonitor()
        self.osutil = OSUtils()
        self.worker = GetObjectWorker(queue=self.queue,
                                      client_factory=self.client_factory,
                                      transfer_monitor=self.transfer_monitor,
                                      osutil=self.osutil)
        self.transfer_id = self.transfer_monitor.notify_new_transfer()
        self.bucket = 'bucket'
        self.key = 'key'
        self.remote_contents = b'my content'
        self.temp_filename = self.files.create_file('tempfile', '')
        self.extra_args = {}
        self.offset = 0
        self.final_filename = self.files.full_path('final_filename')
        self.stream = six.BytesIO(self.remote_contents)
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1000)

    def tearDown(self):
        super(TestGetObjectWorker, self).tearDown()
        self.files.remove_all()

    def add_get_object_job(self, **override_kwargs):
        kwargs = {
            'transfer_id': self.transfer_id,
            'bucket': self.bucket,
            'key': self.key,
            'temp_filename': self.temp_filename,
            'extra_args': self.extra_args,
            'offset': self.offset,
            'filename': self.final_filename
        }
        kwargs.update(override_kwargs)
        self.queue.put(GetObjectJob(**kwargs))

    def add_shutdown(self):
        self.queue.put(SHUTDOWN_SIGNAL)

    def add_stubbed_get_object_response(self, body=None, expected_params=None):
        if body is None:
            body = self.stream
        get_object_response = {'Body': body}

        if expected_params is None:
            expected_params = {'Bucket': self.bucket, 'Key': self.key}

        self.stubber.add_response('get_object', get_object_response,
                                  expected_params)

    def assert_contents(self, filename, contents):
        self.assertTrue(os.path.exists(filename))
        with open(filename, 'rb') as f:
            self.assertEqual(f.read(), contents)

    def assert_does_not_exist(self, filename):
        self.assertFalse(os.path.exists(filename))

    def test_run_is_final_job(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.add_stubbed_get_object_response()
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1)

        self.worker.run()
        self.stubber.assert_no_pending_responses()
        self.assert_does_not_exist(self.temp_filename)
        self.assert_contents(self.final_filename, self.remote_contents)

    def test_run_jobs_is_not_final_job(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.add_stubbed_get_object_response()
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1000)

        self.worker.run()
        self.stubber.assert_no_pending_responses()
        self.assert_contents(self.temp_filename, self.remote_contents)
        self.assert_does_not_exist(self.final_filename)

    def test_run_with_extra_args(self):
        self.add_get_object_job(extra_args={'VersionId': 'versionid'})
        self.add_shutdown()
        self.add_stubbed_get_object_response(expected_params={
            'Bucket': self.bucket,
            'Key': self.key,
            'VersionId': 'versionid'
        })

        self.worker.run()
        self.stubber.assert_no_pending_responses()

    def test_run_with_offset(self):
        offset = 1
        self.add_get_object_job(offset=offset)
        self.add_shutdown()
        self.add_stubbed_get_object_response()

        self.worker.run()
        with open(self.temp_filename, 'rb') as f:
            f.seek(offset)
            self.assertEqual(f.read(), self.remote_contents)

    def test_run_error_in_get_object(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.stubber.add_client_error('get_object', 'NoSuchKey', 404)
        self.add_stubbed_get_object_response()

        self.worker.run()
        self.assertIsInstance(
            self.transfer_monitor.get_exception(self.transfer_id), ClientError)

    def test_run_does_retries_for_get_object(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.add_stubbed_get_object_response(
            body=StreamWithError(self.stream, ReadTimeoutError(
                endpoint_url='')))
        self.add_stubbed_get_object_response()

        self.worker.run()
        self.stubber.assert_no_pending_responses()
        self.assert_contents(self.temp_filename, self.remote_contents)

    def test_run_can_exhaust_retries_for_get_object(self):
        self.add_get_object_job()
        self.add_shutdown()
        # 5 is the current setting for max number of GetObject attempts
        for _ in range(5):
            self.add_stubbed_get_object_response(body=StreamWithError(
                self.stream, ReadTimeoutError(endpoint_url='')))

        self.worker.run()
        self.stubber.assert_no_pending_responses()
        self.assertIsInstance(
            self.transfer_monitor.get_exception(self.transfer_id),
            RetriesExceededError)

    def test_run_skips_get_object_on_previous_exception(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.transfer_monitor.notify_exception(self.transfer_id, Exception())

        self.worker.run()
        # Note we did not add a stubbed response for get_object
        self.stubber.assert_no_pending_responses()

    def test_run_final_job_removes_file_on_previous_exception(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.transfer_monitor.notify_exception(self.transfer_id, Exception())
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1)

        self.worker.run()
        self.stubber.assert_no_pending_responses()
        self.assert_does_not_exist(self.temp_filename)
        self.assert_does_not_exist(self.final_filename)

    def test_run_fails_to_rename_file(self):
        exception = OSError()
        osutil = RenameFailingOSUtils(exception)
        self.worker = GetObjectWorker(queue=self.queue,
                                      client_factory=self.client_factory,
                                      transfer_monitor=self.transfer_monitor,
                                      osutil=osutil)
        self.add_get_object_job()
        self.add_shutdown()
        self.add_stubbed_get_object_response()
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1)

        self.worker.run()
        self.assertEqual(self.transfer_monitor.get_exception(self.transfer_id),
                         exception)
        self.assert_does_not_exist(self.temp_filename)
        self.assert_does_not_exist(self.final_filename)

    @skip_if_windows('os.kill() with SIGINT not supported on Windows')
    def test_worker_cannot_be_killed(self):
        self.add_get_object_job()
        self.add_shutdown()
        self.transfer_monitor.notify_expected_jobs_to_complete(
            self.transfer_id, 1)

        def raise_ctrl_c(**kwargs):
            os.kill(os.getpid(), signal.SIGINT)

        mock_client = mock.Mock()
        mock_client.get_object = raise_ctrl_c
        self.client_factory.create_client.return_value = mock_client

        try:
            self.worker.run()
        except KeyboardInterrupt:
            self.fail('The worker should have not been killed by the '
                      'KeyboardInterrupt')
class TestProcessPoolDownloader(unittest.TestCase):
    def setUp(self):
        # The stubbed client needs to run in a manager to be shared across
        # processes and have it properly consume the stubbed response across
        # processes.
        self.manager = StubbedClientManager()
        self.manager.start()
        self.stubbed_client = self.manager.StubbedClient()
        self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)

        self.client_factory_patch = mock.patch(
            'ibm_s3transfer.processpool.ClientFactory',
            self.stubbed_client_factory)
        self.client_factory_patch.start()
        self.files = FileCreator()

        self.config = ProcessTransferConfig(max_request_processes=1)
        self.downloader = ProcessPoolDownloader(config=self.config)
        self.bucket = 'mybucket'
        self.key = 'mykey'
        self.filename = self.files.full_path('filename')
        self.remote_contents = b'my content'
        self.stream = six.BytesIO(self.remote_contents)

    def tearDown(self):
        self.manager.shutdown()
        self.client_factory_patch.stop()
        self.files.remove_all()

    def assert_contents(self, filename, expected_contents):
        self.assertTrue(os.path.exists(filename))
        with open(filename, 'rb') as f:
            self.assertEqual(f.read(), expected_contents)

    def test_download_file(self):
        self.stubbed_client.add_response(
            'head_object', {'ContentLength': len(self.remote_contents)})
        self.stubbed_client.add_response('get_object', {'Body': self.stream})
        with self.downloader:
            self.downloader.download_file(self.bucket, self.key, self.filename)
        self.assert_contents(self.filename, self.remote_contents)

    def test_download_multiple_files(self):
        self.stubbed_client.add_response('get_object', {'Body': self.stream})
        self.stubbed_client.add_response(
            'get_object', {'Body': six.BytesIO(self.remote_contents)})
        with self.downloader:
            self.downloader.download_file(self.bucket,
                                          self.key,
                                          self.filename,
                                          expected_size=len(
                                              self.remote_contents))
            other_file = self.files.full_path('filename2')
            self.downloader.download_file(self.bucket,
                                          self.key,
                                          other_file,
                                          expected_size=len(
                                              self.remote_contents))
        self.assert_contents(self.filename, self.remote_contents)
        self.assert_contents(other_file, self.remote_contents)

    def test_download_file_ranged_download(self):
        half_of_content_length = int(len(self.remote_contents) / 2)
        self.stubbed_client.add_response(
            'head_object', {'ContentLength': len(self.remote_contents)})
        self.stubbed_client.add_response('get_object', {
            'Body':
            six.BytesIO(self.remote_contents[:half_of_content_length])
        })
        self.stubbed_client.add_response('get_object', {
            'Body':
            six.BytesIO(self.remote_contents[half_of_content_length:])
        })
        downloader = ProcessPoolDownloader(config=ProcessTransferConfig(
            multipart_chunksize=half_of_content_length,
            multipart_threshold=half_of_content_length,
            max_request_processes=1))
        with downloader:
            downloader.download_file(self.bucket, self.key, self.filename)
        self.assert_contents(self.filename, self.remote_contents)

    def test_download_file_extra_args(self):
        self.stubbed_client.add_response(
            'head_object', {'ContentLength': len(self.remote_contents)},
            expected_params={
                'Bucket': self.bucket,
                'Key': self.key,
                'VersionId': 'versionid'
            })
        self.stubbed_client.add_response('get_object', {'Body': self.stream},
                                         expected_params={
                                             'Bucket': self.bucket,
                                             'Key': self.key,
                                             'VersionId': 'versionid'
                                         })
        with self.downloader:
            self.downloader.download_file(
                self.bucket,
                self.key,
                self.filename,
                extra_args={'VersionId': 'versionid'})
        self.assert_contents(self.filename, self.remote_contents)

    def test_download_file_expected_size(self):
        self.stubbed_client.add_response('get_object', {'Body': self.stream})
        with self.downloader:
            self.downloader.download_file(self.bucket,
                                          self.key,
                                          self.filename,
                                          expected_size=len(
                                              self.remote_contents))
        self.assert_contents(self.filename, self.remote_contents)

    def test_cleans_up_tempfile_on_failure(self):
        self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
        with self.downloader:
            self.downloader.download_file(self.bucket,
                                          self.key,
                                          self.filename,
                                          expected_size=len(
                                              self.remote_contents))
        self.assertFalse(os.path.exists(self.filename))
        # Any tempfile should have been erased as well
        possible_matches = glob.glob('%s*' % self.filename + os.extsep)
        self.assertEqual(possible_matches, [])

    def test_validates_extra_args(self):
        with self.downloader:
            with self.assertRaises(ValueError):
                self.downloader.download_file(
                    self.bucket,
                    self.key,
                    self.filename,
                    extra_args={'NotSupported': 'NotSupported'})

    def test_result_with_success(self):
        self.stubbed_client.add_response('get_object', {'Body': self.stream})
        with self.downloader:
            future = self.downloader.download_file(self.bucket,
                                                   self.key,
                                                   self.filename,
                                                   expected_size=len(
                                                       self.remote_contents))
            self.assertIsNone(future.result())

    def test_result_with_exception(self):
        self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
        with self.downloader:
            future = self.downloader.download_file(self.bucket,
                                                   self.key,
                                                   self.filename,
                                                   expected_size=len(
                                                       self.remote_contents))
            with self.assertRaises(ibm_botocore.exceptions.ClientError):
                future.result()

    def test_result_with_cancel(self):
        self.stubbed_client.add_response('get_object', {'Body': self.stream})
        with self.downloader:
            future = self.downloader.download_file(self.bucket,
                                                   self.key,
                                                   self.filename,
                                                   expected_size=len(
                                                       self.remote_contents))
            future.cancel()
            with self.assertRaises(CancelledError):
                future.result()

    def test_shutdown_with_no_downloads(self):
        downloader = ProcessPoolDownloader()
        try:
            downloader.shutdown()
        except AttributeError:
            self.fail(
                'The downloader should be able to be shutdown even though '
                'the downloader was never started.')

    def test_shutdown_with_no_downloads_and_ctrl_c(self):
        # Special shutdown logic happens if a KeyboardInterrupt is raised in
        # the context manager. However, this logic can not happen if the
        # downloader was never started. So a KeyboardInterrupt should be
        # the only exception propagated.
        with self.assertRaises(KeyboardInterrupt):
            with self.downloader:
                raise KeyboardInterrupt()