def test_concat_source_to_shuffle_sink(self): work = workitem.get_work_items( get_concat_source_to_shuffle_sink_message()) self.assertIsNotNone(work) expected_sub_sources = [] expected_sub_sources.append( io.TextFileSource(file_path='gs://sort_g/input_small_files/' 'ascii_sort_1MB_input.0000006', start_offset=0, end_offset=1000000, strip_trailing_newlines=True, coder=CODER)) expected_sub_sources.append( io.TextFileSource(file_path='gs://sort_g/input_small_files/' 'ascii_sort_1MB_input.0000007', start_offset=0, end_offset=1000000, strip_trailing_newlines=True, coder=CODER)) expected_concat_source = concat_reader.ConcatSource( expected_sub_sources) self.assertEqual((work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead(expected_concat_source, output_coders=[CODER]), maptask.WorkerDoFn(serialized_fn='code', output_tags=['out'], input=(1, 0), side_inputs=[], output_coders=[CODER]), maptask.WorkerShuffleWrite(shuffle_kind='group_keys', shuffle_writer_config='opaque', input=(1, 0), output_coders=(CODER, )) ]))
def test_concat_source_to_shuffle_sink(self): work = workitem.get_work_items(get_concat_source_to_shuffle_sink_message()) self.assertIsNotNone(work) expected_sub_sources = [] expected_sub_sources.append( io.TextFileSource( file_path='gs://sort_g/input_small_files/' 'ascii_sort_1MB_input.0000006', start_offset=0, end_offset=1000000, strip_trailing_newlines=True, coder=CODER)) expected_sub_sources.append( io.TextFileSource( file_path='gs://sort_g/input_small_files/' 'ascii_sort_1MB_input.0000007', start_offset=0, end_offset=1000000, strip_trailing_newlines=True, coder=CODER)) expected_concat_source = concat_reader.ConcatSource(expected_sub_sources) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead( expected_concat_source, tag=None), maptask.WorkerDoFn( serialized_fn='code', output_tags=['out'], input=(1, 0), side_inputs=[]), maptask.WorkerShuffleWrite( shuffle_kind='group_keys', shuffle_writer_config='opaque', input=(1, 0), coders=(CODER.key_coder(), CODER.value_coder()))]))
def build_split_work_item(self, split_proto): lease_work_item_response_proto = dataflow.LeaseWorkItemResponse() work_item_proto = dataflow.WorkItem() lease_work_item_response_proto.workItems = [work_item_proto] source_operation_task = dataflow.SourceOperationRequest() work_item_proto.sourceOperationTask = source_operation_task source_operation_task.split = split_proto return workitem.get_work_items(lease_work_item_response_proto)
def test_in_memory_source_to_flatten(self): work = workitem.get_work_items( get_in_memory_source_to_flatten_message()) self.assertEqual((work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead(inmemory.InMemorySource( start_index=1, end_index=3, elements=[ base64.b64decode(v['value']) for v in IN_MEMORY_ELEMENTS ], coder=CODER), output_coders=[CODER]), maptask.WorkerFlatten(inputs=[(0, 0)], output_coders=[CODER]) ]))
def test_in_memory_source_to_flatten(self): work = workitem.get_work_items(get_in_memory_source_to_flatten_message()) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead( inmemory.InMemorySource( start_index=1, end_index=3, elements=[base64.b64decode(v['value']) for v in IN_MEMORY_ELEMENTS], coder=CODER), tag=None), maptask.WorkerFlatten( inputs=[(0, 0)])]))
def test_ungrouped_shuffle_source_to_text_sink(self): work = workitem.get_work_items( get_shuffle_source_to_text_sink_message(UNGROUPED_SHUFFLE_SOURCE_SPEC)) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerUngroupedShuffleRead( start_shuffle_position='opaque', end_shuffle_position='opaque', shuffle_reader_config='opaque', coders=(CODER.key_coder(), CODER.value_coder())), maptask.WorkerWrite(io.TextFileSink( file_path_prefix='gs://somefile', append_trailing_newlines=True, coder=CODER), input=(0, 0))]))
def test_shuffle_source_to_text_sink(self): work = workitem.get_work_items( get_shuffle_source_to_text_sink_message(GROUPING_SHUFFLE_SOURCE_SPEC)) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerGroupingShuffleRead( start_shuffle_position='opaque', end_shuffle_position='opaque', shuffle_reader_config='opaque', coder=CODER, output_coders=[CODER]), maptask.WorkerWrite(fileio.NativeTextFileSink( file_path_prefix='gs://somefile', append_trailing_newlines=True, coder=CODER), input=(0, 0), output_coders=(CODER,))]))
def test_in_memory_source_to_text_sink(self): work = workitem.get_work_items(get_in_memory_source_to_text_sink_message()) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead( inmemory.InMemorySource( start_index=1, end_index=3, elements=[base64.b64decode(v['value']) for v in IN_MEMORY_ELEMENTS], coder=CODER), tag=None), maptask.WorkerWrite(io.TextFileSink( file_path_prefix='gs://somefile', append_trailing_newlines=True, coder=CODER), input=(0, 0))]))
def test_ungrouped_shuffle_source_to_text_sink(self): work = workitem.get_work_items( get_shuffle_source_to_text_sink_message( UNGROUPED_SHUFFLE_SOURCE_SPEC)) self.assertEqual((work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerUngroupedShuffleRead(start_shuffle_position='opaque', end_shuffle_position='opaque', shuffle_reader_config='opaque', coder=CODER, output_coders=[CODER]), maptask.WorkerWrite(fileio.NativeTextFileSink( file_path_prefix='gs://somefile', append_trailing_newlines=True, coder=CODER), input=(0, 0), output_coders=(CODER, )) ]))
def test_text_source_to_shuffle_sink(self): work = workitem.get_work_items(get_text_source_to_shuffle_sink_message()) self.assertEqual( (work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead(io.TextFileSource( file_path='gs://somefile', start_offset=123, end_offset=123123, strip_trailing_newlines=True, coder=CODER), tag=None), maptask.WorkerDoFn( serialized_fn='code', output_tags=['out'], input=(1, 0), side_inputs=[]), maptask.WorkerShuffleWrite( shuffle_kind='group_keys', shuffle_writer_config='opaque', input=(1, 0), coders=(CODER.key_coder(), CODER.value_coder()))]))
def test_in_memory_source_to_text_sink(self): work = workitem.get_work_items( get_in_memory_source_to_text_sink_message()) self.assertEqual((work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead(inmemory.InMemorySource( start_index=1, end_index=3, elements=[ base64.b64decode(v['value']) for v in IN_MEMORY_ELEMENTS ], coder=CODER), output_coders=[CODER]), maptask.WorkerWrite(fileio.NativeTextFileSink( file_path_prefix='gs://somefile', append_trailing_newlines=True, coder=CODER), input=(0, 0), output_coders=(CODER, )) ]))
def test_text_source_to_shuffle_sink(self): work = workitem.get_work_items( get_text_source_to_shuffle_sink_message()) self.assertEqual((work.proto.id, work.map_task.operations), (1234, [ maptask.WorkerRead(io.TextFileSource(file_path='gs://somefile', start_offset=123, end_offset=123123, strip_trailing_newlines=True, coder=CODER), output_coders=[CODER]), maptask.WorkerDoFn(serialized_fn='code', output_tags=['out'], input=(1, 0), side_inputs=[], output_coders=[CODER]), maptask.WorkerShuffleWrite(shuffle_kind='group_keys', shuffle_writer_config='opaque', input=(1, 0), output_coders=(CODER, )) ]))
def run(self): """Runs the worker loop for leasing and executing work items.""" if self.running_in_gce: auth.set_running_in_gce(worker_executing_project=self.project_id) # Deferred exceptions are used as a way to report unrecoverable errors that # happen before they could be reported to the service. If it is not None, # worker will use the first work item to report deferred exceptions and # fail eventually. deferred_exception_details = None logging.info('Loading main session from the staging area...') try: self._load_main_session(self.local_staging_directory) except Exception: # pylint: disable=broad-except deferred_exception_details = traceback.format_exc() logging.error('Could not load main session: %s', deferred_exception_details, exc_info=True) # Start status HTTP server thread. thread = threading.Thread(target=self.status_server) thread.daemon = True thread.start() # Start the progress reporting thread. thread = threading.Thread(target=self.progress_reporting_thread) thread.daemon = True thread.start() # The batch execution context is currently a placeholder, so we don't yet # need to have it change between work items. execution_context = maptask.BatchExecutionContext() work_item = None # Loop forever leasing work items, executing them, and reporting status. while True: # TODO(silviuc): Do we still need the outer try/except? try: # Lease a work item. The lease_work call will retry for server errors # (e.g., 500s) however it will not retry for a 404 (no item to lease). # In such cases we introduce random sleep delays with the code below. should_sleep = False try: work = self.client.lease_work(self) work_item = workitem.get_work_items(work, self.environment, execution_context) if work_item is None: should_sleep = True except HttpError as exn: # Not found errors (404) are benign. The rest are not and must be # re-raised. if exn.status_code != 404: raise should_sleep = True if should_sleep: logging.debug('No work items. Sleeping a bit ...') # The sleeping is done with a bit of jitter to avoid having workers # requesting leases in lock step. time.sleep(1.0 * (1 - 0.5 * random.random())) continue with logger.PerThreadLoggingContext( work_item_id=work_item.proto.id, stage_name=work_item.map_task.stage_name): # TODO(silviuc): Add more detailed timing and profiling support. start_time = time.time() if deferred_exception_details: # Report (fatal) deferred exceptions that happened earlier. This # workflow will fail with the deferred exception. with work_item.lock: self.set_current_work_item_and_executor( work_item, executor.MapTaskExecutor()) work_item.map_task.executed_operations = [] self.report_completion_status( work_item, exception_details=deferred_exception_details) work_item.done = True else: # Do the work. The do_work() call will mark the work completed or # failed. The progress reporting_thread will take care of sending # updates and updating in the workitem object the reporting indexes # and duration for the lease. self.do_work(work_item) logging.info('Completed work item: %s in %.9f seconds', work_item.proto.id, time.time() - start_time) except Exception: # pylint: disable=broad-except # This is an exception raised outside of executing a work item most # likely while leasing a work item. We log an error and march on. logging.error('Exception in worker loop: %s', traceback.format_exc(), exc_info=True)
def run(self): """Runs the worker loop for leasing and executing work items.""" if self.running_in_gce: auth.set_running_in_gce(worker_executing_project=self.project_id) # Deferred exceptions are used as a way to report unrecoverable errors that # happen before they could be reported to the service. If it is not None, # worker will use the first work item to report deferred exceptions and # fail eventually. # TODO(silviuc): Add the deferred exception mechanism to streaming worker deferred_exception_details = None if self.environment_info_path is not None: try: environment.check_sdk_compatibility(self.environment_info_path) except Exception: # pylint: disable=broad-except deferred_exception_details = traceback.format_exc() logging.error('SDK compatibility check failed: %s', deferred_exception_details, exc_info=True) if deferred_exception_details is None: logging.info('Loading main session from the staging area...') try: self._load_main_session(self.local_staging_directory) except Exception: # pylint: disable=broad-except deferred_exception_details = traceback.format_exc() logging.error('Could not load main session: %s', deferred_exception_details, exc_info=True) # Start status HTTP server thread. thread = threading.Thread(target=self.status_server) thread.daemon = True thread.start() # The batch execution context is currently a placeholder, so we don't yet # need to have it change between work items. execution_context = maptask.BatchExecutionContext() work_item = None # Loop forever leasing work items, executing them, and reporting status. while not self._shutdown: try: # Lease a work item. The lease_work call will retry for server errors # (e.g., 500s) however it will not retry for a 404 (no item to lease). # In such cases we introduce random sleep delays with the code below. should_sleep = False try: work = self.client.lease_work( self.worker_info_for_client(), self.default_desired_lease_duration()) work_item = workitem.get_work_items( work, self.environment, execution_context) if work_item is None: should_sleep = True except HttpError as exn: # Not found errors (404) are benign. The rest are not and must be # re-raised. if exn.status_code != 404: raise should_sleep = True if should_sleep: logging.debug('No work items. Sleeping a bit ...') # The sleeping is done with a bit of jitter to avoid having workers # requesting leases in lock step. time.sleep(1.0 * (1 - 0.5 * random.random())) continue stage_name = None if work_item.map_task: stage_name = work_item.map_task.stage_name with logger.PerThreadLoggingContext( work_item_id=work_item.proto.id, stage_name=stage_name): # TODO(silviuc): Add more detailed timing and profiling support. start_time = time.time() # Do the work. The do_work() call will mark the work completed or # failed. The progress reporting_thread will take care of sending # updates and updating in the workitem object the reporting indexes # and duration for the lease. if self.work_item_profiling: with profiler.Profile( profile_id=work_item.proto.id, profile_location=self.profile_location, log_results=True): self.do_work(work_item, deferred_exception_details= deferred_exception_details) else: self.do_work(work_item, deferred_exception_details= deferred_exception_details) logging.info('Completed work item: %s in %.9f seconds', work_item.proto.id, time.time() - start_time) except Exception: # pylint: disable=broad-except # This is an exception raised outside of executing a work item most # likely while leasing a work item. We log an error and march on. logging.error('Exception in worker loop: %s', traceback.format_exc(), exc_info=True) # sleeping a bit after Exception to prevent a busy loop. time.sleep(1)