示例#1
0
 def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState:
     if self._keyed_state_backend:
         return self._keyed_state_backend.get_map_state(
             state_descriptor.name, PickleCoder(), PickleCoder())
     else:
         raise Exception(
             "This state is only accessible by functions executed on a KeyedStream."
         )
示例#2
0
def extract_data_view_specs_from_accumulator(current_index, accumulator):
    # for built in functions we extract the data view specs from their accumulator
    i = -1
    extracted_specs = []
    for field in accumulator:
        i += 1
        # TODO: infer the coder from the input types and output type of the built-in functions
        if isinstance(field, MapView):
            extracted_specs.append(
                MapViewSpec("builtInAgg%df%d" % (current_index, i), i,
                            PickleCoder(), PickleCoder()))
        elif isinstance(field, ListView):
            extracted_specs.append(
                ListViewSpec("builtInAgg%df%d" % (current_index, i), i,
                             PickleCoder()))
    return extracted_specs
示例#3
0
    def generate_func(self, serialized_fn):
        user_defined_aggs = []
        input_extractors = []
        for i in range(len(serialized_fn.udfs)):
            if i != self.index_of_count_star:
                user_defined_agg, input_extractor = extract_user_defined_aggregate_function(
                    serialized_fn.udfs[i])
            else:
                user_defined_agg = Count1AggFunction()

                def dummy_input_extractor(value):
                    return []
                input_extractor = dummy_input_extractor
            user_defined_aggs.append(user_defined_agg)
            input_extractors.append(input_extractor)
        aggs_handler_function = SimpleAggsHandleFunction(
            user_defined_aggs,
            input_extractors,
            self.index_of_count_star,
            self.data_view_specs)
        key_selector = RowKeySelector(self.grouping)
        if len(self.data_view_specs) > 0:
            state_value_coder = DataViewFilterCoder(self.data_view_specs)
        else:
            state_value_coder = PickleCoder()
        self.group_agg_function = GroupAggFunction(
            aggs_handler_function,
            key_selector,
            self.keyed_state_backend,
            state_value_coder,
            self.generate_update_before,
            self.state_cleaning_enabled,
            self.index_of_count_star)
        return lambda it: map(self.process_element_or_timer, it), []
示例#4
0
 def get_aggregating_state(
         self, state_descriptor: AggregatingStateDescriptor) -> AggregatingState:
     if self._keyed_state_backend:
         return self._keyed_state_backend.get_aggregating_state(
             state_descriptor.get_name(), PickleCoder(), state_descriptor.get_agg_function())
     else:
         raise Exception("This state is only accessible by functions executed on a KeyedStream.")
示例#5
0
 def get_partitioned_state(self,
                           state_descriptor: StateDescriptor) -> State:
     if isinstance(state_descriptor, ValueStateDescriptor):
         state = self._state_backend.get_value_state(
             state_descriptor.name, PickleCoder())
     elif isinstance(state_descriptor, ListStateDescriptor):
         state = self._state_backend.get_list_state(state_descriptor.name,
                                                    PickleCoder())
     elif isinstance(state_descriptor, MapStateDescriptor):
         state = self._state_backend.get_map_state(state_descriptor.name,
                                                   PickleCoder(),
                                                   PickleCoder())
     else:
         raise Exception("Unknown supported StateDescriptor %s" %
                         state_descriptor)
     state.set_current_namespace(self.window)
     return state
示例#6
0
    def generate_func(self, serialized_fn):
        user_defined_aggs = []
        input_extractors = []
        filter_args = []
        # stores the indexes of the distinct views which the agg functions used
        distinct_indexes = []
        # stores the indexes of the functions which share the same distinct view
        # and the filter args of them
        distinct_info_dict = {}
        for i in range(len(serialized_fn.udfs)):
            if i != self.index_of_count_star:
                user_defined_agg, input_extractor, filter_arg, distinct_index = \
                    extract_user_defined_aggregate_function(
                        i, serialized_fn.udfs[i], distinct_info_dict)
            else:
                user_defined_agg = Count1AggFunction()
                filter_arg = -1
                distinct_index = -1

                def dummy_input_extractor(value):
                    return []
                input_extractor = dummy_input_extractor
            user_defined_aggs.append(user_defined_agg)
            input_extractors.append(input_extractor)
            filter_args.append(filter_arg)
            distinct_indexes.append(distinct_index)
        distinct_view_descriptors = {}
        for agg_index_list, filter_arg_list in distinct_info_dict.values():
            if -1 in filter_arg_list:
                # If there is a non-filter call, we don't need to check filter or not before
                # writing the distinct data view.
                filter_arg_list = []
            # use the agg index of the first function as the key of shared distinct view
            distinct_view_descriptors[agg_index_list[0]] = DistinctViewDescriptor(
                input_extractors[agg_index_list[0]], filter_arg_list)
        aggs_handler_function = SimpleAggsHandleFunction(
            user_defined_aggs,
            input_extractors,
            self.index_of_count_star,
            self.data_view_specs,
            filter_args,
            distinct_indexes,
            distinct_view_descriptors)
        key_selector = RowKeySelector(self.grouping)
        if len(self.data_view_specs) > 0:
            state_value_coder = DataViewFilterCoder(self.data_view_specs)
        else:
            state_value_coder = PickleCoder()
        self.group_agg_function = GroupAggFunction(
            aggs_handler_function,
            key_selector,
            self.keyed_state_backend,
            state_value_coder,
            self.generate_update_before,
            self.state_cleaning_enabled,
            self.index_of_count_star)
        return self.process_element_or_timer, []
示例#7
0
 def open(self, state_data_view_store):
     for udf in self._udfs:
         udf.open(state_data_view_store.get_runtime_context())
     self._udf_data_views = []
     for data_view_specs in self._udf_data_view_specs:
         data_views = {}
         for data_view_spec in data_view_specs:
             if isinstance(data_view_spec, ListViewSpec):
                 data_views[data_view_spec.field_index] = \
                     state_data_view_store.get_state_list_view(
                         data_view_spec.state_id,
                         PickleCoder())
             elif isinstance(data_view_spec, MapViewSpec):
                 data_views[data_view_spec.field_index] = \
                     state_data_view_store.get_state_map_view(
                         data_view_spec.state_id,
                         PickleCoder(),
                         PickleCoder())
         self._udf_data_views.append(data_views)
示例#8
0
 def __init__(self, aggs_handle: AggsHandleFunction,
              key_selector: RowKeySelector,
              state_backend: RemoteKeyedStateBackend,
              generate_update_before: bool, state_cleaning_enabled: bool,
              index_of_count_star: int):
     self.aggs_handle = aggs_handle
     self.generate_update_before = generate_update_before
     self.state_cleaning_enabled = state_cleaning_enabled
     self.key_selector = key_selector
     # Currently we do not support user-defined type accumulator.
     # So any accumulators can be encoded by the PickleCoder.
     self.state_value_coder = PickleCoder()
     self.state_backend = state_backend
     self.record_counter = RecordCounter.of(index_of_count_star)
示例#9
0
 def open(self, state_data_view_store):
     for udf in self._udfs:
         udf.open(state_data_view_store.get_runtime_context())
     self._udf_data_views = []
     for data_view_specs in self._udf_data_view_specs:
         data_views = {}
         for data_view_spec in data_view_specs:
             if isinstance(data_view_spec, ListViewSpec):
                 data_views[data_view_spec.field_index] = \
                     state_data_view_store.get_state_list_view(
                         data_view_spec.state_id,
                         PickleCoder())
             elif isinstance(data_view_spec, MapViewSpec):
                 data_views[data_view_spec.field_index] = \
                     state_data_view_store.get_state_map_view(
                         data_view_spec.state_id,
                         PickleCoder(),
                         PickleCoder())
         self._udf_data_views.append(data_views)
     for key in self._distinct_view_descriptors.keys():
         self._distinct_data_views[
             key] = state_data_view_store.get_state_map_view(
                 "agg%ddistinct" % key, PickleCoder(), PickleCoder())
示例#10
0
class PairRecordsFn(beam.DoFn):
  """Pairs two consecutive elements after shuffle"""
  BUFFER = BagStateSpec('buffer', PickleCoder())
  def process(self, element, buffer=beam.DoFn.StateParam(BUFFER)):
    try:
      previous_element = list(buffer.read())[0]
    except:
      previous_element = []
    unused_key, value = element

    if previous_element:
      yield (previous_element, value)
      buffer.clear()
    else:
      buffer.add(value)
示例#11
0
    def generate_func(self, serialized_fn):
        user_defined_aggs = []
        input_extractors = []
        filter_args = []
        # stores the indexes of the distinct views which the agg functions used
        distinct_indexes = []
        # stores the indexes of the functions which share the same distinct view
        # and the filter args of them
        distinct_info_dict = {}
        for i in range(len(serialized_fn.udfs)):
            user_defined_agg, input_extractor, filter_arg, distinct_index = \
                extract_user_defined_aggregate_function(
                    i, serialized_fn.udfs[i], distinct_info_dict)
            user_defined_aggs.append(user_defined_agg)
            input_extractors.append(input_extractor)
            filter_args.append(filter_arg)
            distinct_indexes.append(distinct_index)
        distinct_view_descriptors = {}
        for agg_index_list, filter_arg_list in distinct_info_dict.values():
            if -1 in filter_arg_list:
                # If there is a non-filter call, we don't need to check filter or not before
                # writing the distinct data view.
                filter_arg_list = []
            # use the agg index of the first function as the key of shared distinct view
            distinct_view_descriptors[
                agg_index_list[0]] = DistinctViewDescriptor(
                    input_extractors[agg_index_list[0]], filter_arg_list)

        key_selector = RowKeySelector(self.grouping)
        if len(self.data_view_specs) > 0:
            state_value_coder = DataViewFilterCoder(self.data_view_specs)
        else:
            state_value_coder = PickleCoder()

        self.group_agg_function = self.create_process_function(
            user_defined_aggs, input_extractors, filter_args, distinct_indexes,
            distinct_view_descriptors, key_selector, state_value_coder)

        return self.process_element_or_timer, []
示例#12
0
 def get_list_state(self,
                    state_descriptor: ListStateDescriptor) -> ListState:
     return self._keyed_state_backend.get_list_state(
         state_descriptor.name, PickleCoder())
示例#13
0
 def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState:
     return self._keyed_state_backend.get_value_state(
         state_descriptor.name, PickleCoder())
示例#14
0
 def get_reducing_state(
         self, state_descriptor: ReducingStateDescriptor) -> ReducingState:
     return self._keyed_state_backend.get_reducing_state(
         state_descriptor.get_name(), PickleCoder(),
         state_descriptor.get_reduce_function())
示例#15
0
 def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState:
     return self._keyed_state_backend.get_map_state(state_descriptor.name,
                                                    PickleCoder(),
                                                    PickleCoder())
示例#16
0
 def get_aggregating_state(
         self,
         state_descriptor: AggregatingStateDescriptor) -> AggregatingState:
     return self._keyed_state_backend.get_aggregating_state(
         state_descriptor.get_name(), PickleCoder(),
         state_descriptor.get_agg_function())
示例#17
0
class GeneralTriggerManagerDoFn(DoFn):
    """A trigger manager that supports all windowing / triggering cases.

  This implements a DoFn that manages triggering in a per-key basis. All
  elements for a single key are processed together. Per-key state holds data
  related to all windows.
  """

    # TODO(BEAM-12026) Add support for Global and custom window fns.
    KNOWN_WINDOWS = SetStateSpec('known_windows', IntervalWindowCoder())
    FINISHED_WINDOWS = SetStateSpec('finished_windows', IntervalWindowCoder())
    LAST_KNOWN_TIME = CombiningValueStateSpec('last_known_time',
                                              combine_fn=max)
    LAST_KNOWN_WATERMARK = CombiningValueStateSpec('last_known_watermark',
                                                   combine_fn=max)

    # TODO(pabloem) What's the coder for the elements/keys here?
    WINDOW_ELEMENT_PAIRS = BagStateSpec(
        'all_elements', TupleCoder([IntervalWindowCoder(),
                                    PickleCoder()]))
    WINDOW_TAG_VALUES = BagStateSpec(
        'per_window_per_tag_value_state',
        TupleCoder([IntervalWindowCoder(),
                    StrUtf8Coder(),
                    VarIntCoder()]))

    PROCESSING_TIME_TIMER = TimerSpec('processing_time_timer',
                                      TimeDomain.REAL_TIME)
    WATERMARK_TIMER = TimerSpec('watermark_timer', TimeDomain.WATERMARK)

    def __init__(self, windowing: Windowing):
        self.windowing = windowing
        # Only session windows are merging. Other windows are non-merging.
        self.merging_windows = self.windowing.windowfn.is_merging()

    def process(
            self,
            element: typing.Tuple[
                K, typing.Iterable[windowed_value.WindowedValue]],
            all_elements: BagRuntimeState = DoFn.StateParam(
                WINDOW_ELEMENT_PAIRS),  # type: ignore
            latest_processing_time: AccumulatingRuntimeState = DoFn.StateParam(
                LAST_KNOWN_TIME),  # type: ignore
            latest_watermark: AccumulatingRuntimeState = DoFn.
        StateParam(  # type: ignore
            LAST_KNOWN_WATERMARK),
            window_tag_values: BagRuntimeState = DoFn.StateParam(
                WINDOW_TAG_VALUES),  # type: ignore
            windows_state: SetRuntimeState = DoFn.StateParam(
                KNOWN_WINDOWS),  # type: ignore
            finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
            processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
            watermark_timer=DoFn.TimerParam(WATERMARK_TIMER),
            *args,
            **kwargs):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        key, windowed_values = element
        watermark = read_watermark(latest_watermark)

        windows_to_elements = collections.defaultdict(list)
        for wv in windowed_values:
            for window in wv.windows:
                # ignore expired windows
                if watermark > window.end + self.windowing.allowed_lateness:
                    continue
                if window in finished_windows_state.read():
                    continue
                windows_to_elements[window].append(
                    TimestampedValue(wv.value, wv.timestamp))

        # Processing merging of windows
        if self.merging_windows:
            old_windows = set(windows_state.read())
            all_windows = old_windows.union(list(windows_to_elements))
            if all_windows != old_windows:
                merge_context = TriggerMergeContext(all_windows, context,
                                                    self.windowing)
                self.windowing.windowfn.merge(merge_context)

                merged_windows_to_elements = collections.defaultdict(list)
                for window, values in windows_to_elements.items():
                    while window in merge_context.merged_away:
                        window = merge_context.merged_away[window]
                    merged_windows_to_elements[window].extend(values)
                windows_to_elements = merged_windows_to_elements

            for w in windows_to_elements:
                windows_state.add(w)
        # Done processing merging of windows

        seen_windows = set()
        for w in windows_to_elements:
            window_context = context.for_window(w)
            seen_windows.add(w)
            for value_w_timestamp in windows_to_elements[w]:
                _LOGGER.debug(value_w_timestamp)
                all_elements.add((w, value_w_timestamp))
                self.windowing.triggerfn.on_element(windowed_values, w,
                                                    window_context)

        return self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                           watermark, None, context,
                                           seen_windows)

    def _fire_eligible_windows(self,
                               key: K,
                               time_domain,
                               timestamp: Timestamp,
                               timer_tag: typing.Optional[str],
                               context: 'FnRunnerStatefulTriggerContext',
                               windows_of_interest: typing.Optional[
                                   typing.Set[BoundedWindow]] = None):
        windows_to_elements = context.windows_to_elements_map()
        context.all_elements_state.clear()

        fired_windows = set()
        _LOGGER.debug('%s - tag %s - timestamp %s', time_domain, timer_tag,
                      timestamp)
        for w, elems in windows_to_elements.items():
            if windows_of_interest is not None and w not in windows_of_interest:
                # windows_of_interest=None means that we care about all windows.
                # If we care only about some windows, and this window is not one of
                # them, then we do not intend to fire this window.
                continue
            window_context = context.for_window(w)
            if self.windowing.triggerfn.should_fire(time_domain, timestamp, w,
                                                    window_context):
                finished = self.windowing.triggerfn.on_fire(
                    timestamp, w, window_context)
                _LOGGER.debug('Firing on window %s. Finished: %s', w, finished)
                fired_windows.add(w)
                if finished:
                    context.finished_windows_state.add(w)
                # TODO(pabloem): Format the output: e.g. pane info
                elems = [
                    WindowedValue(e.value, e.timestamp, (w, )) for e in elems
                ]
                yield (key, elems)

        finished_windows: typing.Set[BoundedWindow] = set(
            context.finished_windows_state.read())
        # Add elements that were not fired back into state.
        for w, elems in windows_to_elements.items():
            for e in elems:
                if (w in finished_windows or
                    (w in fired_windows and self.windowing.accumulation_mode
                     == AccumulationMode.DISCARDING)):
                    continue
                context.all_elements_state.add((w, e))

    @on_timer(PROCESSING_TIME_TIMER)
    def processing_time_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_processing_time=DoFn.StateParam(LAST_KNOWN_TIME),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=latest_processing_time,
            latest_watermark=None,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.REAL_TIME,
                                             timestamp, timer_tag, context)
        latest_processing_time.add(timestamp)
        return result

    @on_timer(WATERMARK_TIMER)
    def watermark_trigger(
        self,
        key=DoFn.KeyParam,
        timer_tag=DoFn.DynamicTimerTagParam,
        timestamp=DoFn.TimestampParam,
        latest_watermark=DoFn.StateParam(LAST_KNOWN_WATERMARK),
        all_elements=DoFn.StateParam(WINDOW_ELEMENT_PAIRS),
        processing_time_timer=DoFn.TimerParam(PROCESSING_TIME_TIMER),
        window_tag_values: BagRuntimeState = DoFn.StateParam(
            WINDOW_TAG_VALUES),  # type: ignore
        finished_windows_state: SetRuntimeState = DoFn.
        StateParam(  # type: ignore
            FINISHED_WINDOWS),
        watermark_timer=DoFn.TimerParam(WATERMARK_TIMER)):
        context = FnRunnerStatefulTriggerContext(
            processing_time_timer=processing_time_timer,
            watermark_timer=watermark_timer,
            latest_processing_time=None,
            latest_watermark=latest_watermark,
            all_elements_state=all_elements,
            window_tag_values=window_tag_values,
            finished_windows_state=finished_windows_state)
        result = self._fire_eligible_windows(key, TimeDomain.WATERMARK,
                                             timestamp, timer_tag, context)
        latest_watermark.add(timestamp)
        return result