def _add_residuals_and_channel_splits_to_deferred_inputs( self, splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] bundle_context_manager, # type: execution.BundleContextManager last_sent, deferred_inputs # type: MutableMapping[str, PartitionableBuffer] ): # type: (...) -> None prev_stops = {} # type: Dict[str, int] for split in splits: for delayed_application in split.residual_roots: name = bundle_context_manager.input_for( delayed_application.application.transform_id, delayed_application.application.input_id) if name not in deferred_inputs: deferred_inputs[name] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl( name)) deferred_inputs[name].append( delayed_application.application.element) for channel_split in split.channel_splits: coder_impl = bundle_context_manager.get_input_coder_impl( channel_split.transform_id) # TODO(SDF): This requires determanistic ordering of buffer iteration. # TODO(SDF): The return split is in terms of indices. Ideally, # a runner could map these back to actual positions to effectively # describe the two "halves" of the now-split range. Even if we have # to buffer each element we send (or at the very least a bit of # metadata, like position, about each of them) this should be doable # if they're already in memory and we are bounding the buffer size # (e.g. to 10mb plus whatever is eagerly read from the SDK). In the # case of non-split-points, we can either immediately replay the # "non-split-position" elements or record them as we do the other # delayed applications. # Decode and recode to split the encoded buffer by element index. all_elements = list( coder_impl.decode_all(b''.join( last_sent[channel_split.transform_id]))) residual_elements = all_elements[ channel_split.first_residual_element:prev_stops. get(channel_split.transform_id, len(all_elements)) + 1] if residual_elements: if channel_split.transform_id not in deferred_inputs: coder_impl = bundle_context_manager.get_input_coder_impl( channel_split.transform_id) deferred_inputs[ channel_split.transform_id] = ListBuffer( coder_impl=coder_impl) deferred_inputs[channel_split.transform_id].append( coder_impl.encode_all(residual_elements)) prev_stops[channel_split. transform_id] = channel_split.last_primary_element
def _collect_written_timers_and_add_to_fired_timers( self, bundle_context_manager, # type: execution.BundleContextManager fired_timers # type: Dict[Tuple[str, str], ListBuffer] ): # type: (...) -> None for (transform_id, timer_family_id) in bundle_context_manager.stage.timers: written_timers = bundle_context_manager.get_buffer( create_buffer_id(timer_family_id, kind='timers'), transform_id) timer_coder_impl = bundle_context_manager.get_timer_coder_impl( transform_id, timer_family_id) if not written_timers.cleared: timers_by_key_and_window = {} for elements_timers in written_timers: for decoded_timer in timer_coder_impl.decode_all( elements_timers): timers_by_key_and_window[ decoded_timer.user_key, decoded_timer.windows[0]] = decoded_timer out = create_OutputStream() for decoded_timer in timers_by_key_and_window.values(): # Only add not cleared timer to fired timers. if not decoded_timer.clear_bit: timer_coder_impl.encode_to_stream( decoded_timer, out, True) fired_timers[(transform_id, timer_family_id)] = ListBuffer( coder_impl=timer_coder_impl) fired_timers[(transform_id, timer_family_id)].append(out.get()) written_timers.clear()
def _run_bundle_multiple_times_for_testing( self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext bundle_context_manager, # type: execution.BundleContextManager data_input, data_output, # type: DataOutput fired_timers, expected_output_timers, cache_token_generator ): # type: (...) -> None """ If bundle_repeat > 0, replay every bundle for profiling and debugging. """ # all workers share state, so use any worker_handler. for _ in range(self._bundle_repeat): try: runner_execution_context.state_servicer.checkpoint() testing_bundle_manager = ParallelBundleManager( bundle_context_manager.worker_handlers, lambda pcoll_id, transform_id: ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl), bundle_context_manager.get_input_coder_impl, bundle_context_manager.process_bundle_descriptor, self._progress_frequency, num_workers=self._num_workers, cache_token_generator=cache_token_generator) testing_bundle_manager.process_bundle( data_input, data_output, fired_timers, expected_output_timers) finally: runner_execution_context.state_servicer.restore()
def _add_sdk_delayed_applications_to_deferred_inputs( self, bundle_context_manager, bundle_result, deferred_inputs): for delayed_application in bundle_result.process_bundle.residual_roots: name = bundle_context_manager.input_for( delayed_application.application.transform_id, delayed_application.application.input_id) if name not in deferred_inputs: deferred_inputs[name] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl(name)) deferred_inputs[name].append(delayed_application.application.element)
def _run_bundle( self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext bundle_context_manager, # type: execution.BundleContextManager data_input, # type: Dict[str, execution.PartitionableBuffer] data_output, # type: DataOutput input_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer] expected_timer_output, # type: Dict[Tuple[str, str], bytes] bundle_manager # type: BundleManager ): # type: (...) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str, execution.PartitionableBuffer], Dict[Tuple[str, str], ListBuffer]] """Execute a bundle, and return a result object, and deferred inputs.""" self._run_bundle_multiple_times_for_testing( runner_execution_context, bundle_manager, data_input, data_output, input_timers, expected_timer_output) result, splits = bundle_manager.process_bundle( data_input, data_output, input_timers, expected_timer_output) # Now we collect all the deferred inputs remaining from bundle execution. # Deferred inputs can be: # - timers # - SDK-initiated deferred applications of root elements # - Runner-initiated deferred applications of root elements deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer] fired_timers = {} # type: Dict[Tuple[str, str], ListBuffer] self._collect_written_timers_and_add_to_fired_timers( bundle_context_manager, fired_timers) self._add_sdk_delayed_applications_to_deferred_inputs( bundle_context_manager, result, deferred_inputs) self._add_residuals_and_channel_splits_to_deferred_inputs( splits, bundle_context_manager, data_input, deferred_inputs) # After collecting deferred inputs, we 'pad' the structure with empty # buffers for other expected inputs. if deferred_inputs or fired_timers: # The worker will be waiting on these inputs as well. for other_input in data_input: if other_input not in deferred_inputs: deferred_inputs[other_input] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl( other_input)) return result, deferred_inputs, fired_timers
def _add_sdk_delayed_applications_to_deferred_inputs( self, bundle_context_manager, # type: execution.BundleContextManager bundle_result, # type: beam_fn_api_pb2.InstructionResponse deferred_inputs # type: MutableMapping[str, execution.PartitionableBuffer] ): # type: (...) -> None for delayed_application in bundle_result.process_bundle.residual_roots: name = bundle_context_manager.input_for( delayed_application.application.transform_id, delayed_application.application.input_id) if name not in deferred_inputs: deferred_inputs[name] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl(name)) deferred_inputs[name].append(delayed_application.application.element)
def _store_side_inputs_in_state(self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext data_side_input, # type: DataSideInput ): # type: (...) -> None for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders[ runner_execution_context.data_channel_coders[pcoll_id]]] elements_by_window = WindowGroupingBuffer(si, value_coder) if buffer_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=value_coder.get_impl()) for element_data in runner_execution_context.pcoll_buffers[buffer_id]: elements_by_window.append(element_data) if si.urn == common_urns.side_inputs.ITERABLE.urn: for _, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id=transform_id, side_input_id=tag, window=window)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) elif si.urn == common_urns.side_inputs.MULTIMAP.urn: for key, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id=transform_id, side_input_id=tag, window=window, key=key)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) else: raise ValueError("Unknown access pattern: '%s'" % si.urn)
def _extract_endpoints(bundle_context_manager, # type: execution.BundleContextManager runner_execution_context, # type: execution.FnApiRunnerExecutionContext ): # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. Args: stage (translations.Stage): The stage to extract endpoints for. data_api_service_descriptor: A GRPC endpoint descriptor for data plane. Returns: A tuple of (data_input, data_side_input, data_output) dictionaries. `data_input` is a dictionary mapping (transform_name, output_name) to a PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. """ data_input = {} # type: Dict[str, PartitionableBuffer] data_side_input = {} # type: DataSideInput data_output = {} # type: DataOutput # A mapping of {(transform_id, timer_family_id) : buffer_id} expected_timer_output = {} # type: Dict[Tuple(str, str), str] for transform in bundle_context_manager.stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: coder_id = runner_execution_context.data_channel_coders[only_element( transform.outputs.values())] coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders.get(coder_id, coder_id)] if pcoll_id == translations.IMPULSE_BUFFER: data_input[transform.unique_name] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) else: if pcoll_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name] = ( runner_execution_context.pcoll_buffers[pcoll_id]) elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = runner_execution_context.data_channel_coders[only_element( transform.inputs.values())] else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) data_api_service_descriptor = ( bundle_context_manager.data_api_service_descriptor()) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( create_buffer_id(transform.inputs[tag]), si.access_pattern) for timer_family_id in payload.timer_family_specs.keys(): expected_timer_output[(transform.unique_name, timer_family_id)] = ( create_buffer_id(timer_family_id, 'timers')) return data_input, data_side_input, data_output, expected_timer_output
def _run_stage(self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext bundle_context_manager, # type: execution.BundleContextManager ): # type: (...) -> beam_fn_api_pb2.InstructionResponse """Run an individual stage. Args: runner_execution_context (execution.FnApiRunnerExecutionContext): An object containing execution information for the pipeline. stage (translations.Stage): A description of the stage to execute. """ worker_handler_list = bundle_context_manager.worker_handlers worker_handler_manager = runner_execution_context.worker_handler_manager _LOGGER.info('Running %s', bundle_context_manager.stage.name) (data_input, data_side_input, data_output, expected_timer_output) = self._extract_endpoints( bundle_context_manager, runner_execution_context) worker_handler_manager.register_process_bundle_descriptor( bundle_context_manager.process_bundle_descriptor) # Store the required side inputs into state so it is accessible for the # worker when it runs this bundle. self._store_side_inputs_in_state(runner_execution_context, data_side_input) # Change cache token across bundle repeats cache_token_generator = FnApiRunner.get_cache_token_generator(static=False) self._run_bundle_multiple_times_for_testing( runner_execution_context, bundle_context_manager, data_input, data_output, {}, expected_timer_output, cache_token_generator=cache_token_generator) bundle_manager = ParallelBundleManager( worker_handler_list, bundle_context_manager.get_buffer, bundle_context_manager.get_input_coder_impl, bundle_context_manager.process_bundle_descriptor, self._progress_frequency, num_workers=self._num_workers, cache_token_generator=cache_token_generator) # For the first time of processing, we don't have fired timers as inputs. result, splits = bundle_manager.process_bundle(data_input, data_output, {}, expected_timer_output) last_result = result last_sent = data_input # We cannot split deferred_input until we include residual_roots to # merged results. Without residual_roots, pipeline stops earlier and we # may miss some data. # We also don't partition fired timer inputs for the same reason. bundle_manager._num_workers = 1 while True: deferred_inputs = {} # type: Dict[str, PartitionableBuffer] fired_timers = {} self._collect_written_timers_and_add_to_fired_timers( bundle_context_manager, fired_timers) # Queue any process-initiated delayed bundle applications. for delayed_application in last_result.process_bundle.residual_roots: name = bundle_context_manager.input_for( delayed_application.application.transform_id, delayed_application.application.input_id) if name not in deferred_inputs: deferred_inputs[name] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl(name)) deferred_inputs[name].append(delayed_application.application.element) # Queue any runner-initiated delayed bundle applications. self._add_residuals_and_channel_splits_to_deferred_inputs( splits, bundle_context_manager, last_sent, deferred_inputs) if deferred_inputs or fired_timers: # The worker will be waiting on these inputs as well. for other_input in data_input: if other_input not in deferred_inputs: deferred_inputs[other_input] = ListBuffer( coder_impl=bundle_context_manager.get_input_coder_impl( other_input)) # TODO(robertwb): merge results last_result, splits = bundle_manager.process_bundle( deferred_inputs, data_output, fired_timers, expected_timer_output) last_sent = deferred_inputs result = beam_fn_api_pb2.InstructionResponse( process_bundle=beam_fn_api_pb2.ProcessBundleResponse( monitoring_infos=monitoring_infos.consolidate( itertools.chain( result.process_bundle.monitoring_infos, last_result.process_bundle.monitoring_infos))), error=result.error or last_result.error) else: break return result