Esempio n. 1
0
class StreamGroupAggregateOperation(StatefulFunctionOperation):

    def __init__(self, name, spec, counter_factory, sampler, consumers, keyed_state_backend):
        self.generate_update_before = spec.serialized_fn.generate_update_before
        self.grouping = [i for i in spec.serialized_fn.grouping]
        self.group_agg_function = None
        # If the upstream generates retract message, we need to add an additional count1() agg
        # to track current accumulated messages count. If all the messages are retracted, we need
        # to send a DELETE message to downstream.
        self.index_of_count_star = spec.serialized_fn.index_of_count_star
        self.state_cache_size = spec.serialized_fn.state_cache_size
        self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled
        self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs)
        super(StreamGroupAggregateOperation, self).__init__(
            name, spec, counter_factory, sampler, consumers, keyed_state_backend)

    def open_func(self):
        self.group_agg_function.open(FunctionContext(self.base_metric_group))

    def generate_func(self, serialized_fn):
        user_defined_aggs = []
        input_extractors = []
        for i in range(len(serialized_fn.udfs)):
            if i != self.index_of_count_star:
                user_defined_agg, input_extractor = extract_user_defined_aggregate_function(
                    serialized_fn.udfs[i])
            else:
                user_defined_agg = Count1AggFunction()

                def dummy_input_extractor(value):
                    return []
                input_extractor = dummy_input_extractor
            user_defined_aggs.append(user_defined_agg)
            input_extractors.append(input_extractor)
        aggs_handler_function = SimpleAggsHandleFunction(
            user_defined_aggs,
            input_extractors,
            self.index_of_count_star,
            self.data_view_specs)
        key_selector = RowKeySelector(self.grouping)
        if len(self.data_view_specs) > 0:
            state_value_coder = DataViewFilterCoder(self.data_view_specs)
        else:
            state_value_coder = PickleCoder()
        self.group_agg_function = GroupAggFunction(
            aggs_handler_function,
            key_selector,
            self.keyed_state_backend,
            state_value_coder,
            self.generate_update_before,
            self.state_cleaning_enabled,
            self.index_of_count_star)
        return lambda it: map(self.process_element_or_timer, it), []

    def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
        # the structure of the input data:
        # [element_type, element(for process_element), timestamp(for timer), key(for timer)]
        # all the fields are nullable except the "element_type"
        if input_data[0] != TRIGGER_TIMER:
            return self.group_agg_function.process_element(input_data[1])
        else:
            self.group_agg_function.on_timer(input_data[3])
            return []

    def teardown(self):
        if self.group_agg_function is not None:
            self.group_agg_function.close()
        super().teardown()
Esempio n. 2
0
class StreamGroupAggregateOperation(StatefulFunctionOperation):
    def __init__(self, spec, keyed_state_backend):
        self.generate_update_before = spec.serialized_fn.generate_update_before
        self.grouping = [i for i in spec.serialized_fn.grouping]
        self.group_agg_function = None
        # If the upstream generates retract message, we need to add an additional count1() agg
        # to track current accumulated messages count. If all the messages are retracted, we need
        # to send a DELETE message to downstream.
        self.index_of_count_star = spec.serialized_fn.index_of_count_star
        self.count_star_inserted = spec.serialized_fn.count_star_inserted
        self.state_cache_size = spec.serialized_fn.state_cache_size
        self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled
        self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs)
        super(StreamGroupAggregateOperation,
              self).__init__(spec, keyed_state_backend)

    def open(self):
        self.group_agg_function.open(FunctionContext(self.base_metric_group))

    def generate_func(self, serialized_fn):
        user_defined_aggs = []
        input_extractors = []
        filter_args = []
        # stores the indexes of the distinct views which the agg functions used
        distinct_indexes = []
        # stores the indexes of the functions which share the same distinct view
        # and the filter args of them
        distinct_info_dict = {}
        for i in range(len(serialized_fn.udfs)):
            user_defined_agg, input_extractor, filter_arg, distinct_index = \
                extract_user_defined_aggregate_function(
                    i, serialized_fn.udfs[i], distinct_info_dict)
            user_defined_aggs.append(user_defined_agg)
            input_extractors.append(input_extractor)
            filter_args.append(filter_arg)
            distinct_indexes.append(distinct_index)
        distinct_view_descriptors = {}
        for agg_index_list, filter_arg_list in distinct_info_dict.values():
            if -1 in filter_arg_list:
                # If there is a non-filter call, we don't need to check filter or not before
                # writing the distinct data view.
                filter_arg_list = []
            # use the agg index of the first function as the key of shared distinct view
            distinct_view_descriptors[
                agg_index_list[0]] = DistinctViewDescriptor(
                    input_extractors[agg_index_list[0]], filter_arg_list)
        aggs_handler_function = SimpleAggsHandleFunction(
            user_defined_aggs, input_extractors, self.index_of_count_star,
            self.count_star_inserted, self.data_view_specs, filter_args,
            distinct_indexes, distinct_view_descriptors)
        key_selector = RowKeySelector(self.grouping)
        if len(self.data_view_specs) > 0:
            state_value_coder = DataViewFilterCoder(self.data_view_specs)
        else:
            state_value_coder = PickleCoder()
        self.group_agg_function = GroupAggFunction(
            aggs_handler_function, key_selector, self.keyed_state_backend,
            state_value_coder, self.generate_update_before,
            self.state_cleaning_enabled, self.index_of_count_star)
        return self.process_element_or_timer, []

    def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
        # the structure of the input data:
        # [element_type, element(for process_element), timestamp(for timer), key(for timer)]
        # all the fields are nullable except the "element_type"
        if input_data[0] != TRIGGER_TIMER:
            return self.group_agg_function.process_element(input_data[1])
        else:
            self.group_agg_function.on_timer(input_data[3])
            return []

    def close(self):
        if self.group_agg_function is not None:
            self.group_agg_function.close()