def _collect_written_timers_and_add_to_fired_timers( self, bundle_context_manager, # type: execution.BundleContextManager fired_timers # type: Dict[Tuple[str, str], ListBuffer] ): # type: (...) -> None for (transform_id, timer_family_id) in bundle_context_manager.stage.timers: written_timers = bundle_context_manager.get_buffer( create_buffer_id(timer_family_id, kind='timers'), transform_id) timer_coder_impl = bundle_context_manager.get_timer_coder_impl( transform_id, timer_family_id) if not written_timers.cleared: timers_by_key_and_window = {} for elements_timers in written_timers: for decoded_timer in timer_coder_impl.decode_all( elements_timers): timers_by_key_and_window[ decoded_timer.user_key, decoded_timer.windows[0]] = decoded_timer out = create_OutputStream() for decoded_timer in timers_by_key_and_window.values(): # Only add not cleared timer to fired timers. if not decoded_timer.clear_bit: timer_coder_impl.encode_to_stream( decoded_timer, out, True) fired_timers[(transform_id, timer_family_id)] = ListBuffer( coder_impl=timer_coder_impl) fired_timers[(transform_id, timer_family_id)].append(out.get()) written_timers.clear()
def _build_data_side_inputs_map(stages): # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput] """Builds an index mapping stages to side input descriptors. A side input descriptor is a map of side input IDs to side input access patterns for all of the outputs of a stage that will be consumed as a side input. """ transform_consumers = collections.defaultdict( list ) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]] stage_consumers = collections.defaultdict( list) # type: DefaultDict[str, List[translations.Stage]] def get_all_side_inputs(): # type: () -> Set[str] all_side_inputs = set() # type: Set[str] for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): transform_consumers[input].append(transform) stage_consumers[input].append(stage) for si in stage.side_inputs(): all_side_inputs.add(si) return all_side_inputs all_side_inputs = frozenset(get_all_side_inputs()) data_side_inputs_by_producing_stage = { } # type: Dict[str, DataSideInput] producing_stages_by_pcoll = {} for s in stages: data_side_inputs_by_producing_stage[s.name] = {} for transform in s.transforms: for o in transform.outputs.values(): if o in s.side_inputs(): continue if o in producing_stages_by_pcoll: continue producing_stages_by_pcoll[o] = s for side_pc in all_side_inputs: for consuming_transform in transform_consumers[side_pc]: if consuming_transform.spec.urn not in translations.PAR_DO_URNS: continue producing_stage = producing_stages_by_pcoll[side_pc] payload = proto_utils.parse_Bytes( consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for si_tag in payload.side_inputs: if consuming_transform.inputs[si_tag] == side_pc: side_input_id = (consuming_transform.unique_name, si_tag) data_side_inputs_by_producing_stage[ producing_stage.name][side_input_id] = ( translations.create_buffer_id(side_pc), payload.side_inputs[si_tag].access_pattern) return data_side_inputs_by_producing_stage
def extract_bundle_inputs_and_outputs(self): # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. Returns: A tuple of (data_input, data_output, expected_timer_output) dictionaries. `data_input` is a dictionary mapping (transform_name, output_name) to a PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. `expected_timer_output` is a dictionary mapping transform_id and timer family ID to a buffer id for timers. """ data_input = {} # type: Dict[str, PartitionableBuffer] data_output = {} # type: DataOutput # A mapping of {(transform_id, timer_family_id) : buffer_id} expected_timer_output = {} # type: Dict[Tuple[str, str], str] for transform in self.stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: coder_id = self.execution_context.data_channel_coders[only_element( transform.outputs.values())] coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders.get(coder_id, coder_id)] if pcoll_id == translations.IMPULSE_BUFFER: data_input[transform.unique_name] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) else: if pcoll_id not in self.execution_context.pcoll_buffers: self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name] = ( self.execution_context.pcoll_buffers[pcoll_id]) elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = self.execution_context.data_channel_coders[only_element( transform.inputs.values())] else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) data_api_service_descriptor = self.data_api_service_descriptor() if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for timer_family_id in payload.timer_family_specs.keys(): expected_timer_output[(transform.unique_name, timer_family_id)] = ( create_buffer_id(timer_family_id, 'timers')) return data_input, data_output, expected_timer_output
def _extract_endpoints(bundle_context_manager, # type: execution.BundleContextManager runner_execution_context, # type: execution.FnApiRunnerExecutionContext ): # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. Args: stage (translations.Stage): The stage to extract endpoints for. data_api_service_descriptor: A GRPC endpoint descriptor for data plane. Returns: A tuple of (data_input, data_side_input, data_output) dictionaries. `data_input` is a dictionary mapping (transform_name, output_name) to a PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. """ data_input = {} # type: Dict[str, PartitionableBuffer] data_side_input = {} # type: DataSideInput data_output = {} # type: DataOutput # A mapping of {(transform_id, timer_family_id) : buffer_id} expected_timer_output = {} # type: Dict[Tuple(str, str), str] for transform in bundle_context_manager.stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: coder_id = runner_execution_context.data_channel_coders[only_element( transform.outputs.values())] coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders.get(coder_id, coder_id)] if pcoll_id == translations.IMPULSE_BUFFER: data_input[transform.unique_name] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) else: if pcoll_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer( coder_impl=coder.get_impl()) data_input[transform.unique_name] = ( runner_execution_context.pcoll_buffers[pcoll_id]) elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = runner_execution_context.data_channel_coders[only_element( transform.inputs.values())] else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) data_api_service_descriptor = ( bundle_context_manager.data_api_service_descriptor()) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( create_buffer_id(transform.inputs[tag]), si.access_pattern) for timer_family_id in payload.timer_family_specs.keys(): expected_timer_output[(transform.unique_name, timer_family_id)] = ( create_buffer_id(timer_family_id, 'timers')) return data_input, data_side_input, data_output, expected_timer_output