Ejemplo n.º 1
0
 def get_buffer(self, buffer_id, transform_id):
     # type: (bytes, str) -> PartitionableBuffer
     """Returns the buffer for a given (operation_type, PCollection ID).
 For grouping-typed operations, we produce a ``GroupingBuffer``. For
 others, we produce a ``ListBuffer``.
 """
     kind, name = split_buffer_id(buffer_id)
     if kind == 'materialize':
         if buffer_id not in self.execution_context.pcoll_buffers:
             self.execution_context.pcoll_buffers[buffer_id] = ListBuffer(
                 coder_impl=self.get_input_coder_impl(transform_id))
         return self.execution_context.pcoll_buffers[buffer_id]
     # For timer buffer, name = timer_family_id
     elif kind == 'timers':
         if buffer_id not in self.execution_context.timer_buffers:
             timer_coder_impl = self.get_timer_coder_impl(
                 transform_id, name)
             self.execution_context.timer_buffers[buffer_id] = ListBuffer(
                 timer_coder_impl)
         return self.execution_context.timer_buffers[buffer_id]
     elif kind == 'group':
         # This is a grouping write, create a grouping buffer if needed.
         if buffer_id not in self.execution_context.pcoll_buffers:
             original_gbk_transform = name
             transform_proto = self.execution_context.pipeline_components.transforms[
                 original_gbk_transform]
             input_pcoll = only_element(
                 list(transform_proto.inputs.values()))
             output_pcoll = only_element(
                 list(transform_proto.outputs.values()))
             pre_gbk_coder = self.execution_context.pipeline_context.coders[
                 self.execution_context.safe_coders[
                     self.execution_context.
                     data_channel_coders[input_pcoll]]]
             post_gbk_coder = self.execution_context.pipeline_context.coders[
                 self.execution_context.safe_coders[
                     self.execution_context.
                     data_channel_coders[output_pcoll]]]
             windowing_strategy = (
                 self.execution_context.pipeline_context.
                 windowing_strategies[
                     self.execution_context.safe_windowing_strategies[
                         self.execution_context.pipeline_components.
                         pcollections[output_pcoll].windowing_strategy_id]])
             self.execution_context.pcoll_buffers[
                 buffer_id] = GroupingBuffer(pre_gbk_coder, post_gbk_coder,
                                             windowing_strategy)
     else:
         # These should be the only two identifiers we produce for now,
         # but special side input writes may go here.
         raise NotImplementedError(buffer_id)
     return self.execution_context.pcoll_buffers[buffer_id]
Ejemplo n.º 2
0
    def commit_side_inputs_to_state(
            self,
            data_side_input,  # type: DataSideInput
    ):
        # type: (...) -> None
        for (consuming_transform_id,
             tag), (buffer_id, func_spec) in data_side_input.items():
            _, pcoll_id = split_buffer_id(buffer_id)
            value_coder = self.pipeline_context.coders[self.safe_coders[
                self.data_channel_coders[pcoll_id]]]
            elements_by_window = WindowGroupingBuffer(func_spec, value_coder)
            if buffer_id not in self.pcoll_buffers:
                self.pcoll_buffers[buffer_id] = ListBuffer(
                    coder_impl=value_coder.get_impl())
            for element_data in self.pcoll_buffers[buffer_id]:
                elements_by_window.append(element_data)

            if func_spec.urn == common_urns.side_inputs.ITERABLE.urn:
                for _, window, elements_data in elements_by_window.encoded_items(
                ):
                    state_key = beam_fn_api_pb2.StateKey(
                        iterable_side_input=beam_fn_api_pb2.StateKey.
                        IterableSideInput(transform_id=consuming_transform_id,
                                          side_input_id=tag,
                                          window=window))
                    self.state_servicer.append_raw(state_key, elements_data)
            elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn:
                for key, window, elements_data in elements_by_window.encoded_items(
                ):
                    state_key = beam_fn_api_pb2.StateKey(
                        multimap_side_input=beam_fn_api_pb2.StateKey.
                        MultimapSideInput(transform_id=consuming_transform_id,
                                          side_input_id=tag,
                                          window=window,
                                          key=key))
                    self.state_servicer.append_raw(state_key, elements_data)
            else:
                raise ValueError("Unknown access pattern: '%s'" %
                                 func_spec.urn)
Ejemplo n.º 3
0
  def _store_side_inputs_in_state(self,
                                  runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
                                  data_side_input,  # type: DataSideInput
                                 ):
    # type: (...) -> None
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = runner_execution_context.pipeline_context.coders[
          runner_execution_context.safe_coders[
              runner_execution_context.data_channel_coders[pcoll_id]]]
      elements_by_window = WindowGroupingBuffer(si, value_coder)
      if buffer_id not in runner_execution_context.pcoll_buffers:
        runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer(
            coder_impl=value_coder.get_impl())
      for element_data in runner_execution_context.pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)

      if si.urn == common_urns.side_inputs.ITERABLE.urn:
        for _, window, elements_data in elements_by_window.encoded_items():
          state_key = beam_fn_api_pb2.StateKey(
              iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                  transform_id=transform_id, side_input_id=tag, window=window))
          (
              runner_execution_context.worker_handler_manager.state_servicer.
              append_raw(state_key, elements_data))
      elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
        for key, window, elements_data in elements_by_window.encoded_items():
          state_key = beam_fn_api_pb2.StateKey(
              multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                  transform_id=transform_id,
                  side_input_id=tag,
                  window=window,
                  key=key))
          (
              runner_execution_context.worker_handler_manager.state_servicer.
              append_raw(state_key, elements_data))
      else:
        raise ValueError("Unknown access pattern: '%s'" % si.urn)
Ejemplo n.º 4
0
    def __init__(self, stages):
        # type: (List[translations.Stage]) -> None
        self._pcollections_by_name: Dict[
            Union[str, translations.TimerFamilyId],
            WatermarkManager.PCollectionNode] = {}
        self._stages_by_name: Dict[str, WatermarkManager.StageNode] = {}

        def add_pcollection(
            pcname: str, snode: WatermarkManager.StageNode
        ) -> WatermarkManager.PCollectionNode:
            if pcname not in self._pcollections_by_name:
                self._pcollections_by_name[
                    pcname] = WatermarkManager.PCollectionNode(pcname)
            pcnode = self._pcollections_by_name[pcname]
            pcnode.consumers += 1
            assert isinstance(pcnode, WatermarkManager.PCollectionNode)
            snode.inputs.add(pcnode)
            return pcnode

        for s in stages:
            stage_name = s.name
            stage_node = WatermarkManager.StageNode(stage_name)
            self._stages_by_name[stage_name] = stage_node

            # 1. Get stage inputs, create nodes for them, add to _stages_by_name,
            #    and add as inputs to stage node.
            for transform in s.transforms:
                if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                    buffer_id = transform.spec.payload
                    if buffer_id == translations.IMPULSE_BUFFER:
                        pcoll_name = transform.unique_name
                        add_pcollection(pcoll_name, stage_node)
                        continue
                    else:
                        _, pcoll_name = split_buffer_id(buffer_id)
                    add_pcollection(pcoll_name, stage_node)

            # 2. Get stage timers, and add them as inputs to the stage.
            for transform in s.transforms:
                if transform.spec.urn in translations.PAR_DO_URNS:
                    payload = proto_utils.parse_Bytes(
                        transform.spec.payload,
                        beam_runner_api_pb2.ParDoPayload)
                    for timer_family_id in payload.timer_family_specs.keys():
                        timer_pcoll_name = (transform.unique_name,
                                            timer_family_id)
                        self._pcollections_by_name[
                            timer_pcoll_name] = WatermarkManager.PCollectionNode(
                                timer_pcoll_name)
                        timer_pcoll_node = self._pcollections_by_name[
                            timer_pcoll_name]
                        assert isinstance(timer_pcoll_node,
                                          WatermarkManager.PCollectionNode)
                        stage_node.inputs.add(timer_pcoll_node)
                        timer_pcoll_node.consumers += 1

            # 3. Get stage outputs, create nodes for them, add to
            # _pcollections_by_name, and add stage as their producer
            for transform in s.transforms:
                if transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                    buffer_id = transform.spec.payload
                    _, pcoll_name = split_buffer_id(buffer_id)
                    if pcoll_name not in self._pcollections_by_name:
                        self._pcollections_by_name[
                            pcoll_name] = WatermarkManager.PCollectionNode(
                                pcoll_name)
                    pcoll_node = self._pcollections_by_name[pcoll_name]
                    assert isinstance(pcoll_node,
                                      WatermarkManager.PCollectionNode)
                    pcoll_node.producers.add(stage_node)
                    stage_node.outputs.add(pcoll_node)

            # 4. Get stage side inputs, create nodes for them, add to
            #    _pcollections_by_name, and add them as side inputs of the stage.
            for pcoll_name in s.side_inputs():
                if pcoll_name not in self._pcollections_by_name:
                    self._pcollections_by_name[
                        pcoll_name] = WatermarkManager.PCollectionNode(
                            pcoll_name)
                pcoll_node = self._pcollections_by_name[pcoll_name]
                assert isinstance(pcoll_node, WatermarkManager.PCollectionNode)
                stage_node.side_inputs.add(pcoll_node)

        self._verify(stages)