Beispiel #1
0
  def extract_bundle_inputs_and_outputs(self):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Returns:
      A tuple of (data_input, data_output, expected_timer_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
        `expected_timer_output` is a dictionary mapping transform_id and
        timer family ID to a buffer id for timers.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
    for transform in self.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = self.execution_context.pipeline_context.coders[
              self.execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in self.execution_context.pcoll_buffers:
              self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                self.execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = self.data_api_service_descriptor()
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_output, expected_timer_output
Beispiel #2
0
 def get_buffer(self, buffer_id, transform_id):
     # type: (bytes, str) -> PartitionableBuffer
     """Returns the buffer for a given (operation_type, PCollection ID).
 For grouping-typed operations, we produce a ``GroupingBuffer``. For
 others, we produce a ``ListBuffer``.
 """
     kind, name = split_buffer_id(buffer_id)
     if kind == 'materialize':
         if buffer_id not in self.execution_context.pcoll_buffers:
             self.execution_context.pcoll_buffers[buffer_id] = ListBuffer(
                 coder_impl=self.get_input_coder_impl(transform_id))
         return self.execution_context.pcoll_buffers[buffer_id]
     # For timer buffer, name = timer_family_id
     elif kind == 'timers':
         if buffer_id not in self.execution_context.timer_buffers:
             timer_coder_impl = self.get_timer_coder_impl(
                 transform_id, name)
             self.execution_context.timer_buffers[buffer_id] = ListBuffer(
                 timer_coder_impl)
         return self.execution_context.timer_buffers[buffer_id]
     elif kind == 'group':
         # This is a grouping write, create a grouping buffer if needed.
         if buffer_id not in self.execution_context.pcoll_buffers:
             original_gbk_transform = name
             transform_proto = self.execution_context.pipeline_components.transforms[
                 original_gbk_transform]
             input_pcoll = only_element(
                 list(transform_proto.inputs.values()))
             output_pcoll = only_element(
                 list(transform_proto.outputs.values()))
             pre_gbk_coder = self.execution_context.pipeline_context.coders[
                 self.execution_context.safe_coders[
                     self.execution_context.
                     data_channel_coders[input_pcoll]]]
             post_gbk_coder = self.execution_context.pipeline_context.coders[
                 self.execution_context.safe_coders[
                     self.execution_context.
                     data_channel_coders[output_pcoll]]]
             windowing_strategy = (
                 self.execution_context.pipeline_context.
                 windowing_strategies[
                     self.execution_context.safe_windowing_strategies[
                         self.execution_context.pipeline_components.
                         pcollections[output_pcoll].windowing_strategy_id]])
             self.execution_context.pcoll_buffers[
                 buffer_id] = GroupingBuffer(pre_gbk_coder, post_gbk_coder,
                                             windowing_strategy)
     else:
         # These should be the only two identifiers we produce for now,
         # but special side input writes may go here.
         raise NotImplementedError(buffer_id)
     return self.execution_context.pcoll_buffers[buffer_id]
    def _generate_splits_for_testing(
            self,
            split_manager,
            inputs,  # type: Mapping[str, PartitionableBuffer]
            process_bundle_id):
        # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse]
        split_results = [
        ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
        read_transform_id, buffer_data = only_element(inputs.items())
        byte_stream = b''.join(buffer_data)
        num_elements = len(
            list(
                self.bundle_context_manager.get_input_coder_impl(
                    read_transform_id).decode_all(byte_stream)))

        # Start the split manager in case it wants to set any breakpoints.
        split_manager_generator = split_manager(num_elements)
        try:
            split_fraction = next(split_manager_generator)
            done = False
        except StopIteration:
            done = True

        # Send all the data.
        self._send_input_to_worker(process_bundle_id, read_transform_id,
                                   [byte_stream])

        assert self._worker_handler is not None

        # Execute the requested splits.
        while not done:
            if split_fraction is None:
                split_result = None
            else:
                split_request = beam_fn_api_pb2.InstructionRequest(
                    process_bundle_split=beam_fn_api_pb2.
                    ProcessBundleSplitRequest(
                        instruction_id=process_bundle_id,
                        desired_splits={
                            read_transform_id:
                            beam_fn_api_pb2.ProcessBundleSplitRequest.
                            DesiredSplit(fraction_of_remainder=split_fraction,
                                         estimated_input_elements=num_elements)
                        }))
                split_response = self._worker_handler.control_conn.push(
                    split_request).get(
                    )  # type: beam_fn_api_pb2.InstructionResponse
                for t in (0.05, 0.1, 0.2):
                    waiting = ('Instruction not running', 'not yet scheduled')
                    if any(msg in split_response.error for msg in waiting):
                        time.sleep(t)
                        split_response = self._worker_handler.control_conn.push(
                            split_request).get()
                if 'Unknown process bundle' in split_response.error:
                    # It may have finished too fast.
                    split_result = None
                elif split_response.error:
                    raise RuntimeError(split_response.error)
                else:
                    split_result = split_response.process_bundle_split
                    split_results.append(split_result)
            try:
                split_fraction = split_manager_generator.send(split_result)
            except StopIteration:
                break
        return split_results
Beispiel #4
0
  def _extract_endpoints(bundle_context_manager,  # type: execution.BundleContextManager
                         runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
                         ):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Args:
      stage (translations.Stage): The stage to extract endpoints
        for.
      data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
    Returns:
      A tuple of (data_input, data_side_input, data_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_side_input = {}  # type: DataSideInput
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple(str, str), str]
    for transform in bundle_context_manager.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = runner_execution_context.pipeline_context.coders[
              runner_execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in runner_execution_context.pcoll_buffers:
              runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                runner_execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = (
            bundle_context_manager.data_api_service_descriptor())
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, si in payload.side_inputs.items():
          data_side_input[transform.unique_name, tag] = (
              create_buffer_id(transform.inputs[tag]), si.access_pattern)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_side_input, data_output, expected_timer_output