Exemplo n.º 1
0
    def test_state_spec_proto_conversion(self):
        context = pipeline_context.PipelineContext()
        state = BagStateSpec('statename', VarIntCoder())
        state_proto = state.to_runner_api(context)
        self.assertEqual(
            beam_runner_api_pb2.FunctionSpec(
                urn=common_urns.user_state.BAG.urn), state_proto.protocol)

        context = pipeline_context.PipelineContext()
        state = CombiningValueStateSpec('statename', VarIntCoder(),
                                        TopCombineFn(10))
        state_proto = state.to_runner_api(context)
        self.assertEqual(
            beam_runner_api_pb2.FunctionSpec(
                urn=common_urns.user_state.BAG.urn), state_proto.protocol)

        context = pipeline_context.PipelineContext()
        state = SetStateSpec('setstatename', VarIntCoder())
        state_proto = state.to_runner_api(context)
        self.assertEqual(
            beam_runner_api_pb2.FunctionSpec(
                urn=common_urns.user_state.BAG.urn), state_proto.protocol)

        context = pipeline_context.PipelineContext()
        state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
        state_proto = state.to_runner_api(context)
        self.assertEqual(
            beam_runner_api_pb2.FunctionSpec(
                urn=common_urns.user_state.BAG.urn), state_proto.protocol)
Exemplo n.º 2
0
  def expand(self, pbegin):
    if not isinstance(pbegin, pvalue.PBegin):
      raise Exception("GenerateSequence must be a root transform")

    coder = VarIntCoder()
    coder_urn = ['beam:coder:varint:v1']
    args = {
        'start':
        ConfigValue(
            coder_urn=coder_urn,
            payload=coder.encode(self.start))
    }
    if self.stop:
      args['stop'] = ConfigValue(
          coder_urn=coder_urn,
          payload=coder.encode(self.stop))
    if self.elements_per_period:
      args['elements_per_period'] = ConfigValue(
          coder_urn=coder_urn,
          payload=coder.encode(self.elements_per_period))
    if self.max_read_time:
      args['max_read_time'] = ConfigValue(
          coder_urn=coder_urn,
          payload=coder.encode(self.max_read_time))

    payload = ExternalConfigurationPayload(configuration=args)
    return pbegin.apply(
        ExternalTransform(
            self._urn,
            payload.SerializeToString(),
            self.expansion_service))
Exemplo n.º 3
0
    def test_extend_fetches_initial_state(self):
        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class UnderlyingStateHandler(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_value(self, value):
                self._encoded_values = coder.encode(value)

            def get_raw(self, *args):
                return self._encoded_values, None

            def append_raw(self, _key, bytes):
                self._encoded_values += bytes

            def clear(self, *args):
                self._encoded_values = bytes()

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state_handler = UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        def get():
            return list(handler.blocking_get(state, coder_impl, True))

        def append(value):
            handler.extend(state, coder_impl, [value], True)

        def clear():
            handler.clear(state, True)

        # Initialize state
        underlying_state_handler.set_value(42)
        with handler.process_instruction_id('bundle', [cache_token]):
            # Append without reading beforehand
            append(43)
            self.assertEqual(get(), [42, 43])
            clear()
            self.assertEqual(get(), [])
            append(44)
            self.assertEqual(get(), [44])
Exemplo n.º 4
0
    def test_spec_construction(self):
        BagStateSpec('statename', VarIntCoder())
        with self.assertRaises(TypeError):
            BagStateSpec(123, VarIntCoder())

        CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10))
        with self.assertRaises(TypeError):
            CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
        with self.assertRaises(TypeError):
            CombiningValueStateSpec('statename', VarIntCoder(), object())

        SetStateSpec('setstatename', VarIntCoder())
        with self.assertRaises(TypeError):
            SetStateSpec(123, VarIntCoder())
        with self.assertRaises(TypeError):
            SetStateSpec('setstatename', object())

        ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
        with self.assertRaises(TypeError):
            ReadModifyWriteStateSpec(123, VarIntCoder())
        with self.assertRaises(TypeError):
            ReadModifyWriteStateSpec('valuestatename', object())

        # TODO: add more spec tests
        with self.assertRaises(ValueError):
            DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))

        TimerSpec('timer', TimeDomain.WATERMARK)
        TimerSpec('timer', TimeDomain.REAL_TIME)
        with self.assertRaises(ValueError):
            TimerSpec('timer', 'bogus_time_domain')
        with self.assertRaises(ValueError):
            DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
Exemplo n.º 5
0
class CountAndSchedule(beam.DoFn):

    COUNTER = BagStateSpec('counter', VarIntCoder())
    SCHEDULED_TIMESTAMP = BagStateSpec('nextSchedule', VarIntCoder())
    TIMER = TimerSpec('timer', TimeDomain.WATERMARK)

    def process(self,
                element,
                timestamp=beam.DoFn.TimestampParam,
                timer=beam.DoFn.TimerParam(TIMER),
                counter=beam.DoFn.StateParam(COUNTER),
                next_schedule=beam.DoFn.StateParam(SCHEDULED_TIMESTAMP),
                *args,
                **kwargs):
        current_count, = list(counter.read()) or [0]
        counter.clear()
        counter.add(current_count + 1)

        event_datetime = timestamp.to_utc_datetime()
        current_hour_end = event_datetime.replace(
            second=0, microsecond=0) + timedelta(minutes=1)

        next_tick = calendar.timegm(current_hour_end.timetuple())
        timer.set(next_tick)

        next_schedule.clear()
        next_schedule.add(next_tick)

    @on_timer(TIMER)
    def timer_ticked(self,
                     timer=beam.DoFn.TimerParam(TIMER),
                     counter=beam.DoFn.StateParam(COUNTER),
                     next_schedule=beam.DoFn.StateParam(SCHEDULED_TIMESTAMP)):
        print("TICKTICK")
        current_count, = counter.read()
        this_tick, = next_schedule.read()

        next_tick = this_tick + 60

        next_schedule.clear()
        next_schedule.add(next_tick)

        counter.clear()
        counter.add(0)

        timer.clear()
        timer.set(next_tick)

        yield {'count': current_count, 'timestamp': this_tick}
Exemplo n.º 6
0
 def test_implicit_payload_builder_with_bytes(self):
   values = PayloadBase.bytes_values
   builder = ImplicitSchemaPayloadBuilder(values)
   result = builder.build()
   if sys.version_info[0] < 3:
     # in python 2.x bytes coder will be inferred
     args = {
         'integer_example': ConfigValue(
             coder_urn=['beam:coder:varint:v1'],
             payload=VarIntCoder()
             .get_impl().encode_nested(values['integer_example'])),
         'string_example': ConfigValue(
             coder_urn=['beam:coder:bytes:v1'],
             payload=StrUtf8Coder()
             .get_impl().encode_nested(values['string_example'])),
         'list_of_strings': ConfigValue(
             coder_urn=['beam:coder:iterable:v1',
                        'beam:coder:bytes:v1'],
             payload=IterableCoder(StrUtf8Coder())
             .get_impl().encode_nested(values['list_of_strings'])),
         'optional_kv': ConfigValue(
             coder_urn=['beam:coder:kv:v1',
                        'beam:coder:bytes:v1',
                        'beam:coder:double:v1'],
             payload=TupleCoder([StrUtf8Coder(), FloatCoder()])
             .get_impl().encode_nested(values['optional_kv'])),
     }
     expected = get_payload(args)
     self.assertEqual(result, expected)
   else:
     expected = get_payload(PayloadBase.args)
     self.assertEqual(result, expected)
Exemplo n.º 7
0
        class GenerateRecords(beam.DoFn):

            EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.REAL_TIME)
            COUNT_STATE = CombiningValueStateSpec('count_state', VarIntCoder(),
                                                  CountCombineFn())

            def __init__(self, frequency, total_records):
                self.total_records = total_records
                self.frequency = frequency

            def process(self,
                        element,
                        emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
                # Processing time timers should be set on ABSOLUTE TIME.
                emit_timer.set(self.frequency)
                yield element[1]

            @on_timer(EMIT_TIMER)
            def emit_values(self,
                            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
                            count_state=beam.DoFn.StateParam(COUNT_STATE)):
                count = count_state.read() or 0
                if self.total_records == count:
                    return

                count_state.add(1)
                # Processing time timers should be set on ABSOLUTE TIME.
                emit_timer.set(count + 1 + self.frequency)
                yield 'value'
Exemplo n.º 8
0
    class IndexAssigningStatefulDoFn(DoFn):
      INDEX_STATE = BagStateSpec('index', VarIntCoder())

      def process(self, element, state=DoFn.StateParam(INDEX_STATE)):
        unused_key, value = element
        next_index, = list(state.read()) or [0]
        yield (value, next_index)
        state.clear()
        state.add(next_index + 1)
Exemplo n.º 9
0
class TestStatefulDoFn(DoFn):
  """An example stateful DoFn with state and timers."""

  BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder())
  BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder())
  EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
  EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
  EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK)
  EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family', TimeDomain.WATERMARK)

  def process(
      self,
      element,
      t=DoFn.TimestampParam,
      buffer_1=DoFn.StateParam(BUFFER_STATE_1),
      buffer_2=DoFn.StateParam(BUFFER_STATE_2),
      timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
      timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
      dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
    yield element

  @on_timer(EXPIRY_TIMER_1)
  def on_expiry_1(
      self,
      window=DoFn.WindowParam,
      timestamp=DoFn.TimestampParam,
      key=DoFn.KeyParam,
      buffer=DoFn.StateParam(BUFFER_STATE_1),
      timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
      timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
      timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
    yield 'expired1'

  @on_timer(EXPIRY_TIMER_2)
  def on_expiry_2(
      self,
      buffer=DoFn.StateParam(BUFFER_STATE_2),
      timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
      timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
    yield 'expired2'

  @on_timer(EXPIRY_TIMER_3)
  def on_expiry_3(
      self,
      buffer_1=DoFn.StateParam(BUFFER_STATE_1),
      buffer_2=DoFn.StateParam(BUFFER_STATE_2),
      timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
    yield 'expired3'

  @on_timer(EXPIRY_TIMER_FAMILY)
  def on_expiry_family(
      self,
      dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY),
      dynamic_timer_tag=DoFn.DynamicTimerTagParam):
    yield (dynamic_timer_tag, 'expired_dynamic_timer')
Exemplo n.º 10
0
    class SetStatefulDoFn(beam.DoFn):

      SET_STATE = SetStateSpec('buffer', VarIntCoder())

      def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)):
        _, value = element
        aggregated_value = 0
        set_state.add(value)
        for saved_value in set_state.read():
          aggregated_value += saved_value
        yield aggregated_value
Exemplo n.º 11
0
class StatefulPrintDoFn(beam.DoFn):
    COUNTER_SPEC = ReadModifyWriteStateSpec('counter', VarIntCoder())

    def __init__(self, step_name):
        self._step_name = step_name

    def process(self, element, counter=beam.DoFn.StateParam(COUNTER_SPEC)):
        current_count = counter.read() or 0
        logging.info('Print [%s] (counter:%d): %s', self._step_name,
                     current_count, element)
        counter.write(current_count + 1)
Exemplo n.º 12
0
    def __init__(self,
                 start,
                 stop=None,
                 elements_per_period=None,
                 max_read_time=None,
                 expansion_service=None):
        coder = VarIntCoder()
        coder_urn = 'beam:coder:varint:v1'
        args = {
            'start': ConfigValue(coder_urn=coder_urn,
                                 payload=coder.encode(start))
        }
        if stop:
            args['stop'] = ConfigValue(coder_urn=coder_urn,
                                       payload=coder.encode(stop))
        if elements_per_period:
            args['elements_per_period'] = ConfigValue(
                coder_urn=coder_urn, payload=coder.encode(elements_per_period))
        if max_read_time:
            args['max_read_time'] = ConfigValue(
                coder_urn=coder_urn, payload=coder.encode(max_read_time))

        payload = ExternalConfigurationPayload(configuration=args)
        super(GenerateSequence,
              self).__init__('beam:external:java:generate_sequence:v1',
                             payload.SerializeToString(), expansion_service)
Exemplo n.º 13
0
    def expand(self, pbegin):
        if not isinstance(pbegin, pvalue.PBegin):
            raise Exception("GenerateSequence must be a root transform")

        coder = VarIntCoder()
        coder_urn = ['beam:coder:varint:v1']
        args = {
            'start':
            ConfigValue(coder_urn=coder_urn, payload=coder.encode(self.start))
        }
        if self.stop:
            args['stop'] = ConfigValue(coder_urn=coder_urn,
                                       payload=coder.encode(self.stop))
        if self.elements_per_period:
            args['elements_per_period'] = ConfigValue(
                coder_urn=coder_urn,
                payload=coder.encode(self.elements_per_period))
        if self.max_read_time:
            args['max_read_time'] = ConfigValue(coder_urn=coder_urn,
                                                payload=coder.encode(
                                                    self.max_read_time))

        payload = ExternalConfigurationPayload(configuration=args)
        return pbegin.apply(
            ExternalTransform(self._urn, payload.SerializeToString(),
                              self.expansion_service))
Exemplo n.º 14
0
    def test_continuation_token(self):
        underlying_state_handler = self.UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        coder = VarIntCoder()

        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        def get(materialize=True):
            result = handler.blocking_get(state, coder.get_impl())
            return list(result) if materialize else result

        def get_type():
            return type(get(materialize=False))

        def append(*values):
            handler.extend(state, coder.get_impl(), values)

        def clear():
            handler.clear(state)

        underlying_state_handler.set_continuations(True)
        underlying_state_handler.set_values([45, 46, 47], coder)
        with handler.process_instruction_id('bundle', [cache_token]):
            self.assertEqual(get_type(),
                             CachingStateHandler.ContinuationIterable)
            self.assertEqual(get(), [45, 46, 47])
            append(48, 49)
            self.assertEqual(get_type(),
                             CachingStateHandler.ContinuationIterable)
            self.assertEqual(get(), [45, 46, 47, 48, 49])
            clear()
            self.assertEqual(get_type(), list)
            self.assertEqual(get(), [])
            append(1)
            self.assertEqual(get(), [1])
            append(2, 3)
            self.assertEqual(get(), [1, 2, 3])
            clear()
            for i in range(1000):
                append(i)
            self.assertEqual(get_type(), list)
            self.assertEqual(get(), [i for i in range(1000)])
Exemplo n.º 15
0
    class SimpleTestSetStatefulDoFn(DoFn):
      BUFFER_STATE = SetStateSpec('buffer', VarIntCoder())
      EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)

      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
        unused_key, value = element
        buffer.add(value)
        timer1.set(20)

      @on_timer(EXPIRY_TIMER)
      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)):
        yield sorted(buffer.read())
Exemplo n.º 16
0
def from_proto(field_type):
    """
    Creates the corresponding :class:`Coder` given the protocol representation of the field type.

    :param field_type: the protocol representation of the field type
    :return: :class:`Coder`
    """
    if field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.BIGINT:
        return VarIntCoder()
    elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW:
        return RowCoder(
            [from_proto(f.type) for f in field_type.row_schema.fields])
    else:
        raise ValueError("field_type %s is not supported." % field_type)
Exemplo n.º 17
0
    class SimpleTestStatefulDoFn(DoFn):
      BUFFER_STATE = CombiningValueStateSpec(
          'buffer',
          IterableCoder(VarIntCoder()), ToListCombineFn())
      EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)

      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
        unused_key, value = element
        buffer.add(value)
        timer1.set(20)

      @on_timer(EXPIRY_TIMER)
      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE),
                          timer=DoFn.TimerParam(EXPIRY_TIMER)):
        yield ''.join(str(x) for x in sorted(buffer.read()))
class BagInStateOutputAfterTimer(beam.DoFn):

    SET_STATE = SetStateSpec('buffer', VarIntCoder())
    EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)

    def process(self,
                element,
                set_state=beam.DoFn.StateParam(SET_STATE),
                emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
        _, values = element
        for v in values:
            set_state.add(v)
        emit_timer.set(1)

    @on_timer(EMIT_TIMER)
    def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
        values = set_state.read()
        return [(random.randint(0, 1000), v) for v in values]
Exemplo n.º 19
0
    def test_append_clear_with_preexisting_state(self):
        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        coder = VarIntCoder()

        underlying_state_handler = self.UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        def get():
            return handler.blocking_get(state, coder.get_impl())

        def append(iterable):
            handler.extend(state, coder.get_impl(), iterable)

        def clear():
            handler.clear(state)

        # Initialize state
        underlying_state_handler.set_value(42, coder)
        with handler.process_instruction_id('bundle', [cache_token]):
            # Append without reading beforehand
            append([43])
            self.assertEqual(get(), [42, 43])
            clear()
            self.assertEqual(get(), [])
            append([44, 45])
            self.assertEqual(get(), [44, 45])
            append((46, 47))
            self.assertEqual(get(), [44, 45, 46, 47])
            clear()
            append(range(1000))
            self.assertEqual(get(), list(range(1000)))
    class StatefulBufferingFn(DoFn):
        BUFFER_STATE = BagStateSpec('buffer', StrUtf8Coder())
        COUNT_STATE = userstate.CombiningValueStateSpec(
            'count', VarIntCoder(), CountCombineFn())

        def process(self,
                    element,
                    buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
                    count_state=beam.DoFn.StateParam(COUNT_STATE)):

            key, value = element
            try:
                index_value = list(buffer_state.read()).index(value)
            except:
                index_value = -1
            if index_value < 0:
                buffer_state.add(value)
                index_value = count_state.read()
                count_state.add(1)

            # print(value, list(buffer_state.read()).index(value), list(buffer_state.read()))
            yield ('{}_{}'.format(value, index_value), 1)
Exemplo n.º 21
0
class TestStatefulDoFn(DoFn):
    """An example stateful DoFn with state and timers."""

    BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder())
    BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder())
    EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
    EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
    EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK)

    def process(self,
                element,
                t=DoFn.TimestampParam,
                buffer_1=DoFn.StateParam(BUFFER_STATE_1),
                buffer_2=DoFn.StateParam(BUFFER_STATE_2),
                timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
                timer_2=DoFn.TimerParam(EXPIRY_TIMER_2)):
        yield element

    @on_timer(EXPIRY_TIMER_1)
    def on_expiry_1(self,
                    buffer=DoFn.StateParam(BUFFER_STATE_1),
                    timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
                    timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired1'

    @on_timer(EXPIRY_TIMER_2)
    def on_expiry_2(self,
                    buffer=DoFn.StateParam(BUFFER_STATE_2),
                    timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired2'

    @on_timer(EXPIRY_TIMER_3)
    def on_expiry_3(self,
                    buffer_1=DoFn.StateParam(BUFFER_STATE_1),
                    buffer_2=DoFn.StateParam(BUFFER_STATE_2),
                    timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
        yield 'expired3'
Exemplo n.º 22
0
        class SetStateClearingStatefulDoFn(beam.DoFn):

            SET_STATE = SetStateSpec('buffer', VarIntCoder())
            EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)

            def process(self,
                        element,
                        set_state=beam.DoFn.StateParam(SET_STATE),
                        emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
                _, value = element
                set_state.add(value)

                all_elements = [element for element in set_state.read()]

                if len(all_elements) == 5:
                    set_state.clear()
                    set_state.add(100)
                    emit_timer.set(1)

            @on_timer(EMIT_TIMER)
            def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
                yield sorted(set_state.read())
Exemplo n.º 23
0
class GeneralTriggerManagerDoFn(DoFn):
    """A trigger manager that supports all windowing / triggering cases.

  This implements a DoFn that manages triggering in a per-key basis. All
  elements for a single key are processed together. Per-key state holds data
  related to all windows.
  """

    # TODO(BEAM-12026) Add support for Global and custom window fns.
    KNOWN_WINDOWS = SetStateSpec('known_windows', IntervalWindowCoder())
    FINISHED_WINDOWS = SetStateSpec('finished_windows', IntervalWindowCoder())
    LAST_KNOWN_TIME = CombiningValueStateSpec('last_known_time',
                                              combine_fn=max)
    LAST_KNOWN_WATERMARK = CombiningValueStateSpec('last_known_watermark',
                                                   combine_fn=max)

    # TODO(pabloem) What's the coder for the elements/keys here?
    WINDOW_ELEMENT_PAIRS = BagStateSpec(
        'all_elements', TupleCoder([IntervalWindowCoder(),
                                    PickleCoder()]))
    WINDOW_TAG_VALUES = BagStateSpec(
        'per_window_per_tag_value_state',
        TupleCoder([IntervalWindowCoder(),
                    StrUtf8Coder(),
                    VarIntCoder()]))

    PROCESSING_TIME_TIMER = TimerSpec('processing_time_timer',
                                      TimeDomain.REAL_TIME)
    WATERMARK_TIMER = TimerSpec('watermark_timer', TimeDomain.WATERMARK)

    def __init__(self, windowing: Windowing):
        self.windowing = windowing
        # Only session windows are merging. Other windows are non-merging.
        self.merging_windows = self.windowing.windowfn.is_merging()

    def process(
            self,
            element: typing.Tuple[
                K, typing.Iterable[windowed_value.WindowedValue]],
            all_elements: BagRuntimeState = DoFn.StateParam(
                WINDOW_ELEMENT_PAIRS),  # type: ignore
            latest_processing_time: AccumulatingRuntimeState = DoFn.StateParam(
                LAST_KNOWN_TIME),  # type: ignore
            latest_watermark: AccumulatingRuntimeState = DoFn.
        StateParam(  # type: ignore
            LAST_KNOWN_WATERMARK),
            window_tag_values: BagRuntimeState = DoFn.StateParam(
                WINDOW_TAG_VALUES),  # type: ignore
            windows_state: SetRuntimeState = DoFn.StateParam(
                KNOWN_WINDOWS),  # type: ignore
            finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
            processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
            watermark_timer=DoFn.TimerParam(WATERMARK_TIMER),
            *args,
            **kwargs):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        key, windowed_values = element
        watermark = read_watermark(latest_watermark)

        windows_to_elements = collections.defaultdict(list)
        for wv in windowed_values:
            for window in wv.windows:
                # ignore expired windows
                if watermark > window.end + self.windowing.allowed_lateness:
                    continue
                if window in finished_windows_state.read():
                    continue
                windows_to_elements[window].append(
                    TimestampedValue(wv.value, wv.timestamp))

        # Processing merging of windows
        if self.merging_windows:
            old_windows = set(windows_state.read())
            all_windows = old_windows.union(list(windows_to_elements))
            if all_windows != old_windows:
                merge_context = TriggerMergeContext(all_windows, context,
                                                    self.windowing)
                self.windowing.windowfn.merge(merge_context)

                merged_windows_to_elements = collections.defaultdict(list)
                for window, values in windows_to_elements.items():
                    while window in merge_context.merged_away:
                        window = merge_context.merged_away[window]
                    merged_windows_to_elements[window].extend(values)
                windows_to_elements = merged_windows_to_elements

            for w in windows_to_elements:
                windows_state.add(w)
        # Done processing merging of windows

        seen_windows = set()
        for w in windows_to_elements:
            window_context = context.for_window(w)
            seen_windows.add(w)
            for value_w_timestamp in windows_to_elements[w]:
                _LOGGER.debug(value_w_timestamp)
                all_elements.add((w, value_w_timestamp))
                self.windowing.triggerfn.on_element(windowed_values, w,
                                                    window_context)

        return self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                           watermark, None, context,
                                           seen_windows)

    def _fire_eligible_windows(self,
                               key: K,
                               time_domain,
                               timestamp: Timestamp,
                               timer_tag: typing.Optional[str],
                               context: 'FnRunnerStatefulTriggerContext',
                               windows_of_interest: typing.Optional[
                                   typing.Set[BoundedWindow]] = None):
        windows_to_elements = context.windows_to_elements_map()
        context.all_elements_state.clear()

        fired_windows = set()
        _LOGGER.debug('%s - tag %s - timestamp %s', time_domain, timer_tag,
                      timestamp)
        for w, elems in windows_to_elements.items():
            if windows_of_interest is not None and w not in windows_of_interest:
                # windows_of_interest=None means that we care about all windows.
                # If we care only about some windows, and this window is not one of
                # them, then we do not intend to fire this window.
                continue
            window_context = context.for_window(w)
            if self.windowing.triggerfn.should_fire(time_domain, timestamp, w,
                                                    window_context):
                finished = self.windowing.triggerfn.on_fire(
                    timestamp, w, window_context)
                _LOGGER.debug('Firing on window %s. Finished: %s', w, finished)
                fired_windows.add(w)
                if finished:
                    context.finished_windows_state.add(w)
                # TODO(pabloem): Format the output: e.g. pane info
                elems = [
                    WindowedValue(e.value, e.timestamp, (w, )) for e in elems
                ]
                yield (key, elems)

        finished_windows: typing.Set[BoundedWindow] = set(
            context.finished_windows_state.read())
        # Add elements that were not fired back into state.
        for w, elems in windows_to_elements.items():
            for e in elems:
                if (w in finished_windows or
                    (w in fired_windows and self.windowing.accumulation_mode
                     == AccumulationMode.DISCARDING)):
                    continue
                context.all_elements_state.add((w, e))

    @on_timer(PROCESSING_TIME_TIMER)
    def processing_time_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_processing_time=DoFn.StateParam(LAST_KNOWN_TIME),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=None,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.REAL_TIME,
                                             timestamp, timer_tag, context)
        latest_processing_time.add(timestamp)
        return result

    @on_timer(WATERMARK_TIMER)
    def watermark_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_watermark=DoFn.StateParam(LAST_KNOWN_WATERMARK),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=None,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                             timestamp, timer_tag, context)
        latest_watermark.add(timestamp)
        return result
Exemplo n.º 24
0
    def test_caching(self):

        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class FakeUnderlyingState(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_counter(self, n):
                self._counter = n

            def get_raw(self, *args):
                self._counter += 1
                return coder.encode(self._counter), None

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state = FakeUnderlyingState()
        state_cache = statecache.StateCache(100)
        caching_state_hander = sdk_worker.CachingStateHandler(
            state_cache, underlying_state)

        state1 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))
        state2 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state2'))
        side1 = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id='transform', side_input_id='side1'))
        side2 = beam_fn_api_pb2.StateKey(
            iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                transform_id='transform', side_input_id='side2'))

        state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token2',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token1',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))
        side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token2',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))

        def get_as_list(key):
            return list(caching_state_hander.blocking_get(key, coder_impl))

        underlying_state.set_counter(100)
        with caching_state_hander.process_instruction_id('bundle1', []):
            self.assertEqual(get_as_list(state1), [101])  # uncached
            self.assertEqual(get_as_list(state2), [102])  # uncached
            self.assertEqual(get_as_list(state1), [101])  # cached on bundle
            self.assertEqual(get_as_list(side1), [103])  # uncached
            self.assertEqual(get_as_list(side2), [104])  # uncached

        underlying_state.set_counter(200)
        with caching_state_hander.process_instruction_id(
                'bundle2', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1), [201])  # uncached
            self.assertEqual(get_as_list(state2), [202])  # uncached
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1), [203])  # uncached
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [204])  # uncached
            self.assertEqual(get_as_list(side2), [204])  # cached on bundle

        underlying_state.set_counter(300)
        with caching_state_hander.process_instruction_id(
                'bundle3', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(state2),
                             [202])  # cached on state token1
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [301])  # uncached
            self.assertEqual(get_as_list(side2), [301])  # cached on bundle

        underlying_state.set_counter(400)
        with caching_state_hander.process_instruction_id(
                'bundle4', [state_token2, side1_token1]):
            self.assertEqual(get_as_list(state1), [401])  # uncached
            self.assertEqual(get_as_list(state2), [402])  # uncached
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [403])  # uncached
            self.assertEqual(get_as_list(side2), [403])  # cached on bundle

        underlying_state.set_counter(500)
        with caching_state_hander.process_instruction_id(
                'bundle5', [state_token2, side1_token2]):
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(state2),
                             [402])  # cached on state token2
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1), [501])  # uncached
            self.assertEqual(get_as_list(side1),
                             [501])  # cached on side1_token2
            self.assertEqual(get_as_list(side2), [502])  # uncached
            self.assertEqual(get_as_list(side2), [502])  # cached on bundle
Exemplo n.º 25
0
class PayloadBase(object):
    values = {
        'integer_example': 1,
        'boolean': True,
        'string_example': u'thing',
        'list_of_strings': [u'foo', u'bar'],
        'optional_kv': (u'key', 1.1),
        'optional_integer': None,
    }

    bytes_values = {
        'integer_example': 1,
        'boolean': True,
        'string_example': 'thing',
        'list_of_strings': ['foo', 'bar'],
        'optional_kv': ('key', 1.1),
        'optional_integer': None,
    }

    args = {
        'integer_example':
        ConfigValue(coder_urn=['beam:coder:varint:v1'],
                    payload=VarIntCoder().get_impl().encode_nested(
                        values['integer_example'])),
        'boolean':
        ConfigValue(coder_urn=['beam:coder:bool:v1'],
                    payload=BooleanCoder().get_impl().encode_nested(
                        values['boolean'])),
        'string_example':
        ConfigValue(coder_urn=['beam:coder:string_utf8:v1'],
                    payload=StrUtf8Coder().get_impl().encode_nested(
                        values['string_example'])),
        'list_of_strings':
        ConfigValue(
            coder_urn=['beam:coder:iterable:v1', 'beam:coder:string_utf8:v1'],
            payload=IterableCoder(StrUtf8Coder()).get_impl().encode_nested(
                values['list_of_strings'])),
        'optional_kv':
        ConfigValue(coder_urn=[
            'beam:coder:kv:v1', 'beam:coder:string_utf8:v1',
            'beam:coder:double:v1'
        ],
                    payload=TupleCoder([
                        StrUtf8Coder(), FloatCoder()
                    ]).get_impl().encode_nested(values['optional_kv'])),
    }

    def get_payload_from_typing_hints(self, values):
        """Return ExternalConfigurationPayload based on python typing hints"""
        raise NotImplementedError

    def get_payload_from_beam_typehints(self, values):
        """Return ExternalConfigurationPayload based on beam typehints"""
        raise NotImplementedError

    def test_typing_payload_builder(self):
        result = self.get_payload_from_typing_hints(self.values)
        expected = get_payload(self.args)
        self.assertEqual(result, expected)

    def test_typing_payload_builder_with_bytes(self):
        """
    string_utf8 coder will be used even if values are not unicode in python 2.x
    """
        result = self.get_payload_from_typing_hints(self.bytes_values)
        expected = get_payload(self.args)
        self.assertEqual(result, expected)

    def test_typehints_payload_builder(self):
        result = self.get_payload_from_beam_typehints(self.values)
        expected = get_payload(self.args)
        self.assertEqual(result, expected)

    def test_typehints_payload_builder_with_bytes(self):
        """
    string_utf8 coder will be used even if values are not unicode in python 2.x
    """
        result = self.get_payload_from_beam_typehints(self.bytes_values)
        expected = get_payload(self.args)
        self.assertEqual(result, expected)

    def test_optional_error(self):
        """
    value can only be None if typehint is Optional
    """
        with self.assertRaises(RuntimeError):
            self.get_payload_from_typing_hints({k: None for k in self.values})
Exemplo n.º 26
0
        def test_metrics(self):
            """Run a simple DoFn that increments a counter and verifies state
      caching metrics. Verifies that its expected value is written to a
      temporary file by the FileReporter"""

            counter_name = 'elem_counter'
            state_spec = userstate.BagStateSpec('state', VarIntCoder())

            class DoFn(beam.DoFn):
                def __init__(self):
                    self.counter = Metrics.counter(self.__class__,
                                                   counter_name)
                    _LOGGER.info('counter: %s' % self.counter.metric_name)

                def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
                    # Trigger materialization
                    list(state.read())
                    state.add(1)
                    self.counter.inc()

            options = self.create_options()
            # Test only supports parallelism of 1
            options._all_options['parallelism'] = 1
            # Create multiple bundles to test cache metrics
            options._all_options['max_bundle_size'] = 10
            options._all_options['max_bundle_time_millis'] = 95130590130
            experiments = options.view_as(DebugOptions).experiments or []
            experiments.append('state_cache_size=123')
            options.view_as(DebugOptions).experiments = experiments
            with Pipeline(self.get_runner(), options) as p:
                # pylint: disable=expression-not-assigned
                (p
                 | "create" >> beam.Create(list(range(0, 110)))
                 | "mapper" >> beam.Map(lambda x: (x % 10, 'val'))
                 | "stateful" >> beam.ParDo(DoFn()))

            lines_expected = {'counter: 110'}
            if streaming:
                lines_expected.update([
                    # Gauges for the last finished bundle
                    'stateful.beam.metric:statecache:capacity: 123',
                    # These are off by 10 because the first bundle contains all the keys
                    # once. Caching is only initialized after the first bundle. Caching
                    # depends on the cache token which is lazily initialized by the
                    # Runner's StateRequestHandlers.
                    'stateful.beam.metric:statecache:size: 10',
                    'stateful.beam.metric:statecache:get: 10',
                    'stateful.beam.metric:statecache:miss: 0',
                    'stateful.beam.metric:statecache:hit: 10',
                    'stateful.beam.metric:statecache:put: 0',
                    'stateful.beam.metric:statecache:extend: 10',
                    'stateful.beam.metric:statecache:evict: 0',
                    # Counters
                    # (total of get/hit will be off by 10 due to the caching
                    # only getting initialized after the first bundle.
                    # Caching depends on the cache token which is lazily
                    # initialized by the Runner's StateRequestHandlers).
                    'stateful.beam.metric:statecache:get_total: 100',
                    'stateful.beam.metric:statecache:miss_total: 10',
                    'stateful.beam.metric:statecache:hit_total: 90',
                    'stateful.beam.metric:statecache:put_total: 10',
                    'stateful.beam.metric:statecache:extend_total: 100',
                    'stateful.beam.metric:statecache:evict_total: 0',
                ])
            else:
                # Batch has a different processing model. All values for
                # a key are processed at once.
                lines_expected.update([
                    # Gauges
                    'stateful).beam.metric:statecache:capacity: 123',
                    # For the first key, the cache token will not be set yet.
                    # It's lazily initialized after first access in StateRequestHandlers
                    'stateful).beam.metric:statecache:size: 9',
                    # We have 11 here because there are 110 / 10 elements per key
                    'stateful).beam.metric:statecache:get: 11',
                    'stateful).beam.metric:statecache:miss: 1',
                    'stateful).beam.metric:statecache:hit: 10',
                    # State is flushed back once per key
                    'stateful).beam.metric:statecache:put: 1',
                    'stateful).beam.metric:statecache:extend: 1',
                    'stateful).beam.metric:statecache:evict: 0',
                    # Counters
                    'stateful).beam.metric:statecache:get_total: 99',
                    'stateful).beam.metric:statecache:miss_total: 9',
                    'stateful).beam.metric:statecache:hit_total: 90',
                    'stateful).beam.metric:statecache:put_total: 9',
                    'stateful).beam.metric:statecache:extend_total: 9',
                    'stateful).beam.metric:statecache:evict_total: 0',
                ])
            lines_actual = set()
            with open(self.test_metrics_path, 'r') as f:
                line = f.readline()
                while line:
                    for metric_str in lines_expected:
                        if metric_str in line:
                            lines_actual.add(metric_str)
                    line = f.readline()
            self.assertSetEqual(lines_actual, lines_expected)