Exemplo n.º 1
0
    def extract_bundle_inputs_and_outputs(self):
        # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]
        """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Returns:
      A tuple of (data_input, data_output, expected_timer_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
        `expected_timer_output` is a dictionary mapping transform_id and
        timer family ID to a buffer id for timers.
    """
        data_input = {}  # type: Dict[str, PartitionableBuffer]
        data_output = {}  # type: DataOutput
        # A mapping of {(transform_id, timer_family_id) : buffer_id}
        expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
        for transform in self.stage.transforms:
            if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                      bundle_processor.DATA_OUTPUT_URN):
                pcoll_id = transform.spec.payload
                if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                    coder_id = self.execution_context.data_channel_coders[
                        only_element(transform.outputs.values())]
                    coder = self.execution_context.pipeline_context.coders[
                        self.execution_context.safe_coders.get(
                            coder_id, coder_id)]
                    if pcoll_id == translations.IMPULSE_BUFFER:
                        data_input[transform.unique_name] = ListBuffer(
                            coder_impl=coder.get_impl())
                        data_input[transform.unique_name].append(
                            ENCODED_IMPULSE_VALUE)
                    else:
                        if pcoll_id not in self.execution_context.pcoll_buffers:
                            self.execution_context.pcoll_buffers[
                                pcoll_id] = ListBuffer(
                                    coder_impl=coder.get_impl())
                        data_input[transform.unique_name] = (
                            self.execution_context.pcoll_buffers[pcoll_id])
                elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                    data_output[transform.unique_name] = pcoll_id
                    coder_id = self.execution_context.data_channel_coders[
                        only_element(transform.inputs.values())]
                else:
                    raise NotImplementedError
                data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
                data_api_service_descriptor = self.data_api_service_descriptor(
                )
                if data_api_service_descriptor:
                    data_spec.api_service_descriptor.url = (
                        data_api_service_descriptor.url)
                transform.spec.payload = data_spec.SerializeToString()
            elif transform.spec.urn in translations.PAR_DO_URNS:
                payload = proto_utils.parse_Bytes(
                    transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
                for timer_family_id in payload.timer_family_specs.keys():
                    expected_timer_output[(transform.unique_name,
                                           timer_family_id)] = (
                                               create_buffer_id(
                                                   timer_family_id, 'timers'))
        return data_input, data_output, expected_timer_output
Exemplo n.º 2
0
 def data_operation_spec(self):
   url = 'localhost:%s' % self.data_port
   remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
   remote_grpc_port.api_service_descriptor.url = url
   return remote_grpc_port
Exemplo n.º 3
0
  def _extract_endpoints(bundle_context_manager,  # type: execution.BundleContextManager
                         runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
                         ):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Args:
      stage (translations.Stage): The stage to extract endpoints
        for.
      data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
    Returns:
      A tuple of (data_input, data_side_input, data_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_side_input = {}  # type: DataSideInput
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple(str, str), str]
    for transform in bundle_context_manager.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = runner_execution_context.pipeline_context.coders[
              runner_execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in runner_execution_context.pcoll_buffers:
              runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                runner_execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = (
            bundle_context_manager.data_api_service_descriptor())
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, si in payload.side_inputs.items():
          data_side_input[transform.unique_name, tag] = (
              create_buffer_id(transform.inputs[tag]), si.access_pattern)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_side_input, data_output, expected_timer_output
Exemplo n.º 4
0
 def make_channel_payload(coder_id):
   # type: (str) -> bytes
   data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
   if data_api_service_descriptor:
     data_spec.api_service_descriptor.url = (data_api_service_descriptor.url)
   return data_spec.SerializeToString()