예제 #1
0
        class BadStatefulDoFn1(DoFn):
            BUFFER_STATE = BagStateSpec('buffer', BytesCoder())

            def process(self,
                        element,
                        b1=DoFn.StateParam(BUFFER_STATE),
                        b2=DoFn.StateParam(BUFFER_STATE)):
                yield element
예제 #2
0
def _pardo_group_into_batches(
    input_coder, batch_size, max_buffering_duration_secs, clock=time.time):
  ELEMENT_STATE = BagStateSpec('values', input_coder)
  COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
  WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK)
  BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME)

  class _GroupIntoBatchesDoFn(DoFn):
    def process(
        self,
        element,
        window=DoFn.WindowParam,
        element_state=DoFn.StateParam(ELEMENT_STATE),
        count_state=DoFn.StateParam(COUNT_STATE),
        window_timer=DoFn.TimerParam(WINDOW_TIMER),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
      # Allowed lateness not supported in Python SDK
      # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
      window_timer.set(window.end)
      element_state.add(element)
      count_state.add(1)
      count = count_state.read()
      if count == 1 and max_buffering_duration_secs > 0:
        # This is the first element in batch. Start counting buffering time if a
        # limit was set.
        # pylint: disable=deprecated-method
        buffering_timer.set(clock() + max_buffering_duration_secs)
      if count >= batch_size:
        return self.flush_batch(element_state, count_state, buffering_timer)

    @on_timer(WINDOW_TIMER)
    def on_window_timer(
        self,
        element_state=DoFn.StateParam(ELEMENT_STATE),
        count_state=DoFn.StateParam(COUNT_STATE),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
      return self.flush_batch(element_state, count_state, buffering_timer)

    @on_timer(BUFFERING_TIMER)
    def on_buffering_timer(
        self,
        element_state=DoFn.StateParam(ELEMENT_STATE),
        count_state=DoFn.StateParam(COUNT_STATE),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
      return self.flush_batch(element_state, count_state, buffering_timer)

    def flush_batch(self, element_state, count_state, buffering_timer):
      batch = [element for element in element_state.read()]
      if not batch:
        return
      key, _ = batch[0]
      batch_values = [v for (k, v) in batch]
      element_state.clear()
      count_state.clear()
      buffering_timer.clear()
      yield key, batch_values

  return _GroupIntoBatchesDoFn()
예제 #3
0
    class IndexAssigningStatefulDoFn(DoFn):
      INDEX_STATE = BagStateSpec('index', VarIntCoder())

      def process(self, element, state=DoFn.StateParam(INDEX_STATE)):
        unused_key, value = element
        next_index, = list(state.read()) or [0]
        yield (value, next_index)
        state.clear()
        state.add(next_index + 1)
예제 #4
0
    class BadStatefulDoFn4(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
      EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)

      @on_timer(EXPIRY_TIMER_1)
      def expiry_callback(self, element, t1=DoFn.TimerParam(EXPIRY_TIMER_2),
                          t2=DoFn.TimerParam(EXPIRY_TIMER_2)):
        yield element
예제 #5
0
class TestStatefulDoFn(DoFn):
    """An example stateful DoFn with state and timers."""

    BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder())
    BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder())
    EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
    EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
    EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK)

    def process(self,
                element,
                t=DoFn.TimestampParam,
                buffer_1=DoFn.StateParam(BUFFER_STATE_1),
                buffer_2=DoFn.StateParam(BUFFER_STATE_2),
                timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
                timer_2=DoFn.TimerParam(EXPIRY_TIMER_2)):
        yield element

    @on_timer(EXPIRY_TIMER_1)
    def on_expiry_1(self,
                    window=DoFn.WindowParam,
                    timestamp=DoFn.TimestampParam,
                    key=DoFn.KeyParam,
                    buffer=DoFn.StateParam(BUFFER_STATE_1),
                    timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
                    timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired1'

    @on_timer(EXPIRY_TIMER_2)
    def on_expiry_2(self,
                    buffer=DoFn.StateParam(BUFFER_STATE_2),
                    timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired2'

    @on_timer(EXPIRY_TIMER_3)
    def on_expiry_3(self,
                    buffer_1=DoFn.StateParam(BUFFER_STATE_1),
                    buffer_2=DoFn.StateParam(BUFFER_STATE_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired3'
예제 #6
0
    class BasicStatefulDoFn(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)

      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
        yield element

      @on_timer(EXPIRY_TIMER)
      def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)):
        yield element
예제 #7
0
        class StatefulDoFn(DoFn):
            BYTES_STATE = BagStateSpec('bytes', BytesCoder())

            def return_recursive(self, count):
                if count == 0:
                    return ["some string"]
                else:
                    self.return_recursive(count - 1)

            def process(self, element, counter=DoFn.StateParam(BYTES_STATE)):
                return self.return_recursive(1)
예제 #8
0
    def test_spec_construction(self):
        BagStateSpec('statename', VarIntCoder())
        with self.assertRaises(AssertionError):
            BagStateSpec(123, VarIntCoder())
        CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10))
        with self.assertRaises(AssertionError):
            CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
        with self.assertRaises(AssertionError):
            CombiningValueStateSpec('statename', VarIntCoder(), object())
        # BagStateSpec('bag', )
        # TODO: add more spec tests
        with self.assertRaises(ValueError):
            DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))

        TimerSpec('timer', TimeDomain.WATERMARK)
        TimerSpec('timer', TimeDomain.REAL_TIME)
        with self.assertRaises(ValueError):
            TimerSpec('timer', 'bogus_time_domain')
        with self.assertRaises(ValueError):
            DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
예제 #9
0
    class SimpleTestStatefulDoFn(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)

      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
        unused_key, value = element
        buffer.add(b'A' + str(value).encode('latin1'))
        timer1.set(20)

      @on_timer(EXPIRY_TIMER)
      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE),
                          timer=DoFn.TimerParam(EXPIRY_TIMER)):
        yield b''.join(sorted(buffer.read()))
예제 #10
0
class PairRecordsFn(beam.DoFn):
  """Pairs two consecutive elements after shuffle"""
  BUFFER = BagStateSpec('buffer', PickleCoder())
  def process(self, element, buffer=beam.DoFn.StateParam(BUFFER)):
    try:
      previous_element = list(buffer.read())[0]
    except:
      previous_element = []
    unused_key, value = element

    if previous_element:
      yield (previous_element, value)
      buffer.clear()
    else:
      buffer.add(value)
예제 #11
0
파일: fileio.py 프로젝트: yifanmai/beam
class _RemoveDuplicates(beam.DoFn):

  FILES_STATE = BagStateSpec('files', StrUtf8Coder())

  def process(self, element, file_state=beam.DoFn.StateParam(FILES_STATE)):
    path = element[0]
    file_metadata = element[1]
    bag_content = [x for x in file_state.read()]

    if not bag_content:
      file_state.add(path)
      _LOGGER.debug('Generated entry for file %s', path)
      yield file_metadata
    else:
      _LOGGER.debug('File %s was already read', path)
예제 #12
0
      class StatefulDoFnWithTimerWithTypo1(DoFn):  # pylint: disable=unused-variable
        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
        EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
        EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)

        def process(self, element):
          pass

        @on_timer(EXPIRY_TIMER_1)
        def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
          yield 'expired1'

        # Note that we mistakenly associate this with the first timer.
        @on_timer(EXPIRY_TIMER_1)
        def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
          yield 'expired2'
예제 #13
0
class CollectingFn(beam.DoFn):
    BUFFER_STATE = BagStateSpec('buffer', VarIntCoder())
    COUNT_STATE = CombiningValueStateSpec('count', sum)

    def process(self,
                element,
                buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
                count_state=beam.DoFn.StateParam(COUNT_STATE)):
        value = int(element[1].decode())
        buffer_state.add(value)

        count_state.add(1)
        count = count_state.read()

        if count >= NUM_RECORDS:
            yield sum(buffer_state.read())
            count_state.clear()
            buffer_state.clear()
예제 #14
0
    class StatefulDoFnWithTimerWithTypo3(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
      EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)

      def process(self, element,
                  timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
                  timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
        pass

      @on_timer(EXPIRY_TIMER_1)
      def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
        yield 'expired1'

      def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
        yield 'expired2'

      # Use a stable string value for matching.
      def __repr__(self):
        return 'StatefulDoFnWithTimerWithTypo3'
예제 #15
0
def _pardo_group_into_batches(batch_size, input_coder):
    ELEMENT_STATE = BagStateSpec('values', input_coder)
    COUNT_STATE = CombiningValueStateSpec('count', input_coder,
                                          CountCombineFn())
    EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)

    class _GroupIntoBatchesDoFn(DoFn):
        def process(self,
                    element,
                    window=DoFn.WindowParam,
                    element_state=DoFn.StateParam(ELEMENT_STATE),
                    count_state=DoFn.StateParam(COUNT_STATE),
                    expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)):
            # Allowed lateness not supported in Python SDK
            # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
            expiry_timer.set(window.end)
            element_state.add(element)
            count_state.add(1)
            count = count_state.read()
            if count >= batch_size:
                batch = [element for element in element_state.read()]
                key, _ = batch[0]
                batch_values = [v for (k, v) in batch]
                yield (key, batch_values)
                element_state.clear()
                count_state.clear()

        @on_timer(EXPIRY_TIMER)
        def expiry(self,
                   element_state=DoFn.StateParam(ELEMENT_STATE),
                   count_state=DoFn.StateParam(COUNT_STATE)):
            batch = [element for element in element_state.read()]
            if batch:
                key, _ = batch[0]
                batch_values = [v for (k, v) in batch]
                yield (key, batch_values)
                element_state.clear()
                count_state.clear()

    return _GroupIntoBatchesDoFn()
    class StatefulBufferingFn(DoFn):
        BUFFER_STATE = BagStateSpec('buffer', StrUtf8Coder())
        COUNT_STATE = userstate.CombiningValueStateSpec(
            'count', VarIntCoder(), CountCombineFn())

        def process(self,
                    element,
                    buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
                    count_state=beam.DoFn.StateParam(COUNT_STATE)):

            key, value = element
            try:
                index_value = list(buffer_state.read()).index(value)
            except:
                index_value = -1
            if index_value < 0:
                buffer_state.add(value)
                index_value = count_state.read()
                count_state.add(1)

            # print(value, list(buffer_state.read()).index(value), list(buffer_state.read()))
            yield ('{}_{}'.format(value, index_value), 1)
예제 #17
0
        class BagStateClearingStatefulDoFn(beam.DoFn):

            BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder())
            EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
            EMIT_TWICE_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)

            def process(
                self,
                element,
                bag_state=beam.DoFn.StateParam(BAG_STATE),
                emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
                emit_twice_timer=beam.DoFn.TimerParam(EMIT_TWICE_TIMER)):
                value = element[1]
                bag_state.add(value)
                emit_twice_timer.set(100)
                emit_timer.set(1000)

            @on_timer(EMIT_TWICE_TIMER)
            @on_timer(EMIT_TIMER)
            def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
                for value in bag_state.read():
                    yield value
예제 #18
0
    class HashJoinStatefulDoFn(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK)

      def process(self, element, state=DoFn.StateParam(BUFFER_STATE),
                  timer=DoFn.TimerParam(UNMATCHED_TIMER)):
        key, value = element
        existing_values = list(state.read())
        if not existing_values:
          state.add(value)
          timer.set(100)
        else:
          yield b'Record<%s,%s,%s>' % (key, existing_values[0], value)
          state.clear()
          timer.clear()

      @on_timer(UNMATCHED_TIMER)
      def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)):
        buffered = list(state.read())
        assert len(buffered) == 1, buffered
        state.clear()
        yield b'Unmatched<%s>' % (buffered[0],)
예제 #19
0
    class StatefulDoFnWithTimerWithTypo2(DoFn):
      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
      EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
      EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)

      def process(self, element,
                  timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
                  timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
        pass

      @on_timer(EXPIRY_TIMER_1)
      def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
        yield 'expired1'

      # Note that we mistakenly reuse the "on_expiry_1" name; this is valid
      # syntactically in Python.
      @on_timer(EXPIRY_TIMER_2)
      def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
        yield 'expired2'

      # Use a stable string value for matching.
      def __repr__(self):
        return 'StatefulDoFnWithTimerWithTypo2'
예제 #20
0
 def test_param_construction(self):
     with self.assertRaises(ValueError):
         DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
     with self.assertRaises(ValueError):
         DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
예제 #21
0
class GeneralTriggerManagerDoFn(DoFn):
    """A trigger manager that supports all windowing / triggering cases.

  This implements a DoFn that manages triggering in a per-key basis. All
  elements for a single key are processed together. Per-key state holds data
  related to all windows.
  """

    # TODO(BEAM-12026) Add support for Global and custom window fns.
    KNOWN_WINDOWS = SetStateSpec('known_windows', IntervalWindowCoder())
    FINISHED_WINDOWS = SetStateSpec('finished_windows', IntervalWindowCoder())
    LAST_KNOWN_TIME = CombiningValueStateSpec('last_known_time',
                                              combine_fn=max)
    LAST_KNOWN_WATERMARK = CombiningValueStateSpec('last_known_watermark',
                                                   combine_fn=max)

    # TODO(pabloem) What's the coder for the elements/keys here?
    WINDOW_ELEMENT_PAIRS = BagStateSpec(
        'all_elements', TupleCoder([IntervalWindowCoder(),
                                    PickleCoder()]))
    WINDOW_TAG_VALUES = BagStateSpec(
        'per_window_per_tag_value_state',
        TupleCoder([IntervalWindowCoder(),
                    StrUtf8Coder(),
                    VarIntCoder()]))

    PROCESSING_TIME_TIMER = TimerSpec('processing_time_timer',
                                      TimeDomain.REAL_TIME)
    WATERMARK_TIMER = TimerSpec('watermark_timer', TimeDomain.WATERMARK)

    def __init__(self, windowing: Windowing):
        self.windowing = windowing
        # Only session windows are merging. Other windows are non-merging.
        self.merging_windows = self.windowing.windowfn.is_merging()

    def process(
            self,
            element: typing.Tuple[
                K, typing.Iterable[windowed_value.WindowedValue]],
            all_elements: BagRuntimeState = DoFn.StateParam(
                WINDOW_ELEMENT_PAIRS),  # type: ignore
            latest_processing_time: AccumulatingRuntimeState = DoFn.StateParam(
                LAST_KNOWN_TIME),  # type: ignore
            latest_watermark: AccumulatingRuntimeState = DoFn.
        StateParam(  # type: ignore
            LAST_KNOWN_WATERMARK),
            window_tag_values: BagRuntimeState = DoFn.StateParam(
                WINDOW_TAG_VALUES),  # type: ignore
            windows_state: SetRuntimeState = DoFn.StateParam(
                KNOWN_WINDOWS),  # type: ignore
            finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
            processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
            watermark_timer=DoFn.TimerParam(WATERMARK_TIMER),
            *args,
            **kwargs):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        key, windowed_values = element
        watermark = read_watermark(latest_watermark)

        windows_to_elements = collections.defaultdict(list)
        for wv in windowed_values:
            for window in wv.windows:
                # ignore expired windows
                if watermark > window.end + self.windowing.allowed_lateness:
                    continue
                if window in finished_windows_state.read():
                    continue
                windows_to_elements[window].append(
                    TimestampedValue(wv.value, wv.timestamp))

        # Processing merging of windows
        if self.merging_windows:
            old_windows = set(windows_state.read())
            all_windows = old_windows.union(list(windows_to_elements))
            if all_windows != old_windows:
                merge_context = TriggerMergeContext(all_windows, context,
                                                    self.windowing)
                self.windowing.windowfn.merge(merge_context)

                merged_windows_to_elements = collections.defaultdict(list)
                for window, values in windows_to_elements.items():
                    while window in merge_context.merged_away:
                        window = merge_context.merged_away[window]
                    merged_windows_to_elements[window].extend(values)
                windows_to_elements = merged_windows_to_elements

            for w in windows_to_elements:
                windows_state.add(w)
        # Done processing merging of windows

        seen_windows = set()
        for w in windows_to_elements:
            window_context = context.for_window(w)
            seen_windows.add(w)
            for value_w_timestamp in windows_to_elements[w]:
                _LOGGER.debug(value_w_timestamp)
                all_elements.add((w, value_w_timestamp))
                self.windowing.triggerfn.on_element(windowed_values, w,
                                                    window_context)

        return self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                           watermark, None, context,
                                           seen_windows)

    def _fire_eligible_windows(self,
                               key: K,
                               time_domain,
                               timestamp: Timestamp,
                               timer_tag: typing.Optional[str],
                               context: 'FnRunnerStatefulTriggerContext',
                               windows_of_interest: typing.Optional[
                                   typing.Set[BoundedWindow]] = None):
        windows_to_elements = context.windows_to_elements_map()
        context.all_elements_state.clear()

        fired_windows = set()
        _LOGGER.debug('%s - tag %s - timestamp %s', time_domain, timer_tag,
                      timestamp)
        for w, elems in windows_to_elements.items():
            if windows_of_interest is not None and w not in windows_of_interest:
                # windows_of_interest=None means that we care about all windows.
                # If we care only about some windows, and this window is not one of
                # them, then we do not intend to fire this window.
                continue
            window_context = context.for_window(w)
            if self.windowing.triggerfn.should_fire(time_domain, timestamp, w,
                                                    window_context):
                finished = self.windowing.triggerfn.on_fire(
                    timestamp, w, window_context)
                _LOGGER.debug('Firing on window %s. Finished: %s', w, finished)
                fired_windows.add(w)
                if finished:
                    context.finished_windows_state.add(w)
                # TODO(pabloem): Format the output: e.g. pane info
                elems = [
                    WindowedValue(e.value, e.timestamp, (w, )) for e in elems
                ]
                yield (key, elems)

        finished_windows: typing.Set[BoundedWindow] = set(
            context.finished_windows_state.read())
        # Add elements that were not fired back into state.
        for w, elems in windows_to_elements.items():
            for e in elems:
                if (w in finished_windows or
                    (w in fired_windows and self.windowing.accumulation_mode
                     == AccumulationMode.DISCARDING)):
                    continue
                context.all_elements_state.add((w, e))

    @on_timer(PROCESSING_TIME_TIMER)
    def processing_time_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_processing_time=DoFn.StateParam(LAST_KNOWN_TIME),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=None,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.REAL_TIME,
                                             timestamp, timer_tag, context)
        latest_processing_time.add(timestamp)
        return result

    @on_timer(WATERMARK_TIMER)
    def watermark_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_watermark=DoFn.StateParam(LAST_KNOWN_WATERMARK),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=None,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                             timestamp, timer_tag, context)
        latest_watermark.add(timestamp)
        return result
예제 #22
0
    class SolveDoFn(beam.DoFn):
        PREV_TIMESTAMP = BagStateSpec(name="timestamp_state", coder=coders.PickleCoder())
        PREV_ELEMENTS = BagStateSpec(name="elements_state", coder=coders.PickleCoder())
        PREV_MODEL = BagStateSpec(name="model_state", coder=coders.PickleCoder())
        PREV_SAMPLESET = BagStateSpec(name="sampleset_state", coder=coders.PickleCoder())

        def process(
            self,
            value,
            timestamp=beam.DoFn.TimestampParam,
            timestamp_state=beam.DoFn.StateParam(PREV_TIMESTAMP),
            elements_state=beam.DoFn.StateParam(PREV_ELEMENTS),
            model_state=beam.DoFn.StateParam(PREV_MODEL),
            sampleset_state=beam.DoFn.StateParam(PREV_SAMPLESET),
            algorithm=None,
            algorithm_options=None,
            map_fn=None,
            solve_fn=None,
            unmap_fn=None,
            solver=LocalSolver(exact=False),  # default solver
            initial_mtype=sawatabi.constants.MODEL_ISING,
        ):
            _, elements = value

            # Sort with the event time.
            # If we sort a list of tuples, the first element of the tuple is recognized as a key by default,
            # so just `sorted` is enough.
            sorted_elements = sorted(elements)

            # generator into a list
            timestamp_state_as_list = list(timestamp_state.read())
            elements_state_as_list = list(elements_state.read())
            model_state_as_list = list(model_state.read())
            sampleset_state_as_list = list(sampleset_state.read())

            # Extract the previous timestamp, elements, and model from state
            if len(timestamp_state_as_list) == 0:
                prev_timestamp = -1.0
            else:
                prev_timestamp = timestamp_state_as_list[-1]
            if len(elements_state_as_list) == 0:
                prev_elements = []
            else:
                prev_elements = elements_state_as_list[-1]
            if len(model_state_as_list) == 0:
                prev_model = sawatabi.model.LogicalModel(mtype=initial_mtype)
            else:
                prev_model = model_state_as_list[-1]
            if len(sampleset_state_as_list) == 0:
                prev_sampleset = None
            else:
                prev_sampleset = sampleset_state_as_list[-1]

            # Sometimes, when we use the sliding window algorithm for a bounded data (such as a local file),
            # we may receive an outdated event whose timestamp is older than timestamp of previously processed event.
            if float(timestamp) < float(prev_timestamp):
                yield (
                    f"The received event is outdated: Timestamp is {timestamp.to_utc_datetime()}, "
                    + f"while an event with timestamp of {timestamp.to_utc_datetime()} has been already processed."
                )
                return

            # Algorithm specific operations
            # Incremental: Append current window into the all previous data.
            if algorithm == sawatabi.constants.ALGORITHM_INCREMENTAL:
                sorted_elements.extend(prev_elements)
                sorted_elements = sorted(sorted_elements)
            # Partial: Merge current window with the specified data.
            elif algorithm == sawatabi.constants.ALGORITHM_PARTIAL:
                filter_fn = algorithm_options["filter_fn"]
                filtered = filter(filter_fn, prev_elements)
                sorted_elements = list(filtered) + sorted_elements
                sorted_elements = sorted(sorted_elements)

            # Resolve outgoing elements in this iteration
            def resolve_outgoing(prev_elements, sorted_elements):
                outgoing = []
                for p in prev_elements:
                    if p[0] >= sorted_elements[0][0]:
                        break
                    outgoing.append(p)
                return outgoing

            outgoing = resolve_outgoing(prev_elements, sorted_elements)

            # Resolve incoming elements in this iteration
            def resolve_incoming(prev_elements, sorted_elements):
                incoming = []
                if len(prev_elements) == 0:
                    incoming = sorted_elements
                else:
                    for v in reversed(sorted_elements):
                        if v[0] <= prev_elements[-1][0]:
                            break
                        incoming.insert(0, v)
                return incoming

            incoming = resolve_incoming(prev_elements, sorted_elements)

            # Clear the BagState so we can hold only the latest state, and
            # Register new timestamp and elements to the states
            timestamp_state.clear()
            timestamp_state.add(timestamp)
            elements_state.clear()
            elements_state.add(sorted_elements)

            # Map problem input to the model
            try:
                model = map_fn(prev_model, prev_sampleset, sorted_elements, incoming, outgoing)
            except Exception as e:
                yield f"Failed to map: {e}\n{traceback.format_exc()}"
                return

            # Clear the BagState so we can hold only the latest state, and
            # Register new model to the state
            model_state.clear()
            model_state.add(model)

            # Algorithm specific operations
            # Attenuation: Update scale based on data timestamp.
            if algorithm == sawatabi.constants.ALGORITHM_ATTENUATION:
                model.to_physical()  # Resolve removed interactions. TODO: Deal with placeholders.
                ref_timestamp = model._interactions_array[algorithm_options["attenuation.key"]]
                min_ts = min(ref_timestamp)
                max_ts = max(ref_timestamp)
                min_scale = algorithm_options["attenuation.min_scale"]
                if min_ts < max_ts:
                    for i, t in enumerate(ref_timestamp):
                        new_scale = (1.0 - min_scale) / (max_ts - min_ts) * (t - min_ts) + min_scale
                        model._interactions_array["scale"][i] = new_scale

            # Solve and unmap to the solution
            try:
                sampleset = solve_fn(solver, model, prev_sampleset, sorted_elements, incoming, outgoing)
            except Exception as e:
                yield f"Failed to solve: {e}\n{traceback.format_exc()}"
                return

            # Clear the BagState so we can hold only the latest state, and
            # Register new sampleset to the state
            sampleset_state.clear()
            sampleset_state.add(sampleset)

            try:
                yield unmap_fn(sampleset, sorted_elements, incoming, outgoing)
            except Exception as e:
                yield f"Failed to unmap: {e}\n{traceback.format_exc()}"
예제 #23
0
def _pardo_group_into_batches_with_multi_bags(
  input_coder, batch_size, max_buffering_duration_secs, clock=time.time):
  ELEMENT_STATE_0 = BagStateSpec('values0', input_coder)
  ELEMENT_STATE_1 = BagStateSpec('values1', input_coder)
  ELEMENT_STATE_2 = BagStateSpec('values2', input_coder)
  ELEMENT_STATE_3 = BagStateSpec('values3', input_coder)
  COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
  WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK)
  BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME)

  class _GroupIntoBatchesDoFnWithMultiBags(DoFn):
    def process(
        self,
        element,
        window=DoFn.WindowParam,
        element_state_0=DoFn.StateParam(ELEMENT_STATE_0),
        element_state_1=DoFn.StateParam(ELEMENT_STATE_1),
        element_state_2=DoFn.StateParam(ELEMENT_STATE_2),
        element_state_3=DoFn.StateParam(ELEMENT_STATE_3),
        count_state=DoFn.StateParam(COUNT_STATE),
        window_timer=DoFn.TimerParam(WINDOW_TIMER),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
      # Allowed lateness not supported in Python SDK
      # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
      window_timer.set(window.end)

      count_state.add(1)
      count = count_state.read()

      element_states = [element_state_0, element_state_1, element_state_2, element_state_3]
      element_states[count % 4].add(element)

      if count == 1 and max_buffering_duration_secs > 0:
        # This is the first element in batch. Start counting buffering time if a
        # limit was set.
        buffering_timer.set(clock() + max_buffering_duration_secs)
      if count >= batch_size:
        return self.flush_batch(element_states, count_state, buffering_timer)

    @on_timer(WINDOW_TIMER)
    def on_window_timer(
        self,
        element_state_0=DoFn.StateParam(ELEMENT_STATE_0),
        element_state_1=DoFn.StateParam(ELEMENT_STATE_1),
        element_state_2=DoFn.StateParam(ELEMENT_STATE_2),
        element_state_3=DoFn.StateParam(ELEMENT_STATE_3),
        count_state=DoFn.StateParam(COUNT_STATE),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):

      element_states = [element_state_0, element_state_1, element_state_2, element_state_3]
      return self.flush_batch(element_states, count_state, buffering_timer)

    @on_timer(BUFFERING_TIMER)
    def on_buffering_timer(
        self,
        element_state_0=DoFn.StateParam(ELEMENT_STATE_0),
        element_state_1=DoFn.StateParam(ELEMENT_STATE_1),
        element_state_2=DoFn.StateParam(ELEMENT_STATE_2),
        element_state_3=DoFn.StateParam(ELEMENT_STATE_3),
        count_state=DoFn.StateParam(COUNT_STATE),
        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):

      element_states = [element_state_0, element_state_1, element_state_2, element_state_3]
      return self.flush_batch(element_states, count_state, buffering_timer)

    def flush_batch(self, element_states, count_state, buffering_timer):
      batch_values = []
      for element_state in element_states:
        for k, v in element_state.read():
          key = k
          batch_values.append(v)
        element_state.clear()

      count_state.clear()
      buffering_timer.clear()

      if not batch_values:
        return

      yield key, batch_values

  return _GroupIntoBatchesDoFnWithMultiBags()