def get_buffer(self, buffer_id, transform_id): # type: (bytes, str) -> PartitionableBuffer """Returns the buffer for a given (operation_type, PCollection ID). For grouping-typed operations, we produce a ``GroupingBuffer``. For others, we produce a ``ListBuffer``. """ kind, name = split_buffer_id(buffer_id) if kind == 'materialize': if buffer_id not in self.execution_context.pcoll_buffers: self.execution_context.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=self.get_input_coder_impl(transform_id)) return self.execution_context.pcoll_buffers[buffer_id] # For timer buffer, name = timer_family_id elif kind == 'timers': if buffer_id not in self.execution_context.timer_buffers: timer_coder_impl = self.get_timer_coder_impl( transform_id, name) self.execution_context.timer_buffers[buffer_id] = ListBuffer( timer_coder_impl) return self.execution_context.timer_buffers[buffer_id] elif kind == 'group': # This is a grouping write, create a grouping buffer if needed. if buffer_id not in self.execution_context.pcoll_buffers: original_gbk_transform = name transform_proto = self.execution_context.pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( list(transform_proto.inputs.values())) output_pcoll = only_element( list(transform_proto.outputs.values())) pre_gbk_coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[ self.execution_context. data_channel_coders[input_pcoll]]] post_gbk_coder = self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[ self.execution_context. data_channel_coders[output_pcoll]]] windowing_strategy = ( self.execution_context.pipeline_context. windowing_strategies[ self.execution_context.safe_windowing_strategies[ self.execution_context.pipeline_components. pcollections[output_pcoll].windowing_strategy_id]]) self.execution_context.pcoll_buffers[ buffer_id] = GroupingBuffer(pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(buffer_id) return self.execution_context.pcoll_buffers[buffer_id]
def commit_side_inputs_to_state( self, data_side_input, # type: DataSideInput ): # type: (...) -> None for (consuming_transform_id, tag), (buffer_id, func_spec) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = self.pipeline_context.coders[self.safe_coders[ self.data_channel_coders[pcoll_id]]] elements_by_window = WindowGroupingBuffer(func_spec, value_coder) if buffer_id not in self.pcoll_buffers: self.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=value_coder.get_impl()) for element_data in self.pcoll_buffers[buffer_id]: elements_by_window.append(element_data) if func_spec.urn == common_urns.side_inputs.ITERABLE.urn: for _, window, elements_data in elements_by_window.encoded_items( ): state_key = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey. IterableSideInput(transform_id=consuming_transform_id, side_input_id=tag, window=window)) self.state_servicer.append_raw(state_key, elements_data) elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn: for key, window, elements_data in elements_by_window.encoded_items( ): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey. MultimapSideInput(transform_id=consuming_transform_id, side_input_id=tag, window=window, key=key)) self.state_servicer.append_raw(state_key, elements_data) else: raise ValueError("Unknown access pattern: '%s'" % func_spec.urn)
def _store_side_inputs_in_state(self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext data_side_input, # type: DataSideInput ): # type: (...) -> None for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders[ runner_execution_context.data_channel_coders[pcoll_id]]] elements_by_window = WindowGroupingBuffer(si, value_coder) if buffer_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=value_coder.get_impl()) for element_data in runner_execution_context.pcoll_buffers[buffer_id]: elements_by_window.append(element_data) if si.urn == common_urns.side_inputs.ITERABLE.urn: for _, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id=transform_id, side_input_id=tag, window=window)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) elif si.urn == common_urns.side_inputs.MULTIMAP.urn: for key, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id=transform_id, side_input_id=tag, window=window, key=key)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) else: raise ValueError("Unknown access pattern: '%s'" % si.urn)
def __init__(self, stages): # type: (List[translations.Stage]) -> None self._pcollections_by_name: Dict[ Union[str, translations.TimerFamilyId], WatermarkManager.PCollectionNode] = {} self._stages_by_name: Dict[str, WatermarkManager.StageNode] = {} def add_pcollection( pcname: str, snode: WatermarkManager.StageNode ) -> WatermarkManager.PCollectionNode: if pcname not in self._pcollections_by_name: self._pcollections_by_name[ pcname] = WatermarkManager.PCollectionNode(pcname) pcnode = self._pcollections_by_name[pcname] pcnode.consumers += 1 assert isinstance(pcnode, WatermarkManager.PCollectionNode) snode.inputs.add(pcnode) return pcnode for s in stages: stage_name = s.name stage_node = WatermarkManager.StageNode(stage_name) self._stages_by_name[stage_name] = stage_node # 1. Get stage inputs, create nodes for them, add to _stages_by_name, # and add as inputs to stage node. for transform in s.transforms: if transform.spec.urn == bundle_processor.DATA_INPUT_URN: buffer_id = transform.spec.payload if buffer_id == translations.IMPULSE_BUFFER: pcoll_name = transform.unique_name add_pcollection(pcoll_name, stage_node) continue else: _, pcoll_name = split_buffer_id(buffer_id) add_pcollection(pcoll_name, stage_node) # 2. Get stage timers, and add them as inputs to the stage. for transform in s.transforms: if transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for timer_family_id in payload.timer_family_specs.keys(): timer_pcoll_name = (transform.unique_name, timer_family_id) self._pcollections_by_name[ timer_pcoll_name] = WatermarkManager.PCollectionNode( timer_pcoll_name) timer_pcoll_node = self._pcollections_by_name[ timer_pcoll_name] assert isinstance(timer_pcoll_node, WatermarkManager.PCollectionNode) stage_node.inputs.add(timer_pcoll_node) timer_pcoll_node.consumers += 1 # 3. Get stage outputs, create nodes for them, add to # _pcollections_by_name, and add stage as their producer for transform in s.transforms: if transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: buffer_id = transform.spec.payload _, pcoll_name = split_buffer_id(buffer_id) if pcoll_name not in self._pcollections_by_name: self._pcollections_by_name[ pcoll_name] = WatermarkManager.PCollectionNode( pcoll_name) pcoll_node = self._pcollections_by_name[pcoll_name] assert isinstance(pcoll_node, WatermarkManager.PCollectionNode) pcoll_node.producers.add(stage_node) stage_node.outputs.add(pcoll_node) # 4. Get stage side inputs, create nodes for them, add to # _pcollections_by_name, and add them as side inputs of the stage. for pcoll_name in s.side_inputs(): if pcoll_name not in self._pcollections_by_name: self._pcollections_by_name[ pcoll_name] = WatermarkManager.PCollectionNode( pcoll_name) pcoll_node = self._pcollections_by_name[pcoll_name] assert isinstance(pcoll_node, WatermarkManager.PCollectionNode) stage_node.side_inputs.add(pcoll_node) self._verify(stages)