def _get_internal_bag_state(self, name, element_coder, state_type: StateType): cached_state = self._internal_state_cache.get((name, self._encoded_current_key)) if cached_state is not None: return cached_state state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec, state_type) return internal_state
class SyncFn(beam.DoFn): STATE = userstate.BagStateSpec('state', coders.PickleCoder()) def __init__(self, size): assert size > 0, 'Must provide a positive size' self.size = size def process(self, element, state=beam.DoFn.StateParam(STATE)): key, value = element cache = list(state.read()) if cache: cache = cache[0] else: cache = {} values = cache.get(key, []) values.append(value) if len(values) == self.size: if key in cache: del cache[key] yield tuple(values) else: cache[key] = values state.clear() if cache: state.add(cache)
def _get_internal_bag_state(self, name, element_coder): cached_state = self._all_internal_states.get((name, self._current_key)) if cached_state is not None: return cached_state state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_state(state_spec) self._all_internal_states.put((name, self._current_key), internal_state) return internal_state
def _run_pardo_state_timers(self, windowed): state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder()) timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) elements = list('abcdefgh') buffer_size = 3 class BufferDoFn(beam.DoFn): def process(self, kv, ts=beam.DoFn.TimestampParam, timer=beam.DoFn.TimerParam(timer_spec), state=beam.DoFn.StateParam(state_spec)): _, element = kv state.add(element) buffer = state.read() # For real use, we'd keep track of this size separately. if len(list(buffer)) >= 3: state.clear() yield buffer else: timer.set(ts + 1) @userstate.on_timer(timer_spec) def process_timer(self, state=beam.DoFn.StateParam(state_spec)): buffer = state.read() state.clear() yield buffer def is_buffered_correctly(actual): # Pickling self in the closure for asserts gives errors (only on jenkins). self = FnApiRunnerTest('__init__') # Acutal should be a grouping of the inputs into batches of size # at most buffer_size, but the actual batching is nondeterministic # based on ordering and trigger firing timing. self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements) self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size) if windowed: # Elements were assigned to windows based on their parity. # Assert that each grouping consists of elements belonging to the # same window to ensure states and timers were properly partitioned. for b in actual: parity = set(ord(e) % 2 for e in b) self.assertEqual(1, len(parity), b) with self.create_pipeline() as p: actual = ( p | beam.Create(elements) # Send even and odd elements to different windows. | beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2)) | beam.WindowInto(window.FixedWindows(1) if windowed else window.GlobalWindows()) | beam.Map(lambda x: ('key', x)) | beam.ParDo(BufferDoFn())) assert_that(actual, is_buffered_correctly)
def _get_internal_bag_state(self, name, element_coder): cached_state = self._internal_state_cache.get( (name, self._encoded_current_key)) if cached_state is not None: return cached_state # The created internal state would not be put into the internal state cache # at once. The internal state cache is only updated when the current key changes. # The reason is that the state cache size may be smaller that the count of activated # state (i.e. the state with current key). state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec) return internal_state
class _StatefulJobOutputsFn(beam.DoFn): STATE = userstate.BagStateSpec('state', coders.PickleCoder()) def process(self, element, level, state=beam.DoFn.StateParam(STATE)): assert level in JobAggregateLevel.STATEFUL # example payload structure... # { # 'source': Any # 'graphid': 0, # 'jobtasks': {0: 3, 1: 3}, # 'jobid': 0, # 'taskid': 2, # 'output': [ # '/tmp/job-0_output-0.task-2.ext', # '/tmp/job-0_output-1.task-2.ext', # ], # } _, payload = element # There are two values we will track that differ depending on the # aggregation type/level desired. # # - key : aggregation per-unique value # - size : total number of times expected to see `key` key = payload[level] if level == JobAggregateLevel.JOB: # str(key) is to deal with json making all dict keys strings size = payload['jobtasks'][str(key)] elif level == JobAggregateLevel.GRAPH: size = sum(payload['jobtasks'].values()) else: raise NotImplementedError cache = dict(state.read()) seen, data = cache.get(key, (0, [])) seen += 1 data.extend(payload['output']) cache[key] = (seen, data) state.clear() for k, v in cache.items(): # size == seen if size == v[0]: # cprint('fire-{}: {}'.format(level, k), 'red', attrs=['bold']) yield cache.pop(k)[1] else: state.add((k, v))
def test_pardo_state_timers(self): state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder()) timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) elements = list('abcdefgh') buffer_size = 3 class BufferDoFn(beam.DoFn): def process(self, kv, ts=beam.DoFn.TimestampParam, timer=beam.DoFn.TimerParam(timer_spec), state=beam.DoFn.StateParam(state_spec)): _, element = kv state.add(element) buffer = state.read() # For real use, we'd keep track of this size separately. if len(list(buffer)) >= 3: state.clear() yield buffer else: timer.set(ts + 1) @userstate.on_timer(timer_spec) def process_timer(self, state=beam.DoFn.StateParam(state_spec)): buffer = state.read() state.clear() yield buffer def is_buffered_correctly(actual): # Pickling self in the closure for asserts gives errors (only on jenkins). self = FnApiRunnerTest('__init__') # Acutal should be a grouping of the inputs into batches of size # at most buffer_size, but the actual batching is nondeterministic # based on ordering and trigger firing timing. self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements) self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size) with self.create_pipeline() as p: actual = (p | beam.Create(elements) | beam.Map(lambda x: ('key', x)) | beam.ParDo(BufferDoFn())) assert_that(actual, is_buffered_correctly)
def _get_internal_bag_state(self, name, namespace, element_coder, ttl_config): encoded_namespace = self._encode_namespace(namespace) cached_state = self._internal_state_cache.get( (name, self._encoded_current_key, encoded_namespace)) if cached_state is not None: return cached_state # The created internal state would not be put into the internal state cache # at once. The internal state cache is only updated when the current key changes. # The reason is that the state cache size may be smaller that the count of activated # state (i.e. the state with current key). if isinstance(element_coder, FieldCoder): element_coder = FlinkCoder(element_coder) state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec, encoded_namespace, ttl_config) return internal_state
def _create_deduplicate_fn(self): processing_timer_spec = userstate.TimerSpec('processing_timer', TimeDomain.REAL_TIME) event_timer_spec = userstate.TimerSpec('event_timer', TimeDomain.WATERMARK) state_spec = userstate.BagStateSpec('seen', BooleanCoder()) processing_time_duration = self.processing_time_duration event_time_duration = self.event_time_duration class DeduplicationFn(core.DoFn): def process( self, kv, ts=core.DoFn.TimestampParam, seen_state=core.DoFn.StateParam(state_spec), processing_timer=core.DoFn.TimerParam(processing_timer_spec), event_timer=core.DoFn.TimerParam(event_timer_spec)): if True in seen_state.read(): return if processing_time_duration is not None: processing_timer.set(timestamp.Timestamp.now() + processing_time_duration) if event_time_duration is not None: event_timer.set(ts + event_time_duration) seen_state.add(True) yield kv @userstate.on_timer(processing_timer_spec) def process_processing_timer( self, seen_state=core.DoFn.StateParam(state_spec)): seen_state.clear() @userstate.on_timer(event_timer_spec) def process_event_timer( self, seen_state=core.DoFn.StateParam(state_spec)): seen_state.clear() return DeduplicationFn()
def test_metrics(self): """Run a simple DoFn that increments a counter and verifies state caching metrics. Verifies that its expected value is written to a temporary file by the FileReporter""" counter_name = 'elem_counter' state_spec = userstate.BagStateSpec('state', VarIntCoder()) class DoFn(beam.DoFn): def __init__(self): self.counter = Metrics.counter(self.__class__, counter_name) _LOGGER.info('counter: %s' % self.counter.metric_name) def process(self, kv, state=beam.DoFn.StateParam(state_spec)): # Trigger materialization list(state.read()) state.add(1) self.counter.inc() options = self.create_options() # Test only supports parallelism of 1 options._all_options['parallelism'] = 1 # Create multiple bundles to test cache metrics options._all_options['max_bundle_size'] = 10 options._all_options['max_bundle_time_millis'] = 95130590130 experiments = options.view_as(DebugOptions).experiments or [] experiments.append('state_cache_size=123') options.view_as(DebugOptions).experiments = experiments with Pipeline(self.get_runner(), options) as p: # pylint: disable=expression-not-assigned (p | "create" >> beam.Create(list(range(0, 110))) | "mapper" >> beam.Map(lambda x: (x % 10, 'val')) | "stateful" >> beam.ParDo(DoFn())) lines_expected = {'counter: 110'} if streaming: lines_expected.update([ # Gauges for the last finished bundle 'stateful.beam.metric:statecache:capacity: 123', # These are off by 10 because the first bundle contains all the keys # once. Caching is only initialized after the first bundle. Caching # depends on the cache token which is lazily initialized by the # Runner's StateRequestHandlers. 'stateful.beam.metric:statecache:size: 10', 'stateful.beam.metric:statecache:get: 10', 'stateful.beam.metric:statecache:miss: 0', 'stateful.beam.metric:statecache:hit: 10', 'stateful.beam.metric:statecache:put: 0', 'stateful.beam.metric:statecache:extend: 10', 'stateful.beam.metric:statecache:evict: 0', # Counters # (total of get/hit will be off by 10 due to the caching # only getting initialized after the first bundle. # Caching depends on the cache token which is lazily # initialized by the Runner's StateRequestHandlers). 'stateful.beam.metric:statecache:get_total: 100', 'stateful.beam.metric:statecache:miss_total: 10', 'stateful.beam.metric:statecache:hit_total: 90', 'stateful.beam.metric:statecache:put_total: 10', 'stateful.beam.metric:statecache:extend_total: 100', 'stateful.beam.metric:statecache:evict_total: 0', ]) else: # Batch has a different processing model. All values for # a key are processed at once. lines_expected.update([ # Gauges 'stateful).beam.metric:statecache:capacity: 123', # For the first key, the cache token will not be set yet. # It's lazily initialized after first access in StateRequestHandlers 'stateful).beam.metric:statecache:size: 9', # We have 11 here because there are 110 / 10 elements per key 'stateful).beam.metric:statecache:get: 11', 'stateful).beam.metric:statecache:miss: 1', 'stateful).beam.metric:statecache:hit: 10', # State is flushed back once per key 'stateful).beam.metric:statecache:put: 1', 'stateful).beam.metric:statecache:extend: 1', 'stateful).beam.metric:statecache:evict: 0', # Counters 'stateful).beam.metric:statecache:get_total: 99', 'stateful).beam.metric:statecache:miss_total: 9', 'stateful).beam.metric:statecache:hit_total: 90', 'stateful).beam.metric:statecache:put_total: 9', 'stateful).beam.metric:statecache:extend_total: 9', 'stateful).beam.metric:statecache:evict_total: 0', ]) lines_actual = set() with open(self.test_metrics_path, 'r') as f: line = f.readline() while line: for metric_str in lines_expected: if metric_str in line: lines_actual.add(metric_str) line = f.readline() self.assertSetEqual(lines_actual, lines_expected)
class JoinFn(beam.DoFn): """ Join auctions and person by person id and emit their product one pair at a time. We know a person may submit any number of auctions. Thus new person event must have the person record stored in persistent state in order to match future auctions by that person. However we know that each auction is associated with at most one person, so only need to store auction records in persistent state until we have seen the corresponding person record. And of course may have already seen that record. """ AUCTIONS = 'auctions_state' PERSON = 'person_state' PERSON_EXPIRING = 'person_state_expiring' auction_spec = userstate.BagStateSpec(AUCTIONS, nexmark_model.Auction.CODER) person_spec = userstate.ReadModifyWriteStateSpec( PERSON, nexmark_model.Person.CODER) person_timer_spec = userstate.TimerSpec( PERSON_EXPIRING, userstate.TimeDomain.WATERMARK) def __init__(self, max_auction_wait_time): self.max_auction_wait_time = max_auction_wait_time def process( self, element, auction_state=beam.DoFn.StateParam(auction_spec), person_state=beam.DoFn.StateParam(person_spec), person_timer=beam.DoFn.TimerParam(person_timer_spec)): # extract group with tags from element tuple _, group = element existing_person = person_state.read() if existing_person: # the person exists in person_state for this person id for auction in group[nexmark_query_util.AUCTION_TAG]: yield auction, existing_person return new_person = None for person in group[nexmark_query_util.PERSON_TAG]: if not new_person: new_person = person else: logging.error( 'two new person wtih same key: %s and %s' % (person, new_person)) continue # read all pending auctions for this person id, output and flush it pending_auctions = auction_state.read() if pending_auctions: for pending_auction in pending_auctions: yield pending_auction, new_person auction_state.clear() # output new auction for this person id for auction in group[nexmark_query_util.AUCTION_TAG]: yield auction, new_person # remember person for max_auction_wait_time seconds for future auctions person_state.write(new_person) person_timer.set(new_person.date_time + self.max_auction_wait_time) # we are done if we have seen a new person if new_person: return # remember auction until we see person for auction in group[nexmark_query_util.AUCTION_TAG]: auction_state.add(auction) @on_timer(person_timer_spec) def expiry(self, person_state=beam.DoFn.StateParam(person_spec)): person_state.clear()