def extract_bundle_inputs_and_outputs(self): # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. Returns: A tuple of (data_input, data_output, expected_timer_output) dictionaries. `data_input` is a dictionary mapping (transform_name, output_name) to a PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. `expected_timer_output` is a dictionary mapping transform_id and timer family ID to a buffer id for timers. """ data_input = {} # type: Dict[str, PartitionableBuffer] data_output = {} # type: DataOutput # A mapping of {(transform_id, timer_family_id) : buffer_id} expected_timer_output = {} # type: Dict[Tuple[str, str], str] for transform in self.stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: coder_id = self.execution_context.data_channel_coders[only_element( transform.outputs.values())] coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders.get(coder_id, coder_id)] if pcoll_id == translations.IMPULSE_BUFFER: data_input[transform.unique_name] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) else: if pcoll_id not in self.execution_context.pcoll_buffers: self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name] = ( self.execution_context.pcoll_buffers[pcoll_id]) elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = self.execution_context.data_channel_coders[only_element( transform.inputs.values())] else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) data_api_service_descriptor = self.data_api_service_descriptor() if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for timer_family_id in payload.timer_family_specs.keys(): expected_timer_output[(transform.unique_name, timer_family_id)] = ( create_buffer_id(timer_family_id, 'timers')) return data_input, data_output, expected_timer_output
def get_buffer(self, buffer_id, transform_id): # type: (bytes, str) -> PartitionableBuffer """Returns the buffer for a given (operation_type, PCollection ID). For grouping-typed operations, we produce a ``GroupingBuffer``. For others, we produce a ``ListBuffer``. """ kind, name = split_buffer_id(buffer_id) if kind == 'materialize': if buffer_id not in self.execution_context.pcoll_buffers: self.execution_context.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=self.get_input_coder_impl(transform_id)) return self.execution_context.pcoll_buffers[buffer_id] # For timer buffer, name = timer_family_id elif kind == 'timers': if buffer_id not in self.execution_context.timer_buffers: timer_coder_impl = self.get_timer_coder_impl( transform_id, name) self.execution_context.timer_buffers[buffer_id] = ListBuffer( timer_coder_impl) return self.execution_context.timer_buffers[buffer_id] elif kind == 'group': # This is a grouping write, create a grouping buffer if needed. if buffer_id not in self.execution_context.pcoll_buffers: original_gbk_transform = name transform_proto = self.execution_context.pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( list(transform_proto.inputs.values())) output_pcoll = only_element( list(transform_proto.outputs.values())) pre_gbk_coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[ self.execution_context. data_channel_coders[input_pcoll]]] post_gbk_coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[ self.execution_context. data_channel_coders[output_pcoll]]] windowing_strategy = ( self.execution_context.pipeline_context. windowing_strategies[ self.execution_context.safe_windowing_strategies[ self.execution_context.pipeline_components. pcollections[output_pcoll].windowing_strategy_id]]) self.execution_context.pcoll_buffers[ buffer_id] = GroupingBuffer(pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(buffer_id) return self.execution_context.pcoll_buffers[buffer_id]
def _generate_splits_for_testing( self, split_manager, inputs, # type: Mapping[str, PartitionableBuffer] process_bundle_id): # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse] split_results = [ ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] read_transform_id, buffer_data = only_element(inputs.items()) byte_stream = b''.join(buffer_data) num_elements = len( list( self.bundle_context_manager.get_input_coder_impl( read_transform_id).decode_all(byte_stream))) # Start the split manager in case it wants to set any breakpoints. split_manager_generator = split_manager(num_elements) try: split_fraction = next(split_manager_generator) done = False except StopIteration: done = True # Send all the data. self._send_input_to_worker(process_bundle_id, read_transform_id, [byte_stream]) assert self._worker_handler is not None # Execute the requested splits. while not done: if split_fraction is None: split_result = None else: split_request = beam_fn_api_pb2.InstructionRequest( process_bundle_split=beam_fn_api_pb2. ProcessBundleSplitRequest( instruction_id=process_bundle_id, desired_splits={ read_transform_id: beam_fn_api_pb2.ProcessBundleSplitRequest. DesiredSplit(fraction_of_remainder=split_fraction, estimated_input_elements=num_elements) })) split_response = self._worker_handler.control_conn.push( split_request).get( ) # type: beam_fn_api_pb2.InstructionResponse for t in (0.05, 0.1, 0.2): waiting = ('Instruction not running', 'not yet scheduled') if any(msg in split_response.error for msg in waiting): time.sleep(t) split_response = self._worker_handler.control_conn.push( split_request).get() if 'Unknown process bundle' in split_response.error: # It may have finished too fast. split_result = None elif split_response.error: raise RuntimeError(split_response.error) else: split_result = split_response.process_bundle_split split_results.append(split_result) try: split_fraction = split_manager_generator.send(split_result) except StopIteration: break return split_results
def _extract_endpoints(bundle_context_manager, # type: execution.BundleContextManager runner_execution_context, # type: execution.FnApiRunnerExecutionContext ): # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. Args: stage (translations.Stage): The stage to extract endpoints for. data_api_service_descriptor: A GRPC endpoint descriptor for data plane. Returns: A tuple of (data_input, data_side_input, data_output) dictionaries. `data_input` is a dictionary mapping (transform_name, output_name) to a PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. """ data_input = {} # type: Dict[str, PartitionableBuffer] data_side_input = {} # type: DataSideInput data_output = {} # type: DataOutput # A mapping of {(transform_id, timer_family_id) : buffer_id} expected_timer_output = {} # type: Dict[Tuple(str, str), str] for transform in bundle_context_manager.stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: coder_id = runner_execution_context.data_channel_coders[only_element( transform.outputs.values())] coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders.get(coder_id, coder_id)] if pcoll_id == translations.IMPULSE_BUFFER: data_input[transform.unique_name] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) else: if pcoll_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name] = ( runner_execution_context.pcoll_buffers[pcoll_id]) elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = runner_execution_context.data_channel_coders[only_element( transform.inputs.values())] else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) data_api_service_descriptor = ( bundle_context_manager.data_api_service_descriptor()) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( create_buffer_id(transform.inputs[tag]), si.access_pattern) for timer_family_id in payload.timer_family_specs.keys(): expected_timer_output[(transform.unique_name, timer_family_id)] = ( create_buffer_id(timer_family_id, 'timers')) return data_input, data_side_input, data_output, expected_timer_output