def test_flat_aggregate_list_view(self): import pandas as pd my_concat = udtaf(ListViewConcatTableAggregateFunction()) self.t_env.get_config().get_configuration().set_string( "python.fn-execution.bundle.size", "2") # trigger the cache eviction in a bundle. self.t_env.get_config().get_configuration().set_string( "python.state.cache-size", "2") t = self.t_env.from_elements([(1, 'Hi', 'Hello'), (3, 'Hi', 'hi'), (3, 'Hi2', 'hi'), (3, 'Hi', 'hi'), (2, 'Hi', 'Hello'), (1, 'Hi2', 'Hello'), (3, 'Hi3', 'hi'), (3, 'Hi2', 'Hello'), (3, 'Hi3', 'hi'), (2, 'Hi3', 'Hello')], ['a', 'b', 'c']) result = t.group_by(t.c) \ .flat_aggregate(my_concat(t.b, ',').alias("b")) \ .select(t.b, t.c) \ .alias("a, c") assert_frame_equal( result.to_pandas().sort_values('c').reset_index(drop=True), pd.DataFrame( [["Hi,Hi,Hi2,Hi2,Hi3", "Hello"], [ "Hi,Hi,Hi2,Hi2,Hi3", "Hello" ], ["Hi,Hi2,Hi,Hi3,Hi3", "hi"], ["Hi,Hi2,Hi,Hi3,Hi3", "hi"]], columns=['a', 'c']))
def test_flat_aggregate(self): import pandas as pd mytop = udtaf(Top2()) t = self.t_env.from_elements([(1, 'Hi', 'Hello'), (3, 'Hi', 'hi'), (5, 'Hi2', 'hi'), (7, 'Hi', 'Hello'), (2, 'Hi', 'Hello')], ['a', 'b', 'c']) result = t.select(t.a, t.c) \ .group_by(t.c) \ .flat_aggregate(mytop.alias('a')) \ .select(t.a) \ .flat_aggregate(mytop.alias("b")) \ .select("b") \ .to_pandas() assert_frame_equal(result, pd.DataFrame([[7], [5]], columns=['b']))
def row_operations(): t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode()) # define the source table = t_env.from_elements(elements=[ (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}' ), (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}' ), (3, '{"name": "world", "tel": 124, "addr": {"country": "China", "city": "NewYork"}}' ), (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}' ) ], schema=['id', 'data']) # map operation @udf(result_type=DataTypes.ROW([ DataTypes.FIELD("id", DataTypes.BIGINT()), DataTypes.FIELD("country", DataTypes.STRING()) ])) def extract_country(input_row: Row): data = json.loads(input_row.data) return Row(input_row.id, data['addr']['country']) table.map(extract_country) \ .execute().print() # +----+----------------------+--------------------------------+ # | op | id | country | # +----+----------------------+--------------------------------+ # | +I | 1 | Germany | # | +I | 2 | China | # | +I | 3 | China | # | +I | 4 | China | # +----+----------------------+--------------------------------+ # flat_map operation @udtf(result_types=[DataTypes.BIGINT(), DataTypes.STRING()]) def extract_city(input_row: Row): data = json.loads(input_row.data) yield input_row.id, data['addr']['city'] table.flat_map(extract_city) \ .execute().print() # +----+----------------------+--------------------------------+ # | op | f0 | f1 | # +----+----------------------+--------------------------------+ # | +I | 1 | Berlin | # | +I | 2 | Shanghai | # | +I | 3 | NewYork | # | +I | 4 | Hangzhou | # +----+----------------------+--------------------------------+ # aggregate operation class CountAndSumAggregateFunction(AggregateFunction): def get_value(self, accumulator): return Row(accumulator[0], accumulator[1]) def create_accumulator(self): return Row(0, 0) def accumulate(self, accumulator, input_row): accumulator[0] += 1 accumulator[1] += int(input_row.tel) def retract(self, accumulator, input_row): accumulator[0] -= 1 accumulator[1] -= int(input_row.tel) def merge(self, accumulator, accumulators): for other_acc in accumulators: accumulator[0] += other_acc[0] accumulator[1] += other_acc[1] def get_accumulator_type(self): return DataTypes.ROW([ DataTypes.FIELD("cnt", DataTypes.BIGINT()), DataTypes.FIELD("sum", DataTypes.BIGINT()) ]) def get_result_type(self): return DataTypes.ROW([ DataTypes.FIELD("cnt", DataTypes.BIGINT()), DataTypes.FIELD("sum", DataTypes.BIGINT()) ]) count_sum = udaf(CountAndSumAggregateFunction()) table.add_columns( col('data').json_value('$.name', DataTypes.STRING()).alias('name'), col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \ .group_by(col('country')) \ .aggregate(count_sum.alias("cnt", "sum")) \ .select(col('country'), col('cnt'), col('sum')) \ .execute().print() # +----+--------------------------------+----------------------+----------------------+ # | op | country | cnt | sum | # +----+--------------------------------+----------------------+----------------------+ # | +I | China | 3 | 291 | # | +I | Germany | 1 | 123 | # +----+--------------------------------+----------------------+----------------------+ # flat_aggregate operation class Top2(TableAggregateFunction): def emit_value(self, accumulator): for v in accumulator: if v: yield Row(v) def create_accumulator(self): return [None, None] def accumulate(self, accumulator, input_row): tel = int(input_row.tel) if accumulator[0] is None or tel > accumulator[0]: accumulator[1] = accumulator[0] accumulator[0] = tel elif accumulator[1] is None or tel > accumulator[1]: accumulator[1] = tel def get_accumulator_type(self): return DataTypes.ARRAY(DataTypes.BIGINT()) def get_result_type(self): return DataTypes.ROW([DataTypes.FIELD("tel", DataTypes.BIGINT())]) top2 = udtaf(Top2()) table.add_columns( col('data').json_value('$.name', DataTypes.STRING()).alias('name'), col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \ .group_by(col('country')) \ .flat_aggregate(top2) \ .select(col('country'), col('tel')) \ .execute().print()