Ejemplo n.º 1
0
 def _get_internal_bag_state(self, name, element_coder, state_type: StateType):
     cached_state = self._internal_state_cache.get((name, self._encoded_current_key))
     if cached_state is not None:
         return cached_state
     state_spec = userstate.BagStateSpec(name, element_coder)
     internal_state = self._create_bag_state(state_spec, state_type)
     return internal_state
Ejemplo n.º 2
0
class SyncFn(beam.DoFn):

    STATE = userstate.BagStateSpec('state', coders.PickleCoder())

    def __init__(self, size):
        assert size > 0, 'Must provide a positive size'
        self.size = size

    def process(self, element, state=beam.DoFn.StateParam(STATE)):
        key, value = element

        cache = list(state.read())
        if cache:
            cache = cache[0]
        else:
            cache = {}

        values = cache.get(key, [])
        values.append(value)

        if len(values) == self.size:
            if key in cache:
                del cache[key]
            yield tuple(values)
        else:
            cache[key] = values

        state.clear()
        if cache:
            state.add(cache)
Ejemplo n.º 3
0
 def _get_internal_bag_state(self, name, element_coder):
     cached_state = self._all_internal_states.get((name, self._current_key))
     if cached_state is not None:
         return cached_state
     state_spec = userstate.BagStateSpec(name, element_coder)
     internal_state = self._create_state(state_spec)
     self._all_internal_states.put((name, self._current_key), internal_state)
     return internal_state
Ejemplo n.º 4
0
  def _run_pardo_state_timers(self, windowed):
    state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder())
    timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
    elements = list('abcdefgh')
    buffer_size = 3

    class BufferDoFn(beam.DoFn):
      def process(self,
                  kv,
                  ts=beam.DoFn.TimestampParam,
                  timer=beam.DoFn.TimerParam(timer_spec),
                  state=beam.DoFn.StateParam(state_spec)):
        _, element = kv
        state.add(element)
        buffer = state.read()
        # For real use, we'd keep track of this size separately.
        if len(list(buffer)) >= 3:
          state.clear()
          yield buffer
        else:
          timer.set(ts + 1)

      @userstate.on_timer(timer_spec)
      def process_timer(self, state=beam.DoFn.StateParam(state_spec)):
        buffer = state.read()
        state.clear()
        yield buffer

    def is_buffered_correctly(actual):
      # Pickling self in the closure for asserts gives errors (only on jenkins).
      self = FnApiRunnerTest('__init__')
      # Acutal should be a grouping of the inputs into batches of size
      # at most buffer_size, but the actual batching is nondeterministic
      # based on ordering and trigger firing timing.
      self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements)
      self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size)
      if windowed:
        # Elements were assigned to windows based on their parity.
        # Assert that each grouping consists of elements belonging to the
        # same window to ensure states and timers were properly partitioned.
        for b in actual:
          parity = set(ord(e) % 2 for e in b)
          self.assertEqual(1, len(parity), b)

    with self.create_pipeline() as p:
      actual = (
          p
          | beam.Create(elements)
          # Send even and odd elements to different windows.
          | beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2))
          | beam.WindowInto(window.FixedWindows(1) if windowed
                            else window.GlobalWindows())
          | beam.Map(lambda x: ('key', x))
          | beam.ParDo(BufferDoFn()))

      assert_that(actual, is_buffered_correctly)
Ejemplo n.º 5
0
 def _get_internal_bag_state(self, name, element_coder):
     cached_state = self._internal_state_cache.get(
         (name, self._encoded_current_key))
     if cached_state is not None:
         return cached_state
     # The created internal state would not be put into the internal state cache
     # at once. The internal state cache is only updated when the current key changes.
     # The reason is that the state cache size may be smaller that the count of activated
     # state (i.e. the state with current key).
     state_spec = userstate.BagStateSpec(name, element_coder)
     internal_state = self._create_bag_state(state_spec)
     return internal_state
Ejemplo n.º 6
0
class _StatefulJobOutputsFn(beam.DoFn):

    STATE = userstate.BagStateSpec('state', coders.PickleCoder())

    def process(self, element, level, state=beam.DoFn.StateParam(STATE)):
        assert level in JobAggregateLevel.STATEFUL

        # example payload structure...
        # {
        #     'source': Any
        #     'graphid': 0,
        #     'jobtasks': {0: 3, 1: 3},
        #     'jobid': 0,
        #     'taskid': 2,
        #     'output': [
        #         '/tmp/job-0_output-0.task-2.ext',
        #         '/tmp/job-0_output-1.task-2.ext',
        #     ],
        # }
        _, payload = element

        # There are two values we will track that differ depending on the
        # aggregation type/level desired.
        #
        # - key : aggregation per-unique value
        # - size : total number of times expected to see `key`

        key = payload[level]
        if level == JobAggregateLevel.JOB:
            # str(key) is to deal with json making all dict keys strings
            size = payload['jobtasks'][str(key)]
        elif level == JobAggregateLevel.GRAPH:
            size = sum(payload['jobtasks'].values())
        else:
            raise NotImplementedError

        cache = dict(state.read())
        seen, data = cache.get(key, (0, []))
        seen += 1
        data.extend(payload['output'])
        cache[key] = (seen, data)
        state.clear()

        for k, v in cache.items():
            # size == seen
            if size == v[0]:
                # cprint('fire-{}: {}'.format(level, k), 'red', attrs=['bold'])
                yield cache.pop(k)[1]
            else:
                state.add((k, v))
Ejemplo n.º 7
0
    def test_pardo_state_timers(self):
        state_spec = userstate.BagStateSpec('state',
                                            beam.coders.StrUtf8Coder())
        timer_spec = userstate.TimerSpec('timer',
                                         userstate.TimeDomain.WATERMARK)
        elements = list('abcdefgh')
        buffer_size = 3

        class BufferDoFn(beam.DoFn):
            def process(self,
                        kv,
                        ts=beam.DoFn.TimestampParam,
                        timer=beam.DoFn.TimerParam(timer_spec),
                        state=beam.DoFn.StateParam(state_spec)):
                _, element = kv
                state.add(element)
                buffer = state.read()
                # For real use, we'd keep track of this size separately.
                if len(list(buffer)) >= 3:
                    state.clear()
                    yield buffer
                else:
                    timer.set(ts + 1)

            @userstate.on_timer(timer_spec)
            def process_timer(self, state=beam.DoFn.StateParam(state_spec)):
                buffer = state.read()
                state.clear()
                yield buffer

        def is_buffered_correctly(actual):
            # Pickling self in the closure for asserts gives errors (only on jenkins).
            self = FnApiRunnerTest('__init__')
            # Acutal should be a grouping of the inputs into batches of size
            # at most buffer_size, but the actual batching is nondeterministic
            # based on ordering and trigger firing timing.
            self.assertEqual(sorted(sum((list(b) for b in actual), [])),
                             elements)
            self.assertEqual(max(len(list(buffer)) for buffer in actual),
                             buffer_size)

        with self.create_pipeline() as p:
            actual = (p
                      | beam.Create(elements)
                      | beam.Map(lambda x: ('key', x))
                      | beam.ParDo(BufferDoFn()))

            assert_that(actual, is_buffered_correctly)
Ejemplo n.º 8
0
 def _get_internal_bag_state(self, name, namespace, element_coder,
                             ttl_config):
     encoded_namespace = self._encode_namespace(namespace)
     cached_state = self._internal_state_cache.get(
         (name, self._encoded_current_key, encoded_namespace))
     if cached_state is not None:
         return cached_state
     # The created internal state would not be put into the internal state cache
     # at once. The internal state cache is only updated when the current key changes.
     # The reason is that the state cache size may be smaller that the count of activated
     # state (i.e. the state with current key).
     if isinstance(element_coder, FieldCoder):
         element_coder = FlinkCoder(element_coder)
     state_spec = userstate.BagStateSpec(name, element_coder)
     internal_state = self._create_bag_state(state_spec, encoded_namespace,
                                             ttl_config)
     return internal_state
Ejemplo n.º 9
0
    def _create_deduplicate_fn(self):
        processing_timer_spec = userstate.TimerSpec('processing_timer',
                                                    TimeDomain.REAL_TIME)
        event_timer_spec = userstate.TimerSpec('event_timer',
                                               TimeDomain.WATERMARK)
        state_spec = userstate.BagStateSpec('seen', BooleanCoder())
        processing_time_duration = self.processing_time_duration
        event_time_duration = self.event_time_duration

        class DeduplicationFn(core.DoFn):
            def process(
                self,
                kv,
                ts=core.DoFn.TimestampParam,
                seen_state=core.DoFn.StateParam(state_spec),
                processing_timer=core.DoFn.TimerParam(processing_timer_spec),
                event_timer=core.DoFn.TimerParam(event_timer_spec)):
                if True in seen_state.read():
                    return

                if processing_time_duration is not None:
                    processing_timer.set(timestamp.Timestamp.now() +
                                         processing_time_duration)
                if event_time_duration is not None:
                    event_timer.set(ts + event_time_duration)
                seen_state.add(True)
                yield kv

            @userstate.on_timer(processing_timer_spec)
            def process_processing_timer(
                self, seen_state=core.DoFn.StateParam(state_spec)):
                seen_state.clear()

            @userstate.on_timer(event_timer_spec)
            def process_event_timer(
                self, seen_state=core.DoFn.StateParam(state_spec)):
                seen_state.clear()

        return DeduplicationFn()
Ejemplo n.º 10
0
        def test_metrics(self):
            """Run a simple DoFn that increments a counter and verifies state
      caching metrics. Verifies that its expected value is written to a
      temporary file by the FileReporter"""

            counter_name = 'elem_counter'
            state_spec = userstate.BagStateSpec('state', VarIntCoder())

            class DoFn(beam.DoFn):
                def __init__(self):
                    self.counter = Metrics.counter(self.__class__,
                                                   counter_name)
                    _LOGGER.info('counter: %s' % self.counter.metric_name)

                def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
                    # Trigger materialization
                    list(state.read())
                    state.add(1)
                    self.counter.inc()

            options = self.create_options()
            # Test only supports parallelism of 1
            options._all_options['parallelism'] = 1
            # Create multiple bundles to test cache metrics
            options._all_options['max_bundle_size'] = 10
            options._all_options['max_bundle_time_millis'] = 95130590130
            experiments = options.view_as(DebugOptions).experiments or []
            experiments.append('state_cache_size=123')
            options.view_as(DebugOptions).experiments = experiments
            with Pipeline(self.get_runner(), options) as p:
                # pylint: disable=expression-not-assigned
                (p
                 | "create" >> beam.Create(list(range(0, 110)))
                 | "mapper" >> beam.Map(lambda x: (x % 10, 'val'))
                 | "stateful" >> beam.ParDo(DoFn()))

            lines_expected = {'counter: 110'}
            if streaming:
                lines_expected.update([
                    # Gauges for the last finished bundle
                    'stateful.beam.metric:statecache:capacity: 123',
                    # These are off by 10 because the first bundle contains all the keys
                    # once. Caching is only initialized after the first bundle. Caching
                    # depends on the cache token which is lazily initialized by the
                    # Runner's StateRequestHandlers.
                    'stateful.beam.metric:statecache:size: 10',
                    'stateful.beam.metric:statecache:get: 10',
                    'stateful.beam.metric:statecache:miss: 0',
                    'stateful.beam.metric:statecache:hit: 10',
                    'stateful.beam.metric:statecache:put: 0',
                    'stateful.beam.metric:statecache:extend: 10',
                    'stateful.beam.metric:statecache:evict: 0',
                    # Counters
                    # (total of get/hit will be off by 10 due to the caching
                    # only getting initialized after the first bundle.
                    # Caching depends on the cache token which is lazily
                    # initialized by the Runner's StateRequestHandlers).
                    'stateful.beam.metric:statecache:get_total: 100',
                    'stateful.beam.metric:statecache:miss_total: 10',
                    'stateful.beam.metric:statecache:hit_total: 90',
                    'stateful.beam.metric:statecache:put_total: 10',
                    'stateful.beam.metric:statecache:extend_total: 100',
                    'stateful.beam.metric:statecache:evict_total: 0',
                ])
            else:
                # Batch has a different processing model. All values for
                # a key are processed at once.
                lines_expected.update([
                    # Gauges
                    'stateful).beam.metric:statecache:capacity: 123',
                    # For the first key, the cache token will not be set yet.
                    # It's lazily initialized after first access in StateRequestHandlers
                    'stateful).beam.metric:statecache:size: 9',
                    # We have 11 here because there are 110 / 10 elements per key
                    'stateful).beam.metric:statecache:get: 11',
                    'stateful).beam.metric:statecache:miss: 1',
                    'stateful).beam.metric:statecache:hit: 10',
                    # State is flushed back once per key
                    'stateful).beam.metric:statecache:put: 1',
                    'stateful).beam.metric:statecache:extend: 1',
                    'stateful).beam.metric:statecache:evict: 0',
                    # Counters
                    'stateful).beam.metric:statecache:get_total: 99',
                    'stateful).beam.metric:statecache:miss_total: 9',
                    'stateful).beam.metric:statecache:hit_total: 90',
                    'stateful).beam.metric:statecache:put_total: 9',
                    'stateful).beam.metric:statecache:extend_total: 9',
                    'stateful).beam.metric:statecache:evict_total: 0',
                ])
            lines_actual = set()
            with open(self.test_metrics_path, 'r') as f:
                line = f.readline()
                while line:
                    for metric_str in lines_expected:
                        if metric_str in line:
                            lines_actual.add(metric_str)
                    line = f.readline()
            self.assertSetEqual(lines_actual, lines_expected)
Ejemplo n.º 11
0
class JoinFn(beam.DoFn):
  """
  Join auctions and person by person id and emit their product one pair at
  a time.

  We know a person may submit any number of auctions. Thus new person event
  must have the person record stored in persistent state in order to match
  future auctions by that person.

  However we know that each auction is associated with at most one person, so
  only need to store auction records in persistent state until we have seen the
  corresponding person record. And of course may have already seen that record.
  """

  AUCTIONS = 'auctions_state'
  PERSON = 'person_state'
  PERSON_EXPIRING = 'person_state_expiring'

  auction_spec = userstate.BagStateSpec(AUCTIONS, nexmark_model.Auction.CODER)
  person_spec = userstate.ReadModifyWriteStateSpec(
      PERSON, nexmark_model.Person.CODER)
  person_timer_spec = userstate.TimerSpec(
      PERSON_EXPIRING, userstate.TimeDomain.WATERMARK)

  def __init__(self, max_auction_wait_time):
    self.max_auction_wait_time = max_auction_wait_time

  def process(
      self,
      element,
      auction_state=beam.DoFn.StateParam(auction_spec),
      person_state=beam.DoFn.StateParam(person_spec),
      person_timer=beam.DoFn.TimerParam(person_timer_spec)):
    # extract group with tags from element tuple
    _, group = element

    existing_person = person_state.read()
    if existing_person:
      # the person exists in person_state for this person id
      for auction in group[nexmark_query_util.AUCTION_TAG]:
        yield auction, existing_person
      return

    new_person = None
    for person in group[nexmark_query_util.PERSON_TAG]:
      if not new_person:
        new_person = person
      else:
        logging.error(
            'two new person wtih same key: %s and %s' % (person, new_person))
        continue
      # read all pending auctions for this person id, output and flush it
      pending_auctions = auction_state.read()
      if pending_auctions:
        for pending_auction in pending_auctions:
          yield pending_auction, new_person
        auction_state.clear()
      # output new auction for this person id
      for auction in group[nexmark_query_util.AUCTION_TAG]:
        yield auction, new_person
      # remember person for max_auction_wait_time seconds for future auctions
      person_state.write(new_person)
      person_timer.set(new_person.date_time + self.max_auction_wait_time)
    # we are done if we have seen a new person
    if new_person:
      return

    # remember auction until we see person
    for auction in group[nexmark_query_util.AUCTION_TAG]:
      auction_state.add(auction)

  @on_timer(person_timer_spec)
  def expiry(self, person_state=beam.DoFn.StateParam(person_spec)):
    person_state.clear()