예제 #1
0
        def run(self) -> dict:
            s3_blobstore = S3BlobStore.from_environment()
            state = self.get_state_copy()
            will_cache = should_cache_file(event[Key.CONTENT_TYPE],
                                           event[_Key.SIZE])
            if not will_cache:
                logger.info("Not caching %s with content-type %s size %s",
                            self.source_key, event[Key.CONTENT_TYPE],
                            event[_Key.SIZE])
            if _Key.NEXT_PART not in state or _Key.LAST_PART not in state:
                # missing the next/last part data.  calculate that from the branch id information.
                parts_per_branch = (
                    (self.part_count + LAMBDA_PARALLELIZATION_FACTOR - 1) //
                    LAMBDA_PARALLELIZATION_FACTOR)
                state[_Key.NEXT_PART] = self.slice_num * parts_per_branch + 1
                state[_Key.LAST_PART] = min(
                    state[_Key.PART_COUNT],
                    state[_Key.NEXT_PART] + parts_per_branch - 1)
                self.save_state(state)

            if state[_Key.NEXT_PART] > state[_Key.LAST_PART]:
                state[Key.FINISHED] = True
                _determine_cache_tagging(will_cache, self.destination_bucket,
                                         self.destination_key)
                return state

            queue = collections.deque(
                s3_blobstore.find_next_missing_parts(
                    self.destination_bucket, self.destination_key,
                    self.upload_id, self.part_count, state[_Key.NEXT_PART],
                    state[_Key.LAST_PART] - state[_Key.NEXT_PART] + 1))

            if len(queue) == 0:
                state[Key.FINISHED] = True
                _determine_cache_tagging(will_cache, self.destination_bucket,
                                         self.destination_key)
                return state

            class ProgressReporter(parallel_worker.Reporter):
                def report_progress(inner_self, first_incomplete: int):
                    state[_Key.NEXT_PART] = first_incomplete
                    self.save_state(state)

            class CopyPartTask(parallel_worker.Task):
                def run(inner_self, subtask_id: int) -> None:
                    self.copy_one_part(subtask_id)

            runner = parallel_worker.Runner(CONCURRENT_REQUESTS, CopyPartTask,
                                            queue, ProgressReporter())
            results = runner.run()
            assert all(results)

            state[Key.FINISHED] = True
            _determine_cache_tagging(will_cache, self.destination_bucket,
                                     self.destination_key)
            return state
    def test_sequential_complete(self, subtasks=5):
        """Sequentially complete tasks."""
        incomplete = list(range(subtasks))
        reporter = RecordingReporter()
        runner = parallel_worker.Runner(8, LatchedTask, incomplete, reporter)
        task: LatchedTask = runner._task
        with ConcurrentContext(runner.run) as context:
            # after each mark_can_run, sleep for a teeny bit to ensure that the reporting has completed.
            for ix in range(subtasks):
                task.mark_can_run(ix)
                time.sleep(0.1)
                self.assertEqual(reporter.progress_reports[-1], ix + 1)

            results = list(context.result())
            self.assertEqual(len(results), subtasks)
            self.assertTrue(all(results))
    def test_random_complete(self):
        """Complete tasks in some random order."""
        incomplete = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  # noqa: E221
        sequence = [5, 4, 0, 8, 1, 9, 2, 3, 7, 6]  # noqa: E221
        next_incomplete = [0, 0, 1, 1, 2, 2, 3, 6, 6, 10]  # noqa: E221
        reporter = RecordingReporter()
        runner = parallel_worker.Runner(8, LatchedTask, incomplete, reporter)
        task: LatchedTask = runner._task
        with ConcurrentContext(runner.run) as context:
            for can_run, expected_next_incomplete in zip(
                    sequence, next_incomplete):
                task.mark_can_run(can_run)
                time.sleep(0.1)
                self.assertEqual(reporter.progress_reports[-1],
                                 expected_next_incomplete)

            results = list(context.result())
            self.assertEqual(len(results), 10)
            self.assertTrue(all(results))
 def test_empty_tasklist(self):
     """Test the case where the tasklist is empty."""
     reporter = RecordingReporter()
     runner = parallel_worker.Runner(8, LatchedTask, [], reporter)
     results = runner.run()
     self.assertEqual(list(results), [])