def test_state_spec_proto_conversion(self): context = pipeline_context.PipelineContext() state = BagStateSpec('statename', VarIntCoder()) state_proto = state.to_runner_api(context) self.assertEqual( beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn), state_proto.protocol) context = pipeline_context.PipelineContext() state = CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10)) state_proto = state.to_runner_api(context) self.assertEqual( beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn), state_proto.protocol) context = pipeline_context.PipelineContext() state = SetStateSpec('setstatename', VarIntCoder()) state_proto = state.to_runner_api(context) self.assertEqual( beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn), state_proto.protocol) context = pipeline_context.PipelineContext() state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder()) state_proto = state.to_runner_api(context) self.assertEqual( beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn), state_proto.protocol)
def expand(self, pbegin): if not isinstance(pbegin, pvalue.PBegin): raise Exception("GenerateSequence must be a root transform") coder = VarIntCoder() coder_urn = ['beam:coder:varint:v1'] args = { 'start': ConfigValue( coder_urn=coder_urn, payload=coder.encode(self.start)) } if self.stop: args['stop'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(self.stop)) if self.elements_per_period: args['elements_per_period'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(self.elements_per_period)) if self.max_read_time: args['max_read_time'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(self.max_read_time)) payload = ExternalConfigurationPayload(configuration=args) return pbegin.apply( ExternalTransform( self._urn, payload.SerializeToString(), self.expansion_service))
def test_extend_fetches_initial_state(self): coder = VarIntCoder() coder_impl = coder.get_impl() class UnderlyingStateHandler(object): """Simply returns an incremented counter as the state "value." """ def set_value(self, value): self._encoded_values = coder.encode(value) def get_raw(self, *args): return self._encoded_values, None def append_raw(self, _key, bytes): self._encoded_values += bytes def clear(self, *args): self._encoded_values = bytes() @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state_handler = UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) def get(): return list(handler.blocking_get(state, coder_impl, True)) def append(value): handler.extend(state, coder_impl, [value], True) def clear(): handler.clear(state, True) # Initialize state underlying_state_handler.set_value(42) with handler.process_instruction_id('bundle', [cache_token]): # Append without reading beforehand append(43) self.assertEqual(get(), [42, 43]) clear() self.assertEqual(get(), []) append(44) self.assertEqual(get(), [44])
def test_spec_construction(self): BagStateSpec('statename', VarIntCoder()) with self.assertRaises(TypeError): BagStateSpec(123, VarIntCoder()) CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10)) with self.assertRaises(TypeError): CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10)) with self.assertRaises(TypeError): CombiningValueStateSpec('statename', VarIntCoder(), object()) SetStateSpec('setstatename', VarIntCoder()) with self.assertRaises(TypeError): SetStateSpec(123, VarIntCoder()) with self.assertRaises(TypeError): SetStateSpec('setstatename', object()) ReadModifyWriteStateSpec('valuestatename', VarIntCoder()) with self.assertRaises(TypeError): ReadModifyWriteStateSpec(123, VarIntCoder()) with self.assertRaises(TypeError): ReadModifyWriteStateSpec('valuestatename', object()) # TODO: add more spec tests with self.assertRaises(ValueError): DoFn.TimerParam(BagStateSpec('elements', BytesCoder())) TimerSpec('timer', TimeDomain.WATERMARK) TimerSpec('timer', TimeDomain.REAL_TIME) with self.assertRaises(ValueError): TimerSpec('timer', 'bogus_time_domain') with self.assertRaises(ValueError): DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
class CountAndSchedule(beam.DoFn): COUNTER = BagStateSpec('counter', VarIntCoder()) SCHEDULED_TIMESTAMP = BagStateSpec('nextSchedule', VarIntCoder()) TIMER = TimerSpec('timer', TimeDomain.WATERMARK) def process(self, element, timestamp=beam.DoFn.TimestampParam, timer=beam.DoFn.TimerParam(TIMER), counter=beam.DoFn.StateParam(COUNTER), next_schedule=beam.DoFn.StateParam(SCHEDULED_TIMESTAMP), *args, **kwargs): current_count, = list(counter.read()) or [0] counter.clear() counter.add(current_count + 1) event_datetime = timestamp.to_utc_datetime() current_hour_end = event_datetime.replace( second=0, microsecond=0) + timedelta(minutes=1) next_tick = calendar.timegm(current_hour_end.timetuple()) timer.set(next_tick) next_schedule.clear() next_schedule.add(next_tick) @on_timer(TIMER) def timer_ticked(self, timer=beam.DoFn.TimerParam(TIMER), counter=beam.DoFn.StateParam(COUNTER), next_schedule=beam.DoFn.StateParam(SCHEDULED_TIMESTAMP)): print("TICKTICK") current_count, = counter.read() this_tick, = next_schedule.read() next_tick = this_tick + 60 next_schedule.clear() next_schedule.add(next_tick) counter.clear() counter.add(0) timer.clear() timer.set(next_tick) yield {'count': current_count, 'timestamp': this_tick}
def test_implicit_payload_builder_with_bytes(self): values = PayloadBase.bytes_values builder = ImplicitSchemaPayloadBuilder(values) result = builder.build() if sys.version_info[0] < 3: # in python 2.x bytes coder will be inferred args = { 'integer_example': ConfigValue( coder_urn=['beam:coder:varint:v1'], payload=VarIntCoder() .get_impl().encode_nested(values['integer_example'])), 'string_example': ConfigValue( coder_urn=['beam:coder:bytes:v1'], payload=StrUtf8Coder() .get_impl().encode_nested(values['string_example'])), 'list_of_strings': ConfigValue( coder_urn=['beam:coder:iterable:v1', 'beam:coder:bytes:v1'], payload=IterableCoder(StrUtf8Coder()) .get_impl().encode_nested(values['list_of_strings'])), 'optional_kv': ConfigValue( coder_urn=['beam:coder:kv:v1', 'beam:coder:bytes:v1', 'beam:coder:double:v1'], payload=TupleCoder([StrUtf8Coder(), FloatCoder()]) .get_impl().encode_nested(values['optional_kv'])), } expected = get_payload(args) self.assertEqual(result, expected) else: expected = get_payload(PayloadBase.args) self.assertEqual(result, expected)
class GenerateRecords(beam.DoFn): EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.REAL_TIME) COUNT_STATE = CombiningValueStateSpec('count_state', VarIntCoder(), CountCombineFn()) def __init__(self, frequency, total_records): self.total_records = total_records self.frequency = frequency def process(self, element, emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)): # Processing time timers should be set on ABSOLUTE TIME. emit_timer.set(self.frequency) yield element[1] @on_timer(EMIT_TIMER) def emit_values(self, emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), count_state=beam.DoFn.StateParam(COUNT_STATE)): count = count_state.read() or 0 if self.total_records == count: return count_state.add(1) # Processing time timers should be set on ABSOLUTE TIME. emit_timer.set(count + 1 + self.frequency) yield 'value'
class IndexAssigningStatefulDoFn(DoFn): INDEX_STATE = BagStateSpec('index', VarIntCoder()) def process(self, element, state=DoFn.StateParam(INDEX_STATE)): unused_key, value = element next_index, = list(state.read()) or [0] yield (value, next_index) state.clear() state.add(next_index + 1)
class TestStatefulDoFn(DoFn): """An example stateful DoFn with state and timers.""" BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder()) BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK) EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family', TimeDomain.WATERMARK) def process( self, element, t=DoFn.TimestampParam, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)): yield element @on_timer(EXPIRY_TIMER_1) def on_expiry_1( self, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam, key=DoFn.KeyParam, buffer=DoFn.StateParam(BUFFER_STATE_1), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired1' @on_timer(EXPIRY_TIMER_2) def on_expiry_2( self, buffer=DoFn.StateParam(BUFFER_STATE_2), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired2' @on_timer(EXPIRY_TIMER_3) def on_expiry_3( self, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired3' @on_timer(EXPIRY_TIMER_FAMILY) def on_expiry_family( self, dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY), dynamic_timer_tag=DoFn.DynamicTimerTagParam): yield (dynamic_timer_tag, 'expired_dynamic_timer')
class SetStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): _, value = element aggregated_value = 0 set_state.add(value) for saved_value in set_state.read(): aggregated_value += saved_value yield aggregated_value
class StatefulPrintDoFn(beam.DoFn): COUNTER_SPEC = ReadModifyWriteStateSpec('counter', VarIntCoder()) def __init__(self, step_name): self._step_name = step_name def process(self, element, counter=beam.DoFn.StateParam(COUNTER_SPEC)): current_count = counter.read() or 0 logging.info('Print [%s] (counter:%d): %s', self._step_name, current_count, element) counter.write(current_count + 1)
def __init__(self, start, stop=None, elements_per_period=None, max_read_time=None, expansion_service=None): coder = VarIntCoder() coder_urn = 'beam:coder:varint:v1' args = { 'start': ConfigValue(coder_urn=coder_urn, payload=coder.encode(start)) } if stop: args['stop'] = ConfigValue(coder_urn=coder_urn, payload=coder.encode(stop)) if elements_per_period: args['elements_per_period'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(elements_per_period)) if max_read_time: args['max_read_time'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(max_read_time)) payload = ExternalConfigurationPayload(configuration=args) super(GenerateSequence, self).__init__('beam:external:java:generate_sequence:v1', payload.SerializeToString(), expansion_service)
def expand(self, pbegin): if not isinstance(pbegin, pvalue.PBegin): raise Exception("GenerateSequence must be a root transform") coder = VarIntCoder() coder_urn = ['beam:coder:varint:v1'] args = { 'start': ConfigValue(coder_urn=coder_urn, payload=coder.encode(self.start)) } if self.stop: args['stop'] = ConfigValue(coder_urn=coder_urn, payload=coder.encode(self.stop)) if self.elements_per_period: args['elements_per_period'] = ConfigValue( coder_urn=coder_urn, payload=coder.encode(self.elements_per_period)) if self.max_read_time: args['max_read_time'] = ConfigValue(coder_urn=coder_urn, payload=coder.encode( self.max_read_time)) payload = ExternalConfigurationPayload(configuration=args) return pbegin.apply( ExternalTransform(self._urn, payload.SerializeToString(), self.expansion_service))
def test_continuation_token(self): underlying_state_handler = self.UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) coder = VarIntCoder() state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) def get(materialize=True): result = handler.blocking_get(state, coder.get_impl()) return list(result) if materialize else result def get_type(): return type(get(materialize=False)) def append(*values): handler.extend(state, coder.get_impl(), values) def clear(): handler.clear(state) underlying_state_handler.set_continuations(True) underlying_state_handler.set_values([45, 46, 47], coder) with handler.process_instruction_id('bundle', [cache_token]): self.assertEqual(get_type(), CachingStateHandler.ContinuationIterable) self.assertEqual(get(), [45, 46, 47]) append(48, 49) self.assertEqual(get_type(), CachingStateHandler.ContinuationIterable) self.assertEqual(get(), [45, 46, 47, 48, 49]) clear() self.assertEqual(get_type(), list) self.assertEqual(get(), []) append(1) self.assertEqual(get(), [1]) append(2, 3) self.assertEqual(get(), [1, 2, 3]) clear() for i in range(1000): append(i) self.assertEqual(get_type(), list) self.assertEqual(get(), [i for i in range(1000)])
class SimpleTestSetStatefulDoFn(DoFn): BUFFER_STATE = SetStateSpec('buffer', VarIntCoder()) EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): unused_key, value = element buffer.add(value) timer1.set(20) @on_timer(EXPIRY_TIMER) def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)): yield sorted(buffer.read())
def from_proto(field_type): """ Creates the corresponding :class:`Coder` given the protocol representation of the field type. :param field_type: the protocol representation of the field type :return: :class:`Coder` """ if field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.BIGINT: return VarIntCoder() elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW: return RowCoder( [from_proto(f.type) for f in field_type.row_schema.fields]) else: raise ValueError("field_type %s is not supported." % field_type)
class SimpleTestStatefulDoFn(DoFn): BUFFER_STATE = CombiningValueStateSpec( 'buffer', IterableCoder(VarIntCoder()), ToListCombineFn()) EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), timer1=DoFn.TimerParam(EXPIRY_TIMER)): unused_key, value = element buffer.add(value) timer1.set(20) @on_timer(EXPIRY_TIMER) def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE), timer=DoFn.TimerParam(EXPIRY_TIMER)): yield ''.join(str(x) for x in sorted(buffer.read()))
class BagInStateOutputAfterTimer(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE), emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)): _, values = element for v in values: set_state.add(v) emit_timer.set(1) @on_timer(EMIT_TIMER) def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): values = set_state.read() return [(random.randint(0, 1000), v) for v in values]
def test_append_clear_with_preexisting_state(self): state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) coder = VarIntCoder() underlying_state_handler = self.UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) def get(): return handler.blocking_get(state, coder.get_impl()) def append(iterable): handler.extend(state, coder.get_impl(), iterable) def clear(): handler.clear(state) # Initialize state underlying_state_handler.set_value(42, coder) with handler.process_instruction_id('bundle', [cache_token]): # Append without reading beforehand append([43]) self.assertEqual(get(), [42, 43]) clear() self.assertEqual(get(), []) append([44, 45]) self.assertEqual(get(), [44, 45]) append((46, 47)) self.assertEqual(get(), [44, 45, 46, 47]) clear() append(range(1000)) self.assertEqual(get(), list(range(1000)))
class StatefulBufferingFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', StrUtf8Coder()) COUNT_STATE = userstate.CombiningValueStateSpec( 'count', VarIntCoder(), CountCombineFn()) def process(self, element, buffer_state=beam.DoFn.StateParam(BUFFER_STATE), count_state=beam.DoFn.StateParam(COUNT_STATE)): key, value = element try: index_value = list(buffer_state.read()).index(value) except: index_value = -1 if index_value < 0: buffer_state.add(value) index_value = count_state.read() count_state.add(1) # print(value, list(buffer_state.read()).index(value), list(buffer_state.read())) yield ('{}_{}'.format(value, index_value), 1)
class TestStatefulDoFn(DoFn): """An example stateful DoFn with state and timers.""" BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder()) BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder()) EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK) def process(self, element, t=DoFn.TimestampParam, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2)): yield element @on_timer(EXPIRY_TIMER_1) def on_expiry_1(self, buffer=DoFn.StateParam(BUFFER_STATE_1), timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired1' @on_timer(EXPIRY_TIMER_2) def on_expiry_2(self, buffer=DoFn.StateParam(BUFFER_STATE_2), timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired2' @on_timer(EXPIRY_TIMER_3) def on_expiry_3(self, buffer_1=DoFn.StateParam(BUFFER_STATE_1), buffer_2=DoFn.StateParam(BUFFER_STATE_2), timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): yield 'expired3'
class SetStateClearingStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE), emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)): _, value = element set_state.add(value) all_elements = [element for element in set_state.read()] if len(all_elements) == 5: set_state.clear() set_state.add(100) emit_timer.set(1) @on_timer(EMIT_TIMER) def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): yield sorted(set_state.read())
class GeneralTriggerManagerDoFn(DoFn): """A trigger manager that supports all windowing / triggering cases. This implements a DoFn that manages triggering in a per-key basis. All elements for a single key are processed together. Per-key state holds data related to all windows. """ # TODO(BEAM-12026) Add support for Global and custom window fns. KNOWN_WINDOWS = SetStateSpec('known_windows', IntervalWindowCoder()) FINISHED_WINDOWS = SetStateSpec('finished_windows', IntervalWindowCoder()) LAST_KNOWN_TIME = CombiningValueStateSpec('last_known_time', combine_fn=max) LAST_KNOWN_WATERMARK = CombiningValueStateSpec('last_known_watermark', combine_fn=max) # TODO(pabloem) What's the coder for the elements/keys here? WINDOW_ELEMENT_PAIRS = BagStateSpec( 'all_elements', TupleCoder([IntervalWindowCoder(), PickleCoder()])) WINDOW_TAG_VALUES = BagStateSpec( 'per_window_per_tag_value_state', TupleCoder([IntervalWindowCoder(), StrUtf8Coder(), VarIntCoder()])) PROCESSING_TIME_TIMER = TimerSpec('processing_time_timer', TimeDomain.REAL_TIME) WATERMARK_TIMER = TimerSpec('watermark_timer', TimeDomain.WATERMARK) def __init__(self, windowing: Windowing): self.windowing = windowing # Only session windows are merging. Other windows are non-merging. self.merging_windows = self.windowing.windowfn.is_merging() def process( self, element: typing.Tuple[ K, typing.Iterable[windowed_value.WindowedValue]], all_elements: BagRuntimeState = DoFn.StateParam( WINDOW_ELEMENT_PAIRS), # type: ignore latest_processing_time: AccumulatingRuntimeState = DoFn.StateParam( LAST_KNOWN_TIME), # type: ignore latest_watermark: AccumulatingRuntimeState = DoFn. StateParam( # type: ignore LAST_KNOWN_WATERMARK), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore windows_state: SetRuntimeState = DoFn.StateParam( KNOWN_WINDOWS), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER), *args, **kwargs): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=latest_processing_time, latest_watermark=latest_watermark, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) key, windowed_values = element watermark = read_watermark(latest_watermark) windows_to_elements = collections.defaultdict(list) for wv in windowed_values: for window in wv.windows: # ignore expired windows if watermark > window.end + self.windowing.allowed_lateness: continue if window in finished_windows_state.read(): continue windows_to_elements[window].append( TimestampedValue(wv.value, wv.timestamp)) # Processing merging of windows if self.merging_windows: old_windows = set(windows_state.read()) all_windows = old_windows.union(list(windows_to_elements)) if all_windows != old_windows: merge_context = TriggerMergeContext(all_windows, context, self.windowing) self.windowing.windowfn.merge(merge_context) merged_windows_to_elements = collections.defaultdict(list) for window, values in windows_to_elements.items(): while window in merge_context.merged_away: window = merge_context.merged_away[window] merged_windows_to_elements[window].extend(values) windows_to_elements = merged_windows_to_elements for w in windows_to_elements: windows_state.add(w) # Done processing merging of windows seen_windows = set() for w in windows_to_elements: window_context = context.for_window(w) seen_windows.add(w) for value_w_timestamp in windows_to_elements[w]: _LOGGER.debug(value_w_timestamp) all_elements.add((w, value_w_timestamp)) self.windowing.triggerfn.on_element(windowed_values, w, window_context) return self._fire_eligible_windows(key, TimeDomain.WATERMARK, watermark, None, context, seen_windows) def _fire_eligible_windows(self, key: K, time_domain, timestamp: Timestamp, timer_tag: typing.Optional[str], context: 'FnRunnerStatefulTriggerContext', windows_of_interest: typing.Optional[ typing.Set[BoundedWindow]] = None): windows_to_elements = context.windows_to_elements_map() context.all_elements_state.clear() fired_windows = set() _LOGGER.debug('%s - tag %s - timestamp %s', time_domain, timer_tag, timestamp) for w, elems in windows_to_elements.items(): if windows_of_interest is not None and w not in windows_of_interest: # windows_of_interest=None means that we care about all windows. # If we care only about some windows, and this window is not one of # them, then we do not intend to fire this window. continue window_context = context.for_window(w) if self.windowing.triggerfn.should_fire(time_domain, timestamp, w, window_context): finished = self.windowing.triggerfn.on_fire( timestamp, w, window_context) _LOGGER.debug('Firing on window %s. Finished: %s', w, finished) fired_windows.add(w) if finished: context.finished_windows_state.add(w) # TODO(pabloem): Format the output: e.g. pane info elems = [ WindowedValue(e.value, e.timestamp, (w, )) for e in elems ] yield (key, elems) finished_windows: typing.Set[BoundedWindow] = set( context.finished_windows_state.read()) # Add elements that were not fired back into state. for w, elems in windows_to_elements.items(): for e in elems: if (w in finished_windows or (w in fired_windows and self.windowing.accumulation_mode == AccumulationMode.DISCARDING)): continue context.all_elements_state.add((w, e)) @on_timer(PROCESSING_TIME_TIMER) def processing_time_trigger( self, key=DoFn.KeyParam, timer_tag=DoFn.DynamicTimerTagParam, timestamp=DoFn.TimestampParam, latest_processing_time=DoFn.StateParam(LAST_KNOWN_TIME), all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=latest_processing_time, latest_watermark=None, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) result = self._fire_eligible_windows(key, TimeDomain.REAL_TIME, timestamp, timer_tag, context) latest_processing_time.add(timestamp) return result @on_timer(WATERMARK_TIMER) def watermark_trigger( self, key=DoFn.KeyParam, timer_tag=DoFn.DynamicTimerTagParam, timestamp=DoFn.TimestampParam, latest_watermark=DoFn.StateParam(LAST_KNOWN_WATERMARK), all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS), processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER), window_tag_values: BagRuntimeState = DoFn.StateParam( WINDOW_TAG_VALUES), # type: ignore finished_windows_state: SetRuntimeState = DoFn. StateParam( # type: ignore FINISHED_WINDOWS), watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)): context = FnRunnerStatefulTriggerContext( processing_time_timer=processing_time_timer, watermark_timer=watermark_timer, latest_processing_time=None, latest_watermark=latest_watermark, all_elements_state=all_elements, window_tag_values=window_tag_values, finished_windows_state=finished_windows_state) result = self._fire_eligible_windows(key, TimeDomain.WATERMARK, timestamp, timer_tag, context) latest_watermark.add(timestamp) return result
def test_caching(self): coder = VarIntCoder() coder_impl = coder.get_impl() class FakeUnderlyingState(object): """Simply returns an incremented counter as the state "value." """ def set_counter(self, n): self._counter = n def get_raw(self, *args): self._counter += 1 return coder.encode(self._counter), None @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state = FakeUnderlyingState() state_cache = statecache.StateCache(100) caching_state_hander = sdk_worker.CachingStateHandler( state_cache, underlying_state) state1 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) state2 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state2')) side1 = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id='transform', side_input_id='side1')) side2 = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id='transform', side_input_id='side2')) state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token2', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token1', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token2', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) def get_as_list(key): return list(caching_state_hander.blocking_get(key, coder_impl)) underlying_state.set_counter(100) with caching_state_hander.process_instruction_id('bundle1', []): self.assertEqual(get_as_list(state1), [101]) # uncached self.assertEqual(get_as_list(state2), [102]) # uncached self.assertEqual(get_as_list(state1), [101]) # cached on bundle self.assertEqual(get_as_list(side1), [103]) # uncached self.assertEqual(get_as_list(side2), [104]) # uncached underlying_state.set_counter(200) with caching_state_hander.process_instruction_id( 'bundle2', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # uncached self.assertEqual(get_as_list(state2), [202]) # uncached self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # uncached self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [204]) # uncached self.assertEqual(get_as_list(side2), [204]) # cached on bundle underlying_state.set_counter(300) with caching_state_hander.process_instruction_id( 'bundle3', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(state2), [202]) # cached on state token1 self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [301]) # uncached self.assertEqual(get_as_list(side2), [301]) # cached on bundle underlying_state.set_counter(400) with caching_state_hander.process_instruction_id( 'bundle4', [state_token2, side1_token1]): self.assertEqual(get_as_list(state1), [401]) # uncached self.assertEqual(get_as_list(state2), [402]) # uncached self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [403]) # uncached self.assertEqual(get_as_list(side2), [403]) # cached on bundle underlying_state.set_counter(500) with caching_state_hander.process_instruction_id( 'bundle5', [state_token2, side1_token2]): self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(state2), [402]) # cached on state token2 self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [501]) # uncached self.assertEqual(get_as_list(side1), [501]) # cached on side1_token2 self.assertEqual(get_as_list(side2), [502]) # uncached self.assertEqual(get_as_list(side2), [502]) # cached on bundle
class PayloadBase(object): values = { 'integer_example': 1, 'boolean': True, 'string_example': u'thing', 'list_of_strings': [u'foo', u'bar'], 'optional_kv': (u'key', 1.1), 'optional_integer': None, } bytes_values = { 'integer_example': 1, 'boolean': True, 'string_example': 'thing', 'list_of_strings': ['foo', 'bar'], 'optional_kv': ('key', 1.1), 'optional_integer': None, } args = { 'integer_example': ConfigValue(coder_urn=['beam:coder:varint:v1'], payload=VarIntCoder().get_impl().encode_nested( values['integer_example'])), 'boolean': ConfigValue(coder_urn=['beam:coder:bool:v1'], payload=BooleanCoder().get_impl().encode_nested( values['boolean'])), 'string_example': ConfigValue(coder_urn=['beam:coder:string_utf8:v1'], payload=StrUtf8Coder().get_impl().encode_nested( values['string_example'])), 'list_of_strings': ConfigValue( coder_urn=['beam:coder:iterable:v1', 'beam:coder:string_utf8:v1'], payload=IterableCoder(StrUtf8Coder()).get_impl().encode_nested( values['list_of_strings'])), 'optional_kv': ConfigValue(coder_urn=[ 'beam:coder:kv:v1', 'beam:coder:string_utf8:v1', 'beam:coder:double:v1' ], payload=TupleCoder([ StrUtf8Coder(), FloatCoder() ]).get_impl().encode_nested(values['optional_kv'])), } def get_payload_from_typing_hints(self, values): """Return ExternalConfigurationPayload based on python typing hints""" raise NotImplementedError def get_payload_from_beam_typehints(self, values): """Return ExternalConfigurationPayload based on beam typehints""" raise NotImplementedError def test_typing_payload_builder(self): result = self.get_payload_from_typing_hints(self.values) expected = get_payload(self.args) self.assertEqual(result, expected) def test_typing_payload_builder_with_bytes(self): """ string_utf8 coder will be used even if values are not unicode in python 2.x """ result = self.get_payload_from_typing_hints(self.bytes_values) expected = get_payload(self.args) self.assertEqual(result, expected) def test_typehints_payload_builder(self): result = self.get_payload_from_beam_typehints(self.values) expected = get_payload(self.args) self.assertEqual(result, expected) def test_typehints_payload_builder_with_bytes(self): """ string_utf8 coder will be used even if values are not unicode in python 2.x """ result = self.get_payload_from_beam_typehints(self.bytes_values) expected = get_payload(self.args) self.assertEqual(result, expected) def test_optional_error(self): """ value can only be None if typehint is Optional """ with self.assertRaises(RuntimeError): self.get_payload_from_typing_hints({k: None for k in self.values})
def test_metrics(self): """Run a simple DoFn that increments a counter and verifies state caching metrics. Verifies that its expected value is written to a temporary file by the FileReporter""" counter_name = 'elem_counter' state_spec = userstate.BagStateSpec('state', VarIntCoder()) class DoFn(beam.DoFn): def __init__(self): self.counter = Metrics.counter(self.__class__, counter_name) _LOGGER.info('counter: %s' % self.counter.metric_name) def process(self, kv, state=beam.DoFn.StateParam(state_spec)): # Trigger materialization list(state.read()) state.add(1) self.counter.inc() options = self.create_options() # Test only supports parallelism of 1 options._all_options['parallelism'] = 1 # Create multiple bundles to test cache metrics options._all_options['max_bundle_size'] = 10 options._all_options['max_bundle_time_millis'] = 95130590130 experiments = options.view_as(DebugOptions).experiments or [] experiments.append('state_cache_size=123') options.view_as(DebugOptions).experiments = experiments with Pipeline(self.get_runner(), options) as p: # pylint: disable=expression-not-assigned (p | "create" >> beam.Create(list(range(0, 110))) | "mapper" >> beam.Map(lambda x: (x % 10, 'val')) | "stateful" >> beam.ParDo(DoFn())) lines_expected = {'counter: 110'} if streaming: lines_expected.update([ # Gauges for the last finished bundle 'stateful.beam.metric:statecache:capacity: 123', # These are off by 10 because the first bundle contains all the keys # once. Caching is only initialized after the first bundle. Caching # depends on the cache token which is lazily initialized by the # Runner's StateRequestHandlers. 'stateful.beam.metric:statecache:size: 10', 'stateful.beam.metric:statecache:get: 10', 'stateful.beam.metric:statecache:miss: 0', 'stateful.beam.metric:statecache:hit: 10', 'stateful.beam.metric:statecache:put: 0', 'stateful.beam.metric:statecache:extend: 10', 'stateful.beam.metric:statecache:evict: 0', # Counters # (total of get/hit will be off by 10 due to the caching # only getting initialized after the first bundle. # Caching depends on the cache token which is lazily # initialized by the Runner's StateRequestHandlers). 'stateful.beam.metric:statecache:get_total: 100', 'stateful.beam.metric:statecache:miss_total: 10', 'stateful.beam.metric:statecache:hit_total: 90', 'stateful.beam.metric:statecache:put_total: 10', 'stateful.beam.metric:statecache:extend_total: 100', 'stateful.beam.metric:statecache:evict_total: 0', ]) else: # Batch has a different processing model. All values for # a key are processed at once. lines_expected.update([ # Gauges 'stateful).beam.metric:statecache:capacity: 123', # For the first key, the cache token will not be set yet. # It's lazily initialized after first access in StateRequestHandlers 'stateful).beam.metric:statecache:size: 9', # We have 11 here because there are 110 / 10 elements per key 'stateful).beam.metric:statecache:get: 11', 'stateful).beam.metric:statecache:miss: 1', 'stateful).beam.metric:statecache:hit: 10', # State is flushed back once per key 'stateful).beam.metric:statecache:put: 1', 'stateful).beam.metric:statecache:extend: 1', 'stateful).beam.metric:statecache:evict: 0', # Counters 'stateful).beam.metric:statecache:get_total: 99', 'stateful).beam.metric:statecache:miss_total: 9', 'stateful).beam.metric:statecache:hit_total: 90', 'stateful).beam.metric:statecache:put_total: 9', 'stateful).beam.metric:statecache:extend_total: 9', 'stateful).beam.metric:statecache:evict_total: 0', ]) lines_actual = set() with open(self.test_metrics_path, 'r') as f: line = f.readline() while line: for metric_str in lines_expected: if metric_str in line: lines_actual.add(metric_str) line = f.readline() self.assertSetEqual(lines_actual, lines_expected)