class BadStatefulDoFn1(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) def process(self, element, b1=DoFn.StateParam(BUFFER_STATE), b2=DoFn.StateParam(BUFFER_STATE)): yield element
def _pardo_group_into_batches( input_coder, batch_size, max_buffering_duration_secs, clock=time.time): ELEMENT_STATE = BagStateSpec('values', input_coder) COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK) BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) class _GroupIntoBatchesDoFn(DoFn): def process( self, element, window=DoFn.WindowParam, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), window_timer=DoFn.TimerParam(WINDOW_TIMER), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): # Allowed lateness not supported in Python SDK # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data window_timer.set(window.end) element_state.add(element) count_state.add(1) count = count_state.read() if count == 1 and max_buffering_duration_secs > 0: # This is the first element in batch. Start counting buffering time if a # limit was set. # pylint: disable=deprecated-method buffering_timer.set(clock() + max_buffering_duration_secs) if count >= batch_size: return self.flush_batch(element_state, count_state, buffering_timer) @on_timer(WINDOW_TIMER) def on_window_timer( self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): return self.flush_batch(element_state, count_state, buffering_timer) @on_timer(BUFFERING_TIMER) def on_buffering_timer( self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): return self.flush_batch(element_state, count_state, buffering_timer) def flush_batch(self, element_state, count_state, buffering_timer): batch = [element for element in element_state.read()] if not batch: return key, _ = batch[0] batch_values = [v for (k, v) in batch] element_state.clear() count_state.clear() buffering_timer.clear() yield key, batch_values return _GroupIntoBatchesDoFn()
class IndexAssigningStatefulDoFn(DoFn): INDEX_STATE = BagStateSpec('index', VarIntCoder()) def process(self, element, state=DoFn.StateParam(INDEX_STATE)): unused_key, value = element next_index, = list(state.read()) or [0] yield (value, next_index) state.clear() state.add(next_index + 1)
class BadStatefulDoFn4(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) @on_timer(EXPIRY_TIMER_1) def expiry_callback(self, element, t1=DoFn.TimerParam(EXPIRY_TIMER_2), t2=DoFn.TimerParam(EXPIRY_TIMER_2)): yield element
class TestStatefulDoFn(DoFn): """An example stateful DoFn with state and timers.""" BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder()) BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK) def process(self, element, t=DoFn.TimestampParam, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2)): yield element @on_timer(EXPIRY_TIMER_1) def on_expiry_1(self, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam, key=DoFn.KeyParam, buffer=DoFn.StateParam(BUFFER_STATE_1), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired1' @on_timer(EXPIRY_TIMER_2) def on_expiry_2(self, buffer=DoFn.StateParam(BUFFER_STATE_2), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired2' @on_timer(EXPIRY_TIMER_3) def on_expiry_3(self, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired3'
class BasicStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): yield element @on_timer(EXPIRY_TIMER) def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)): yield element
class StatefulDoFn(DoFn): BYTES_STATE = BagStateSpec('bytes', BytesCoder()) def return_recursive(self, count): if count == 0: return ["some string"] else: self.return_recursive(count - 1) def process(self, element, counter=DoFn.StateParam(BYTES_STATE)): return self.return_recursive(1)
def test_spec_construction(self): BagStateSpec('statename', VarIntCoder()) with self.assertRaises(AssertionError): BagStateSpec(123, VarIntCoder()) CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10)) with self.assertRaises(AssertionError): CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10)) with self.assertRaises(AssertionError): CombiningValueStateSpec('statename', VarIntCoder(), object()) # BagStateSpec('bag', ) # TODO: add more spec tests with self.assertRaises(ValueError): DoFn.TimerParam(BagStateSpec('elements', BytesCoder())) TimerSpec('timer', TimeDomain.WATERMARK) TimerSpec('timer', TimeDomain.REAL_TIME) with self.assertRaises(ValueError): TimerSpec('timer', 'bogus_time_domain') with self.assertRaises(ValueError): DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
class SimpleTestStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): unused_key, value = element buffer.add(b'A' + str(value).encode('latin1')) timer1.set(20) @on_timer(EXPIRY_TIMER) def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE), timer=DoFn.TimerParam(EXPIRY_TIMER)): yield b''.join(sorted(buffer.read()))
class PairRecordsFn(beam.DoFn): """Pairs two consecutive elements after shuffle""" BUFFER = BagStateSpec('buffer', PickleCoder()) def process(self, element, buffer=beam.DoFn.StateParam(BUFFER)): try: previous_element = list(buffer.read())[0] except: previous_element = [] unused_key, value = element if previous_element: yield (previous_element, value) buffer.clear() else: buffer.add(value)
class _RemoveDuplicates(beam.DoFn): FILES_STATE = BagStateSpec('files', StrUtf8Coder()) def process(self, element, file_state=beam.DoFn.StateParam(FILES_STATE)): path = element[0] file_metadata = element[1] bag_content = [x for x in file_state.read()] if not bag_content: file_state.add(path) _LOGGER.debug('Generated entry for file %s', path) yield file_metadata else: _LOGGER.debug('File %s was already read', path)
class StatefulDoFnWithTimerWithTypo1(DoFn): # pylint: disable=unused-variable BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) def process(self, element): pass @on_timer(EXPIRY_TIMER_1) def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired1' # Note that we mistakenly associate this with the first timer. @on_timer(EXPIRY_TIMER_1) def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired2'
class CollectingFn(beam.DoFn): BUFFER_STATE = BagStateSpec('buffer', VarIntCoder()) COUNT_STATE = CombiningValueStateSpec('count', sum) def process(self, element, buffer_state=beam.DoFn.StateParam(BUFFER_STATE), count_state=beam.DoFn.StateParam(COUNT_STATE)): value = int(element[1].decode()) buffer_state.add(value) count_state.add(1) count = count_state.read() if count >= NUM_RECORDS: yield sum(buffer_state.read()) count_state.clear() buffer_state.clear()
class StatefulDoFnWithTimerWithTypo3(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) def process(self, element, timer1=DoFn.TimerParam(EXPIRY_TIMER_1), timer2=DoFn.TimerParam(EXPIRY_TIMER_2)): pass @on_timer(EXPIRY_TIMER_1) def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired1' def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired2' # Use a stable string value for matching. def __repr__(self): return 'StatefulDoFnWithTimerWithTypo3'
def _pardo_group_into_batches(batch_size, input_coder): ELEMENT_STATE = BagStateSpec('values', input_coder) COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) class _GroupIntoBatchesDoFn(DoFn): def process(self, element, window=DoFn.WindowParam, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)): # Allowed lateness not supported in Python SDK # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data expiry_timer.set(window.end) element_state.add(element) count_state.add(1) count = count_state.read() if count >= batch_size: batch = [element for element in element_state.read()] key, _ = batch[0] batch_values = [v for (k, v) in batch] yield (key, batch_values) element_state.clear() count_state.clear() @on_timer(EXPIRY_TIMER) def expiry(self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE)): batch = [element for element in element_state.read()] if batch: key, _ = batch[0] batch_values = [v for (k, v) in batch] yield (key, batch_values) element_state.clear() count_state.clear() return _GroupIntoBatchesDoFn()
class StatefulBufferingFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', StrUtf8Coder()) COUNT_STATE = userstate.CombiningValueStateSpec( 'count', VarIntCoder(), CountCombineFn()) def process(self, element, buffer_state=beam.DoFn.StateParam(BUFFER_STATE), count_state=beam.DoFn.StateParam(COUNT_STATE)): key, value = element try: index_value = list(buffer_state.read()).index(value) except: index_value = -1 if index_value < 0: buffer_state.add(value) index_value = count_state.read() count_state.add(1) # print(value, list(buffer_state.read()).index(value), list(buffer_state.read())) yield ('{}_{}'.format(value, index_value), 1)
class BagStateClearingStatefulDoFn(beam.DoFn): BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder()) EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) EMIT_TWICE_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK) def process( self, element, bag_state=beam.DoFn.StateParam(BAG_STATE), emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), emit_twice_timer=beam.DoFn.TimerParam(EMIT_TWICE_TIMER)): value = element[1] bag_state.add(value) emit_twice_timer.set(100) emit_timer.set(1000) @on_timer(EMIT_TWICE_TIMER) @on_timer(EMIT_TIMER) def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): for value in bag_state.read(): yield value
class HashJoinStatefulDoFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK) def process(self, element, state=DoFn.StateParam(BUFFER_STATE), timer=DoFn.TimerParam(UNMATCHED_TIMER)): key, value = element existing_values = list(state.read()) if not existing_values: state.add(value) timer.set(100) else: yield b'Record<%s,%s,%s>' % (key, existing_values[0], value) state.clear() timer.clear() @on_timer(UNMATCHED_TIMER) def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)): buffered = list(state.read()) assert len(buffered) == 1, buffered state.clear() yield b'Unmatched<%s>' % (buffered[0],)
class StatefulDoFnWithTimerWithTypo2(DoFn): BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) def process(self, element, timer1=DoFn.TimerParam(EXPIRY_TIMER_1), timer2=DoFn.TimerParam(EXPIRY_TIMER_2)): pass @on_timer(EXPIRY_TIMER_1) def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired1' # Note that we mistakenly reuse the "on_expiry_1" name; this is valid # syntactically in Python. @on_timer(EXPIRY_TIMER_2) def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): yield 'expired2' # Use a stable string value for matching. def __repr__(self): return 'StatefulDoFnWithTimerWithTypo2'
def test_param_construction(self): with self.assertRaises(ValueError): DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK)) with self.assertRaises(ValueError): DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
class GeneralTriggerManagerDoFn(DoFn): """A trigger manager that supports all windowing / triggering cases. This implements a DoFn that manages triggering in a per-key basis. All elements for a single key are processed together. Per-key state holds data related to all windows. """ # TODO(BEAM-12026) Add support for Global and custom window fns. KNOWN_WINDOWS = SetStateSpec('known_windows', IntervalWindowCoder()) FINISHED_WINDOWS = SetStateSpec('finished_windows', IntervalWindowCoder()) LAST_KNOWN_TIME = CombiningValueStateSpec('last_known_time', combine_fn=max) LAST_KNOWN_WATERMARK = CombiningValueStateSpec('last_known_watermark', combine_fn=max) # TODO(pabloem) What's the coder for the elements/keys here? WINDOW_ELEMENT_PAIRS = BagStateSpec( 'all_elements', TupleCoder([IntervalWindowCoder(), PickleCoder()])) WINDOW_TAG_VALUES = BagStateSpec( 'per_window_per_tag_value_state', TupleCoder([IntervalWindowCoder(), StrUtf8Coder(), VarIntCoder()])) PROCESSING_TIME_TIMER = TimerSpec('processing_time_timer', TimeDomain.REAL_TIME) WATERMARK_TIMER = TimerSpec('watermark_timer', TimeDomain.WATERMARK) def __init__(self, windowing: Windowing): self.windowing = windowing # Only session windows are merging. Other windows are non-merging. self.merging_windows = self.windowing.windowfn.is_merging() def process( self, element: typing.Tuple[ K, typing.Iterable[windowed_value.WindowedValue]], all_elements: BagRuntimeState = DoFn.StateParam( WINDOW_ELEMENT_PAIRS), # type: ignore latest_processing_time: AccumulatingRuntimeState = DoFn.StateParam( LAST_KNOWN_TIME), # type: ignore latest_watermark: AccumulatingRuntimeState = DoFn. StateParam( # type: ignore LAST_KNOWN_WATERMARK), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore windows_state: SetRuntimeState = DoFn.StateParam( KNOWN_WINDOWS), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER), *args, **kwargs): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=latest_processing_time, latest_watermark=latest_watermark, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) key, windowed_values = element watermark = read_watermark(latest_watermark) windows_to_elements = collections.defaultdict(list) for wv in windowed_values: for window in wv.windows: # ignore expired windows if watermark > window.end + self.windowing.allowed_lateness: continue if window in finished_windows_state.read(): continue windows_to_elements[window].append( TimestampedValue(wv.value, wv.timestamp)) # Processing merging of windows if self.merging_windows: old_windows = set(windows_state.read()) all_windows = old_windows.union(list(windows_to_elements)) if all_windows != old_windows: merge_context = TriggerMergeContext(all_windows, context, self.windowing) self.windowing.windowfn.merge(merge_context) merged_windows_to_elements = collections.defaultdict(list) for window, values in windows_to_elements.items(): while window in merge_context.merged_away: window = merge_context.merged_away[window] merged_windows_to_elements[window].extend(values) windows_to_elements = merged_windows_to_elements for w in windows_to_elements: windows_state.add(w) # Done processing merging of windows seen_windows = set() for w in windows_to_elements: window_context = context.for_window(w) seen_windows.add(w) for value_w_timestamp in windows_to_elements[w]: _LOGGER.debug(value_w_timestamp) all_elements.add((w, value_w_timestamp)) self.windowing.triggerfn.on_element(windowed_values, w, window_context) return self._fire_eligible_windows(key, TimeDomain.WATERMARK, watermark, None, context, seen_windows) def _fire_eligible_windows(self, key: K, time_domain, timestamp: Timestamp, timer_tag: typing.Optional[str], context: 'FnRunnerStatefulTriggerContext', windows_of_interest: typing.Optional[ typing.Set[BoundedWindow]] = None): windows_to_elements = context.windows_to_elements_map() context.all_elements_state.clear() fired_windows = set() _LOGGER.debug('%s - tag %s - timestamp %s', time_domain, timer_tag, timestamp) for w, elems in windows_to_elements.items(): if windows_of_interest is not None and w not in windows_of_interest: # windows_of_interest=None means that we care about all windows. # If we care only about some windows, and this window is not one of # them, then we do not intend to fire this window. continue window_context = context.for_window(w) if self.windowing.triggerfn.should_fire(time_domain, timestamp, w, window_context): finished = self.windowing.triggerfn.on_fire( timestamp, w, window_context) _LOGGER.debug('Firing on window %s. Finished: %s', w, finished) fired_windows.add(w) if finished: context.finished_windows_state.add(w) # TODO(pabloem): Format the output: e.g. pane info elems = [ WindowedValue(e.value, e.timestamp, (w, )) for e in elems ] yield (key, elems) finished_windows: typing.Set[BoundedWindow] = set( context.finished_windows_state.read()) # Add elements that were not fired back into state. for w, elems in windows_to_elements.items(): for e in elems: if (w in finished_windows or (w in fired_windows and self.windowing.accumulation_mode == AccumulationMode.DISCARDING)): continue context.all_elements_state.add((w, e)) @on_timer(PROCESSING_TIME_TIMER) def processing_time_trigger( self, key=DoFn.KeyParam, timer_tag=DoFn.DynamicTimerTagParam, timestamp=DoFn.TimestampParam, latest_processing_time=DoFn.StateParam(LAST_KNOWN_TIME), all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=latest_processing_time, latest_watermark=None, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) result = self._fire_eligible_windows(key, TimeDomain.REAL_TIME, timestamp, timer_tag, context) latest_processing_time.add(timestamp) return result @on_timer(WATERMARK_TIMER) def watermark_trigger( self, key=DoFn.KeyParam, timer_tag=DoFn.DynamicTimerTagParam, timestamp=DoFn.TimestampParam, latest_watermark=DoFn.StateParam(LAST_KNOWN_WATERMARK), all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=None, latest_watermark=latest_watermark, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) result = self._fire_eligible_windows(key, TimeDomain.WATERMARK, timestamp, timer_tag, context) latest_watermark.add(timestamp) return result
class SolveDoFn(beam.DoFn): PREV_TIMESTAMP = BagStateSpec(name="timestamp_state", coder=coders.PickleCoder()) PREV_ELEMENTS = BagStateSpec(name="elements_state", coder=coders.PickleCoder()) PREV_MODEL = BagStateSpec(name="model_state", coder=coders.PickleCoder()) PREV_SAMPLESET = BagStateSpec(name="sampleset_state", coder=coders.PickleCoder()) def process( self, value, timestamp=beam.DoFn.TimestampParam, timestamp_state=beam.DoFn.StateParam(PREV_TIMESTAMP), elements_state=beam.DoFn.StateParam(PREV_ELEMENTS), model_state=beam.DoFn.StateParam(PREV_MODEL), sampleset_state=beam.DoFn.StateParam(PREV_SAMPLESET), algorithm=None, algorithm_options=None, map_fn=None, solve_fn=None, unmap_fn=None, solver=LocalSolver(exact=False), # default solver initial_mtype=sawatabi.constants.MODEL_ISING, ): _, elements = value # Sort with the event time. # If we sort a list of tuples, the first element of the tuple is recognized as a key by default, # so just `sorted` is enough. sorted_elements = sorted(elements) # generator into a list timestamp_state_as_list = list(timestamp_state.read()) elements_state_as_list = list(elements_state.read()) model_state_as_list = list(model_state.read()) sampleset_state_as_list = list(sampleset_state.read()) # Extract the previous timestamp, elements, and model from state if len(timestamp_state_as_list) == 0: prev_timestamp = -1.0 else: prev_timestamp = timestamp_state_as_list[-1] if len(elements_state_as_list) == 0: prev_elements = [] else: prev_elements = elements_state_as_list[-1] if len(model_state_as_list) == 0: prev_model = sawatabi.model.LogicalModel(mtype=initial_mtype) else: prev_model = model_state_as_list[-1] if len(sampleset_state_as_list) == 0: prev_sampleset = None else: prev_sampleset = sampleset_state_as_list[-1] # Sometimes, when we use the sliding window algorithm for a bounded data (such as a local file), # we may receive an outdated event whose timestamp is older than timestamp of previously processed event. if float(timestamp) < float(prev_timestamp): yield ( f"The received event is outdated: Timestamp is {timestamp.to_utc_datetime()}, " + f"while an event with timestamp of {timestamp.to_utc_datetime()} has been already processed." ) return # Algorithm specific operations # Incremental: Append current window into the all previous data. if algorithm == sawatabi.constants.ALGORITHM_INCREMENTAL: sorted_elements.extend(prev_elements) sorted_elements = sorted(sorted_elements) # Partial: Merge current window with the specified data. elif algorithm == sawatabi.constants.ALGORITHM_PARTIAL: filter_fn = algorithm_options["filter_fn"] filtered = filter(filter_fn, prev_elements) sorted_elements = list(filtered) + sorted_elements sorted_elements = sorted(sorted_elements) # Resolve outgoing elements in this iteration def resolve_outgoing(prev_elements, sorted_elements): outgoing = [] for p in prev_elements: if p[0] >= sorted_elements[0][0]: break outgoing.append(p) return outgoing outgoing = resolve_outgoing(prev_elements, sorted_elements) # Resolve incoming elements in this iteration def resolve_incoming(prev_elements, sorted_elements): incoming = [] if len(prev_elements) == 0: incoming = sorted_elements else: for v in reversed(sorted_elements): if v[0] <= prev_elements[-1][0]: break incoming.insert(0, v) return incoming incoming = resolve_incoming(prev_elements, sorted_elements) # Clear the BagState so we can hold only the latest state, and # Register new timestamp and elements to the states timestamp_state.clear() timestamp_state.add(timestamp) elements_state.clear() elements_state.add(sorted_elements) # Map problem input to the model try: model = map_fn(prev_model, prev_sampleset, sorted_elements, incoming, outgoing) except Exception as e: yield f"Failed to map: {e}\n{traceback.format_exc()}" return # Clear the BagState so we can hold only the latest state, and # Register new model to the state model_state.clear() model_state.add(model) # Algorithm specific operations # Attenuation: Update scale based on data timestamp. if algorithm == sawatabi.constants.ALGORITHM_ATTENUATION: model.to_physical() # Resolve removed interactions. TODO: Deal with placeholders. ref_timestamp = model._interactions_array[algorithm_options["attenuation.key"]] min_ts = min(ref_timestamp) max_ts = max(ref_timestamp) min_scale = algorithm_options["attenuation.min_scale"] if min_ts < max_ts: for i, t in enumerate(ref_timestamp): new_scale = (1.0 - min_scale) / (max_ts - min_ts) * (t - min_ts) + min_scale model._interactions_array["scale"][i] = new_scale # Solve and unmap to the solution try: sampleset = solve_fn(solver, model, prev_sampleset, sorted_elements, incoming, outgoing) except Exception as e: yield f"Failed to solve: {e}\n{traceback.format_exc()}" return # Clear the BagState so we can hold only the latest state, and # Register new sampleset to the state sampleset_state.clear() sampleset_state.add(sampleset) try: yield unmap_fn(sampleset, sorted_elements, incoming, outgoing) except Exception as e: yield f"Failed to unmap: {e}\n{traceback.format_exc()}"
def _pardo_group_into_batches_with_multi_bags( input_coder, batch_size, max_buffering_duration_secs, clock=time.time): ELEMENT_STATE_0 = BagStateSpec('values0', input_coder) ELEMENT_STATE_1 = BagStateSpec('values1', input_coder) ELEMENT_STATE_2 = BagStateSpec('values2', input_coder) ELEMENT_STATE_3 = BagStateSpec('values3', input_coder) COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK) BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) class _GroupIntoBatchesDoFnWithMultiBags(DoFn): def process( self, element, window=DoFn.WindowParam, element_state_0=DoFn.StateParam(ELEMENT_STATE_0), element_state_1=DoFn.StateParam(ELEMENT_STATE_1), element_state_2=DoFn.StateParam(ELEMENT_STATE_2), element_state_3=DoFn.StateParam(ELEMENT_STATE_3), count_state=DoFn.StateParam(COUNT_STATE), window_timer=DoFn.TimerParam(WINDOW_TIMER), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): # Allowed lateness not supported in Python SDK # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data window_timer.set(window.end) count_state.add(1) count = count_state.read() element_states = [element_state_0, element_state_1, element_state_2, element_state_3] element_states[count % 4].add(element) if count == 1 and max_buffering_duration_secs > 0: # This is the first element in batch. Start counting buffering time if a # limit was set. buffering_timer.set(clock() + max_buffering_duration_secs) if count >= batch_size: return self.flush_batch(element_states, count_state, buffering_timer) @on_timer(WINDOW_TIMER) def on_window_timer( self, element_state_0=DoFn.StateParam(ELEMENT_STATE_0), element_state_1=DoFn.StateParam(ELEMENT_STATE_1), element_state_2=DoFn.StateParam(ELEMENT_STATE_2), element_state_3=DoFn.StateParam(ELEMENT_STATE_3), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): element_states = [element_state_0, element_state_1, element_state_2, element_state_3] return self.flush_batch(element_states, count_state, buffering_timer) @on_timer(BUFFERING_TIMER) def on_buffering_timer( self, element_state_0=DoFn.StateParam(ELEMENT_STATE_0), element_state_1=DoFn.StateParam(ELEMENT_STATE_1), element_state_2=DoFn.StateParam(ELEMENT_STATE_2), element_state_3=DoFn.StateParam(ELEMENT_STATE_3), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): element_states = [element_state_0, element_state_1, element_state_2, element_state_3] return self.flush_batch(element_states, count_state, buffering_timer) def flush_batch(self, element_states, count_state, buffering_timer): batch_values = [] for element_state in element_states: for k, v in element_state.read(): key = k batch_values.append(v) element_state.clear() count_state.clear() buffering_timer.clear() if not batch_values: return yield key, batch_values return _GroupIntoBatchesDoFnWithMultiBags()