def convert_to_python_obj(data, type_info): if type_info == Types.PICKLED_BYTE_ARRAY(): return pickle.loads(data) elif isinstance(type_info, ExternalTypeInfo): return convert_to_python_obj(data, type_info._type_info) else: gateway = get_gateway() pickle_bytes = gateway.jvm.PythonBridgeUtils. \ getPickledBytesFromJavaObject(data, type_info.get_java_type_info()) if isinstance(type_info, RowTypeInfo) or isinstance( type_info, TupleTypeInfo): field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types()) fields = [] for data, field_type in field_data: if len(data) == 0: fields.append(None) else: fields.append( pickled_bytes_to_python_converter(data, field_type)) if isinstance(type_info, RowTypeInfo): return Row.of_kind( RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields) else: return tuple(fields) else: return pickled_bytes_to_python_converter(pickle_bytes, type_info)
def key_by(self, key_selector: Union[Callable, KeySelector], key_type_info: TypeInformation = None) -> 'KeyedStream': """ Creates a new KeyedStream that uses the provided key for partitioning its operator states. :param key_selector: The KeySelector to be used for extracting the key for partitioning. :param key_type_info: The type information describing the key type. :return: The DataStream with partitioned state(i.e. KeyedStream). """ if callable(key_selector): key_selector = KeySelectorFunctionWrapper(key_selector) if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)): raise TypeError("Parameter key_selector should be a type of KeySelector.") gateway = get_gateway() PickledKeySelector = gateway.jvm \ .org.apache.flink.datastream.runtime.functions.python.PickledKeySelector j_output_type_info = self._j_data_stream.getTransformation().getOutputType() output_type_info = typeinfo._from_java_type(j_output_type_info) is_key_pickled_byte_array = False if key_type_info is None: key_type_info = Types.PICKLED_BYTE_ARRAY() is_key_pickled_byte_array = True intermediate_map_stream = self.map(lambda x: (key_selector.get_key(x), x), type_info=Types.ROW([key_type_info, output_type_info])) intermediate_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil .STREAM_KEY_BY_MAP_OPERATOR_NAME) generated_key_stream = KeyedStream(intermediate_map_stream._j_data_stream .keyBy(PickledKeySelector(is_key_pickled_byte_array), key_type_info.get_java_type_info()), self) generated_key_stream._original_data_type_info = output_type_info return generated_key_stream
def collect(self, data_stream: DataStream): gateway = get_gateway() self._is_python_objects = data_stream.get_type( ) == Types.PICKLED_BYTE_ARRAY() self._j_data_stream_test_collect_sink = gateway.jvm \ .org.apache.flink.python.util.DataStreamTestCollectSink(self._is_python_objects) data_stream._j_data_stream.addSink( self._j_data_stream_test_collect_sink)
def _from_collection(self, elements: List[Any], type_info: TypeInformation = None) -> DataStream: temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp()) serializer = self.serializer try: with temp_file: # dumps elements to a temporary file by pickle serializer. serializer.serialize(elements, temp_file) gateway = get_gateway() # if user does not defined the element data types, read the pickled data as a byte array # list. if type_info is None: j_objs = gateway.jvm.PythonBridgeUtils.readPickledBytes( temp_file.name) out_put_type_info = Types.PICKLED_BYTE_ARRAY( ) # type: TypeInformation else: j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects( temp_file.name) out_put_type_info = type_info # Since flink python module depends on table module, we can make use of utils of it when # implementing python DataStream API. PythonTableUtils = gateway.jvm\ .org.apache.flink.table.planner.utils.python.PythonTableUtils execution_config = self._j_stream_execution_environment.getConfig() j_input_format = PythonTableUtils.getCollectionInputFormat( j_objs, out_put_type_info.get_java_type_info(), execution_config) JInputFormatSourceFunction = gateway.jvm.org.apache.flink.streaming.api.functions.\ source.InputFormatSourceFunction JBoundedness = gateway.jvm.org.apache.flink.api.connector.source.Boundedness j_data_stream_source = invoke_method( self._j_stream_execution_environment, "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment", "addSource", [ JInputFormatSourceFunction( j_input_format, out_put_type_info.get_java_type_info()), "Collection Source", out_put_type_info.get_java_type_info(), JBoundedness.BOUNDED ], [ "org.apache.flink.streaming.api.functions.source.SourceFunction", "java.lang.String", "org.apache.flink.api.common.typeinfo.TypeInformation", "org.apache.flink.api.connector.source.Boundedness" ]) j_data_stream_source.forceNonParallel() return DataStream(j_data_stream=j_data_stream_source) finally: os.unlink(temp_file.name)
def test_from_collection_with_data_types(self): # verify from_collection for the collection with single object. ds = self.env.from_collection(['Hi', 'Hello'], type_info=Types.STRING()) ds.add_sink(self.test_sink) self.env.execute("test from collection with single object") results = self.test_sink.get_results(False) expected = ['Hello', 'Hi'] results.sort() expected.sort() self.assertEqual(expected, results) # verify from_collection for the collection with multiple objects like tuple. ds = self.env.from_collection([(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13), datetime.time(hour=12, minute=0, second=0, microsecond=123000), datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [1, 2, 3], decimal.Decimal('1000000000000000000.05'), decimal.Decimal('1000000000000000000.0599999999999' '9999899999999999')), (2, None, 2, True, 43878, 9147483648, 9.87, 2.98936, bytearray(b'flink'), 'pyflink', datetime.date(2015, 10, 14), datetime.time(hour=11, minute=2, second=2, microsecond=234500), datetime.datetime(2020, 4, 15, 8, 2, 6, 235000), [2, 4, 6], decimal.Decimal('2000000000000000000.74'), decimal.Decimal('2000000000000000000.061111111111111' '11111111111111'))], type_info=Types.ROW( [Types.LONG(), Types.LONG(), Types.SHORT(), Types.BOOLEAN(), Types.SHORT(), Types.INT(), Types.FLOAT(), Types.DOUBLE(), Types.PICKLED_BYTE_ARRAY(), Types.STRING(), Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP(), Types.BASIC_ARRAY(Types.LONG()), Types.BIG_DEC(), Types.BIG_DEC()])) ds.add_sink(self.test_sink) self.env.execute("test from collection with tuple object") results = self.test_sink.get_results(False) # if user specifies data types of input data, the collected result should be in row format. expected = [ '+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, [102, 108, 105, 110, 107], ' 'pyflink, 2014-09-13, 12:00:00, 2018-03-11 03:00:00.123, [1, 2, 3], ' '1000000000000000000.05, 1000000000000000000.05999999999999999899999999999]', '+I[2, null, 2, true, -21658, 557549056, 9.87, 2.98936, [102, 108, 105, 110, 107], ' 'pyflink, 2015-10-14, 11:02:02, 2020-04-15 08:02:06.235, [2, 4, 6], ' '2000000000000000000.74, 2000000000000000000.06111111111111111111111111111]'] results.sort() expected.sort() self.assertEqual(expected, results)
def __init__(self, tag_id: str, type_info: Optional[Union[TypeInformation, list]] = None): if not tag_id: raise ValueError("OutputTag tag_id cannot be None or empty string") self.tag_id = tag_id if type_info is None: self.type_info = Types.PICKLED_BYTE_ARRAY() elif isinstance(type_info, list): self.type_info = RowTypeInfo(type_info) elif not isinstance(type_info, TypeInformation): raise TypeError( "OutputTag type_info must be None, list or TypeInformation") else: self.type_info = type_info
def _from_collection(self, elements: List[Any], type_info: TypeInformation = None) -> DataStream: temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp()) serializer = self.serializer try: with temp_file: # dumps elements to a temporary file by pickle serializer. serializer.dump_to_stream(elements, temp_file) gateway = get_gateway() # if user does not defined the element data types, read the pickled data as a byte array # list. if type_info is None: j_objs = gateway.jvm.PythonBridgeUtils.readPickledBytes( temp_file.name) out_put_type_info = Types.PICKLED_BYTE_ARRAY( ) # type: TypeInformation else: j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects( temp_file.name) out_put_type_info = type_info # Since flink python module depends on table module, we can make use of utils of it when # implementing python DataStream API. PythonTableUtils = gateway.jvm\ .org.apache.flink.table.planner.utils.python.PythonTableUtils execution_config = self._j_stream_execution_environment.getConfig() j_input_format = PythonTableUtils.getCollectionInputFormat( j_objs, out_put_type_info.get_java_type_info(), execution_config) j_data_stream_source = self._j_stream_execution_environment.createInput( j_input_format, out_put_type_info.get_java_type_info()) j_data_stream_source.forceNonParallel() return DataStream(j_data_stream=j_data_stream_source) finally: os.unlink(temp_file.name)
def open(self, runtime_context: RuntimeContext, internal_timer_service: InternalTimerService): self.window_function.open(runtime_context) self.num_late_records_dropped = runtime_context.get_metrics_group( ).counter(self.LATE_ELEMENTS_DROPPED_METRIC_NAME) self.internal_timer_service = internal_timer_service self.trigger_context = Context(runtime_context, internal_timer_service, self.trigger) self.process_context = WindowContext(self.window_assigner, runtime_context, self.window_function, self.internal_timer_service) self.window_assigner_context = WindowAssignerContext( self.internal_timer_service, runtime_context) # create (or restore) the state that hold the actual window contents # NOTE - the state may be null in the case of the overriding evicting window operator if self.window_state_descriptor is not None: self.window_state = get_or_create_keyed_state( runtime_context, self.window_state_descriptor) if isinstance(self.window_assigner, MergingWindowAssigner): if isinstance(self.window_state, InternalMergingState): self.window_merging_state = self.window_state # TODO: the type info is just a placeholder currently. # it should be the real type serializer after supporting the user-defined state type # serializer merging_sets_state_descriptor = ListStateDescriptor( "merging-window-set", Types.PICKLED_BYTE_ARRAY()) self.merging_sets_state = get_or_create_keyed_state( runtime_context, merging_sets_state_descriptor) self.merge_function = WindowMergeFunction(self)
def test_from_java_type(self): basic_int_type_info = Types.INT() self.assertEqual(basic_int_type_info, _from_java_type(basic_int_type_info.get_java_type_info())) basic_short_type_info = Types.SHORT() self.assertEqual(basic_short_type_info, _from_java_type(basic_short_type_info.get_java_type_info())) basic_long_type_info = Types.LONG() self.assertEqual(basic_long_type_info, _from_java_type(basic_long_type_info.get_java_type_info())) basic_float_type_info = Types.FLOAT() self.assertEqual(basic_float_type_info, _from_java_type(basic_float_type_info.get_java_type_info())) basic_double_type_info = Types.DOUBLE() self.assertEqual(basic_double_type_info, _from_java_type(basic_double_type_info.get_java_type_info())) basic_char_type_info = Types.CHAR() self.assertEqual(basic_char_type_info, _from_java_type(basic_char_type_info.get_java_type_info())) basic_byte_type_info = Types.BYTE() self.assertEqual(basic_byte_type_info, _from_java_type(basic_byte_type_info.get_java_type_info())) basic_big_int_type_info = Types.BIG_INT() self.assertEqual(basic_big_int_type_info, _from_java_type(basic_big_int_type_info.get_java_type_info())) basic_big_dec_type_info = Types.BIG_DEC() self.assertEqual(basic_big_dec_type_info, _from_java_type(basic_big_dec_type_info.get_java_type_info())) basic_sql_date_type_info = Types.SQL_DATE() self.assertEqual(basic_sql_date_type_info, _from_java_type(basic_sql_date_type_info.get_java_type_info())) basic_sql_time_type_info = Types.SQL_TIME() self.assertEqual(basic_sql_time_type_info, _from_java_type(basic_sql_time_type_info.get_java_type_info())) basic_sql_timestamp_type_info = Types.SQL_TIMESTAMP() self.assertEqual(basic_sql_timestamp_type_info, _from_java_type(basic_sql_timestamp_type_info.get_java_type_info())) row_type_info = Types.ROW([Types.INT(), Types.STRING()]) self.assertEqual(row_type_info, _from_java_type(row_type_info.get_java_type_info())) tuple_type_info = Types.TUPLE([Types.CHAR(), Types.INT()]) self.assertEqual(tuple_type_info, _from_java_type(tuple_type_info.get_java_type_info())) primitive_int_array_type_info = Types.PRIMITIVE_ARRAY(Types.INT()) self.assertEqual(primitive_int_array_type_info, _from_java_type(primitive_int_array_type_info.get_java_type_info())) object_array_type_info = Types.OBJECT_ARRAY(Types.SQL_DATE()) self.assertEqual(object_array_type_info, _from_java_type(object_array_type_info.get_java_type_info())) pickled_byte_array_type_info = Types.PICKLED_BYTE_ARRAY() self.assertEqual(pickled_byte_array_type_info, _from_java_type(pickled_byte_array_type_info.get_java_type_info())) sql_date_type_info = Types.SQL_DATE() self.assertEqual(sql_date_type_info, _from_java_type(sql_date_type_info.get_java_type_info())) map_type_info = Types.MAP(Types.INT(), Types.STRING()) self.assertEqual(map_type_info, _from_java_type(map_type_info.get_java_type_info())) list_type_info = Types.LIST(Types.INT()) self.assertEqual(list_type_info, _from_java_type(list_type_info.get_java_type_info()))
def open(self, ctx: Context[Any, CountWindow]): count_descriptor = ValueStateDescriptor('slide-count-assigner', Types.PICKLED_BYTE_ARRAY()) self._count = ctx.get_partitioned_state(count_descriptor)
def __init__(self, count_elements: int): self._count_elements = count_elements self._count_state_desc = ValueStateDescriptor( "trigger-count-%s" % count_elements, Types.PICKLED_BYTE_ARRAY()) self._ctx = None # type: TriggerContext