Beispiel #1
0
    def merge(self, merge_context):
        # type: (window.WindowFn.MergeContext) -> None
        worker_handler = self.worker_handle()

        assert self.windowed_input_coder_impl is not None
        assert self.windowed_output_coder_impl is not None
        process_bundle_id = self.uid('process')
        to_worker = worker_handler.data_conn.output_stream(
            process_bundle_id, self.TO_SDK_TRANSFORM)
        to_worker.write(
            self.windowed_input_coder_impl.encode_nested(
                window.GlobalWindows.windowed_value(
                    (b'', merge_context.windows))))
        to_worker.close()

        process_bundle_req = beam_fn_api_pb2.InstructionRequest(
            instruction_id=process_bundle_id,
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_id=self._bundle_processor_id))
        result_future = worker_handler.control_conn.push(process_bundle_req)
        for output in worker_handler.data_conn.input_elements(
                process_bundle_id, [self.FROM_SDK_TRANSFORM],
                abort_callback=lambda: bool(result_future.is_done() and
                                            result_future.get().error)):
            if isinstance(output, beam_fn_api_pb2.Elements.Data):
                windowed_result = self.windowed_output_coder_impl.decode_nested(
                    output.data)
                for merge_result, originals in windowed_result.value[1][1]:
                    merge_context.merge(originals, merge_result)
            else:
                raise RuntimeError("Unexpected data: %s" % output)

        result = result_future.get()
        if result.error:
            raise RuntimeError(result.error)
Beispiel #2
0
  def process_bundle(self, inputs, expected_outputs):
    # Unique id for the instruction processing this bundle.
    BundleManager._uid_counter += 1
    process_bundle_id = 'bundle_%s' % BundleManager._uid_counter

    # Register the bundle descriptor, if needed.
    if self._registered:
      registration_future = None
    else:
      process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
          register=beam_fn_api_pb2.RegisterRequest(
              process_bundle_descriptor=[self._bundle_descriptor]))
      registration_future = self._controller.control_handler.push(
          process_bundle_registration)
      self._registered = True

    # Write all the input data to the channel.
    for (transform_id, name), elements in inputs.items():
      data_out = self._controller.data_plane_handler.output_stream(
          process_bundle_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      for element_data in elements:
        data_out.write(element_data)
      data_out.close()

    # Actually start the bundle.
    if registration_future and registration_future.get().error:
      raise RuntimeError(registration_future.get().error)
    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=process_bundle_id,
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=self._bundle_descriptor.id))
    result_future = self._controller.control_handler.push(process_bundle)

    with ProgressRequester(
        self._controller, process_bundle_id, self._progress_frequency):
      # Gather all output data.
      expected_targets = [
          beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                 name=output_name)
          for (transform_id, output_name), _ in expected_outputs.items()]
      logging.debug('Gather all output data from %s.', expected_targets)
      for output in self._controller.data_plane_handler.input_elements(
          process_bundle_id,
          expected_targets,
          abort_callback=lambda: (result_future.is_done()
                                  and result_future.get().error)):
        target_tuple = (
            output.target.primitive_transform_reference, output.target.name)
        if target_tuple in expected_outputs:
          self._get_buffer(expected_outputs[target_tuple]).append(output.data)

      logging.debug('Wait for the bundle to finish.')
      result = result_future.get()

    if result.error:
      raise RuntimeError(result.error)
    return result
    def _run_map_task(self, map_task, control_handler, state_handler,
                      data_plane_handler, data_operation_spec):
        registration, sinks, input_data = self._map_task_registration(
            map_task, state_handler, data_operation_spec)
        control_handler.push(registration)
        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=registration.register.
                process_bundle_descriptor[0].id))

        for (transform_id, name), elements in input_data.items():
            data_out = data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            data_out.write(elements)
            data_out.close()

        control_handler.push(process_bundle)
        while True:
            result = control_handler.pull()
            if result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                expected_targets = [
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_id,
                        name=output_name)
                    for (transform_id, output_name), _ in sinks.items()
                ]
                for output in data_plane_handler.input_elements(
                        process_bundle.instruction_id, expected_targets):
                    target_tuple = (
                        output.target.primitive_transform_reference,
                        output.target.name)
                    if target_tuple not in sinks:
                        # Unconsumed output.
                        continue
                    sink_op = sinks[target_tuple]
                    coder = sink_op.output_coders[0]
                    input_stream = create_InputStream(output.data)
                    elements = []
                    while input_stream.size() > 0:
                        elements.append(coder.get_impl().decode_from_stream(
                            input_stream, True))
                    if not sink_op.write_windowed_values:
                        elements = [e.value for e in elements]
                    for e in elements:
                        sink_op.output_buffer.append(e)
                return
Beispiel #4
0
    def run_stage(self, controller, pipeline_components, stage, pcoll_buffers,
                  safe_coders):

        context = pipeline_context.PipelineContext(pipeline_components)
        data_operation_spec = controller.data_operation_spec()

        def extract_endpoints(stage):
            # Returns maps of transform names to PCollection identifiers.
            # Also mutates IO stages to point to the data data_operation_spec.
            data_input = {}
            data_side_input = {}
            data_output = {}
            for transform in stage.transforms:
                if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                          bundle_processor.DATA_OUTPUT_URN):
                    pcoll_id = transform.spec.payload
                    if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.outputs)
                        data_input[target] = pcoll_id
                    elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.inputs)
                        data_output[target] = pcoll_id
                    else:
                        raise NotImplementedError
                    if data_operation_spec:
                        transform.spec.payload = data_operation_spec.SerializeToString(
                        )
                    else:
                        transform.spec.payload = ""
                elif transform.spec.urn == urns.PARDO_TRANSFORM:
                    payload = proto_utils.parse_Bytes(
                        transform.spec.payload,
                        beam_runner_api_pb2.ParDoPayload)
                    for tag, si in payload.side_inputs.items():
                        data_side_input[transform.unique_name, tag] = (
                            'materialize:' + transform.inputs[tag],
                            beam.pvalue.SideInputData.from_runner_api(
                                si, None))
            return data_input, data_side_input, data_output

        logging.info('Running %s', stage.name)
        logging.debug('       %s', stage)
        data_input, data_side_input, data_output = extract_endpoints(stage)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            transforms={
                transform.unique_name: transform
                for transform in stage.transforms
            },
            pcollections=dict(pipeline_components.pcollections.items()),
            coders=dict(pipeline_components.coders.items()),
            windowing_strategies=dict(
                pipeline_components.windowing_strategies.items()),
            environments=dict(pipeline_components.environments.items()))

        process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor]))

        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=process_bundle_descriptor.
                id))

        # Write all the input data to the channel.
        for (transform_id, name), pcoll_id in data_input.items():
            data_out = controller.data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            for element_data in pcoll_buffers[pcoll_id]:
                data_out.write(element_data)
            data_out.close()

        # Store the required side inputs into state.
        for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
            elements_by_window = _WindowGroupingBuffer(si)
            for element_data in pcoll_buffers[pcoll_id]:
                elements_by_window.append(element_data)
            for window, elements_data in elements_by_window.items():
                state_key = beam_fn_api_pb2.StateKey(
                    multimap_side_input=beam_fn_api_pb2.StateKey.
                    MultimapSideInput(ptransform_id=transform_id,
                                      side_input_id=tag,
                                      window=window))
                controller.state_handler.blocking_append(
                    state_key, elements_data, process_bundle.instruction_id)

        # Register and start running the bundle.
        logging.debug('Register and start running the bundle')
        controller.control_handler.push(process_bundle_registration)
        controller.control_handler.push(process_bundle)

        # Wait for the bundle to finish.
        logging.debug('Wait for the bundle to finish.')
        while True:
            result = controller.control_handler.pull()
            if result and result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                break

        expected_targets = [
            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                   name=output_name)
            for (transform_id, output_name), _ in data_output.items()
        ]

        # Gather all output data.
        logging.debug('Gather all output data from %s.', expected_targets)

        for output in controller.data_plane_handler.input_elements(
                process_bundle.instruction_id, expected_targets):
            target_tuple = (output.target.primitive_transform_reference,
                            output.target.name)
            if target_tuple in data_output:
                pcoll_id = data_output[target_tuple]
                if pcoll_id.startswith('materialize:'):
                    # Just store the data chunks for replay.
                    pcoll_buffers[pcoll_id].append(output.data)
                elif pcoll_id.startswith('group:'):
                    # This is a grouping write, create a grouping buffer if needed.
                    if pcoll_id not in pcoll_buffers:
                        original_gbk_transform = pcoll_id.split(':', 1)[1]
                        transform_proto = pipeline_components.transforms[
                            original_gbk_transform]
                        input_pcoll = only_element(
                            transform_proto.inputs.values())
                        output_pcoll = only_element(
                            transform_proto.outputs.values())
                        pre_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[input_pcoll].coder_id]]
                        post_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[output_pcoll].coder_id]]
                        windowing_strategy = context.windowing_strategies[
                            pipeline_components.pcollections[output_pcoll].
                            windowing_strategy_id]
                        pcoll_buffers[pcoll_id] = _GroupingBuffer(
                            pre_gbk_coder, post_gbk_coder, windowing_strategy)
                    pcoll_buffers[pcoll_id].append(output.data)
                else:
                    # These should be the only two identifiers we produce for now,
                    # but special side input writes may go here.
                    raise NotImplementedError(pcoll_id)
        return result
    def process_bundle(
        self,
        inputs,  # type: Mapping[str, execution.PartitionableBuffer]
        expected_outputs,  # type: DataOutput
        fired_timers,  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
        expected_output_timers,  # type: Dict[Tuple[str, str], str]
        dry_run=False,
    ):
        # type: (...) -> BundleProcessResult
        # Unique id for the instruction processing this bundle.
        with BundleManager._lock:
            BundleManager._uid_counter += 1
            process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
            self._worker_handler = self.bundle_context_manager.worker_handlers[
                BundleManager._uid_counter %
                len(self.bundle_context_manager.worker_handlers)]

        split_manager = self._select_split_manager()
        if not split_manager:
            # Send timers.
            for transform_id, timer_family_id in expected_output_timers.keys():
                self._send_timers_to_worker(
                    process_bundle_id, transform_id, timer_family_id,
                    fired_timers.get((transform_id, timer_family_id), []))

            # If there is no split_manager, write all input data to the channel.
            for transform_id, elements in inputs.items():
                self._send_input_to_worker(process_bundle_id, transform_id,
                                           elements)

        # Actually start the bundle.
        process_bundle_req = beam_fn_api_pb2.InstructionRequest(
            instruction_id=process_bundle_id,
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_id=self.bundle_context_manager.
                process_bundle_descriptor.id,
                cache_tokens=[next(self._cache_token_generator)]))
        result_future = self._worker_handler.control_conn.push(
            process_bundle_req)

        split_results = [
        ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
        with ProgressRequester(self._worker_handler, process_bundle_id,
                               self._progress_frequency):

            if split_manager:
                split_results = self._generate_splits_for_testing(
                    split_manager, inputs, process_bundle_id)

            expect_reads = list(expected_outputs.keys())
            expect_reads.extend(list(expected_output_timers.keys()))

            # Gather all output data.
            for output in self._worker_handler.data_conn.input_elements(
                    process_bundle_id,
                    expect_reads,
                    abort_callback=lambda:
                (result_future.is_done() and result_future.get().error)):
                if isinstance(output,
                              beam_fn_api_pb2.Elements.Timers) and not dry_run:
                    with BundleManager._lock:
                        timer_buffer = self.bundle_context_manager.get_buffer(
                            expected_output_timers[(output.transform_id,
                                                    output.timer_family_id)],
                            output.transform_id)
                        if timer_buffer.cleared:
                            timer_buffer.reset()
                        timer_buffer.append(output.timers)
                if isinstance(output,
                              beam_fn_api_pb2.Elements.Data) and not dry_run:
                    with BundleManager._lock:
                        self.bundle_context_manager.get_buffer(
                            expected_outputs[output.transform_id],
                            output.transform_id).append(output.data)

            _LOGGER.debug('Wait for the bundle %s to finish.' %
                          process_bundle_id)
            result = result_future.get(
            )  # type: beam_fn_api_pb2.InstructionResponse

        if result.error:
            raise RuntimeError(result.error)

        if result.process_bundle.requires_finalization:
            finalize_request = beam_fn_api_pb2.InstructionRequest(
                finalize_bundle=beam_fn_api_pb2.FinalizeBundleRequest(
                    instruction_id=process_bundle_id))
            self._worker_handler.control_conn.push(finalize_request)

        return result, split_results
Beispiel #6
0
  def run_stage(
      self, controller, pipeline_components, stage, pcoll_buffers, safe_coders):

    coders = pipeline_context.PipelineContext(pipeline_components).coders
    data_operation_spec = controller.data_operation_spec()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data data_operation_spec.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        pcoll_id = transform.spec.payload
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            data_input[target] = pcoll_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
          else:
            raise NotImplementedError
          if data_operation_spec:
            transform.spec.payload = data_operation_spec.SerializeToString()
            transform.spec.any_param.Pack(data_operation_spec)
          else:
            transform.spec.payload = ""
            transform.spec.any_param.Clear()
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)
    if data_side_input:
      raise NotImplementedError('Side inputs.')

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        register=beam_fn_api_pb2.RegisterRequest(
            process_bundle_descriptor=[process_bundle_descriptor]))

    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=
            process_bundle_descriptor.id))

    # Write all the input data to the channel.
    for (transform_id, name), pcoll_id in data_input.items():
      data_out = controller.data_plane_handler.output_stream(
          process_bundle.instruction_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      for element_data in pcoll_buffers[pcoll_id]:
        data_out.write(element_data)
      data_out.close()

    # Register and start running the bundle.
    controller.control_handler.push(process_bundle_registration)
    controller.control_handler.push(process_bundle)

    # Wait for the bundle to finish.
    while True:
      result = controller.control_handler.pull()
      if result and result.instruction_id == process_bundle.instruction_id:
        if result.error:
          raise RuntimeError(result.error)
        break

    # Gather all output data.
    expected_targets = [
        beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                               name=output_name)
        for (transform_id, output_name), _ in data_output.items()]
    for output in controller.data_plane_handler.input_elements(
        process_bundle.instruction_id, expected_targets):
      target_tuple = (
          output.target.primitive_transform_reference, output.target.name)
      if target_tuple in data_output:
        pcoll_id = data_output[target_tuple]
        if pcoll_id.startswith('materialize:'):
          # Just store the data chunks for replay.
          pcoll_buffers[pcoll_id].append(output.data)
        elif pcoll_id.startswith('group:'):
          # This is a grouping write, create a grouping buffer if needed.
          if pcoll_id not in pcoll_buffers:
            original_gbk_transform = pcoll_id.split(':', 1)[1]
            transform_proto = pipeline_components.transforms[
                original_gbk_transform]
            input_pcoll = only_element(transform_proto.inputs.values())
            output_pcoll = only_element(transform_proto.outputs.values())
            pre_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[input_pcoll].coder_id]]
            post_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[output_pcoll].coder_id]]
            pcoll_buffers[pcoll_id] = _GroupingBuffer(
                pre_gbk_coder, post_gbk_coder)
          pcoll_buffers[pcoll_id].append(output.data)
        else:
          # These should be the only two identifiers we produce for now,
          # but special side input writes may go here.
          raise NotImplementedError(pcoll_id)