def get_type(self) -> TypeInformation: """ Gets the type of the stream. :return: The type of the DataStream. """ return typeinfo._from_java_type(self._j_data_stream.getType())
def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream': """ Applies a Filter transformation on a DataStream. The transformation calls a FilterFunction for each element of the DataStream and retains only those element for which the function returns true. Elements for which the function returns false are filtered. The user can also extend RichFilterFunction to gain access to other features provided by the RichFunction interface. :param func: The FilterFunction that is called for each element of the DataStream. :return: The filtered DataStream. """ class FilterFlatMap(FlatMapFunction): def __init__(self, filter_func): self._func = filter_func def flat_map(self, value): if self._func.filter(value): yield value if isinstance(func, Callable): func = FilterFunctionWrapper(func) elif not isinstance(func, FilterFunction): raise TypeError( "func must be a Callable or instance of FilterFunction.") j_input_type = self._j_data_stream.getTransformation().getOutputType() type_info = typeinfo._from_java_type(j_input_type) j_data_stream = self.flat_map(FilterFlatMap(func), type_info=type_info)._j_data_stream filtered_stream = DataStream(j_data_stream) filtered_stream.name("Filter") return filtered_stream
def get_type_info(self): if self._type_info is None: jvm = get_gateway().jvm j_type_info = jvm.org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter \ .toLegacyTypeInfo(_to_java_data_type(self._row_type)) self._type_info = _from_java_type(j_type_info) return self._type_info
def key_by(self, key_selector: Union[Callable, KeySelector], key_type_info: TypeInformation = None) -> 'KeyedStream': """ Creates a new KeyedStream that uses the provided key for partitioning its operator states. :param key_selector: The KeySelector to be used for extracting the key for partitioning. :param key_type_info: The type information describing the key type. :return: The DataStream with partitioned state(i.e. KeyedStream). """ if callable(key_selector): key_selector = KeySelectorFunctionWrapper(key_selector) if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)): raise TypeError("Parameter key_selector should be a type of KeySelector.") gateway = get_gateway() PickledKeySelector = gateway.jvm \ .org.apache.flink.datastream.runtime.functions.python.PickledKeySelector j_output_type_info = self._j_data_stream.getTransformation().getOutputType() output_type_info = typeinfo._from_java_type(j_output_type_info) is_key_pickled_byte_array = False if key_type_info is None: key_type_info = Types.PICKLED_BYTE_ARRAY() is_key_pickled_byte_array = True intermediate_map_stream = self.map(lambda x: (key_selector.get_key(x), x), type_info=Types.ROW([key_type_info, output_type_info])) intermediate_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil .STREAM_KEY_BY_MAP_OPERATOR_NAME) generated_key_stream = KeyedStream(intermediate_map_stream._j_data_stream .keyBy(PickledKeySelector(is_key_pickled_byte_array), key_type_info.get_java_type_info()), self) generated_key_stream._original_data_type_info = output_type_info return generated_key_stream
def pickled_bytes_to_python_converter(data, field_type): if isinstance(field_type, RowTypeInfo): data = zip(list(data[1:]), field_type.get_field_types()) fields = [] for d, d_type in data: fields.append(pickled_bytes_to_python_converter(d, d_type)) return tuple(fields) else: data = pickle.loads(data) if field_type == Types.SQL_TIME(): seconds, microseconds = divmod(data, 10**6) minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) return datetime.time(hours, minutes, seconds, microseconds) elif field_type == Types.SQL_DATE(): return field_type.from_internal_type(data) elif field_type == Types.SQL_TIMESTAMP(): return field_type.from_internal_type(int(data.timestamp() * 10**6)) elif field_type == Types.FLOAT(): return field_type.from_internal_type(ast.literal_eval(data)) elif is_basic_array_type_info( field_type) or is_primitive_array_type_info(field_type): element_type = typeinfo._from_java_type( field_type.get_java_type_info().getComponentInfo()) elements = [] for element_bytes in data: elements.append( pickled_bytes_to_python_converter(element_bytes, element_type)) return elements else: return field_type.from_internal_type(data)
def json_schema(self, json_schema: str): """ Creates a JSON deserialization schema for the given JSON schema. :param json_schema: JSON schema describing the result type. """ if json_schema is None: raise TypeError("The json_schema must not be None.") j_type_info = get_gateway().jvm \ .org.apache.flink.formats.json.JsonRowSchemaConverter.convert(json_schema) self._type_info = typeinfo._from_java_type(j_type_info) return self
def from_collection(self, collection: List[Any], type_info: TypeInformation = None) -> DataStream: """ Creates a data stream from the given non-empty collection. The type of the data stream is that of the elements in the collection. Note that this operation will result in a non-parallel data stream source, i.e. a data stream source with parallelism one. :param collection: The collection of elements to create the data stream from. :param type_info: The TypeInformation for the produced data stream :return: the data stream representing the given collection. """ if type_info is not None: wrapper_type = _from_java_type(type_info.get_java_type_info()) collection = [wrapper_type.to_internal_type(element) for element in collection] return self._from_collection(collection, type_info)
def kafka_connector_assertion(self, flink_kafka_consumer_clz, flink_kafka_producer_clz): source_topic = 'test_source_topic' sink_topic = 'test_sink_topic' props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'} type_info = Types.ROW([Types.INT(), Types.STRING()]) # Test for kafka consumer deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(type_info=type_info).build() flink_kafka_consumer = flink_kafka_consumer_clz(source_topic, deserialization_schema, props) flink_kafka_consumer.set_start_from_earliest() flink_kafka_consumer.set_commit_offsets_on_checkpoints(True) j_properties = get_private_field(flink_kafka_consumer.get_java_function(), 'properties') self.assertEqual('localhost:9092', j_properties.getProperty('bootstrap.servers')) self.assertEqual('test_group', j_properties.getProperty('group.id')) self.assertTrue(get_private_field(flink_kafka_consumer.get_java_function(), 'enableCommitOnCheckpoints')) j_start_up_mode = get_private_field(flink_kafka_consumer.get_java_function(), 'startupMode') j_deserializer = get_private_field(flink_kafka_consumer.get_java_function(), 'deserializer') j_deserialize_type_info = invoke_java_object_method(j_deserializer, "getProducedType") deserialize_type_info = typeinfo._from_java_type(j_deserialize_type_info) self.assertTrue(deserialize_type_info == type_info) self.assertTrue(j_start_up_mode.equals(get_gateway().jvm .org.apache.flink.streaming.connectors .kafka.config.StartupMode.EARLIEST)) j_topic_desc = get_private_field(flink_kafka_consumer.get_java_function(), 'topicsDescriptor') j_topics = invoke_java_object_method(j_topic_desc, 'getFixedTopics') self.assertEqual(['test_source_topic'], list(j_topics)) # Test for kafka producer serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \ .build() flink_kafka_producer = flink_kafka_producer_clz(sink_topic, serialization_schema, props) flink_kafka_producer.set_write_timestamp_to_kafka(False) j_producer_config = get_private_field(flink_kafka_producer.get_java_function(), 'producerConfig') self.assertEqual('localhost:9092', j_producer_config.getProperty('bootstrap.servers')) self.assertEqual('test_group', j_producer_config.getProperty('group.id')) self.assertFalse(get_private_field(flink_kafka_producer.get_java_function(), 'writeTimestampToKafka'))
def test_from_java_type(self): basic_int_type_info = Types.INT() self.assertEqual(basic_int_type_info, _from_java_type(basic_int_type_info.get_java_type_info())) basic_short_type_info = Types.SHORT() self.assertEqual(basic_short_type_info, _from_java_type(basic_short_type_info.get_java_type_info())) basic_long_type_info = Types.LONG() self.assertEqual(basic_long_type_info, _from_java_type(basic_long_type_info.get_java_type_info())) basic_float_type_info = Types.FLOAT() self.assertEqual(basic_float_type_info, _from_java_type(basic_float_type_info.get_java_type_info())) basic_double_type_info = Types.DOUBLE() self.assertEqual(basic_double_type_info, _from_java_type(basic_double_type_info.get_java_type_info())) basic_char_type_info = Types.CHAR() self.assertEqual(basic_char_type_info, _from_java_type(basic_char_type_info.get_java_type_info())) basic_byte_type_info = Types.BYTE() self.assertEqual(basic_byte_type_info, _from_java_type(basic_byte_type_info.get_java_type_info())) basic_big_int_type_info = Types.BIG_INT() self.assertEqual(basic_big_int_type_info, _from_java_type(basic_big_int_type_info.get_java_type_info())) basic_big_dec_type_info = Types.BIG_DEC() self.assertEqual(basic_big_dec_type_info, _from_java_type(basic_big_dec_type_info.get_java_type_info())) basic_sql_date_type_info = Types.SQL_DATE() self.assertEqual(basic_sql_date_type_info, _from_java_type(basic_sql_date_type_info.get_java_type_info())) basic_sql_time_type_info = Types.SQL_TIME() self.assertEqual(basic_sql_time_type_info, _from_java_type(basic_sql_time_type_info.get_java_type_info())) basic_sql_timestamp_type_info = Types.SQL_TIMESTAMP() self.assertEqual(basic_sql_timestamp_type_info, _from_java_type(basic_sql_timestamp_type_info.get_java_type_info())) row_type_info = Types.ROW([Types.INT(), Types.STRING()]) self.assertEqual(row_type_info, _from_java_type(row_type_info.get_java_type_info())) tuple_type_info = Types.TUPLE([Types.CHAR(), Types.INT()]) self.assertEqual(tuple_type_info, _from_java_type(tuple_type_info.get_java_type_info())) primitive_int_array_type_info = Types.PRIMITIVE_ARRAY(Types.INT()) self.assertEqual(primitive_int_array_type_info, _from_java_type(primitive_int_array_type_info.get_java_type_info())) object_array_type_info = Types.OBJECT_ARRAY(Types.SQL_DATE()) self.assertEqual(object_array_type_info, _from_java_type(object_array_type_info.get_java_type_info())) pickled_byte_array_type_info = Types.PICKLED_BYTE_ARRAY() self.assertEqual(pickled_byte_array_type_info, _from_java_type(pickled_byte_array_type_info.get_java_type_info())) sql_date_type_info = Types.SQL_DATE() self.assertEqual(sql_date_type_info, _from_java_type(sql_date_type_info.get_java_type_info())) map_type_info = Types.MAP(Types.INT(), Types.STRING()) self.assertEqual(map_type_info, _from_java_type(map_type_info.get_java_type_info())) list_type_info = Types.LIST(Types.INT()) self.assertEqual(list_type_info, _from_java_type(list_type_info.get_java_type_info()))
def get_produced_type(self) -> TypeInformation: return typeinfo._from_java_type(self._j_function.getProducedType())