def run(self) -> dict: s3_blobstore = S3BlobStore.from_environment() state = self.get_state_copy() will_cache = should_cache_file(event[Key.CONTENT_TYPE], event[_Key.SIZE]) if not will_cache: logger.info("Not caching %s with content-type %s size %s", self.source_key, event[Key.CONTENT_TYPE], event[_Key.SIZE]) if _Key.NEXT_PART not in state or _Key.LAST_PART not in state: # missing the next/last part data. calculate that from the branch id information. parts_per_branch = ( (self.part_count + LAMBDA_PARALLELIZATION_FACTOR - 1) // LAMBDA_PARALLELIZATION_FACTOR) state[_Key.NEXT_PART] = self.slice_num * parts_per_branch + 1 state[_Key.LAST_PART] = min( state[_Key.PART_COUNT], state[_Key.NEXT_PART] + parts_per_branch - 1) self.save_state(state) if state[_Key.NEXT_PART] > state[_Key.LAST_PART]: state[Key.FINISHED] = True _determine_cache_tagging(will_cache, self.destination_bucket, self.destination_key) return state queue = collections.deque( s3_blobstore.find_next_missing_parts( self.destination_bucket, self.destination_key, self.upload_id, self.part_count, state[_Key.NEXT_PART], state[_Key.LAST_PART] - state[_Key.NEXT_PART] + 1)) if len(queue) == 0: state[Key.FINISHED] = True _determine_cache_tagging(will_cache, self.destination_bucket, self.destination_key) return state class ProgressReporter(parallel_worker.Reporter): def report_progress(inner_self, first_incomplete: int): state[_Key.NEXT_PART] = first_incomplete self.save_state(state) class CopyPartTask(parallel_worker.Task): def run(inner_self, subtask_id: int) -> None: self.copy_one_part(subtask_id) runner = parallel_worker.Runner(CONCURRENT_REQUESTS, CopyPartTask, queue, ProgressReporter()) results = runner.run() assert all(results) state[Key.FINISHED] = True _determine_cache_tagging(will_cache, self.destination_bucket, self.destination_key) return state
def test_sequential_complete(self, subtasks=5): """Sequentially complete tasks.""" incomplete = list(range(subtasks)) reporter = RecordingReporter() runner = parallel_worker.Runner(8, LatchedTask, incomplete, reporter) task: LatchedTask = runner._task with ConcurrentContext(runner.run) as context: # after each mark_can_run, sleep for a teeny bit to ensure that the reporting has completed. for ix in range(subtasks): task.mark_can_run(ix) time.sleep(0.1) self.assertEqual(reporter.progress_reports[-1], ix + 1) results = list(context.result()) self.assertEqual(len(results), subtasks) self.assertTrue(all(results))
def test_random_complete(self): """Complete tasks in some random order.""" incomplete = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # noqa: E221 sequence = [5, 4, 0, 8, 1, 9, 2, 3, 7, 6] # noqa: E221 next_incomplete = [0, 0, 1, 1, 2, 2, 3, 6, 6, 10] # noqa: E221 reporter = RecordingReporter() runner = parallel_worker.Runner(8, LatchedTask, incomplete, reporter) task: LatchedTask = runner._task with ConcurrentContext(runner.run) as context: for can_run, expected_next_incomplete in zip( sequence, next_incomplete): task.mark_can_run(can_run) time.sleep(0.1) self.assertEqual(reporter.progress_reports[-1], expected_next_incomplete) results = list(context.result()) self.assertEqual(len(results), 10) self.assertTrue(all(results))
def test_empty_tasklist(self): """Test the case where the tasklist is empty.""" reporter = RecordingReporter() runner = parallel_worker.Runner(8, LatchedTask, [], reporter) results = runner.run() self.assertEqual(list(results), [])