Example #1
0
    def test_good_signatures(self):
        class BasicStatefulDoFn(DoFn):
            BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
            EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)

            def process(self,
                        element,
                        buffer=DoFn.StateParam(BUFFER_STATE),
                        timer1=DoFn.TimerParam(EXPIRY_TIMER)):
                yield element

            @on_timer(EXPIRY_TIMER)
            def expiry_callback(self,
                                element,
                                timer=DoFn.TimerParam(EXPIRY_TIMER)):
                yield element

        # Validate get_dofn_specs() and timer callbacks in
        # DoFnSignature.
        stateful_dofn = BasicStatefulDoFn()
        signature = self._validate_dofn(stateful_dofn)
        expected_specs = (set([BasicStatefulDoFn.BUFFER_STATE]),
                          set([BasicStatefulDoFn.EXPIRY_TIMER]))
        self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
        self.assertEqual(
            stateful_dofn.expiry_callback, signature.timer_methods[
                BasicStatefulDoFn.EXPIRY_TIMER].method_value)

        stateful_dofn = TestStatefulDoFn()
        signature = self._validate_dofn(stateful_dofn)
        expected_specs = (set(
            [TestStatefulDoFn.BUFFER_STATE_1,
             TestStatefulDoFn.BUFFER_STATE_2]),
                          set([
                              TestStatefulDoFn.EXPIRY_TIMER_1,
                              TestStatefulDoFn.EXPIRY_TIMER_2,
                              TestStatefulDoFn.EXPIRY_TIMER_3
                          ]))
        self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
        self.assertEqual(
            stateful_dofn.on_expiry_1, signature.timer_methods[
                TestStatefulDoFn.EXPIRY_TIMER_1].method_value)
        self.assertEqual(
            stateful_dofn.on_expiry_2, signature.timer_methods[
                TestStatefulDoFn.EXPIRY_TIMER_2].method_value)
        self.assertEqual(
            stateful_dofn.on_expiry_3, signature.timer_methods[
                TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
Example #2
0
 def visit_transform(self, applied_ptransform):
     transform = applied_ptransform.transform
     # The FnApiRunner does not support streaming execution.
     if isinstance(transform, TestStream):
         self.supported_by_fnapi_runner = False
     # The FnApiRunner does not support reads from NativeSources.
     if (isinstance(transform, beam.io.Read)
             and isinstance(transform.source, NativeSource)):
         self.supported_by_fnapi_runner = False
     # The FnApiRunner does not support the use of _NativeWrites.
     if isinstance(transform, _NativeWrite):
         self.supported_by_fnapi_runner = False
     if isinstance(transform, beam.ParDo):
         dofn = transform.dofn
         # The FnApiRunner does not support execution of CombineFns with
         # deferred side inputs.
         if isinstance(dofn, CombineValuesDoFn):
             args, kwargs = transform.raw_side_inputs
             args_to_check = itertools.chain(args, kwargs.values())
             if any(
                     isinstance(arg, ArgumentPlaceholder)
                     for arg in args_to_check):
                 self.supported_by_fnapi_runner = False
         if userstate.is_stateful_dofn(dofn):
             _, timer_specs = userstate.get_dofn_specs(dofn)
             for timer in timer_specs:
                 if timer.time_domain == TimeDomain.REAL_TIME:
                     self.supported_by_fnapi_runner = False
Example #3
0
    def __init__(self, do_fn):
        # We add a property here for all methods defined by Beam DoFn features.

        assert isinstance(do_fn, core.DoFn)
        self.do_fn = do_fn

        self.process_method = MethodWrapper(do_fn, 'process')
        self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
        self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')

        restriction_provider = self.get_restriction_provider()
        self.initial_restriction_method = (MethodWrapper(
            restriction_provider, 'initial_restriction')
                                           if restriction_provider else None)
        self.restriction_coder_method = (MethodWrapper(restriction_provider,
                                                       'restriction_coder')
                                         if restriction_provider else None)
        self.create_tracker_method = (MethodWrapper(restriction_provider,
                                                    'create_tracker')
                                      if restriction_provider else None)
        self.split_method = (MethodWrapper(restriction_provider, 'split')
                             if restriction_provider else None)

        self._validate()

        # Handle stateful DoFns.
        self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
        self.timer_methods = {}
        if self._is_stateful_dofn:
            # Populate timer firing methods, keyed by TimerSpec.
            _, all_timer_specs = userstate.get_dofn_specs(do_fn)
            for timer_spec in all_timer_specs:
                method = timer_spec._attached_callback
                self.timer_methods[timer_spec] = MethodWrapper(
                    do_fn, method.__name__)
Example #4
0
  def __init__(self, do_fn):
    # We add a property here for all methods defined by Beam DoFn features.

    assert isinstance(do_fn, core.DoFn)
    self.do_fn = do_fn

    self.process_method = MethodWrapper(do_fn, 'process')
    self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
    self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')

    restriction_provider = self.get_restriction_provider()
    self.initial_restriction_method = (
        MethodWrapper(restriction_provider, 'initial_restriction')
        if restriction_provider else None)
    self.restriction_coder_method = (
        MethodWrapper(restriction_provider, 'restriction_coder')
        if restriction_provider else None)
    self.create_tracker_method = (
        MethodWrapper(restriction_provider, 'create_tracker')
        if restriction_provider else None)
    self.split_method = (
        MethodWrapper(restriction_provider, 'split')
        if restriction_provider else None)

    self._validate()

    # Handle stateful DoFns.
    self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
    self.timer_methods = {}
    if self._is_stateful_dofn:
      # Populate timer firing methods, keyed by TimerSpec.
      _, all_timer_specs = userstate.get_dofn_specs(do_fn)
      for timer_spec in all_timer_specs:
        method = timer_spec._attached_callback
        self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
Example #5
0
    def setup(self):
        # type: () -> None
        with self.scoped_start_state:
            super(DoOperation, self).setup()

            # See fn_data in dataflow_runner.py
            fn, args, kwargs, tags_and_types, window_fn = (pickler.loads(
                self.spec.serialized_fn))

            state = common.DoFnState(self.counter_factory)
            state.step_name = self.name_context.logging_name()

            # Tag to output index map used to dispatch the side output values emitted
            # by the DoFn function to the appropriate receivers. The main output is
            # tagged with None and is associated with its corresponding index.
            self.tagged_receivers = _TaggedReceivers(
                self.counter_factory, self.name_context.logging_name())

            output_tag_prefix = PropertyNames.OUT + '_'
            for index, tag in enumerate(self.spec.output_tags):
                if tag == PropertyNames.OUT:
                    original_tag = None  # type: Optional[str]
                elif tag.startswith(output_tag_prefix):
                    original_tag = tag[len(output_tag_prefix):]
                else:
                    raise ValueError(
                        'Unexpected output name for operation: %s' % tag)
                self.tagged_receivers[original_tag] = self.receivers[index]

            if self.user_state_context:
                self.user_state_context.update_timer_receivers(
                    self.tagged_receivers)
                self.timer_specs = {
                    spec.name: spec
                    for spec in userstate.get_dofn_specs(fn)[1]
                }

            if self.side_input_maps is None:
                if tags_and_types:
                    self.side_input_maps = list(
                        self._read_side_inputs(tags_and_types))
                else:
                    self.side_input_maps = []

            self.dofn_runner = common.DoFnRunner(
                fn,
                args,
                kwargs,
                self.side_input_maps,
                window_fn,
                tagged_receivers=self.tagged_receivers,
                step_name=self.name_context.logging_name(),
                state=state,
                user_state_context=self.user_state_context,
                operation_name=self.name_context.metrics_name())
            self.dofn_runner.setup()

            self.dofn_receiver = (self.dofn_runner if isinstance(
                self.dofn_runner, Receiver) else DoFnRunnerReceiver(
                    self.dofn_runner))
  def test_good_signatures(self):
    class BasicStatefulDoFn(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)

      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
        yield element

      @on_timer(EXPIRY_TIMER)
      def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)):
        yield element

    # Validate get_dofn_specs() and timer callbacks in
    # DoFnSignature.
    stateful_dofn = BasicStatefulDoFn()
    signature = self._validate_dofn(stateful_dofn)
    expected_specs = (set([BasicStatefulDoFn.BUFFER_STATE]),
                      set([BasicStatefulDoFn.EXPIRY_TIMER]))
    self.assertEqual(expected_specs,
                     get_dofn_specs(stateful_dofn))
    self.assertEqual(
        stateful_dofn.expiry_callback,
        signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value)

    stateful_dofn = TestStatefulDoFn()
    signature = self._validate_dofn(stateful_dofn)
    expected_specs = (set([TestStatefulDoFn.BUFFER_STATE_1,
                           TestStatefulDoFn.BUFFER_STATE_2]),
                      set([TestStatefulDoFn.EXPIRY_TIMER_1,
                           TestStatefulDoFn.EXPIRY_TIMER_2,
                           TestStatefulDoFn.EXPIRY_TIMER_3]))
    self.assertEqual(expected_specs,
                     get_dofn_specs(stateful_dofn))
    self.assertEqual(
        stateful_dofn.on_expiry_1,
        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value)
    self.assertEqual(
        stateful_dofn.on_expiry_2,
        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value)
    self.assertEqual(
        stateful_dofn.on_expiry_3,
        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
Example #7
0
    def setup(self):
        # type: () -> None
        with self.scoped_start_state:
            super(DoOperation, self).setup()

            # See fn_data in dataflow_runner.py
            fn, args, kwargs, tags_and_types, window_fn = (pickler.loads(
                self.spec.serialized_fn))

            state = common.DoFnState(self.counter_factory)
            state.step_name = self.name_context.logging_name()

            # Tag to output index map used to dispatch the output values emitted
            # by the DoFn function to the appropriate receivers. The main output is
            # either the only output or the output tagged with 'None' and is
            # associated with its corresponding index.
            self.tagged_receivers = _TaggedReceivers(
                self.counter_factory, self.name_context.logging_name())

            if len(self.spec.output_tags) == 1:
                self.tagged_receivers[None] = self.receivers[0]
                self.tagged_receivers[
                    self.spec.output_tags[0]] = self.receivers[0]
            else:
                for index, tag in enumerate(self.spec.output_tags):
                    self.tagged_receivers[tag] = self.receivers[index]
                    if tag == 'None':
                        self.tagged_receivers[None] = self.receivers[index]

            if self.user_state_context:
                self.timer_specs = {
                    spec.name: spec
                    for spec in userstate.get_dofn_specs(fn)[1]
                }

            if self.side_input_maps is None:
                if tags_and_types:
                    self.side_input_maps = list(
                        self._read_side_inputs(tags_and_types))
                else:
                    self.side_input_maps = []

            self.dofn_runner = common.DoFnRunner(
                fn,
                args,
                kwargs,
                self.side_input_maps,
                window_fn,
                tagged_receivers=self.tagged_receivers,
                step_name=self.name_context.logging_name(),
                state=state,
                user_state_context=self.user_state_context,
                operation_name=self.name_context.metrics_name())
            self.dofn_runner.setup()
Example #8
0
  def start(self):
    with self.scoped_start_state:
      super(DoOperation, self).start()

      # See fn_data in dataflow_runner.py
      fn, args, kwargs, tags_and_types, window_fn = (
          pickler.loads(self.spec.serialized_fn))

      state = common.DoFnState(self.counter_factory)
      state.step_name = self.name_context.logging_name()

      # Tag to output index map used to dispatch the side output values emitted
      # by the DoFn function to the appropriate receivers. The main output is
      # tagged with None and is associated with its corresponding index.
      self.tagged_receivers = _TaggedReceivers(
          self.counter_factory, self.name_context.logging_name())

      output_tag_prefix = PropertyNames.OUT + '_'
      for index, tag in enumerate(self.spec.output_tags):
        if tag == PropertyNames.OUT:
          original_tag = None
        elif tag.startswith(output_tag_prefix):
          original_tag = tag[len(output_tag_prefix):]
        else:
          raise ValueError('Unexpected output name for operation: %s' % tag)
        self.tagged_receivers[original_tag] = self.receivers[index]

      if self.user_state_context:
        self.user_state_context.update_timer_receivers(self.tagged_receivers)
        self.timer_specs = {
            spec.name: spec
            for spec in userstate.get_dofn_specs(fn)[1]
        }

      if self.side_input_maps is None:
        if tags_and_types:
          self.side_input_maps = list(self._read_side_inputs(tags_and_types))
        else:
          self.side_input_maps = []

      self.dofn_runner = common.DoFnRunner(
          fn, args, kwargs, self.side_input_maps, window_fn,
          tagged_receivers=self.tagged_receivers,
          step_name=self.name_context.logging_name(),
          state=state,
          user_state_context=self.user_state_context,
          operation_name=self.name_context.metrics_name())

      self.dofn_receiver = (self.dofn_runner
                            if isinstance(self.dofn_runner, Receiver)
                            else DoFnRunnerReceiver(self.dofn_runner))

      self.dofn_runner.start()
Example #9
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (
        pickler.loads(pickler.dumps(transform.dofn))
        if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn,
        args,
        kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.setup()
    self.runner.start()
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != typehints.Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(typehints.Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.start()
  def __init__(self, step_context, dofn, key_coder):
    self.step_context = step_context
    self.dofn = dofn
    self.key_coder = key_coder

    self.all_state_specs, self.all_timer_specs = (
        userstate.get_dofn_specs(dofn))
    self.state_tags = {}
    for state_spec in self.all_state_specs:
      state_key = 'user/%s' % state_spec.name
      if isinstance(state_spec, userstate.BagStateSpec):
        state_tag = _ListStateTag(state_key)
      elif isinstance(state_spec, userstate.CombiningValueStateSpec):
        state_tag = _ListStateTag(state_key)
      else:
        raise ValueError('Invalid state spec: %s' % state_spec)
      self.state_tags[state_spec] = state_tag

    self.cached_states = {}
    self.cached_timers = {}
Example #12
0
    def __init__(self, step_context, dofn, key_coder):
        self.step_context = step_context
        self.dofn = dofn
        self.key_coder = key_coder

        self.all_state_specs, self.all_timer_specs = (
            userstate.get_dofn_specs(dofn))
        self.state_tags = {}
        for state_spec in self.all_state_specs:
            state_key = 'user/%s' % state_spec.name
            if isinstance(state_spec, userstate.BagStateSpec):
                state_tag = _ListStateTag(state_key)
            elif isinstance(state_spec, userstate.CombiningValueStateSpec):
                state_tag = _ListStateTag(state_key)
            else:
                raise ValueError('Invalid state spec: %s' % state_spec)
            self.state_tags[state_spec] = state_tag

        self.cached_states = {}
        self.cached_timers = {}
Example #13
0
 def has_timers(self):
     # type: () -> bool
     _, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
     return bool(all_timer_specs)
Example #14
0
 def has_timers(self):
   _, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
   return bool(all_timer_specs)