Beispiel #1
0
def convert_to_python_obj(data, type_info):
    if type_info == Types.PICKLED_BYTE_ARRAY():
        return pickle.loads(data)
    elif isinstance(type_info, ExternalTypeInfo):
        return convert_to_python_obj(data, type_info._type_info)
    else:
        gateway = get_gateway()
        pickle_bytes = gateway.jvm.PythonBridgeUtils. \
            getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
        if isinstance(type_info, RowTypeInfo) or isinstance(
                type_info, TupleTypeInfo):
            field_data = zip(list(pickle_bytes[1:]),
                             type_info.get_field_types())
            fields = []
            for data, field_type in field_data:
                if len(data) == 0:
                    fields.append(None)
                else:
                    fields.append(
                        pickled_bytes_to_python_converter(data, field_type))
            if isinstance(type_info, RowTypeInfo):
                return Row.of_kind(
                    RowKind(int.from_bytes(pickle_bytes[0], 'little')),
                    *fields)
            else:
                return tuple(fields)
        else:
            return pickled_bytes_to_python_converter(pickle_bytes, type_info)
Beispiel #2
0
    def key_by(self, key_selector: Union[Callable, KeySelector],
               key_type_info: TypeInformation = None) -> 'KeyedStream':
        """
        Creates a new KeyedStream that uses the provided key for partitioning its operator states.

        :param key_selector: The KeySelector to be used for extracting the key for partitioning.
        :param key_type_info: The type information describing the key type.
        :return: The DataStream with partitioned state(i.e. KeyedStream).
        """
        if callable(key_selector):
            key_selector = KeySelectorFunctionWrapper(key_selector)
        if not isinstance(key_selector, (KeySelector, KeySelectorFunctionWrapper)):
            raise TypeError("Parameter key_selector should be a type of KeySelector.")

        gateway = get_gateway()
        PickledKeySelector = gateway.jvm \
            .org.apache.flink.datastream.runtime.functions.python.PickledKeySelector
        j_output_type_info = self._j_data_stream.getTransformation().getOutputType()
        output_type_info = typeinfo._from_java_type(j_output_type_info)
        is_key_pickled_byte_array = False
        if key_type_info is None:
            key_type_info = Types.PICKLED_BYTE_ARRAY()
            is_key_pickled_byte_array = True

        intermediate_map_stream = self.map(lambda x: (key_selector.get_key(x), x),
                                           type_info=Types.ROW([key_type_info, output_type_info]))
        intermediate_map_stream.name(gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
                                     .STREAM_KEY_BY_MAP_OPERATOR_NAME)
        generated_key_stream = KeyedStream(intermediate_map_stream._j_data_stream
                                           .keyBy(PickledKeySelector(is_key_pickled_byte_array),
                                                  key_type_info.get_java_type_info()), self)
        generated_key_stream._original_data_type_info = output_type_info
        return generated_key_stream
Beispiel #3
0
 def collect(self, data_stream: DataStream):
     gateway = get_gateway()
     self._is_python_objects = data_stream.get_type(
     ) == Types.PICKLED_BYTE_ARRAY()
     self._j_data_stream_test_collect_sink = gateway.jvm \
         .org.apache.flink.python.util.DataStreamTestCollectSink(self._is_python_objects)
     data_stream._j_data_stream.addSink(
         self._j_data_stream_test_collect_sink)
    def _from_collection(self,
                         elements: List[Any],
                         type_info: TypeInformation = None) -> DataStream:
        temp_file = tempfile.NamedTemporaryFile(delete=False,
                                                dir=tempfile.mkdtemp())
        serializer = self.serializer
        try:
            with temp_file:
                # dumps elements to a temporary file by pickle serializer.
                serializer.serialize(elements, temp_file)
            gateway = get_gateway()
            # if user does not defined the element data types, read the pickled data as a byte array
            # list.
            if type_info is None:
                j_objs = gateway.jvm.PythonBridgeUtils.readPickledBytes(
                    temp_file.name)
                out_put_type_info = Types.PICKLED_BYTE_ARRAY(
                )  # type: TypeInformation
            else:
                j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects(
                    temp_file.name)
                out_put_type_info = type_info
            # Since flink python module depends on table module, we can make use of utils of it when
            # implementing python DataStream API.
            PythonTableUtils = gateway.jvm\
                .org.apache.flink.table.planner.utils.python.PythonTableUtils
            execution_config = self._j_stream_execution_environment.getConfig()
            j_input_format = PythonTableUtils.getCollectionInputFormat(
                j_objs, out_put_type_info.get_java_type_info(),
                execution_config)

            JInputFormatSourceFunction = gateway.jvm.org.apache.flink.streaming.api.functions.\
                source.InputFormatSourceFunction
            JBoundedness = gateway.jvm.org.apache.flink.api.connector.source.Boundedness

            j_data_stream_source = invoke_method(
                self._j_stream_execution_environment,
                "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment",
                "addSource", [
                    JInputFormatSourceFunction(
                        j_input_format,
                        out_put_type_info.get_java_type_info()),
                    "Collection Source",
                    out_put_type_info.get_java_type_info(),
                    JBoundedness.BOUNDED
                ], [
                    "org.apache.flink.streaming.api.functions.source.SourceFunction",
                    "java.lang.String",
                    "org.apache.flink.api.common.typeinfo.TypeInformation",
                    "org.apache.flink.api.connector.source.Boundedness"
                ])
            j_data_stream_source.forceNonParallel()
            return DataStream(j_data_stream=j_data_stream_source)
        finally:
            os.unlink(temp_file.name)
    def test_from_collection_with_data_types(self):
        # verify from_collection for the collection with single object.
        ds = self.env.from_collection(['Hi', 'Hello'], type_info=Types.STRING())
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection with single object")
        results = self.test_sink.get_results(False)
        expected = ['Hello', 'Hi']
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

        # verify from_collection for the collection with multiple objects like tuple.
        ds = self.env.from_collection([(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
                                        bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
                                        datetime.time(hour=12, minute=0, second=0,
                                                      microsecond=123000),
                                        datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [1, 2, 3],
                                        decimal.Decimal('1000000000000000000.05'),
                                        decimal.Decimal('1000000000000000000.0599999999999'
                                                        '9999899999999999')),
                                       (2, None, 2, True, 43878, 9147483648, 9.87, 2.98936,
                                        bytearray(b'flink'), 'pyflink', datetime.date(2015, 10, 14),
                                        datetime.time(hour=11, minute=2, second=2,
                                                      microsecond=234500),
                                        datetime.datetime(2020, 4, 15, 8, 2, 6, 235000), [2, 4, 6],
                                        decimal.Decimal('2000000000000000000.74'),
                                        decimal.Decimal('2000000000000000000.061111111111111'
                                                        '11111111111111'))],
                                      type_info=Types.ROW(
                                          [Types.LONG(), Types.LONG(), Types.SHORT(),
                                           Types.BOOLEAN(), Types.SHORT(), Types.INT(),
                                           Types.FLOAT(), Types.DOUBLE(),
                                           Types.PICKLED_BYTE_ARRAY(),
                                           Types.STRING(), Types.SQL_DATE(), Types.SQL_TIME(),
                                           Types.SQL_TIMESTAMP(),
                                           Types.BASIC_ARRAY(Types.LONG()), Types.BIG_DEC(),
                                           Types.BIG_DEC()]))
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection with tuple object")
        results = self.test_sink.get_results(False)
        # if user specifies data types of input data, the collected result should be in row format.
        expected = [
            '+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, [102, 108, 105, 110, 107], '
            'pyflink, 2014-09-13, 12:00:00, 2018-03-11 03:00:00.123, [1, 2, 3], '
            '1000000000000000000.05, 1000000000000000000.05999999999999999899999999999]',
            '+I[2, null, 2, true, -21658, 557549056, 9.87, 2.98936, [102, 108, 105, 110, 107], '
            'pyflink, 2015-10-14, 11:02:02, 2020-04-15 08:02:06.235, [2, 4, 6], '
            '2000000000000000000.74, 2000000000000000000.06111111111111111111111111111]']
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)
Beispiel #6
0
 def __init__(self,
              tag_id: str,
              type_info: Optional[Union[TypeInformation, list]] = None):
     if not tag_id:
         raise ValueError("OutputTag tag_id cannot be None or empty string")
     self.tag_id = tag_id
     if type_info is None:
         self.type_info = Types.PICKLED_BYTE_ARRAY()
     elif isinstance(type_info, list):
         self.type_info = RowTypeInfo(type_info)
     elif not isinstance(type_info, TypeInformation):
         raise TypeError(
             "OutputTag type_info must be None, list or TypeInformation")
     else:
         self.type_info = type_info
Beispiel #7
0
    def _from_collection(self,
                         elements: List[Any],
                         type_info: TypeInformation = None) -> DataStream:
        temp_file = tempfile.NamedTemporaryFile(delete=False,
                                                dir=tempfile.mkdtemp())
        serializer = self.serializer
        try:
            with temp_file:
                # dumps elements to a temporary file by pickle serializer.
                serializer.dump_to_stream(elements, temp_file)
            gateway = get_gateway()
            # if user does not defined the element data types, read the pickled data as a byte array
            # list.
            if type_info is None:
                j_objs = gateway.jvm.PythonBridgeUtils.readPickledBytes(
                    temp_file.name)
                out_put_type_info = Types.PICKLED_BYTE_ARRAY(
                )  # type: TypeInformation
            else:
                j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects(
                    temp_file.name)
                out_put_type_info = type_info
            # Since flink python module depends on table module, we can make use of utils of it when
            # implementing python DataStream API.
            PythonTableUtils = gateway.jvm\
                .org.apache.flink.table.planner.utils.python.PythonTableUtils
            execution_config = self._j_stream_execution_environment.getConfig()
            j_input_format = PythonTableUtils.getCollectionInputFormat(
                j_objs, out_put_type_info.get_java_type_info(),
                execution_config)

            j_data_stream_source = self._j_stream_execution_environment.createInput(
                j_input_format, out_put_type_info.get_java_type_info())
            j_data_stream_source.forceNonParallel()
            return DataStream(j_data_stream=j_data_stream_source)
        finally:
            os.unlink(temp_file.name)
Beispiel #8
0
    def open(self, runtime_context: RuntimeContext,
             internal_timer_service: InternalTimerService):
        self.window_function.open(runtime_context)

        self.num_late_records_dropped = runtime_context.get_metrics_group(
        ).counter(self.LATE_ELEMENTS_DROPPED_METRIC_NAME)
        self.internal_timer_service = internal_timer_service
        self.trigger_context = Context(runtime_context, internal_timer_service,
                                       self.trigger)
        self.process_context = WindowContext(self.window_assigner,
                                             runtime_context,
                                             self.window_function,
                                             self.internal_timer_service)
        self.window_assigner_context = WindowAssignerContext(
            self.internal_timer_service, runtime_context)

        # create (or restore) the state that hold the actual window contents
        # NOTE - the state may be null in the case of the overriding evicting window operator
        if self.window_state_descriptor is not None:
            self.window_state = get_or_create_keyed_state(
                runtime_context, self.window_state_descriptor)

        if isinstance(self.window_assigner, MergingWindowAssigner):
            if isinstance(self.window_state, InternalMergingState):
                self.window_merging_state = self.window_state

            # TODO: the type info is just a placeholder currently.
            # it should be the real type serializer after supporting the user-defined state type
            # serializer
            merging_sets_state_descriptor = ListStateDescriptor(
                "merging-window-set", Types.PICKLED_BYTE_ARRAY())

            self.merging_sets_state = get_or_create_keyed_state(
                runtime_context, merging_sets_state_descriptor)

        self.merge_function = WindowMergeFunction(self)
Beispiel #9
0
    def test_from_java_type(self):
        basic_int_type_info = Types.INT()
        self.assertEqual(basic_int_type_info,
                         _from_java_type(basic_int_type_info.get_java_type_info()))

        basic_short_type_info = Types.SHORT()
        self.assertEqual(basic_short_type_info,
                         _from_java_type(basic_short_type_info.get_java_type_info()))

        basic_long_type_info = Types.LONG()
        self.assertEqual(basic_long_type_info,
                         _from_java_type(basic_long_type_info.get_java_type_info()))

        basic_float_type_info = Types.FLOAT()
        self.assertEqual(basic_float_type_info,
                         _from_java_type(basic_float_type_info.get_java_type_info()))

        basic_double_type_info = Types.DOUBLE()
        self.assertEqual(basic_double_type_info,
                         _from_java_type(basic_double_type_info.get_java_type_info()))

        basic_char_type_info = Types.CHAR()
        self.assertEqual(basic_char_type_info,
                         _from_java_type(basic_char_type_info.get_java_type_info()))

        basic_byte_type_info = Types.BYTE()
        self.assertEqual(basic_byte_type_info,
                         _from_java_type(basic_byte_type_info.get_java_type_info()))

        basic_big_int_type_info = Types.BIG_INT()
        self.assertEqual(basic_big_int_type_info,
                         _from_java_type(basic_big_int_type_info.get_java_type_info()))

        basic_big_dec_type_info = Types.BIG_DEC()
        self.assertEqual(basic_big_dec_type_info,
                         _from_java_type(basic_big_dec_type_info.get_java_type_info()))

        basic_sql_date_type_info = Types.SQL_DATE()
        self.assertEqual(basic_sql_date_type_info,
                         _from_java_type(basic_sql_date_type_info.get_java_type_info()))

        basic_sql_time_type_info = Types.SQL_TIME()
        self.assertEqual(basic_sql_time_type_info,
                         _from_java_type(basic_sql_time_type_info.get_java_type_info()))

        basic_sql_timestamp_type_info = Types.SQL_TIMESTAMP()
        self.assertEqual(basic_sql_timestamp_type_info,
                         _from_java_type(basic_sql_timestamp_type_info.get_java_type_info()))

        row_type_info = Types.ROW([Types.INT(), Types.STRING()])
        self.assertEqual(row_type_info, _from_java_type(row_type_info.get_java_type_info()))

        tuple_type_info = Types.TUPLE([Types.CHAR(), Types.INT()])
        self.assertEqual(tuple_type_info, _from_java_type(tuple_type_info.get_java_type_info()))

        primitive_int_array_type_info = Types.PRIMITIVE_ARRAY(Types.INT())
        self.assertEqual(primitive_int_array_type_info,
                         _from_java_type(primitive_int_array_type_info.get_java_type_info()))

        object_array_type_info = Types.OBJECT_ARRAY(Types.SQL_DATE())
        self.assertEqual(object_array_type_info,
                         _from_java_type(object_array_type_info.get_java_type_info()))

        pickled_byte_array_type_info = Types.PICKLED_BYTE_ARRAY()
        self.assertEqual(pickled_byte_array_type_info,
                         _from_java_type(pickled_byte_array_type_info.get_java_type_info()))

        sql_date_type_info = Types.SQL_DATE()
        self.assertEqual(sql_date_type_info,
                         _from_java_type(sql_date_type_info.get_java_type_info()))

        map_type_info = Types.MAP(Types.INT(), Types.STRING())
        self.assertEqual(map_type_info,
                         _from_java_type(map_type_info.get_java_type_info()))

        list_type_info = Types.LIST(Types.INT())
        self.assertEqual(list_type_info,
                         _from_java_type(list_type_info.get_java_type_info()))
Beispiel #10
0
 def open(self, ctx: Context[Any, CountWindow]):
     count_descriptor = ValueStateDescriptor('slide-count-assigner', Types.PICKLED_BYTE_ARRAY())
     self._count = ctx.get_partitioned_state(count_descriptor)
Beispiel #11
0
 def __init__(self, count_elements: int):
     self._count_elements = count_elements
     self._count_state_desc = ValueStateDescriptor(
         "trigger-count-%s" % count_elements, Types.PICKLED_BYTE_ARRAY())
     self._ctx = None  # type: TriggerContext