def test_to_retract_stream(self): self.env.set_parallelism(1) t_env = StreamTableEnvironment.create( self.env, environment_settings=EnvironmentSettings.in_streaming_mode()) table = t_env.from_elements([(1, "Hi", "Hello"), (1, "Hi", "Hello")], ["a", "b", "c"]) new_table = table.group_by("c").select("a.sum, c as b") ds = t_env.to_retract_stream(table=new_table, type_info=Types.ROW( [Types.LONG(), Types.STRING()])) test_sink = DataStreamTestSinkFunction() ds.map(lambda x: x).add_sink(test_sink) self.env.execute("test_to_retract_stream") result = test_sink.get_results(True) expected = [ "(True, Row(f0=1, f1='Hello'))", "(False, Row(f0=1, f1='Hello'))", "(True, Row(f0=2, f1='Hello'))" ] self.assertEqual(result, expected)
def test_add_classpaths(self): # find kafka connector jars flink_source_root = _find_flink_source_root() jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka' specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar') specific_jars = ['file://' + specific_jar for specific_jar in specific_jars] self.env.add_classpaths(*specific_jars) source_topic = 'test_source_topic' props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'} type_info = Types.ROW([Types.INT(), Types.STRING()]) # Test for kafka consumer deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(type_info=type_info).build() # It Will raise a ClassNotFoundException if the kafka connector is not added into the # pipeline classpaths. kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props) self.env.add_source(kafka_consumer).print() self.env.get_execution_plan()
def test_map_function_without_data_types(self): self.env.set_parallelism(1) ds = self.env.from_collection([('ab', decimal.Decimal(1)), ('bdc', decimal.Decimal(2)), ('cfgs', decimal.Decimal(3)), ('deeefg', decimal.Decimal(4))], type_info=Types.ROW( [Types.STRING(), Types.BIG_DEC()])) mapped_stream = ds.map(MyMapFunction()) collect_util = DataStreamCollectUtil() collect_util.collect(mapped_stream) self.env.execute('map_function_test') results = collect_util.results() expected = [ "('ab', 2, Decimal('1'))", "('bdc', 3, Decimal('2'))", "('cfgs', 4, Decimal('3'))", "('deeefg', 6, Decimal('4'))" ] expected.sort() results.sort() self.assertEqual(expected, results)
def test_compile(self): sink = KafkaSink.builder() \ .set_bootstrap_servers('localhost:9092') \ .set_record_serializer(self._build_serialization_schema()) \ .build() ds = self.env.from_collection([], type_info=Types.STRING()) ds.sink_to(sink) plan = json.loads(self.env.get_execution_plan()) self.assertEqual(plan['nodes'][1]['type'], 'Sink: Writer') self.assertEqual(plan['nodes'][2]['type'], 'Sink: Committer')
def test_partition_custom(self): ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2), ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)], type_info=Types.ROW( [Types.STRING(), Types.INT()])) expected_num_partitions = 5 def my_partitioner(key, num_partitions): assert expected_num_partitions, num_partitions return key % num_partitions partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(), Types.INT()]))\ .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1]) JPartitionCustomTestMapFunction = get_gateway().jvm\ .org.apache.flink.python.util.PartitionCustomTestMapFunction test_map_stream = DataStream( partitioned_stream._j_data_stream.map( JPartitionCustomTestMapFunction())) test_map_stream.set_parallelism(expected_num_partitions).add_sink( self.test_sink) self.env.execute('test_partition_custom')
def test_json_row_serialization_deserialization_schema(self): jsons = [ "{\"svt\":\"2020-02-24T12:58:09.209+0800\"}", "{\"svt\":\"2020-02-24T12:58:09.209+0800\", " "\"ops\":{\"id\":\"281708d0-4092-4c21-9233-931950b6eccf\"},\"ids\":[1, 2, 3]}", "{\"svt\":\"2020-02-24T12:58:09.209+0800\"}" ] expected_jsons = [ "{\"svt\":\"2020-02-24T12:58:09.209+0800\",\"ops\":null,\"ids\":null}", "{\"svt\":\"2020-02-24T12:58:09.209+0800\"," "\"ops\":{\"id\":\"281708d0-4092-4c21-9233-931950b6eccf\"}," "\"ids\":[1,2,3]}", "{\"svt\":\"2020-02-24T12:58:09.209+0800\",\"ops\":null,\"ids\":null}" ] row_schema = Types.ROW_NAMED(["svt", "ops", "ids"], [ Types.STRING(), Types.ROW_NAMED(['id'], [Types.STRING()]), Types.PRIMITIVE_ARRAY(Types.INT()) ]) json_row_serialization_schema = JsonRowSerializationSchema.builder() \ .with_type_info(row_schema).build() json_row_deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(row_schema).build() for i in range(len(jsons)): j_row = json_row_deserialization_schema._j_deserialization_schema\ .deserialize(bytes(jsons[i], encoding='utf-8')) result = str(json_row_serialization_schema._j_serialization_schema. serialize(j_row), encoding='utf-8') self.assertEqual(expected_jsons[i], result)
def test_set_value_only_deserializer(self): def _check(schema: DeserializationSchema, class_name: str): source = KafkaSource.builder() \ .set_bootstrap_servers('localhost:9092') \ .set_topics('test_topic') \ .set_value_only_deserializer(schema) \ .build() deserialization_schema_wrapper = get_field_value( source.get_java_function(), 'deserializationSchema') self.assertEqual( deserialization_schema_wrapper.getClass().getCanonicalName(), 'org.apache.flink.connector.kafka.source.reader.deserializer' '.KafkaValueOnlyDeserializationSchemaWrapper') deserialization_schema = get_field_value( deserialization_schema_wrapper, 'deserializationSchema') self.assertEqual( deserialization_schema.getClass().getCanonicalName(), class_name) _check(SimpleStringSchema(), 'org.apache.flink.api.common.serialization.SimpleStringSchema') _check( JsonRowDeserializationSchema.builder().type_info( Types.ROW([Types.STRING()])).build(), 'org.apache.flink.formats.json.JsonRowDeserializationSchema') _check( CsvRowDeserializationSchema.Builder(Types.ROW([Types.STRING() ])).build(), 'org.apache.flink.formats.csv.CsvRowDeserializationSchema') avro_schema_string = """ { "type": "record", "name": "test_record", "fields": [] } """ _check( AvroRowDeserializationSchema( avro_schema_string=avro_schema_string), 'org.apache.flink.formats.avro.AvroRowDeserializationSchema')
def key_by(self, key_selector: Union[Callable, KeySelector], key_type_info: TypeInformation = None) -> 'KeyedStream': """ Creates a new KeyedStream that uses the provided key for partitioning its operator states. :param key_selector: The KeySelector to be used for extracting the key for partitioning. :param key_type_info: The type information describing the key type. :return: The DataStream with partitioned state(i.e. KeyedStream). """ if callable(key_selector): key_selector = KeySelectorFunctionWrapper(key_selector) if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)): raise TypeError( "Parameter key_selector should be a type of KeySelector.") gateway = get_gateway() PickledKeySelector = gateway.jvm \ .org.apache.flink.datastream.runtime.functions.python.PickledKeySelector j_output_type_info = self._j_data_stream.getTransformation( ).getOutputType() output_type_info = typeinfo._from_java_type(j_output_type_info) is_key_pickled_byte_array = False if key_type_info is None: key_type_info = Types.PICKLED_BYTE_ARRAY() is_key_pickled_byte_array = True intermediate_map_stream = self.map( lambda x: (key_selector.get_key(x), x), type_info=Types.ROW([key_type_info, output_type_info])) intermediate_map_stream.name( gateway.jvm.org.apache.flink.python.util.PythonConfigUtil. STREAM_KEY_BY_MAP_OPERATOR_NAME) generated_key_stream = KeyedStream( intermediate_map_stream._j_data_stream.keyBy( PickledKeySelector(is_key_pickled_byte_array), key_type_info.get_java_type_info()), self) generated_key_stream._original_data_type_info = output_type_info return generated_key_stream
def test_window_aggregate_process(self): data_stream = self.env.from_collection( [('a', 1), ('a', 2), ('b', 3), ('a', 6), ('b', 8), ('b', 9), ('a', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) class MyAggregateFunction(AggregateFunction): def create_accumulator(self) -> Tuple[int, str]: return 0, '' def add(self, value: Tuple[str, int], accumulator: Tuple[int, str]) -> Tuple[int, str]: return value[1] + accumulator[0], value[0] def get_result(self, accumulator: Tuple[str, int]): return accumulator[1], accumulator[0] def merge(self, acc_a: Tuple[int, str], acc_b: Tuple[int, str]): return acc_a[0] + acc_b[0], acc_a[1] class MyProcessWindowFunction(ProcessWindowFunction): def process(self, key: str, context: ProcessWindowFunction.Context, elements: Iterable[Tuple[str, int]]) -> Iterable[str]: agg_result = next(iter(elements)) yield "key {} timestamp sum {}".format(agg_result[0], agg_result[1]) def clear(self, context: ProcessWindowFunction.Context) -> None: pass data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_gap(Time.milliseconds(2))) \ .aggregate(MyAggregateFunction(), window_function=MyProcessWindowFunction(), accumulator_type=Types.TUPLE([Types.INT(), Types.STRING()]), output_type=Types.STRING()) \ .add_sink(self.test_sink) self.env.execute('test_time_window_aggregate_accumulator_type') results = self.test_sink.get_results() expected = [ 'key a timestamp sum 15', 'key a timestamp sum 3', 'key a timestamp sum 6', 'key b timestamp sum 17', 'key b timestamp sum 3' ] self.assert_equals_sorted(expected, results)
def test_event_time_tumbling_window(self): data_stream = self.env.from_collection([ ('hi', 1), ('hi', 2), ('hi', 3), ('hi', 4), ('hi', 5), ('hi', 8), ('hi', 9), ('hi', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(TumblingEventTimeWindows.of(Time.milliseconds(5))) \ .process(CountWindowProcessFunction(), Types.TUPLE([Types.STRING(), Types.LONG(), Types.LONG(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_event_time_tumbling_window') results = self.test_sink.get_results() expected = ['(hi,0,5,4)', '(hi,5,10,3)', '(hi,15,20,1)'] self.assert_equals_sorted(expected, results)
def test_session_window_late_merge(self): data_stream = self.env.from_collection( [('hi', 0), ('hi', 8), ('hi', 4)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_gap(Time.milliseconds(5))) \ .process(CountWindowProcessFunction(), Types.TUPLE([Types.STRING(), Types.LONG(), Types.LONG(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_session_window_late_merge') results = self.test_sink.get_results() expected = ['(hi,0,13,3)'] self.assert_equals_sorted(expected, results)
def _check_record(data, topic, serialized_data): input_type = Types.ROW([Types.STRING()]) serialization_schema = KafkaRecordSerializationSchema.builder() \ .set_topic_selector(_select) \ .set_value_serialization_schema( JsonRowSerializationSchema.builder().with_type_info(input_type).build()) \ .build() sink = KafkaSink.builder() \ .set_bootstrap_servers('localhost:9092') \ .set_record_serializer(serialization_schema) \ .build() ds = MockDataStream(Types.ROW([Types.STRING()])) ds.sink_to(sink) row = Row(data) topic_row = ds.feed(row) # type: Row j_record = serialization_schema._j_serialization_schema.serialize( to_java_data_structure(topic_row), None, None) self.assertEqual(j_record.topic(), topic) self.assertIsNone(j_record.key()) self.assertEqual(j_record.value(), serialized_data)
def popular_destination_query(): env = StreamExecutionEnvironment.get_execution_environment() t_env = StreamTableEnvironment.create(stream_execution_environment=env) t_env.execute_sql( create_table_ddl( "WATERMARK FOR pickupTime AS pickupTime - INTERVAL '30' SECONDS")) query = f"""SELECT destLocationId, wstart, wend, cnt FROM (SELECT destLocationId, HOP_START(pickupTime, INTERVAL '5' MINUTE, INTERVAL '15' MINUTE) AS wstart, HOP_END(pickupTime, INTERVAL '5' MINUTE, INTERVAL '15' MINUTE) AS wend, COUNT(destLocationId) AS cnt FROM (SELECT pickupTime, destLocationId FROM TaxiRide) GROUP BY destLocationId, HOP(pickupTime, INTERVAL '5' MINUTE, INTERVAL '15' MINUTE) ) WHERE cnt > {args.threshold} """ results = t_env.sql_query(query) t_env.to_append_stream( results, Types.ROW_NAMED(['destLocationId', 'wstart', 'wend', 'cnt'], [ Types.INT(), Types.SQL_TIMESTAMP(), Types.SQL_TIMESTAMP(), Types.LONG() ])).print() env.execute('Popular-Destination')
def test_from_data_stream_with_schema(self): from pyflink.table import Schema ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')], type_info=Types.ROW_NAMED( ["a", "b", "c"], [Types.INT(), Types.STRING(), Types.STRING()])) table = self.t_env.from_data_stream(ds, Schema.new_builder() .column("a", DataTypes.INT()) .column("b", DataTypes.STRING()) .column("c", DataTypes.STRING()) .build()) result = table.execute() with result.collect() as result: collected_result = [str(item) for item in result] expected_result = [item for item in map(str, [Row(1, 'Hi', 'Hello'), Row(2, 'Hello', 'Hi')])] expected_result.sort() collected_result.sort() self.assertEqual(expected_result, collected_result)
def test_csv_row_serialization_schema(self): JRow = get_gateway().jvm.org.apache.flink.types.Row j_row = JRow(3) j_row.setField(0, "BEGIN") j_row.setField(2, "END") def field_assertion(field_info, csv_value, value, field_delimiter): row_info = Types.ROW([Types.STRING(), field_info, Types.STRING()]) expected_csv = "BEGIN" + field_delimiter + csv_value + field_delimiter + "END\n" j_row.setField(1, value) csv_row_serialization_schema = CsvRowSerializationSchema.Builder(row_info)\ .set_escape_character('*').set_quote_character('\'')\ .set_array_element_delimiter(':').set_field_delimiter(';').build() csv_row_deserialization_schema = CsvRowDeserializationSchema.Builder(row_info)\ .set_escape_character('*').set_quote_character('\'')\ .set_array_element_delimiter(':').set_field_delimiter(';').build() serialized_bytes = csv_row_serialization_schema._j_serialization_schema.serialize(j_row) self.assertEqual(expected_csv, str(serialized_bytes, encoding='utf-8')) j_deserialized_row = csv_row_deserialization_schema._j_deserialization_schema\ .deserialize(expected_csv.encode("utf-8")) self.assertTrue(j_row.equals(j_deserialized_row)) field_assertion(Types.STRING(), "'123''4**'", "123'4*", ";") field_assertion(Types.STRING(), "'a;b''c'", "a;b'c", ";") field_assertion(Types.INT(), "12", 12, ";") test_j_row = JRow(2) test_j_row.setField(0, "1") test_j_row.setField(1, "hello") field_assertion(Types.ROW([Types.STRING(), Types.STRING()]), "'1:hello'", test_j_row, ";") test_j_row.setField(1, "hello world") field_assertion(Types.ROW([Types.STRING(), Types.STRING()]), "'1:hello world'", test_j_row, ";") field_assertion(Types.STRING(), "null", "null", ";")
def test_primitive_array_type_info(self): ds = self.env.from_collection([(1, [1.1, 1.2, 1.30]), (2, [2.1, 2.2, 2.3]), (3, [3.1, 3.2, 3.3])], type_info=Types.ROW([Types.INT(), Types.PRIMITIVE_ARRAY(Types.FLOAT())])) ds.map(lambda x: x, output_type=Types.ROW([Types.INT(), Types.PRIMITIVE_ARRAY(Types.FLOAT())]))\ .add_sink(self.test_sink) self.env.execute("test primitive array type info") results = self.test_sink.get_results() expected = ['1,[1.1, 1.2, 1.3]', '2,[2.1, 2.2, 2.3]', '3,[3.1, 3.2, 3.3]'] results.sort() expected.sort() self.assertEqual(expected, results)
def test_kinesis_firehose_sink(self): sink_properties = { 'aws.region': 'eu-west-1', 'aws.credentials.provider.basic.accesskeyid': 'aws_access_key_id', 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key' } ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], type_info=Types.ROW( [Types.STRING(), Types.INT()])) kinesis_firehose_sink = KinesisFirehoseSink.builder() \ .set_firehose_client_properties(sink_properties) \ .set_serialization_schema(SimpleStringSchema()) \ .set_delivery_stream_name('stream-1') \ .set_fail_on_error(False) \ .set_max_batch_size(500) \ .set_max_in_flight_requests(50) \ .set_max_buffered_requests(10000) \ .set_max_batch_size_in_bytes(5 * 1024 * 1024) \ .set_max_time_in_buffer_ms(5000) \ .set_max_record_size_in_bytes(1 * 1024 * 1024) \ .build() ds.sink_to(kinesis_firehose_sink).name('kinesis firehose sink') plan = eval(self.env.get_execution_plan()) self.assertEqual('kinesis firehose sink: Writer', plan['nodes'][1]['type']) self.assertEqual( get_field_value(kinesis_firehose_sink.get_java_function(), 'failOnError'), False) self.assertEqual( get_field_value(kinesis_firehose_sink.get_java_function(), 'deliveryStreamName'), 'stream-1')
def test_event_time_session_window_with_purging_trigger(self): data_stream = self.env.from_collection([ ('hi', 1), ('hi', 2), ('hi', 3), ('hi', 4), ('hi', 8), ('hi', 9), ('hi', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_gap(Time.milliseconds(3))) \ .trigger(PurgingTrigger.of(EventTimeTrigger.create())) \ .process(CountWindowProcessFunction(), Types.TUPLE([Types.STRING(), Types.LONG(), Types.LONG(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_event_time_session_window_with_purging_trigger') results = self.test_sink.get_results() expected = ['(hi,1,7,4)', '(hi,8,12,2)', '(hi,15,18,1)'] self.assert_equals_sorted(expected, results)
def test_map_function_with_data_types_and_function_object(self): ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], type_info=Types.ROW([Types.STRING(), Types.INT()])) ds.map(MyMapFunction(), output_type=Types.ROW([Types.STRING(), Types.INT(), Types.INT()]))\ .add_sink(self.test_sink) self.env.execute('map_function_test') results = self.test_sink.get_results(False) expected = ['ab,2,1', 'bdc,3,2', 'cfgs,4,3', 'deeefg,6,4'] expected.sort() results.sort() self.assertEqual(expected, results)
def test_event_time_dynamic_gap_session_window(self): self.env.set_parallelism(1) data_stream = self.env.from_collection([ ('hi', 1), ('hi', 2), ('hi', 3), ('hi', 4), ('hi', 9), ('hi', 9), ('hi', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_dynamic_gap(MySessionWindowTimeGapExtractor())) \ .process(CountWindowProcessFunction(), Types.TUPLE([Types.STRING(), Types.LONG(), Types.LONG(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_event_time_dynamic_gap_session_window') results = self.test_sink.get_results() expected = ['(hi,1,8,4)', '(hi,9,30,3)'] self.assert_equals_sorted(expected, results)
def test_window_reduce_process(self): data_stream = self.env.from_collection( [('a', 1), ('a', 2), ('b', 3), ('a', 6), ('b', 8), ('b', 9), ('a', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) class MyProcessFunction(ProcessWindowFunction): def clear(self, context: ProcessWindowFunction.Context) -> None: pass def process(self, key, context: ProcessWindowFunction.Context, elements: Iterable[Tuple[str, int]]) -> Iterable[str]: yield "current window start at {}, reduce result {}".format( context.window().start, next(iter(elements)), ) data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_gap(Time.milliseconds(2))) \ .reduce(lambda a, b: (b[0], a[1] + b[1]), window_function=MyProcessFunction(), output_type=Types.STRING()) \ .add_sink(self.test_sink) self.env.execute('test_time_window_reduce_process') results = self.test_sink.get_results() expected = [ "current window start at 1, reduce result ('a', 3)", "current window start at 15, reduce result ('a', 15)", "current window start at 3, reduce result ('b', 3)", "current window start at 6, reduce result ('a', 6)", "current window start at 8, reduce result ('b', 17)" ] self.assert_equals_sorted(expected, results)
def _create_parquet_array_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]: row_type = DataTypes.ROW([ DataTypes.FIELD( 'string_array', DataTypes.ARRAY(DataTypes.STRING()).bridged_to('java.util.ArrayList') ), DataTypes.FIELD( 'int_array', DataTypes.ARRAY(DataTypes.INT()).bridged_to('java.util.ArrayList') ), ]) row_type_info = Types.ROW_NAMED([ 'string_array', 'int_array', ], [ Types.LIST(Types.STRING()), Types.LIST(Types.INT()), ]) data = [Row( string_array=['a', 'b', 'c'], int_array=[1, 2, 3], )] return row_type, row_type_info, data
def test_rabbitmq_connectors(self): connection_config = RMQConnectionConfig.Builder() \ .set_host('localhost') \ .set_port(5672) \ .set_virtual_host('/') \ .set_user_name('guest') \ .set_password('guest') \ .build() type_info = Types.ROW([Types.INT(), Types.STRING()]) deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(type_info=type_info).build() rmq_source = RMQSource( connection_config, 'source_queue', True, deserialization_schema) self.assertEqual( get_field_value(rmq_source.get_java_function(), 'queueName'), 'source_queue') self.assertTrue(get_field_value(rmq_source.get_java_function(), 'usesCorrelationId')) serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \ .build() rmq_sink = RMQSink(connection_config, 'sink_queue', serialization_schema) self.assertEqual( get_field_value(rmq_sink.get_java_function(), 'queueName'), 'sink_queue')
def test_sql_timestamp_type_info(self): ds = self.env.from_collection([(datetime.date(2021, 1, 9), datetime.time(12, 0, 0), datetime.datetime(2021, 1, 9, 12, 0, 0, 11000))], type_info=Types.ROW([Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP()])) ds.map(lambda x: x, output_type=Types.ROW([Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP()]))\ .add_sink(self.test_sink) self.env.execute("test sql timestamp type info") results = self.test_sink.get_results() expected = ['+I[2021-01-09, 12:00:00, 2021-01-09 12:00:00.011]'] self.assertEqual(expected, results)
def _check_serialization_schema_implementations(check_function): input_type = Types.ROW([Types.STRING()]) check_function( JsonRowSerializationSchema.builder().with_type_info( input_type).build(), 'org.apache.flink.formats.json.JsonRowSerializationSchema') check_function( CsvRowSerializationSchema.Builder(input_type).build(), 'org.apache.flink.formats.csv.CsvRowSerializationSchema') avro_schema_string = """ { "type": "record", "name": "test_record", "fields": [] } """ check_function( AvroRowSerializationSchema(avro_schema_string=avro_schema_string), 'org.apache.flink.formats.avro.AvroRowSerializationSchema') check_function( SimpleStringSchema(), 'org.apache.flink.api.common.serialization.SimpleStringSchema')
def __init__(self, name: str, elem_type_info: TypeInformation): """ Constructor of the ListStateDescriptor. :param name: The name of the state. :param elem_type_info: the type information of the state element. """ if not isinstance(elem_type_info, PickledBytesTypeInfo): raise ValueError( "The type information of the element could only be " "PickledBytesTypeInfo (created via Types.PICKLED_BYTE_ARRAY()) " "currently, got %s" % type(elem_type_info)) super(ListStateDescriptor, self).__init__(name, Types.LIST(elem_type_info))
def test_window_aggregate_passthrough(self): data_stream = self.env.from_collection( [('a', 1), ('a', 2), ('b', 3), ('a', 6), ('b', 8), ('b', 9), ('a', 15)], type_info=Types.TUPLE([Types.STRING(), Types.INT()])) # type: DataStream watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ .with_timestamp_assigner(SecondColumnTimestampAssigner()) class MyAggregateFunction(AggregateFunction): def create_accumulator(self) -> Tuple[str, Dict[int, int]]: return '', {0: 0, 1: 0} def add( self, value: Tuple[str, int], accumulator: Tuple[str, Dict[int, int]] ) -> Tuple[str, Dict[int, int]]: number_map = accumulator[1] number_map[value[1] % 2] += 1 return value[0], number_map def get_result( self, accumulator: Tuple[str, Dict[int, int]]) -> Tuple[str, int]: number_map = accumulator[1] return accumulator[0], number_map[0] - number_map[1] def merge( self, acc_a: Tuple[str, Dict[int, int]], acc_b: Tuple[str, Dict[int, int]]) -> Tuple[str, Dict[int, int]]: number_map_a = acc_a[1] number_map_b = acc_b[1] new_number_map = { 0: number_map_a[0] + number_map_b[0], 1: number_map_a[1] + number_map_b[1] } return acc_a[0], new_number_map data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ .key_by(lambda x: x[0], key_type=Types.STRING()) \ .window(EventTimeSessionWindows.with_gap(Time.milliseconds(2))) \ .aggregate(MyAggregateFunction(), output_type=Types.TUPLE([Types.STRING(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_time_window_aggregate_passthrough') results = self.test_sink.get_results() expected = ['(a,-1)', '(a,0)', '(a,1)', '(b,-1)', '(b,0)'] self.assert_equals_sorted(expected, results)
def test_key_by_on_connect_stream(self): ds1 = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], type_info=Types.ROW([Types.STRING(), Types.INT()])) \ .key_by(MyKeySelector(), key_type_info=Types.INT()) ds2 = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], type_info=Types.ROW( [Types.STRING(), Types.INT()])) class AssertKeyCoMapFunction(CoMapFunction): def __init__(self): self.pre1 = None self.pre2 = None def map1(self, value): if value[0] == 'b': assert self.pre1 == 'a' if value[0] == 'd': assert self.pre1 == 'c' self.pre1 = value[0] return value def map2(self, value): if value[0] == 'b': assert self.pre2 == 'a' if value[0] == 'd': assert self.pre2 == 'c' self.pre2 = value[0] return value ds1.connect(ds2)\ .key_by(MyKeySelector(), MyKeySelector(), key_type_info=Types.INT())\ .map(AssertKeyCoMapFunction())\ .add_sink(self.test_sink) self.env.execute('key_by_test') results = self.test_sink.get_results(True) expected = [ "Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b', f1=0)", "Row(f0='c', f1=1)", "Row(f0='d', f1=1)", "Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b', f1=0)", "Row(f0='c', f1=1)", "Row(f0='d', f1=1)" ] results.sort() expected.sort() self.assertEqual(expected, results)
def test_co_map_function_with_data_types(self): self.env.set_parallelism(1) ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], type_info=Types.ROW([Types.INT(), Types.INT()])) ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], type_info=Types.ROW([Types.STRING(), Types.STRING()])) ds1.connect(ds2).map(MyCoMapFunction(), output_type=Types.STRING()).add_sink(self.test_sink) self.env.execute('co_map_function_test') results = self.test_sink.get_results(False) expected = ['2', '3', '4', 'a', 'b', 'c'] expected.sort() results.sort() self.assertEqual(expected, results)
def test_count_sliding_window(self): data_stream = self.env.from_collection([ (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')], type_info=Types.TUPLE([Types.INT(), Types.STRING()])) # type: DataStream data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \ .window(CountSlidingWindowAssigner(2, 1)) \ .apply(SumWindowFunction(), Types.TUPLE([Types.STRING(), Types.INT()])) \ .add_sink(self.test_sink) self.env.execute('test_count_sliding_window') results = self.test_sink.get_results() expected = ['(hello,6)', '(hi,8)', '(hi,4)', '(hello,10)'] self.assert_equals_sorted(expected, results)