Beispiel #1
0
    def get_type(self) -> TypeInformation:
        """
        Gets the type of the stream.

        :return: The type of the DataStream.
        """
        return typeinfo._from_java_type(self._j_data_stream.getType())
Beispiel #2
0
    def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
        """
        Applies a Filter transformation on a DataStream. The transformation calls a FilterFunction
        for each element of the DataStream and retains only those element for which the function
        returns true. Elements for which the function returns false are filtered. The user can also
        extend RichFilterFunction to gain access to other features provided by the RichFunction
        interface.

        :param func: The FilterFunction that is called for each element of the DataStream.
        :return: The filtered DataStream.
        """
        class FilterFlatMap(FlatMapFunction):
            def __init__(self, filter_func):
                self._func = filter_func

            def flat_map(self, value):
                if self._func.filter(value):
                    yield value

        if isinstance(func, Callable):
            func = FilterFunctionWrapper(func)
        elif not isinstance(func, FilterFunction):
            raise TypeError(
                "func must be a Callable or instance of FilterFunction.")

        j_input_type = self._j_data_stream.getTransformation().getOutputType()
        type_info = typeinfo._from_java_type(j_input_type)
        j_data_stream = self.flat_map(FilterFlatMap(func),
                                      type_info=type_info)._j_data_stream
        filtered_stream = DataStream(j_data_stream)
        filtered_stream.name("Filter")
        return filtered_stream
Beispiel #3
0
 def get_type_info(self):
     if self._type_info is None:
         jvm = get_gateway().jvm
         j_type_info = jvm.org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter \
             .toLegacyTypeInfo(_to_java_data_type(self._row_type))
         self._type_info = _from_java_type(j_type_info)
     return self._type_info
Beispiel #4
0
    def key_by(self, key_selector: Union[Callable, KeySelector],
               key_type_info: TypeInformation = None) -> 'KeyedStream':
        """
        Creates a new KeyedStream that uses the provided key for partitioning its operator states.

        :param key_selector: The KeySelector to be used for extracting the key for partitioning.
        :param key_type_info: The type information describing the key type.
        :return: The DataStream with partitioned state(i.e. KeyedStream).
        """
        if callable(key_selector):
            key_selector = KeySelectorFunctionWrapper(key_selector)
        if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)):
            raise TypeError("Parameter key_selector should be a type of KeySelector.")

        gateway = get_gateway()
        PickledKeySelector = gateway.jvm \
            .org.apache.flink.datastream.runtime.functions.python.PickledKeySelector
        j_output_type_info = self._j_data_stream.getTransformation().getOutputType()
        output_type_info = typeinfo._from_java_type(j_output_type_info)
        is_key_pickled_byte_array = False
        if key_type_info is None:
            key_type_info = Types.PICKLED_BYTE_ARRAY()
            is_key_pickled_byte_array = True

        intermediate_map_stream = self.map(lambda x: (key_selector.get_key(x), x),
                                           type_info=Types.ROW([key_type_info, output_type_info]))
        intermediate_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
                                     .STREAM_KEY_BY_MAP_OPERATOR_NAME)
        generated_key_stream = KeyedStream(intermediate_map_stream._j_data_stream
                                           .keyBy(PickledKeySelector(is_key_pickled_byte_array),
                                                  key_type_info.get_java_type_info()), self)
        generated_key_stream._original_data_type_info = output_type_info
        return generated_key_stream
Beispiel #5
0
def pickled_bytes_to_python_converter(data, field_type):
    if isinstance(field_type, RowTypeInfo):
        data = zip(list(data[1:]), field_type.get_field_types())
        fields = []
        for d, d_type in data:
            fields.append(pickled_bytes_to_python_converter(d, d_type))
        return tuple(fields)
    else:
        data = pickle.loads(data)
        if field_type == Types.SQL_TIME():
            seconds, microseconds = divmod(data, 10**6)
            minutes, seconds = divmod(seconds, 60)
            hours, minutes = divmod(minutes, 60)
            return datetime.time(hours, minutes, seconds, microseconds)
        elif field_type == Types.SQL_DATE():
            return field_type.from_internal_type(data)
        elif field_type == Types.SQL_TIMESTAMP():
            return field_type.from_internal_type(int(data.timestamp() * 10**6))
        elif field_type == Types.FLOAT():
            return field_type.from_internal_type(ast.literal_eval(data))
        elif is_basic_array_type_info(
                field_type) or is_primitive_array_type_info(field_type):
            element_type = typeinfo._from_java_type(
                field_type.get_java_type_info().getComponentInfo())
            elements = []
            for element_bytes in data:
                elements.append(
                    pickled_bytes_to_python_converter(element_bytes,
                                                      element_type))
            return elements
        else:
            return field_type.from_internal_type(data)
Beispiel #6
0
        def json_schema(self, json_schema: str):
            """
            Creates a JSON deserialization schema for the given JSON schema.

            :param json_schema: JSON schema describing the result type.
            """
            if json_schema is None:
                raise TypeError("The json_schema must not be None.")
            j_type_info = get_gateway().jvm \
                .org.apache.flink.formats.json.JsonRowSchemaConverter.convert(json_schema)
            self._type_info = typeinfo._from_java_type(j_type_info)
            return self
Beispiel #7
0
    def from_collection(self, collection: List[Any],
                        type_info: TypeInformation = None) -> DataStream:
        """
        Creates a data stream from the given non-empty collection. The type of the data stream is
        that of the elements in the collection.

        Note that this operation will result in a non-parallel data stream source, i.e. a data
        stream source with parallelism one.

        :param collection: The collection of elements to create the data stream from.
        :param type_info: The TypeInformation for the produced data stream
        :return: the data stream representing the given collection.
        """
        if type_info is not None:
            wrapper_type = _from_java_type(type_info.get_java_type_info())
            collection = [wrapper_type.to_internal_type(element) for element in collection]
        return self._from_collection(collection, type_info)
Beispiel #8
0
    def kafka_connector_assertion(self, flink_kafka_consumer_clz, flink_kafka_producer_clz):
        source_topic = 'test_source_topic'
        sink_topic = 'test_sink_topic'
        props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
        type_info = Types.ROW([Types.INT(), Types.STRING()])

        # Test for kafka consumer
        deserialization_schema = JsonRowDeserializationSchema.builder() \
            .type_info(type_info=type_info).build()

        flink_kafka_consumer = flink_kafka_consumer_clz(source_topic, deserialization_schema, props)
        flink_kafka_consumer.set_start_from_earliest()
        flink_kafka_consumer.set_commit_offsets_on_checkpoints(True)

        j_properties = get_private_field(flink_kafka_consumer.get_java_function(), 'properties')
        self.assertEqual('localhost:9092', j_properties.getProperty('bootstrap.servers'))
        self.assertEqual('test_group', j_properties.getProperty('group.id'))
        self.assertTrue(get_private_field(flink_kafka_consumer.get_java_function(),
                                          'enableCommitOnCheckpoints'))
        j_start_up_mode = get_private_field(flink_kafka_consumer.get_java_function(), 'startupMode')

        j_deserializer = get_private_field(flink_kafka_consumer.get_java_function(), 'deserializer')
        j_deserialize_type_info = invoke_java_object_method(j_deserializer, "getProducedType")
        deserialize_type_info = typeinfo._from_java_type(j_deserialize_type_info)
        self.assertTrue(deserialize_type_info == type_info)
        self.assertTrue(j_start_up_mode.equals(get_gateway().jvm
                                               .org.apache.flink.streaming.connectors
                                               .kafka.config.StartupMode.EARLIEST))
        j_topic_desc = get_private_field(flink_kafka_consumer.get_java_function(),
                                         'topicsDescriptor')
        j_topics = invoke_java_object_method(j_topic_desc, 'getFixedTopics')
        self.assertEqual(['test_source_topic'], list(j_topics))

        # Test for kafka producer
        serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
            .build()
        flink_kafka_producer = flink_kafka_producer_clz(sink_topic, serialization_schema, props)
        flink_kafka_producer.set_write_timestamp_to_kafka(False)

        j_producer_config = get_private_field(flink_kafka_producer.get_java_function(),
                                              'producerConfig')
        self.assertEqual('localhost:9092', j_producer_config.getProperty('bootstrap.servers'))
        self.assertEqual('test_group', j_producer_config.getProperty('group.id'))
        self.assertFalse(get_private_field(flink_kafka_producer.get_java_function(),
                                           'writeTimestampToKafka'))
Beispiel #9
0
    def test_from_java_type(self):
        basic_int_type_info = Types.INT()
        self.assertEqual(basic_int_type_info,
                         _from_java_type(basic_int_type_info.get_java_type_info()))

        basic_short_type_info = Types.SHORT()
        self.assertEqual(basic_short_type_info,
                         _from_java_type(basic_short_type_info.get_java_type_info()))

        basic_long_type_info = Types.LONG()
        self.assertEqual(basic_long_type_info,
                         _from_java_type(basic_long_type_info.get_java_type_info()))

        basic_float_type_info = Types.FLOAT()
        self.assertEqual(basic_float_type_info,
                         _from_java_type(basic_float_type_info.get_java_type_info()))

        basic_double_type_info = Types.DOUBLE()
        self.assertEqual(basic_double_type_info,
                         _from_java_type(basic_double_type_info.get_java_type_info()))

        basic_char_type_info = Types.CHAR()
        self.assertEqual(basic_char_type_info,
                         _from_java_type(basic_char_type_info.get_java_type_info()))

        basic_byte_type_info = Types.BYTE()
        self.assertEqual(basic_byte_type_info,
                         _from_java_type(basic_byte_type_info.get_java_type_info()))

        basic_big_int_type_info = Types.BIG_INT()
        self.assertEqual(basic_big_int_type_info,
                         _from_java_type(basic_big_int_type_info.get_java_type_info()))

        basic_big_dec_type_info = Types.BIG_DEC()
        self.assertEqual(basic_big_dec_type_info,
                         _from_java_type(basic_big_dec_type_info.get_java_type_info()))

        basic_sql_date_type_info = Types.SQL_DATE()
        self.assertEqual(basic_sql_date_type_info,
                         _from_java_type(basic_sql_date_type_info.get_java_type_info()))

        basic_sql_time_type_info = Types.SQL_TIME()
        self.assertEqual(basic_sql_time_type_info,
                         _from_java_type(basic_sql_time_type_info.get_java_type_info()))

        basic_sql_timestamp_type_info = Types.SQL_TIMESTAMP()
        self.assertEqual(basic_sql_timestamp_type_info,
                         _from_java_type(basic_sql_timestamp_type_info.get_java_type_info()))

        row_type_info = Types.ROW([Types.INT(), Types.STRING()])
        self.assertEqual(row_type_info, _from_java_type(row_type_info.get_java_type_info()))

        tuple_type_info = Types.TUPLE([Types.CHAR(), Types.INT()])
        self.assertEqual(tuple_type_info, _from_java_type(tuple_type_info.get_java_type_info()))

        primitive_int_array_type_info = Types.PRIMITIVE_ARRAY(Types.INT())
        self.assertEqual(primitive_int_array_type_info,
                         _from_java_type(primitive_int_array_type_info.get_java_type_info()))

        object_array_type_info = Types.OBJECT_ARRAY(Types.SQL_DATE())
        self.assertEqual(object_array_type_info,
                         _from_java_type(object_array_type_info.get_java_type_info()))

        pickled_byte_array_type_info = Types.PICKLED_BYTE_ARRAY()
        self.assertEqual(pickled_byte_array_type_info,
                         _from_java_type(pickled_byte_array_type_info.get_java_type_info()))

        sql_date_type_info = Types.SQL_DATE()
        self.assertEqual(sql_date_type_info,
                         _from_java_type(sql_date_type_info.get_java_type_info()))

        map_type_info = Types.MAP(Types.INT(), Types.STRING())
        self.assertEqual(map_type_info,
                         _from_java_type(map_type_info.get_java_type_info()))

        list_type_info = Types.LIST(Types.INT())
        self.assertEqual(list_type_info,
                         _from_java_type(list_type_info.get_java_type_info()))
Beispiel #10
0
 def get_produced_type(self) -> TypeInformation:
     return typeinfo._from_java_type(self._j_function.getProducedType())