def test_good_signatures(self): class BasicStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): yield element @on_timer(EXPIRY_TIMER) def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)): yield element # Validate get_dofn_specs() and timer callbacks in # DoFnSignature. stateful_dofn = BasicStatefulDoFn() signature = self._validate_dofn(stateful_dofn) expected_specs = (set([BasicStatefulDoFn.BUFFER_STATE]), set([BasicStatefulDoFn.EXPIRY_TIMER])) self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) self.assertEqual( stateful_dofn.expiry_callback, signature.timer_methods[ BasicStatefulDoFn.EXPIRY_TIMER].method_value) stateful_dofn = TestStatefulDoFn() signature = self._validate_dofn(stateful_dofn) expected_specs = (set( [TestStatefulDoFn.BUFFER_STATE_1, TestStatefulDoFn.BUFFER_STATE_2]), set([ TestStatefulDoFn.EXPIRY_TIMER_1, TestStatefulDoFn.EXPIRY_TIMER_2, TestStatefulDoFn.EXPIRY_TIMER_3 ])) self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) self.assertEqual( stateful_dofn.on_expiry_1, signature.timer_methods[ TestStatefulDoFn.EXPIRY_TIMER_1].method_value) self.assertEqual( stateful_dofn.on_expiry_2, signature.timer_methods[ TestStatefulDoFn.EXPIRY_TIMER_2].method_value) self.assertEqual( stateful_dofn.on_expiry_3, signature.timer_methods[ TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
def visit_transform(self, applied_ptransform): transform = applied_ptransform.transform # The FnApiRunner does not support streaming execution. if isinstance(transform, TestStream): self.supported_by_fnapi_runner = False # The FnApiRunner does not support reads from NativeSources. if (isinstance(transform, beam.io.Read) and isinstance(transform.source, NativeSource)): self.supported_by_fnapi_runner = False # The FnApiRunner does not support the use of _NativeWrites. if isinstance(transform, _NativeWrite): self.supported_by_fnapi_runner = False if isinstance(transform, beam.ParDo): dofn = transform.dofn # The FnApiRunner does not support execution of CombineFns with # deferred side inputs. if isinstance(dofn, CombineValuesDoFn): args, kwargs = transform.raw_side_inputs args_to_check = itertools.chain(args, kwargs.values()) if any( isinstance(arg, ArgumentPlaceholder) for arg in args_to_check): self.supported_by_fnapi_runner = False if userstate.is_stateful_dofn(dofn): _, timer_specs = userstate.get_dofn_specs(dofn) for timer in timer_specs: if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False
def __init__(self, do_fn): # We add a property here for all methods defined by Beam DoFn features. assert isinstance(do_fn, core.DoFn) self.do_fn = do_fn self.process_method = MethodWrapper(do_fn, 'process') self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') restriction_provider = self.get_restriction_provider() self.initial_restriction_method = (MethodWrapper( restriction_provider, 'initial_restriction') if restriction_provider else None) self.restriction_coder_method = (MethodWrapper(restriction_provider, 'restriction_coder') if restriction_provider else None) self.create_tracker_method = (MethodWrapper(restriction_provider, 'create_tracker') if restriction_provider else None) self.split_method = (MethodWrapper(restriction_provider, 'split') if restriction_provider else None) self._validate() # Handle stateful DoFns. self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) self.timer_methods = {} if self._is_stateful_dofn: # Populate timer firing methods, keyed by TimerSpec. _, all_timer_specs = userstate.get_dofn_specs(do_fn) for timer_spec in all_timer_specs: method = timer_spec._attached_callback self.timer_methods[timer_spec] = MethodWrapper( do_fn, method.__name__)
def __init__(self, do_fn): # We add a property here for all methods defined by Beam DoFn features. assert isinstance(do_fn, core.DoFn) self.do_fn = do_fn self.process_method = MethodWrapper(do_fn, 'process') self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') restriction_provider = self.get_restriction_provider() self.initial_restriction_method = ( MethodWrapper(restriction_provider, 'initial_restriction') if restriction_provider else None) self.restriction_coder_method = ( MethodWrapper(restriction_provider, 'restriction_coder') if restriction_provider else None) self.create_tracker_method = ( MethodWrapper(restriction_provider, 'create_tracker') if restriction_provider else None) self.split_method = ( MethodWrapper(restriction_provider, 'split') if restriction_provider else None) self._validate() # Handle stateful DoFns. self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) self.timer_methods = {} if self._is_stateful_dofn: # Populate timer firing methods, keyed by TimerSpec. _, all_timer_specs = userstate.get_dofn_specs(do_fn) for timer_spec in all_timer_specs: method = timer_spec._attached_callback self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
def setup(self): # type: () -> None with self.scoped_start_state: super(DoOperation, self).setup() # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = (pickler.loads( self.spec.serialized_fn)) state = common.DoFnState(self.counter_factory) state.step_name = self.name_context.logging_name() # Tag to output index map used to dispatch the side output values emitted # by the DoFn function to the appropriate receivers. The main output is # tagged with None and is associated with its corresponding index. self.tagged_receivers = _TaggedReceivers( self.counter_factory, self.name_context.logging_name()) output_tag_prefix = PropertyNames.OUT + '_' for index, tag in enumerate(self.spec.output_tags): if tag == PropertyNames.OUT: original_tag = None # type: Optional[str] elif tag.startswith(output_tag_prefix): original_tag = tag[len(output_tag_prefix):] else: raise ValueError( 'Unexpected output name for operation: %s' % tag) self.tagged_receivers[original_tag] = self.receivers[index] if self.user_state_context: self.user_state_context.update_timer_receivers( self.tagged_receivers) self.timer_specs = { spec.name: spec for spec in userstate.get_dofn_specs(fn)[1] } if self.side_input_maps is None: if tags_and_types: self.side_input_maps = list( self._read_side_inputs(tags_and_types)) else: self.side_input_maps = [] self.dofn_runner = common.DoFnRunner( fn, args, kwargs, self.side_input_maps, window_fn, tagged_receivers=self.tagged_receivers, step_name=self.name_context.logging_name(), state=state, user_state_context=self.user_state_context, operation_name=self.name_context.metrics_name()) self.dofn_runner.setup() self.dofn_receiver = (self.dofn_runner if isinstance( self.dofn_runner, Receiver) else DoFnRunnerReceiver( self.dofn_runner))
def test_good_signatures(self): class BasicStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): yield element @on_timer(EXPIRY_TIMER) def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)): yield element # Validate get_dofn_specs() and timer callbacks in # DoFnSignature. stateful_dofn = BasicStatefulDoFn() signature = self._validate_dofn(stateful_dofn) expected_specs = (set([BasicStatefulDoFn.BUFFER_STATE]), set([BasicStatefulDoFn.EXPIRY_TIMER])) self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) self.assertEqual( stateful_dofn.expiry_callback, signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value) stateful_dofn = TestStatefulDoFn() signature = self._validate_dofn(stateful_dofn) expected_specs = (set([TestStatefulDoFn.BUFFER_STATE_1, TestStatefulDoFn.BUFFER_STATE_2]), set([TestStatefulDoFn.EXPIRY_TIMER_1, TestStatefulDoFn.EXPIRY_TIMER_2, TestStatefulDoFn.EXPIRY_TIMER_3])) self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) self.assertEqual( stateful_dofn.on_expiry_1, signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value) self.assertEqual( stateful_dofn.on_expiry_2, signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value) self.assertEqual( stateful_dofn.on_expiry_3, signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
def setup(self): # type: () -> None with self.scoped_start_state: super(DoOperation, self).setup() # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = (pickler.loads( self.spec.serialized_fn)) state = common.DoFnState(self.counter_factory) state.step_name = self.name_context.logging_name() # Tag to output index map used to dispatch the output values emitted # by the DoFn function to the appropriate receivers. The main output is # either the only output or the output tagged with 'None' and is # associated with its corresponding index. self.tagged_receivers = _TaggedReceivers( self.counter_factory, self.name_context.logging_name()) if len(self.spec.output_tags) == 1: self.tagged_receivers[None] = self.receivers[0] self.tagged_receivers[ self.spec.output_tags[0]] = self.receivers[0] else: for index, tag in enumerate(self.spec.output_tags): self.tagged_receivers[tag] = self.receivers[index] if tag == 'None': self.tagged_receivers[None] = self.receivers[index] if self.user_state_context: self.timer_specs = { spec.name: spec for spec in userstate.get_dofn_specs(fn)[1] } if self.side_input_maps is None: if tags_and_types: self.side_input_maps = list( self._read_side_inputs(tags_and_types)) else: self.side_input_maps = [] self.dofn_runner = common.DoFnRunner( fn, args, kwargs, self.side_input_maps, window_fn, tagged_receivers=self.tagged_receivers, step_name=self.name_context.logging_name(), state=state, user_state_context=self.user_state_context, operation_name=self.name_context.metrics_name()) self.dofn_runner.setup()
def start(self): with self.scoped_start_state: super(DoOperation, self).start() # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = ( pickler.loads(self.spec.serialized_fn)) state = common.DoFnState(self.counter_factory) state.step_name = self.name_context.logging_name() # Tag to output index map used to dispatch the side output values emitted # by the DoFn function to the appropriate receivers. The main output is # tagged with None and is associated with its corresponding index. self.tagged_receivers = _TaggedReceivers( self.counter_factory, self.name_context.logging_name()) output_tag_prefix = PropertyNames.OUT + '_' for index, tag in enumerate(self.spec.output_tags): if tag == PropertyNames.OUT: original_tag = None elif tag.startswith(output_tag_prefix): original_tag = tag[len(output_tag_prefix):] else: raise ValueError('Unexpected output name for operation: %s' % tag) self.tagged_receivers[original_tag] = self.receivers[index] if self.user_state_context: self.user_state_context.update_timer_receivers(self.tagged_receivers) self.timer_specs = { spec.name: spec for spec in userstate.get_dofn_specs(fn)[1] } if self.side_input_maps is None: if tags_and_types: self.side_input_maps = list(self._read_side_inputs(tags_and_types)) else: self.side_input_maps = [] self.dofn_runner = common.DoFnRunner( fn, args, kwargs, self.side_input_maps, window_fn, tagged_receivers=self.tagged_receivers, step_name=self.name_context.logging_name(), state=state, user_state_context=self.user_state_context, operation_name=self.name_context.metrics_name()) self.dofn_receiver = (self.dofn_runner if isinstance(self.dofn_runner, Receiver) else DoFnRunnerReceiver(self.dofn_runner)) self.dofn_runner.start()
def start_bundle(self): transform = self._applied_ptransform.transform self._tagged_receivers = _TaggedReceivers(self._evaluation_context) for output_tag in self._applied_ptransform.outputs: output_pcollection = pvalue.PCollection(None, tag=output_tag) output_pcollection.producer = self._applied_ptransform self._tagged_receivers[output_tag] = ( self._evaluation_context.create_bundle(output_pcollection)) self._tagged_receivers[output_tag].tag = output_tag self._counter_factory = counters.CounterFactory() # TODO(aaltay): Consider storing the serialized form as an optimization. dofn = ( pickler.loads(pickler.dumps(transform.dofn)) if self._perform_dofn_pickle_test else transform.dofn) args = transform.args if hasattr(transform, 'args') else [] kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {} self.user_state_context = None self.user_timer_map = {} if is_stateful_dofn(dofn): kv_type_hint = self._applied_ptransform.inputs[0].element_type if kv_type_hint and kv_type_hint != Any: coder = coders.registry.get_coder(kv_type_hint) self.key_coder = coder.key_coder() else: self.key_coder = coders.registry.get_coder(Any) self.user_state_context = DirectUserStateContext( self._step_context, dofn, self.key_coder) _, all_timer_specs = get_dofn_specs(dofn) for timer_spec in all_timer_specs: self.user_timer_map['user/%s' % timer_spec.name] = timer_spec self.runner = DoFnRunner( dofn, args, kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, tagged_receivers=self._tagged_receivers, step_name=self._applied_ptransform.full_label, state=DoFnState(self._counter_factory), user_state_context=self.user_state_context) self.runner.setup() self.runner.start()
def start_bundle(self): transform = self._applied_ptransform.transform self._tagged_receivers = _TaggedReceivers(self._evaluation_context) for output_tag in self._applied_ptransform.outputs: output_pcollection = pvalue.PCollection(None, tag=output_tag) output_pcollection.producer = self._applied_ptransform self._tagged_receivers[output_tag] = ( self._evaluation_context.create_bundle(output_pcollection)) self._tagged_receivers[output_tag].tag = output_tag self._counter_factory = counters.CounterFactory() # TODO(aaltay): Consider storing the serialized form as an optimization. dofn = (pickler.loads(pickler.dumps(transform.dofn)) if self._perform_dofn_pickle_test else transform.dofn) args = transform.args if hasattr(transform, 'args') else [] kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {} self.user_state_context = None self.user_timer_map = {} if is_stateful_dofn(dofn): kv_type_hint = self._applied_ptransform.inputs[0].element_type if kv_type_hint and kv_type_hint != typehints.Any: coder = coders.registry.get_coder(kv_type_hint) self.key_coder = coder.key_coder() else: self.key_coder = coders.registry.get_coder(typehints.Any) self.user_state_context = DirectUserStateContext( self._step_context, dofn, self.key_coder) _, all_timer_specs = get_dofn_specs(dofn) for timer_spec in all_timer_specs: self.user_timer_map['user/%s' % timer_spec.name] = timer_spec self.runner = DoFnRunner( dofn, args, kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, tagged_receivers=self._tagged_receivers, step_name=self._applied_ptransform.full_label, state=DoFnState(self._counter_factory), user_state_context=self.user_state_context) self.runner.start()
def __init__(self, step_context, dofn, key_coder): self.step_context = step_context self.dofn = dofn self.key_coder = key_coder self.all_state_specs, self.all_timer_specs = ( userstate.get_dofn_specs(dofn)) self.state_tags = {} for state_spec in self.all_state_specs: state_key = 'user/%s' % state_spec.name if isinstance(state_spec, userstate.BagStateSpec): state_tag = _ListStateTag(state_key) elif isinstance(state_spec, userstate.CombiningValueStateSpec): state_tag = _ListStateTag(state_key) else: raise ValueError('Invalid state spec: %s' % state_spec) self.state_tags[state_spec] = state_tag self.cached_states = {} self.cached_timers = {}
def __init__(self, step_context, dofn, key_coder): self.step_context = step_context self.dofn = dofn self.key_coder = key_coder self.all_state_specs, self.all_timer_specs = ( userstate.get_dofn_specs(dofn)) self.state_tags = {} for state_spec in self.all_state_specs: state_key = 'user/%s' % state_spec.name if isinstance(state_spec, userstate.BagStateSpec): state_tag = _ListStateTag(state_key) elif isinstance(state_spec, userstate.CombiningValueStateSpec): state_tag = _ListStateTag(state_key) else: raise ValueError('Invalid state spec: %s' % state_spec) self.state_tags[state_spec] = state_tag self.cached_states = {} self.cached_timers = {}
def has_timers(self): # type: () -> bool _, all_timer_specs = userstate.get_dofn_specs(self.do_fn) return bool(all_timer_specs)
def has_timers(self): _, all_timer_specs = userstate.get_dofn_specs(self.do_fn) return bool(all_timer_specs)