def test_non_expose_apis(self): threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker) with self.assertRaises(AttributeError): tracker_view.check_done() with self.assertRaises(AttributeError): tracker_view.current_progress() with self.assertRaises(AttributeError): tracker_view.try_split() with self.assertRaises(AttributeError): tracker_view.deferred_status()
def test_api_expose(self): threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( OffsetRestrictionTracker(OffsetRange(0, 10))) tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker) current_restriction = tracker_view.current_restriction() self.assertEqual(current_restriction, OffsetRange(0, 10)) self.assertTrue(tracker_view.try_claim(0)) tracker_view.defer_remainder() deferred_remainder, deferred_watermark = ( threadsafe_tracker.deferred_status()) self.assertEqual(deferred_remainder, OffsetRange(1, 10)) self.assertEqual(deferred_watermark, timestamp.Duration())
def test_initialization(self): with self.assertRaises(ValueError): iobase.RestrictionTrackerView( OffsetRestrictionTracker(OffsetRange(0, 10)))
def invoke_process(self, windowed_value, # type: WindowedValue restriction_tracker=None, additional_args=None, additional_kwargs=None ): # type: (...) -> Optional[SplitResultType] if not additional_args: additional_args = [] if not additional_kwargs: additional_kwargs = {} self.context.set_element(windowed_value) # Call for the process function for each window if has windowed side inputs # or if the process accesses the window parameter. We can just call it once # otherwise as none of the arguments are changing if self.is_splittable and not restriction_tracker: restriction = self.invoke_initial_restriction(windowed_value.value) restriction_tracker = self.invoke_create_tracker(restriction) if restriction_tracker: if len(windowed_value.windows) > 1 and self.has_windowed_inputs: # Should never get here due to window explosion in # the upstream pair-with-restriction. raise NotImplementedError( 'SDFs in multiply-windowed values with windowed arguments.') restriction_tracker_param = ( self.signature.process_method.restriction_provider_arg_name) if not restriction_tracker_param: raise ValueError( 'A RestrictionTracker %r was provided but DoFn does not have a ' 'RestrictionTrackerParam defined' % restriction_tracker) from apache_beam.io import iobase self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker( restriction_tracker) additional_kwargs[restriction_tracker_param] = ( iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker)) if self.watermark_estimator: # The watermark estimator needs to be reset for every element. self.watermark_estimator.reset() additional_kwargs[self.watermark_estimator_param] = ( self.watermark_estimator) try: self.current_windowed_value = windowed_value return self._invoke_process_per_window( windowed_value, additional_args, additional_kwargs) finally: self.threadsafe_restriction_tracker = None self.current_windowed_value = windowed_value elif self.has_windowed_inputs and len(windowed_value.windows) != 1: for w in windowed_value.windows: self._invoke_process_per_window( WindowedValue( windowed_value.value, windowed_value.timestamp, (w, )), additional_args, additional_kwargs) else: self._invoke_process_per_window( windowed_value, additional_args, additional_kwargs) return None