class StreamGroupAggregateOperation(StatefulFunctionOperation): def __init__(self, name, spec, counter_factory, sampler, consumers, keyed_state_backend): self.generate_update_before = spec.serialized_fn.generate_update_before self.grouping = [i for i in spec.serialized_fn.grouping] self.group_agg_function = None # If the upstream generates retract message, we need to add an additional count1() agg # to track current accumulated messages count. If all the messages are retracted, we need # to send a DELETE message to downstream. self.index_of_count_star = spec.serialized_fn.index_of_count_star self.state_cache_size = spec.serialized_fn.state_cache_size self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs) super(StreamGroupAggregateOperation, self).__init__( name, spec, counter_factory, sampler, consumers, keyed_state_backend) def open_func(self): self.group_agg_function.open(FunctionContext(self.base_metric_group)) def generate_func(self, serialized_fn): user_defined_aggs = [] input_extractors = [] for i in range(len(serialized_fn.udfs)): if i != self.index_of_count_star: user_defined_agg, input_extractor = extract_user_defined_aggregate_function( serialized_fn.udfs[i]) else: user_defined_agg = Count1AggFunction() def dummy_input_extractor(value): return [] input_extractor = dummy_input_extractor user_defined_aggs.append(user_defined_agg) input_extractors.append(input_extractor) aggs_handler_function = SimpleAggsHandleFunction( user_defined_aggs, input_extractors, self.index_of_count_star, self.data_view_specs) key_selector = RowKeySelector(self.grouping) if len(self.data_view_specs) > 0: state_value_coder = DataViewFilterCoder(self.data_view_specs) else: state_value_coder = PickleCoder() self.group_agg_function = GroupAggFunction( aggs_handler_function, key_selector, self.keyed_state_backend, state_value_coder, self.generate_update_before, self.state_cleaning_enabled, self.index_of_count_star) return lambda it: map(self.process_element_or_timer, it), [] def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]): # the structure of the input data: # [element_type, element(for process_element), timestamp(for timer), key(for timer)] # all the fields are nullable except the "element_type" if input_data[0] != TRIGGER_TIMER: return self.group_agg_function.process_element(input_data[1]) else: self.group_agg_function.on_timer(input_data[3]) return [] def teardown(self): if self.group_agg_function is not None: self.group_agg_function.close() super().teardown()
class StreamGroupAggregateOperation(StatefulFunctionOperation): def __init__(self, spec, keyed_state_backend): self.generate_update_before = spec.serialized_fn.generate_update_before self.grouping = [i for i in spec.serialized_fn.grouping] self.group_agg_function = None # If the upstream generates retract message, we need to add an additional count1() agg # to track current accumulated messages count. If all the messages are retracted, we need # to send a DELETE message to downstream. self.index_of_count_star = spec.serialized_fn.index_of_count_star self.count_star_inserted = spec.serialized_fn.count_star_inserted self.state_cache_size = spec.serialized_fn.state_cache_size self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs) super(StreamGroupAggregateOperation, self).__init__(spec, keyed_state_backend) def open(self): self.group_agg_function.open(FunctionContext(self.base_metric_group)) def close(self): self.group_agg_function.close() def generate_func(self, serialized_fn): user_defined_aggs = [] input_extractors = [] filter_args = [] # stores the indexes of the distinct views which the agg functions used distinct_indexes = [] # stores the indexes of the functions which share the same distinct view # and the filter args of them distinct_info_dict = {} for i in range(len(serialized_fn.udfs)): user_defined_agg, input_extractor, filter_arg, distinct_index = \ extract_user_defined_aggregate_function( i, serialized_fn.udfs[i], distinct_info_dict) user_defined_aggs.append(user_defined_agg) input_extractors.append(input_extractor) filter_args.append(filter_arg) distinct_indexes.append(distinct_index) distinct_view_descriptors = {} for agg_index_list, filter_arg_list in distinct_info_dict.values(): if -1 in filter_arg_list: # If there is a non-filter call, we don't need to check filter or not before # writing the distinct data view. filter_arg_list = [] # use the agg index of the first function as the key of shared distinct view distinct_view_descriptors[agg_index_list[0]] = DistinctViewDescriptor( input_extractors[agg_index_list[0]], filter_arg_list) aggs_handler_function = SimpleAggsHandleFunction( user_defined_aggs, input_extractors, self.index_of_count_star, self.count_star_inserted, self.data_view_specs, filter_args, distinct_indexes, distinct_view_descriptors) key_selector = RowKeySelector(self.grouping) if len(self.data_view_specs) > 0: state_value_coder = DataViewFilterCoder(self.data_view_specs) else: state_value_coder = PickleCoder() self.group_agg_function = GroupAggFunction( aggs_handler_function, key_selector, self.keyed_state_backend, state_value_coder, self.generate_update_before, self.state_cleaning_enabled, self.index_of_count_star) return self.process_element_or_timer, [] def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]): # the structure of the input data: # [element_type, element(for process_element), timestamp(for timer), key(for timer)] # all the fields are nullable except the "element_type" if input_data[0] != TRIGGER_TIMER: return self.group_agg_function.process_element(input_data[1]) else: self.group_agg_function.on_timer(input_data[3]) return []