def __init__(self, do_fn): # We add a property here for all methods defined by Beam DoFn features. assert isinstance(do_fn, core.DoFn) self.do_fn = do_fn self.process_method = MethodWrapper(do_fn, 'process') self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') restriction_provider = self.get_restriction_provider() self.initial_restriction_method = (MethodWrapper( restriction_provider, 'initial_restriction') if restriction_provider else None) self.restriction_coder_method = (MethodWrapper(restriction_provider, 'restriction_coder') if restriction_provider else None) self.create_tracker_method = (MethodWrapper(restriction_provider, 'create_tracker') if restriction_provider else None) self.split_method = (MethodWrapper(restriction_provider, 'split') if restriction_provider else None) self._validate() # Handle stateful DoFns. self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) self.timer_methods = {} if self._is_stateful_dofn: # Populate timer firing methods, keyed by TimerSpec. _, all_timer_specs = userstate.get_dofn_specs(do_fn) for timer_spec in all_timer_specs: method = timer_spec._attached_callback self.timer_methods[timer_spec] = MethodWrapper( do_fn, method.__name__)
def should_execute_serially(self, applied_ptransform): """Returns True if this applied_ptransform should run one bundle at a time. Some TransformEvaluators use a global state object to keep track of their global execution state. For example evaluator for _GroupByKeyOnly uses this state as an in memory dictionary to buffer keys. Serially executed evaluators will act as syncing point in the graph and execution will not move forward until they receive all of their inputs. Once they receive all of their input, they will release the combined output. Their output may consist of multiple bundles as they may divide their output into pieces before releasing. Args: applied_ptransform: Transform to be used for execution. Returns: True if executor should execute applied_ptransform serially. """ if isinstance(applied_ptransform.transform, (core._GroupByKeyOnly, _StreamingGroupByKeyOnly, _StreamingGroupAlsoByWindow, _NativeWrite)): return True elif (isinstance(applied_ptransform.transform, core.ParDo) and is_stateful_dofn(applied_ptransform.transform.dofn)): return True return False
def visit_transform(self, applied_ptransform): transform = applied_ptransform.transform # The FnApiRunner does not support streaming execution. if isinstance(transform, TestStream): self.supported_by_fnapi_runner = False # The FnApiRunner does not support reads from NativeSources. if (isinstance(transform, beam.io.Read) and isinstance(transform.source, NativeSource)): self.supported_by_fnapi_runner = False # The FnApiRunner does not support the use of _NativeWrites. if isinstance(transform, _NativeWrite): self.supported_by_fnapi_runner = False if isinstance(transform, beam.ParDo): dofn = transform.dofn # The FnApiRunner does not support execution of CombineFns with # deferred side inputs. if isinstance(dofn, CombineValuesDoFn): args, kwargs = transform.raw_side_inputs args_to_check = itertools.chain(args, kwargs.values()) if any( isinstance(arg, ArgumentPlaceholder) for arg in args_to_check): self.supported_by_fnapi_runner = False if userstate.is_stateful_dofn(dofn): _, timer_specs = userstate.get_dofn_specs(dofn) for timer in timer_specs: if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False
def __init__(self, do_fn): # We add a property here for all methods defined by Beam DoFn features. assert isinstance(do_fn, core.DoFn) self.do_fn = do_fn self.process_method = MethodWrapper(do_fn, 'process') self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') restriction_provider = self.get_restriction_provider() self.initial_restriction_method = ( MethodWrapper(restriction_provider, 'initial_restriction') if restriction_provider else None) self.restriction_coder_method = ( MethodWrapper(restriction_provider, 'restriction_coder') if restriction_provider else None) self.create_tracker_method = ( MethodWrapper(restriction_provider, 'create_tracker') if restriction_provider else None) self.split_method = ( MethodWrapper(restriction_provider, 'split') if restriction_provider else None) self._validate() # Handle stateful DoFns. self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) self.timer_methods = {} if self._is_stateful_dofn: # Populate timer firing methods, keyed by TimerSpec. _, all_timer_specs = userstate.get_dofn_specs(do_fn) for timer_spec in all_timer_specs: method = timer_spec._attached_callback self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
def start_bundle(self): transform = self._applied_ptransform.transform self._tagged_receivers = _TaggedReceivers(self._evaluation_context) for output_tag in self._applied_ptransform.outputs: output_pcollection = pvalue.PCollection(None, tag=output_tag) output_pcollection.producer = self._applied_ptransform self._tagged_receivers[output_tag] = ( self._evaluation_context.create_bundle(output_pcollection)) self._tagged_receivers[output_tag].tag = output_tag self._counter_factory = counters.CounterFactory() # TODO(aaltay): Consider storing the serialized form as an optimization. dofn = ( pickler.loads(pickler.dumps(transform.dofn)) if self._perform_dofn_pickle_test else transform.dofn) args = transform.args if hasattr(transform, 'args') else [] kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {} self.user_state_context = None self.user_timer_map = {} if is_stateful_dofn(dofn): kv_type_hint = self._applied_ptransform.inputs[0].element_type if kv_type_hint and kv_type_hint != Any: coder = coders.registry.get_coder(kv_type_hint) self.key_coder = coder.key_coder() else: self.key_coder = coders.registry.get_coder(Any) self.user_state_context = DirectUserStateContext( self._step_context, dofn, self.key_coder) _, all_timer_specs = get_dofn_specs(dofn) for timer_spec in all_timer_specs: self.user_timer_map['user/%s' % timer_spec.name] = timer_spec self.runner = DoFnRunner( dofn, args, kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, tagged_receivers=self._tagged_receivers, step_name=self._applied_ptransform.full_label, state=DoFnState(self._counter_factory), user_state_context=self.user_state_context) self.runner.setup() self.runner.start()
def start_bundle(self): transform = self._applied_ptransform.transform self._tagged_receivers = _TaggedReceivers(self._evaluation_context) for output_tag in self._applied_ptransform.outputs: output_pcollection = pvalue.PCollection(None, tag=output_tag) output_pcollection.producer = self._applied_ptransform self._tagged_receivers[output_tag] = ( self._evaluation_context.create_bundle(output_pcollection)) self._tagged_receivers[output_tag].tag = output_tag self._counter_factory = counters.CounterFactory() # TODO(aaltay): Consider storing the serialized form as an optimization. dofn = (pickler.loads(pickler.dumps(transform.dofn)) if self._perform_dofn_pickle_test else transform.dofn) args = transform.args if hasattr(transform, 'args') else [] kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {} self.user_state_context = None self.user_timer_map = {} if is_stateful_dofn(dofn): kv_type_hint = self._applied_ptransform.inputs[0].element_type if kv_type_hint and kv_type_hint != typehints.Any: coder = coders.registry.get_coder(kv_type_hint) self.key_coder = coder.key_coder() else: self.key_coder = coders.registry.get_coder(typehints.Any) self.user_state_context = DirectUserStateContext( self._step_context, dofn, self.key_coder) _, all_timer_specs = get_dofn_specs(dofn) for timer_spec in all_timer_specs: self.user_timer_map['user/%s' % timer_spec.name] = timer_spec self.runner = DoFnRunner( dofn, args, kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, tagged_receivers=self._tagged_receivers, step_name=self._applied_ptransform.full_label, state=DoFnState(self._counter_factory), user_state_context=self.user_state_context) self.runner.start()
def test_stateful_dofn_detection(self): self.assertFalse(is_stateful_dofn(DoFn())) self.assertTrue(is_stateful_dofn(TestStatefulDoFn()))
def _create_pardo_operation(factory, transform_id, transform_proto, consumers, serialized_fn, side_inputs_proto=None): if side_inputs_proto: input_tags_to_coders = factory.get_input_coders(transform_proto) tagged_side_inputs = [ (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) for tag, si in side_inputs_proto.items() ] tagged_side_inputs.sort(key=lambda tag_si: int( re.match('side([0-9]+)(-.*)?$', tag_si[0]).group(1))) side_input_maps = [ StateBackedSideInputMap(factory.state_handler, transform_id, tag, si, input_tags_to_coders[tag]) for tag, si in tagged_side_inputs ] else: side_input_maps = [] output_tags = list(transform_proto.outputs.keys()) # Hack to match out prefix injected by dataflow runner. def mutate_tag(tag): if 'None' in output_tags: if tag == 'None': return 'out' else: return 'out_' + tag else: return tag dofn_data = pickler.loads(serialized_fn) if not dofn_data[-1]: # Windowing not set. side_input_tags = side_inputs_proto or () pcoll_id, = [ pcoll for tag, pcoll in transform_proto.inputs.items() if tag not in side_input_tags ] windowing = factory.context.windowing_strategies.get_by_id( factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, )) if userstate.is_stateful_dofn(dofn_data[0]): input_coder = factory.get_only_input_coder(transform_proto) user_state_context = FnApiUserStateContext(factory.state_handler, transform_id, input_coder.key_coder(), input_coder.window_coder) else: user_state_context = None output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=serialized_fn, output_tags=[mutate_tag(tag) for tag in output_tags], input=None, side_inputs=None, # Fn API uses proto definitions and the Fn State API output_coders=[output_coders[tag] for tag in output_tags]) return factory.augment_oldstyle_op( operations.DoOperation(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, side_input_maps, user_state_context), transform_proto.unique_name, consumers, output_tags)