def _collect_written_timers_and_add_to_fired_timers(
        self,
        bundle_context_manager,  # type: execution.BundleContextManager
        fired_timers  # type: Dict[Tuple[str, str], ListBuffer]
    ):
        # type: (...) -> None

        for (transform_id,
             timer_family_id) in bundle_context_manager.stage.timers:
            written_timers = bundle_context_manager.get_buffer(
                create_buffer_id(timer_family_id, kind='timers'), transform_id)
            timer_coder_impl = bundle_context_manager.get_timer_coder_impl(
                transform_id, timer_family_id)
            if not written_timers.cleared:
                timers_by_key_and_window = {}
                for elements_timers in written_timers:
                    for decoded_timer in timer_coder_impl.decode_all(
                            elements_timers):
                        timers_by_key_and_window[
                            decoded_timer.user_key,
                            decoded_timer.windows[0]] = decoded_timer
                out = create_OutputStream()
                for decoded_timer in timers_by_key_and_window.values():
                    # Only add not cleared timer to fired timers.
                    if not decoded_timer.clear_bit:
                        timer_coder_impl.encode_to_stream(
                            decoded_timer, out, True)
                fired_timers[(transform_id, timer_family_id)] = ListBuffer(
                    coder_impl=timer_coder_impl)
                fired_timers[(transform_id, timer_family_id)].append(out.get())
                written_timers.clear()
Example #2
0
    def _build_data_side_inputs_map(stages):
        # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]
        """Builds an index mapping stages to side input descriptors.

    A side input descriptor is a map of side input IDs to side input access
    patterns for all of the outputs of a stage that will be consumed as a
    side input.
    """
        transform_consumers = collections.defaultdict(
            list
        )  # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
        stage_consumers = collections.defaultdict(
            list)  # type: DefaultDict[str, List[translations.Stage]]

        def get_all_side_inputs():
            # type: () -> Set[str]
            all_side_inputs = set()  # type: Set[str]
            for stage in stages:
                for transform in stage.transforms:
                    for input in transform.inputs.values():
                        transform_consumers[input].append(transform)
                        stage_consumers[input].append(stage)
                for si in stage.side_inputs():
                    all_side_inputs.add(si)
            return all_side_inputs

        all_side_inputs = frozenset(get_all_side_inputs())
        data_side_inputs_by_producing_stage = {
        }  # type: Dict[str, DataSideInput]

        producing_stages_by_pcoll = {}

        for s in stages:
            data_side_inputs_by_producing_stage[s.name] = {}
            for transform in s.transforms:
                for o in transform.outputs.values():
                    if o in s.side_inputs():
                        continue
                    if o in producing_stages_by_pcoll:
                        continue
                    producing_stages_by_pcoll[o] = s

        for side_pc in all_side_inputs:
            for consuming_transform in transform_consumers[side_pc]:
                if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
                    continue
                producing_stage = producing_stages_by_pcoll[side_pc]
                payload = proto_utils.parse_Bytes(
                    consuming_transform.spec.payload,
                    beam_runner_api_pb2.ParDoPayload)
                for si_tag in payload.side_inputs:
                    if consuming_transform.inputs[si_tag] == side_pc:
                        side_input_id = (consuming_transform.unique_name,
                                         si_tag)
                        data_side_inputs_by_producing_stage[
                            producing_stage.name][side_input_id] = (
                                translations.create_buffer_id(side_pc),
                                payload.side_inputs[si_tag].access_pattern)

        return data_side_inputs_by_producing_stage
Example #3
0
  def extract_bundle_inputs_and_outputs(self):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Returns:
      A tuple of (data_input, data_output, expected_timer_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
        `expected_timer_output` is a dictionary mapping transform_id and
        timer family ID to a buffer id for timers.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
    for transform in self.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = self.execution_context.pipeline_context.coders[
              self.execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in self.execution_context.pcoll_buffers:
              self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                self.execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = self.data_api_service_descriptor()
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_output, expected_timer_output
Example #4
0
  def _extract_endpoints(bundle_context_manager,  # type: execution.BundleContextManager
                         runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
                         ):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Args:
      stage (translations.Stage): The stage to extract endpoints
        for.
      data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
    Returns:
      A tuple of (data_input, data_side_input, data_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_side_input = {}  # type: DataSideInput
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple(str, str), str]
    for transform in bundle_context_manager.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = runner_execution_context.pipeline_context.coders[
              runner_execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in runner_execution_context.pcoll_buffers:
              runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                runner_execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = runner_execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = (
            bundle_context_manager.data_api_service_descriptor())
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, si in payload.side_inputs.items():
          data_side_input[transform.unique_name, tag] = (
              create_buffer_id(transform.inputs[tag]), si.access_pattern)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_side_input, data_output, expected_timer_output