def test_self_checkpoint_with_relative_time(self): threadsafe_tracker = ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) threadsafe_tracker.defer_remainder(timestamp.Duration(100)) time.sleep(2) _, deferred_time = threadsafe_tracker.deferred_status() self.assertTrue(isinstance(deferred_time, timestamp.Duration)) # The expectation = 100 - 2 - some_delta self.assertTrue(deferred_time <= 98)
def test_self_checkpoint_immediately(self): restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10)) threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker) threadsafe_tracker.defer_remainder() deferred_residual, deferred_time = threadsafe_tracker.deferred_status() expected_residual = OffsetRange(0, 10) self.assertEqual(deferred_residual, expected_residual) self.assertTrue(isinstance(deferred_time, timestamp.Duration)) self.assertEqual(deferred_time, 0)
def test_api_expose(self): threadsafe_tracker = ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) tracker_view = RestrictionTrackerView(threadsafe_tracker) current_restriction = tracker_view.current_restriction() self.assertEqual(current_restriction, OffsetRange(0, 10)) self.assertTrue(tracker_view.try_claim(0)) tracker_view.defer_remainder() deferred_remainder, deferred_watermark = ( threadsafe_tracker.deferred_status()) self.assertEqual(deferred_remainder, OffsetRange(1, 10)) self.assertEqual(deferred_watermark, timestamp.Duration())
def test_self_checkpoint_with_absolute_time(self): threadsafe_tracker = ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) now = timestamp.Timestamp.now() schedule_time = now + timestamp.Duration(100) self.assertTrue(isinstance(schedule_time, timestamp.Timestamp)) threadsafe_tracker.defer_remainder(schedule_time) time.sleep(2) _, deferred_time = threadsafe_tracker.deferred_status() self.assertTrue(isinstance(deferred_time, timestamp.Duration)) # The expectation = # schedule_time - the time when deferred_status is called - some_delta self.assertTrue(deferred_time <= 98)
def test_non_expose_apis(self): threadsafe_tracker = ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) tracker_view = RestrictionTrackerView(threadsafe_tracker) with self.assertRaises(AttributeError): tracker_view.check_done() with self.assertRaises(AttributeError): tracker_view.current_progress() with self.assertRaises(AttributeError): tracker_view.try_split() with self.assertRaises(AttributeError): tracker_view.deferred_status()
def invoke_process( self, windowed_value, # type: WindowedValue restriction_tracker=None, additional_args=None, additional_kwargs=None): # type: (...) -> Optional[SplitResultType] if not additional_args: additional_args = [] if not additional_kwargs: additional_kwargs = {} self.context.set_element(windowed_value) # Call for the process function for each window if has windowed side inputs # or if the process accesses the window parameter. We can just call it once # otherwise as none of the arguments are changing if self.is_splittable and not restriction_tracker: restriction = self.invoke_initial_restriction(windowed_value.value) restriction_tracker = self.invoke_create_tracker(restriction) if restriction_tracker: if len(windowed_value.windows) > 1 and self.has_windowed_inputs: # Should never get here due to window explosion in # the upstream pair-with-restriction. raise NotImplementedError( 'SDFs in multiply-windowed values with windowed arguments.' ) restriction_tracker_param = ( self.signature.process_method.restriction_provider_arg_name) if not restriction_tracker_param: raise ValueError( 'A RestrictionTracker %r was provided but DoFn does not have a ' 'RestrictionTrackerParam defined' % restriction_tracker) self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker( restriction_tracker) additional_kwargs[restriction_tracker_param] = ( RestrictionTrackerView(self.threadsafe_restriction_tracker)) if self.watermark_estimator: # The watermark estimator needs to be reset for every element. self.watermark_estimator.reset() additional_kwargs[self.watermark_estimator_param] = ( self.watermark_estimator) try: self.current_windowed_value = windowed_value return self._invoke_process_per_window(windowed_value, additional_args, additional_kwargs) finally: self.threadsafe_restriction_tracker = None self.current_windowed_value = windowed_value elif self.has_windowed_inputs and len(windowed_value.windows) != 1: for w in windowed_value.windows: self._invoke_process_per_window( WindowedValue(windowed_value.value, windowed_value.timestamp, (w, )), additional_args, additional_kwargs) else: self._invoke_process_per_window(windowed_value, additional_args, additional_kwargs) return None
class PerWindowInvoker(DoFnInvoker): """An invoker that processes elements considering windowing information.""" def __init__( self, output_processor, # type: OutputProcessor signature, # type: DoFnSignature context, # type: DoFnContext side_inputs, # type: Iterable[sideinputs.SideInputMap] input_args, input_kwargs, user_state_context, # type: Optional[userstate.UserStateContext] bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] ): super(PerWindowInvoker, self).__init__(output_processor, signature) self.side_inputs = side_inputs self.context = context self.process_method = signature.process_method.method_value default_arg_values = signature.process_method.defaults self.has_windowed_inputs = (not all(si.is_globally_windowed() for si in side_inputs) or (core.DoFn.WindowParam in default_arg_values) or signature.is_stateful_dofn()) self.user_state_context = user_state_context self.is_splittable = signature.is_splittable_dofn() self.watermark_estimator = self.signature.get_watermark_estimator() self.watermark_estimator_param = ( self.signature.process_method.watermark_estimator_arg_name if self.watermark_estimator else None) self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker] self.current_windowed_value = None # type: Optional[WindowedValue] self.bundle_finalizer_param = bundle_finalizer_param self.is_key_param_required = False # Try to prepare all the arguments that can just be filled in # without any additional work. in the process function. # Also cache all the placeholders needed in the process function. # Flag to cache additional arguments on the first element if all # inputs are within the global window. self.cache_globally_windowed_args = not self.has_windowed_inputs input_args = input_args if input_args else [] input_kwargs = input_kwargs if input_kwargs else {} arg_names = signature.process_method.args # Create placeholder for element parameter of DoFn.process() method. # Not to be confused with ArgumentPlaceHolder, which may be passed in # input_args and is a placeholder for side-inputs. class ArgPlaceholder(object): def __init__(self, placeholder): self.placeholder = placeholder if core.DoFn.ElementParam not in default_arg_values: # TODO(BEAM-7867): Handle cases in which len(arg_names) == # len(default_arg_values). args_to_pick = len(arg_names) - len(default_arg_values) - 1 # Positional argument values for process(), with placeholders for special # values such as the element, timestamp, etc. args_with_placeholders = ( [ArgPlaceholder(core.DoFn.ElementParam)] + input_args[:args_to_pick]) else: args_to_pick = len(arg_names) - len(default_arg_values) args_with_placeholders = input_args[:args_to_pick] # Fill the OtherPlaceholders for context, key, window or timestamp remaining_args_iter = iter(input_args[args_to_pick:]) for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values): if core.DoFn.ElementParam == d: args_with_placeholders.append(ArgPlaceholder(d)) elif core.DoFn.KeyParam == d: self.is_key_param_required = True args_with_placeholders.append(ArgPlaceholder(d)) elif core.DoFn.WindowParam == d: args_with_placeholders.append(ArgPlaceholder(d)) elif core.DoFn.TimestampParam == d: args_with_placeholders.append(ArgPlaceholder(d)) elif core.DoFn.PaneInfoParam == d: args_with_placeholders.append(ArgPlaceholder(d)) elif core.DoFn.SideInputParam == d: # If no more args are present then the value must be passed via kwarg try: args_with_placeholders.append(next(remaining_args_iter)) except StopIteration: if a not in input_kwargs: raise ValueError( "Value for sideinput %s not provided" % a) elif isinstance(d, core.DoFn.StateParam): args_with_placeholders.append(ArgPlaceholder(d)) elif isinstance(d, core.DoFn.TimerParam): args_with_placeholders.append(ArgPlaceholder(d)) elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d: args_with_placeholders.append(ArgPlaceholder(d)) else: # If no more args are present then the value must be passed via kwarg try: args_with_placeholders.append(next(remaining_args_iter)) except StopIteration: pass args_with_placeholders.extend(list(remaining_args_iter)) # Stash the list of placeholder positions for performance self.placeholders = [(i, x.placeholder) for (i, x) in enumerate(args_with_placeholders) if isinstance(x, ArgPlaceholder)] self.args_for_process = args_with_placeholders self.kwargs_for_process = input_kwargs def invoke_process( self, windowed_value, # type: WindowedValue restriction_tracker=None, additional_args=None, additional_kwargs=None): # type: (...) -> Optional[SplitResultType] if not additional_args: additional_args = [] if not additional_kwargs: additional_kwargs = {} self.context.set_element(windowed_value) # Call for the process function for each window if has windowed side inputs # or if the process accesses the window parameter. We can just call it once # otherwise as none of the arguments are changing if self.is_splittable and not restriction_tracker: restriction = self.invoke_initial_restriction(windowed_value.value) restriction_tracker = self.invoke_create_tracker(restriction) if restriction_tracker: if len(windowed_value.windows) > 1 and self.has_windowed_inputs: # Should never get here due to window explosion in # the upstream pair-with-restriction. raise NotImplementedError( 'SDFs in multiply-windowed values with windowed arguments.' ) restriction_tracker_param = ( self.signature.process_method.restriction_provider_arg_name) if not restriction_tracker_param: raise ValueError( 'A RestrictionTracker %r was provided but DoFn does not have a ' 'RestrictionTrackerParam defined' % restriction_tracker) self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker( restriction_tracker) additional_kwargs[restriction_tracker_param] = ( RestrictionTrackerView(self.threadsafe_restriction_tracker)) if self.watermark_estimator: # The watermark estimator needs to be reset for every element. self.watermark_estimator.reset() additional_kwargs[self.watermark_estimator_param] = ( self.watermark_estimator) try: self.current_windowed_value = windowed_value return self._invoke_process_per_window(windowed_value, additional_args, additional_kwargs) finally: self.threadsafe_restriction_tracker = None self.current_windowed_value = windowed_value elif self.has_windowed_inputs and len(windowed_value.windows) != 1: for w in windowed_value.windows: self._invoke_process_per_window( WindowedValue(windowed_value.value, windowed_value.timestamp, (w, )), additional_args, additional_kwargs) else: self._invoke_process_per_window(windowed_value, additional_args, additional_kwargs) return None def _invoke_process_per_window( self, windowed_value, # type: WindowedValue additional_args, additional_kwargs, ): # type: (...) -> Optional[SplitResultResidual] if self.has_windowed_inputs: window, = windowed_value.windows side_inputs = [si[window] for si in self.side_inputs] side_inputs.extend(additional_args) args_for_process, kwargs_for_process = util.insert_values_in_args( self.args_for_process, self.kwargs_for_process, side_inputs) elif self.cache_globally_windowed_args: # Attempt to cache additional args if all inputs are globally # windowed inputs when processing the first element. self.cache_globally_windowed_args = False # Fill in sideInputs if they are globally windowed global_window = GlobalWindow() self.args_for_process, self.kwargs_for_process = ( util.insert_values_in_args( self.args_for_process, self.kwargs_for_process, [si[global_window] for si in self.side_inputs])) args_for_process, kwargs_for_process = (self.args_for_process, self.kwargs_for_process) else: args_for_process, kwargs_for_process = (self.args_for_process, self.kwargs_for_process) # Extract key in the case of a stateful DoFn. Note that in the case of a # stateful DoFn, we set during __init__ self.has_windowed_inputs to be # True. Therefore, windows will be exploded coming into this method, and # we can rely on the window variable being set above. if self.user_state_context or self.is_key_param_required: try: key, unused_value = windowed_value.value except (TypeError, ValueError): raise ValueError(( 'Input value to a stateful DoFn or KeyParam must be a KV tuple; ' 'instead, got \'%s\'.') % (windowed_value.value, )) for i, p in self.placeholders: if core.DoFn.ElementParam == p: args_for_process[i] = windowed_value.value elif core.DoFn.KeyParam == p: args_for_process[i] = key elif core.DoFn.WindowParam == p: args_for_process[i] = window elif core.DoFn.TimestampParam == p: args_for_process[i] = windowed_value.timestamp elif core.DoFn.PaneInfoParam == p: args_for_process[i] = windowed_value.pane_info elif isinstance(p, core.DoFn.StateParam): assert self.user_state_context is not None args_for_process[i] = (self.user_state_context.get_state( p.state_spec, key, window)) elif isinstance(p, core.DoFn.TimerParam): assert self.user_state_context is not None args_for_process[i] = (self.user_state_context.get_timer( p.timer_spec, key, window)) elif core.DoFn.BundleFinalizerParam == p: args_for_process[i] = self.bundle_finalizer_param if additional_kwargs: if kwargs_for_process is None: kwargs_for_process = additional_kwargs else: for key in additional_kwargs: kwargs_for_process[key] = additional_kwargs[key] if kwargs_for_process: self.output_processor.process_outputs( windowed_value, self.process_method(*args_for_process, **kwargs_for_process)) else: self.output_processor.process_outputs( windowed_value, self.process_method(*args_for_process)) if self.is_splittable: assert self.threadsafe_restriction_tracker is not None # TODO: Consider calling check_done right after SDF.Process() finishing. # In order to do this, we need to know that current invoking dofn is # ProcessSizedElementAndRestriction. self.threadsafe_restriction_tracker.check_done() deferred_status = self.threadsafe_restriction_tracker.deferred_status( ) current_watermark = None if self.watermark_estimator: current_watermark = self.watermark_estimator.current_watermark( ) if deferred_status: deferred_restriction, deferred_timestamp = deferred_status element = windowed_value.value size = self.signature.get_restriction_provider( ).restriction_size(element, deferred_restriction) residual_value = ((element, deferred_restriction), size) return SplitResultResidual( residual_value=windowed_value.with_value(residual_value), current_watermark=current_watermark, deferred_timestamp=deferred_timestamp) return None def try_split(self, fraction): # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]] if self.threadsafe_restriction_tracker and self.current_windowed_value: # Temporary workaround for [BEAM-7473]: get current_watermark before # split, in case watermark gets advanced before getting split results. # In worst case, current_watermark is always stale, which is ok. if self.watermark_estimator: current_watermark = self.watermark_estimator.current_watermark( ) else: current_watermark = None split = self.threadsafe_restriction_tracker.try_split(fraction) if split: primary, residual = split element = self.current_windowed_value.value restriction_provider = self.signature.get_restriction_provider( ) primary_size = restriction_provider.restriction_size( element, primary) residual_size = restriction_provider.restriction_size( element, residual) primary_value = ((element, primary), primary_size) residual_value = ((element, residual), residual_size) return (SplitResultPrimary( primary_value=self.current_windowed_value.with_value( primary_value)), SplitResultResidual( residual_value=self.current_windowed_value. with_value(residual_value), current_watermark=current_watermark, deferred_timestamp=None)) return None def current_element_progress(self): # type: () -> Optional[RestrictionProgress] restriction_tracker = self.threadsafe_restriction_tracker if restriction_tracker: return restriction_tracker.current_progress() else: return None
def test_defer_remainder_with_wrong_time_type(self): threadsafe_tracker = ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) with self.assertRaises(ValueError): threadsafe_tracker.defer_remainder(10)
def test_initialization(self): with self.assertRaises(ValueError): ThreadsafeRestrictionTracker(RangeSource(0, 1))