def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        for output_tag in self._applied_ptransform.outputs:
            output_pcollection = pvalue.PCollection(None, tag=output_tag)
            output_pcollection.producer = self._applied_ptransform
            self._tagged_receivers[output_tag] = (
                self._evaluation_context.create_bundle(output_pcollection))
            self._tagged_receivers[output_tag].tag = output_tag

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = (pickler.loads(pickler.dumps(transform.dofn))
                if self._perform_dofn_pickle_test else transform.dofn)

        args = transform.args if hasattr(transform, 'args') else []
        kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

        self.runner = DoFnRunner(dofn,
                                 args,
                                 kwargs,
                                 self._side_inputs,
                                 self._applied_ptransform.inputs[0].windowing,
                                 tagged_receivers=self._tagged_receivers,
                                 step_name=self._applied_ptransform.full_label,
                                 state=DoFnState(self._counter_factory))
        self.runner.start()
Beispiel #2
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Beispiel #3
0
    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        if isinstance(self._applied_ptransform.parent.transform,
                      core._MultiParDo):  # pylint: disable=protected-access
            do_outputs_tuple = self._applied_ptransform.parent.outputs[0]
            assert isinstance(do_outputs_tuple, pvalue.DoOutputsTuple)
            main_output_pcollection = do_outputs_tuple[
                do_outputs_tuple._main_tag]  # pylint: disable=protected-access

            for side_output_tag in transform.side_output_tags:
                output_pcollection = do_outputs_tuple[side_output_tag]
                self._tagged_receivers[side_output_tag] = (
                    self._evaluation_context.create_bundle(output_pcollection))
                self._tagged_receivers[side_output_tag].tag = side_output_tag
        else:
            assert len(self._outputs) == 1
            main_output_pcollection = list(self._outputs)[0]

        self._tagged_receivers[None] = self._evaluation_context.create_bundle(
            main_output_pcollection)
        self._tagged_receivers[None].tag = None  # main_tag is None.

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = pickler.loads(pickler.dumps(transform.dofn))

        pipeline_options = self._evaluation_context.pipeline_options
        if (pipeline_options is not None
                and pipeline_options.view_as(TypeOptions).runtime_type_check):
            # TODO(sourabhbajaj): Remove this if-else
            if isinstance(dofn, core.NewDoFn):
                dofn = TypeCheckWrapperNewDoFn(dofn,
                                               transform.get_type_hints())
            else:
                dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

        # TODO(sourabhbajaj): Remove this if-else
        if isinstance(dofn, core.NewDoFn):
            dofn = OutputCheckWrapperNewDoFn(
                dofn, self._applied_ptransform.full_label)
        else:
            dofn = OutputCheckWrapperDoFn(dofn,
                                          self._applied_ptransform.full_label)
        self.runner = DoFnRunner(
            dofn,
            transform.args,
            transform.kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()
Beispiel #4
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (
        pickler.loads(pickler.dumps(transform.dofn))
        if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn,
        args,
        kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.setup()
    self.runner.start()
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Beispiel #6
0
class _ParDoEvaluator(_TransformEvaluator):
  """TransformEvaluator for ParDo transform."""

  def __init__(self, evaluation_context, applied_ptransform,
               input_committed_bundle, side_inputs, scoped_metrics_container,
               perform_dofn_pickle_test=True):
    super(_ParDoEvaluator, self).__init__(
        evaluation_context, applied_ptransform, input_committed_bundle,
        side_inputs, scoped_metrics_container)
    # This is a workaround for SDF implementation. SDF implementation adds state
    # to the SDF that is not picklable.
    self._perform_dofn_pickle_test = perform_dofn_pickle_test

  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()

  def process_element(self, element):
    self.runner.process(element)

  def finish_bundle(self):
    self.runner.finish()
    bundles = self._tagged_receivers.values()
    result_counters = self._counter_factory.get_counters()
    return TransformResult(
        self, bundles, [], result_counters, None,
        self._tagged_receivers.undeclared_in_memory_tag_values)
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != typehints.Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(typehints.Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.start()
Beispiel #8
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    if isinstance(self._applied_ptransform.parent.transform, core._MultiParDo):  # pylint: disable=protected-access
      do_outputs_tuple = self._applied_ptransform.parent.outputs[0]
      assert isinstance(do_outputs_tuple, pvalue.DoOutputsTuple)
      main_output_pcollection = do_outputs_tuple[do_outputs_tuple._main_tag]  # pylint: disable=protected-access

      for side_output_tag in transform.side_output_tags:
        output_pcollection = do_outputs_tuple[side_output_tag]
        self._tagged_receivers[side_output_tag] = (
            self._evaluation_context.create_bundle(output_pcollection))
        self._tagged_receivers[side_output_tag].tag = side_output_tag
    else:
      assert len(self._outputs) == 1
      main_output_pcollection = list(self._outputs)[0]

    self._tagged_receivers[None] = self._evaluation_context.create_bundle(
        main_output_pcollection)
    self._tagged_receivers[None].tag = None  # main_tag is None.

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Beispiel #9
0
class _ParDoEvaluator(_TransformEvaluator):
    """TransformEvaluator for ParDo transform."""
    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        for output_tag in self._applied_ptransform.outputs:
            output_pcollection = pvalue.PCollection(None, tag=output_tag)
            output_pcollection.producer = self._applied_ptransform
            self._tagged_receivers[output_tag] = (
                self._evaluation_context.create_bundle(output_pcollection))
            self._tagged_receivers[output_tag].tag = output_tag

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = pickler.loads(pickler.dumps(transform.dofn))

        pipeline_options = self._evaluation_context.pipeline_options
        if (pipeline_options is not None
                and pipeline_options.view_as(TypeOptions).runtime_type_check):
            dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

        dofn = OutputCheckWrapperDoFn(dofn,
                                      self._applied_ptransform.full_label)
        self.runner = DoFnRunner(
            dofn,
            transform.args,
            transform.kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()

    def process_element(self, element):
        self.runner.process(element)

    def finish_bundle(self):
        self.runner.finish()
        bundles = self._tagged_receivers.values()
        result_counters = self._counter_factory.get_counters()
        return TransformResult(
            self._applied_ptransform, bundles, None, result_counters, None,
            self._tagged_receivers.undeclared_in_memory_tag_values)
Beispiel #10
0
class _ParDoEvaluator(_TransformEvaluator):
  """TransformEvaluator for ParDo transform."""
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()

  def process_element(self, element):
    self.runner.process(element)

  def finish_bundle(self):
    self.runner.finish()
    bundles = self._tagged_receivers.values()
    result_counters = self._counter_factory.get_counters()
    return TransformResult(
        self._applied_ptransform, bundles, [], result_counters, None,
        self._tagged_receivers.undeclared_in_memory_tag_values)
Beispiel #11
0
class _ParDoEvaluator(_TransformEvaluator):
    """TransformEvaluator for ParDo transform."""
    def __init__(
            self,
            evaluation_context,  # type: EvaluationContext
            applied_ptransform,  # type: AppliedPTransform
            input_committed_bundle,
            side_inputs,
            perform_dofn_pickle_test=True):
        super(_ParDoEvaluator,
              self).__init__(evaluation_context, applied_ptransform,
                             input_committed_bundle, side_inputs)
        # This is a workaround for SDF implementation. SDF implementation adds state
        # to the SDF that is not picklable.
        self._perform_dofn_pickle_test = perform_dofn_pickle_test

    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        for output_tag in self._applied_ptransform.outputs:
            output_pcollection = pvalue.PCollection(None, tag=output_tag)
            output_pcollection.producer = self._applied_ptransform
            self._tagged_receivers[output_tag] = (
                self._evaluation_context.create_bundle(output_pcollection))
            self._tagged_receivers[output_tag].tag = output_tag

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = (pickler.loads(pickler.dumps(transform.dofn))
                if self._perform_dofn_pickle_test else transform.dofn)

        args = transform.args if hasattr(transform, 'args') else []
        kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

        self.user_state_context = None
        self.user_timer_map = {}
        if is_stateful_dofn(dofn):
            kv_type_hint = self._applied_ptransform.inputs[0].element_type
            if kv_type_hint and kv_type_hint != Any:
                coder = coders.registry.get_coder(kv_type_hint)
                self.key_coder = coder.key_coder()
            else:
                self.key_coder = coders.registry.get_coder(Any)

            self.user_state_context = DirectUserStateContext(
                self._step_context, dofn, self.key_coder)
            _, all_timer_specs = get_dofn_specs(dofn)
            for timer_spec in all_timer_specs:
                self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

        self.runner = DoFnRunner(dofn,
                                 args,
                                 kwargs,
                                 self._side_inputs,
                                 self._applied_ptransform.inputs[0].windowing,
                                 tagged_receivers=self._tagged_receivers,
                                 step_name=self._applied_ptransform.full_label,
                                 state=DoFnState(self._counter_factory),
                                 user_state_context=self.user_state_context)
        self.runner.start()

    def process_timer(self, timer_firing):
        if timer_firing.name not in self.user_timer_map:
            _LOGGER.warning('Unknown timer fired: %s', timer_firing)
        timer_spec = self.user_timer_map[timer_firing.name]
        self.runner.process_user_timer(
            timer_spec, self.key_coder.decode(timer_firing.encoded_key),
            timer_firing.window, timer_firing.timestamp)

    def process_element(self, element):
        self.runner.process(element)

    def finish_bundle(self):
        self.runner.finish()
        bundles = list(self._tagged_receivers.values())
        result_counters = self._counter_factory.get_counters()
        if self.user_state_context:
            self.user_state_context.commit()
        return TransformResult(self, bundles, [], result_counters, None)
class _ParDoEvaluator(_TransformEvaluator):
  """TransformEvaluator for ParDo transform."""

  def __init__(self, evaluation_context, applied_ptransform,
               input_committed_bundle, side_inputs,
               perform_dofn_pickle_test=True):
    super(_ParDoEvaluator, self).__init__(
        evaluation_context, applied_ptransform, input_committed_bundle,
        side_inputs)
    # This is a workaround for SDF implementation. SDF implementation adds state
    # to the SDF that is not picklable.
    self._perform_dofn_pickle_test = perform_dofn_pickle_test

  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != typehints.Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(typehints.Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.start()

  def process_timer(self, timer_firing):
    if timer_firing.name not in self.user_timer_map:
      logging.warning('Unknown timer fired: %s', timer_firing)
    timer_spec = self.user_timer_map[timer_firing.name]
    self.runner.process_user_timer(
        timer_spec, self.key_coder.decode(timer_firing.encoded_key),
        timer_firing.window, timer_firing.timestamp)

  def process_element(self, element):
    self.runner.process(element)

  def finish_bundle(self):
    self.runner.finish()
    bundles = list(self._tagged_receivers.values())
    result_counters = self._counter_factory.get_counters()
    if self.user_state_context:
      self.user_state_context.commit()
    return TransformResult(
        self, bundles, [], result_counters, None)
Beispiel #13
0
class _ParDoEvaluator(_TransformEvaluator):
    """TransformEvaluator for ParDo transform."""
    def __init__(self, evaluation_context, applied_ptransform,
                 input_committed_bundle, side_inputs,
                 scoped_metrics_container):
        super(_ParDoEvaluator,
              self).__init__(evaluation_context, applied_ptransform,
                             input_committed_bundle, side_inputs,
                             scoped_metrics_container)

    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        if isinstance(self._applied_ptransform.parent.transform,
                      core._MultiParDo):  # pylint: disable=protected-access
            do_outputs_tuple = self._applied_ptransform.parent.outputs[0]
            assert isinstance(do_outputs_tuple, pvalue.DoOutputsTuple)
            main_output_pcollection = do_outputs_tuple[
                do_outputs_tuple._main_tag]  # pylint: disable=protected-access

            for side_output_tag in transform.side_output_tags:
                output_pcollection = do_outputs_tuple[side_output_tag]
                self._tagged_receivers[side_output_tag] = (
                    self._evaluation_context.create_bundle(output_pcollection))
                self._tagged_receivers[side_output_tag].tag = side_output_tag
        else:
            assert len(self._outputs) == 1
            main_output_pcollection = list(self._outputs)[0]

        self._tagged_receivers[None] = self._evaluation_context.create_bundle(
            main_output_pcollection)
        self._tagged_receivers[None].tag = None  # main_tag is None.

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = pickler.loads(pickler.dumps(transform.dofn))

        pipeline_options = self._evaluation_context.pipeline_options
        if (pipeline_options is not None
                and pipeline_options.view_as(TypeOptions).runtime_type_check):
            dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

        dofn = OutputCheckWrapperDoFn(dofn,
                                      self._applied_ptransform.full_label)
        self.runner = DoFnRunner(
            dofn,
            transform.args,
            transform.kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()

    def process_element(self, element):
        self.runner.process(element)

    def finish_bundle(self):
        self.runner.finish()
        bundles = self._tagged_receivers.values()
        result_counters = self._counter_factory.get_counters()
        return TransformResult(
            self._applied_ptransform, bundles, None, None, result_counters,
            None, self._tagged_receivers.undeclared_in_memory_tag_values)