def merge(self, merge_context): # type: (window.WindowFn.MergeContext) -> None worker_handler = self.worker_handle() assert self.windowed_input_coder_impl is not None assert self.windowed_output_coder_impl is not None process_bundle_id = self.uid('process') to_worker = worker_handler.data_conn.output_stream( process_bundle_id, self.TO_SDK_TRANSFORM) to_worker.write( self.windowed_input_coder_impl.encode_nested( window.GlobalWindows.windowed_value( (b'', merge_context.windows)))) to_worker.close() process_bundle_req = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_id=self._bundle_processor_id)) result_future = worker_handler.control_conn.push(process_bundle_req) for output in worker_handler.data_conn.input_elements( process_bundle_id, [self.FROM_SDK_TRANSFORM], abort_callback=lambda: bool(result_future.is_done() and result_future.get().error)): if isinstance(output, beam_fn_api_pb2.Elements.Data): windowed_result = self.windowed_output_coder_impl.decode_nested( output.data) for merge_result, originals in windowed_result.value[1][1]: merge_context.merge(originals, merge_result) else: raise RuntimeError("Unexpected data: %s" % output) result = result_future.get() if result.error: raise RuntimeError(result.error)
def process_bundle(self, inputs, expected_outputs): # Unique id for the instruction processing this bundle. BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter # Register the bundle descriptor, if needed. if self._registered: registration_future = None else: process_bundle_registration = beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[self._bundle_descriptor])) registration_future = self._controller.control_handler.push( process_bundle_registration) self._registered = True # Write all the input data to the channel. for (transform_id, name), elements in inputs.items(): data_out = self._controller.data_plane_handler.output_stream( process_bundle_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in elements: data_out.write(element_data) data_out.close() # Actually start the bundle. if registration_future and registration_future.get().error: raise RuntimeError(registration_future.get().error) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=self._bundle_descriptor.id)) result_future = self._controller.control_handler.push(process_bundle) with ProgressRequester( self._controller, process_bundle_id, self._progress_frequency): # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in expected_outputs.items()] logging.debug('Gather all output data from %s.', expected_targets) for output in self._controller.data_plane_handler.input_elements( process_bundle_id, expected_targets, abort_callback=lambda: (result_future.is_done() and result_future.get().error)): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple in expected_outputs: self._get_buffer(expected_outputs[target_tuple]).append(output.data) logging.debug('Wait for the bundle to finish.') result = result_future.get() if result.error: raise RuntimeError(result.error) return result
def _run_map_task(self, map_task, control_handler, state_handler, data_plane_handler, data_operation_spec): registration, sinks, input_data = self._map_task_registration( map_task, state_handler, data_operation_spec) control_handler.push(registration) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=registration.register. process_bundle_descriptor[0].id)) for (transform_id, name), elements in input_data.items(): data_out = data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) data_out.write(elements) data_out.close() control_handler.push(process_bundle) while True: result = control_handler.pull() if result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) expected_targets = [ beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in sinks.items() ] for output in data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple not in sinks: # Unconsumed output. continue sink_op = sinks[target_tuple] coder = sink_op.output_coders[0] input_stream = create_InputStream(output.data) elements = [] while input_stream.size() > 0: elements.append(coder.get_impl().decode_from_stream( input_stream, True)) if not sink_op.write_windowed_values: elements = [e.value for e in elements] for e in elements: sink_op.output_buffer.append(e) return
def run_stage(self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element( transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element( transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString( ) else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api( si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={ transform.unique_name: transform for transform in stage.transforms }, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=process_bundle_descriptor. id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey. MultimapSideInput(ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append( state_key, elements_data, process_bundle.instruction_id) # Register and start running the bundle. logging.debug('Register and start running the bundle') controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. logging.debug('Wait for the bundle to finish.') while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items() ] # Gather all output data. logging.debug('Gather all output data from %s.', expected_targets) for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = (output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( transform_proto.inputs.values()) output_pcoll = only_element( transform_proto.outputs.values()) pre_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components.pcollections[output_pcoll]. windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return result
def process_bundle( self, inputs, # type: Mapping[str, execution.PartitionableBuffer] expected_outputs, # type: DataOutput fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer] expected_output_timers, # type: Dict[Tuple[str, str], str] dry_run=False, ): # type: (...) -> BundleProcessResult # Unique id for the instruction processing this bundle. with BundleManager._lock: BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter self._worker_handler = self.bundle_context_manager.worker_handlers[ BundleManager._uid_counter % len(self.bundle_context_manager.worker_handlers)] split_manager = self._select_split_manager() if not split_manager: # Send timers. for transform_id, timer_family_id in expected_output_timers.keys(): self._send_timers_to_worker( process_bundle_id, transform_id, timer_family_id, fired_timers.get((transform_id, timer_family_id), [])) # If there is no split_manager, write all input data to the channel. for transform_id, elements in inputs.items(): self._send_input_to_worker(process_bundle_id, transform_id, elements) # Actually start the bundle. process_bundle_req = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_id=self.bundle_context_manager. process_bundle_descriptor.id, cache_tokens=[next(self._cache_token_generator)])) result_future = self._worker_handler.control_conn.push( process_bundle_req) split_results = [ ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] with ProgressRequester(self._worker_handler, process_bundle_id, self._progress_frequency): if split_manager: split_results = self._generate_splits_for_testing( split_manager, inputs, process_bundle_id) expect_reads = list(expected_outputs.keys()) expect_reads.extend(list(expected_output_timers.keys())) # Gather all output data. for output in self._worker_handler.data_conn.input_elements( process_bundle_id, expect_reads, abort_callback=lambda: (result_future.is_done() and result_future.get().error)): if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: with BundleManager._lock: timer_buffer = self.bundle_context_manager.get_buffer( expected_output_timers[(output.transform_id, output.timer_family_id)], output.transform_id) if timer_buffer.cleared: timer_buffer.reset() timer_buffer.append(output.timers) if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: with BundleManager._lock: self.bundle_context_manager.get_buffer( expected_outputs[output.transform_id], output.transform_id).append(output.data) _LOGGER.debug('Wait for the bundle %s to finish.' % process_bundle_id) result = result_future.get( ) # type: beam_fn_api_pb2.InstructionResponse if result.error: raise RuntimeError(result.error) if result.process_bundle.requires_finalization: finalize_request = beam_fn_api_pb2.InstructionRequest( finalize_bundle=beam_fn_api_pb2.FinalizeBundleRequest( instruction_id=process_bundle_id)) self._worker_handler.control_conn.push(finalize_request) return result, split_results
def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): coders = pipeline_context.PipelineContext(pipeline_components).coders data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: pcoll_id = transform.spec.payload if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString() transform.spec.any_param.Pack(data_operation_spec) else: transform.spec.payload = "" transform.spec.any_param.Clear() return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) if data_side_input: raise NotImplementedError('Side inputs.') process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference= process_bundle_descriptor.id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Register and start running the bundle. controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items()] for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) pre_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id)