Example #1
0
    def test_get_set_db_storage_paths(self):
        if on_windows():
            checkpoints_path = "file:/C:/var/checkpoints/"
            storage_path = [
                "file:/C:/var/db_storage_dir1/",
                "file:/C:/var/db_storage_dir2/",
                "file:/C:/var/db_storage_dir3/"
            ]
            expected = [
                "C:\\var\\db_storage_dir1", "C:\\var\\db_storage_dir2",
                "C:\\var\\db_storage_dir3"
            ]
        else:
            checkpoints_path = "file://var/checkpoints/"
            storage_path = [
                "file://var/db_storage_dir1/", "file://var/db_storage_dir2/",
                "file://var/db_storage_dir3/"
            ]
            expected = [
                "/db_storage_dir1", "/db_storage_dir2", "/db_storage_dir3"
            ]

        state_backend = RocksDBStateBackend(checkpoints_path)
        state_backend.set_db_storage_paths(*storage_path)
        self.assertEqual(state_backend.get_db_storage_paths(), expected)
Example #2
0
    def setUp(self):
        provision_info = json_format.Parse('{"retrievalToken": "test_token"}',
                                           ProvisionInfo())
        response = GetProvisionInfoResponse(info=provision_info)

        def get_unused_port():
            sock = socket.socket()
            sock.bind(('', 0))
            port = sock.getsockname()[1]
            sock.close()
            return port

        class ProvisionService(ProvisionServiceServicer):
            def GetProvisionInfo(self, request, context):
                return response

        def start_test_provision_server():
            server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
            add_ProvisionServiceServicer_to_server(ProvisionService(), server)
            port = get_unused_port()
            server.add_insecure_port('[::]:' + str(port))
            server.start()
            return server, port

        self.provision_server, self.provision_port = start_test_provision_server(
        )

        self.env = dict(os.environ)
        self.env["python"] = sys.executable
        self.env["FLINK_BOOT_TESTING"] = "1"
        self.env["BOOT_LOG_DIR"] = os.path.join(self.env["FLINK_HOME"], "log")

        self.tmp_dir = tempfile.mkdtemp(str(time.time()), dir=self.tempdir)
        # assume that this file is in flink-python source code directory.
        flink_python_source_root = os.path.dirname(
            os.path.dirname(
                os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
        runner_script = "pyflink-udf-runner.bat" if on_windows() else \
            "pyflink-udf-runner.sh"
        self.runner_path = os.path.join(flink_python_source_root, "bin",
                                        runner_script)
class PythonBootTests(PyFlinkTestCase):

    def setUp(self):
        provision_info = json_format.Parse('{"retrievalToken": "test_token"}', ProvisionInfo())
        response = GetProvisionInfoResponse(info=provision_info)

        def get_unused_port():
            sock = socket.socket()
            sock.bind(('', 0))
            port = sock.getsockname()[1]
            sock.close()
            return port

        class ProvisionService(ProvisionServiceServicer):
            def GetProvisionInfo(self, request, context):
                return response

        def start_test_provision_server():
            server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
            add_ProvisionServiceServicer_to_server(ProvisionService(), server)
            port = get_unused_port()
            server.add_insecure_port('[::]:' + str(port))
            server.start()
            return server, port

        self.provision_server, self.provision_port = start_test_provision_server()

        self.env = dict(os.environ)
        self.env["python"] = sys.executable
        self.env["FLINK_BOOT_TESTING"] = "1"
        self.env["BOOT_LOG_DIR"] = os.path.join(self.env["FLINK_HOME"], "log")

        self.tmp_dir = tempfile.mkdtemp(str(time.time()), dir=self.tempdir)
        # assume that this file is in flink-python source code directory.
        flink_python_source_root = os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
        runner_script = "pyflink-udf-runner.bat" if on_windows() else \
            "pyflink-udf-runner.sh"
        self.runner_path = os.path.join(
            flink_python_source_root, "bin", runner_script)

    def run_boot_py(self):
        args = [self.runner_path, "--id", "1",
                "--logging_endpoint", "localhost:0000",
                "--artifact_endpoint", "whatever",
                "--provision_endpoint", "localhost:%d" % self.provision_port,
                "--control_endpoint", "localhost:0000",
                "--semi_persist_dir", self.tmp_dir]

        return subprocess.call(args, env=self.env)

    def test_python_boot(self):
        exit_code = self.run_boot_py()
        self.assertTrue(exit_code == 0, "the boot.py exited with non-zero code.")

    @unittest.skipIf(on_windows(), "'subprocess.check_output' in Windows always return empty "
                                   "string, skip this test.")
    def test_param_validation(self):
        args = [self.runner_path]
        exit_message = subprocess.check_output(args, env=self.env).decode("utf-8")
        self.assertIn("No id provided.", exit_message)

        args = [self.runner_path, "--id", "1"]
        exit_message = subprocess.check_output(args, env=self.env).decode("utf-8")
        self.assertIn("No logging endpoint provided.", exit_message)

        args = [self.runner_path, "--id", "1",
                "--logging_endpoint", "localhost:0000"]
        exit_message = subprocess.check_output(args, env=self.env).decode("utf-8")
        self.assertIn("No provision endpoint provided.", exit_message)

        args = [self.runner_path, "--id", "1",
                "--logging_endpoint", "localhost:0000",
                "--provision_endpoint", "localhost:%d" % self.provision_port]
        exit_message = subprocess.check_output(args, env=self.env).decode("utf-8")
        self.assertIn("No control endpoint provided.", exit_message)

    def test_set_working_directory(self):
        JProcessPythonEnvironmentManager = \
            get_gateway().jvm.org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager

        output_file = os.path.join(self.tmp_dir, "output.txt")
        pyflink_dir = os.path.join(self.tmp_dir, "pyflink")
        os.mkdir(pyflink_dir)
        # just create an empty file
        open(os.path.join(pyflink_dir, "__init__.py"), 'a').close()
        fn_execution_dir = os.path.join(pyflink_dir, "fn_execution")
        os.mkdir(fn_execution_dir)
        open(os.path.join(fn_execution_dir, "__init__.py"), 'a').close()
        with open(os.path.join(fn_execution_dir, "boot.py"), "w") as f:
            f.write("import os\nwith open(r'%s', 'w') as f:\n    f.write(os.getcwd())" %
                    output_file)

        # test if the name of working directory variable of udf runner is consist with
        # ProcessPythonEnvironmentManager.
        self.env[JProcessPythonEnvironmentManager.PYTHON_WORKING_DIR] = self.tmp_dir
        self.env["python"] = sys.executable
        args = [self.runner_path]
        subprocess.check_output(args, env=self.env)
        process_cwd = None
        if os.path.exists(output_file):
            with open(output_file, 'r') as f:
                process_cwd = f.read()

        self.assertEqual(os.path.realpath(self.tmp_dir),
                         process_cwd,
                         "setting working directory variable is not work!")

    def tearDown(self):
        self.provision_server.stop(0)
        try:
            if self.tmp_dir is not None:
                shutil.rmtree(self.tmp_dir)
        except:
            pass
Example #4
0
class TypesTests(PyFlinkTestCase):
    def test_infer_schema(self):
        from decimal import Decimal

        class A(object):
            def __init__(self):
                self.a = 1

        from collections import namedtuple
        Point = namedtuple('Point', 'x y')

        data = [
            True,
            1,
            "a",
            u"a",
            datetime.date(1970, 1, 1),
            datetime.time(0, 0, 0),
            datetime.datetime(1970, 1, 1, 0, 0),
            1.0,
            array.array("d", [1]),
            [1],
            (1, ),
            Point(1.0, 5.0),
            {
                "a": 1
            },
            bytearray(1),
            Decimal(1),
            Row(a=1),
            Row("a")(1),
            A(),
        ]

        expected = [
            'BooleanType(true)',
            'BigIntType(true)',
            'VarCharType(2147483647, true)',
            'VarCharType(2147483647, true)',
            'DateType(true)',
            'TimeType(0, true)',
            'LocalZonedTimestampType(6, true)',
            'DoubleType(true)',
            "ArrayType(DoubleType(false), true)",
            "ArrayType(BigIntType(true), true)",
            'RowType(RowField(_1, BigIntType(true), ...))',
            'RowType(RowField(x, DoubleType(true), ...),RowField(y, DoubleType(true), ...))',
            'MapType(VarCharType(2147483647, false), BigIntType(true), true)',
            'VarBinaryType(2147483647, true)',
            'DecimalType(38, 18, true)',
            'RowType(RowField(a, BigIntType(true), ...))',
            'RowType(RowField(a, BigIntType(true), ...))',
            'RowType(RowField(a, BigIntType(true), ...))',
        ]

        schema = _infer_schema_from_data([data])
        self.assertEqual(expected, [repr(f.data_type) for f in schema.fields])

    def test_infer_schema_nulltype(self):
        elements = [
            Row(c1=[], c2={}, c3=None),
            Row(c1=[Row(a=1, b='s')], c2={"key": Row(c=1.0, d="2")}, c3="")
        ]
        schema = _infer_schema_from_data(elements)
        self.assertTrue(isinstance(schema, RowType))
        self.assertEqual(3, len(schema.fields))

        # first column is array
        self.assertTrue(isinstance(schema.fields[0].data_type, ArrayType))

        # element type of first column is struct
        self.assertTrue(
            isinstance(schema.fields[0].data_type.element_type, RowType))

        self.assertTrue(
            isinstance(
                schema.fields[0].data_type.element_type.fields[0].data_type,
                BigIntType))
        self.assertTrue(
            isinstance(
                schema.fields[0].data_type.element_type.fields[1].data_type,
                VarCharType))

        # second column is map
        self.assertTrue(isinstance(schema.fields[1].data_type, MapType))
        self.assertTrue(
            isinstance(schema.fields[1].data_type.key_type, VarCharType))
        self.assertTrue(
            isinstance(schema.fields[1].data_type.value_type, RowType))

        # third column is varchar
        self.assertTrue(isinstance(schema.fields[2].data_type, VarCharType))

    def test_infer_schema_not_enough_names(self):
        schema = _infer_schema_from_data([["a", "b"]], ["col1"])
        self.assertTrue(schema.names, ['col1', '_2'])

    def test_infer_schema_fails(self):
        with self.assertRaises(TypeError):
            _infer_schema_from_data([[1, 1], ["x", 1]], names=["a", "b"])

    def test_infer_nested_schema(self):
        NestedRow = Row("f1", "f2")
        data1 = [
            NestedRow([1, 2], {"row1": 1.0}),
            NestedRow([2, 3], {"row2": 2.0})
        ]
        schema1 = _infer_schema_from_data(data1)
        expected1 = [
            'ArrayType(BigIntType(true), true)',
            'MapType(VarCharType(2147483647, false), DoubleType(true), true)'
        ]
        self.assertEqual(expected1,
                         [repr(f.data_type) for f in schema1.fields])

        data2 = [
            NestedRow([[1, 2], [2, 3]], [1, 2]),
            NestedRow([[2, 3], [3, 4]], [2, 3])
        ]
        schema2 = _infer_schema_from_data(data2)
        expected2 = [
            'ArrayType(ArrayType(BigIntType(true), true), true)',
            'ArrayType(BigIntType(true), true)'
        ]
        self.assertEqual(expected2,
                         [repr(f.data_type) for f in schema2.fields])

    def test_convert_row_to_dict(self):
        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
        self.assertEqual(1, row.as_dict()['l'][0].a)
        self.assertEqual(1.0, row.as_dict()['d']['key'].c)

    def test_udt(self):
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _create_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
        self.assertRaises(
            ValueError, lambda: _create_type_verifier(ExamplePointUDT())
            ([1.0, 2.0]))

        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _create_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
        self.assertRaises(
            ValueError, lambda: _create_type_verifier(PythonOnlyUDT())
            ([1.0, 2.0]))

    def test_nested_udt_in_df(self):
        expected_schema = DataTypes.ROW() \
            .add("_1", DataTypes.BIGINT()).add("_2", DataTypes.ARRAY(PythonOnlyUDT()))
        data = (1, [PythonOnlyPoint(float(1), float(2))])
        self.assertEqual(expected_schema, _infer_type(data))

        expected_schema = DataTypes.ROW().add("_1", DataTypes.BIGINT()).add(
            "_2", DataTypes.MAP(DataTypes.BIGINT(False), PythonOnlyUDT()))
        p = (1, {1: PythonOnlyPoint(1, float(2))})
        self.assertEqual(expected_schema, _infer_type(p))

    def test_struct_type(self):
        row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
            .add("f2", DataTypes.STRING(nullable=True))
        row2 = DataTypes.ROW([
            DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)),
            DataTypes.FIELD("f2", DataTypes.STRING(nullable=True), None)
        ])
        self.assertEqual(row1.field_names(), row2.names)
        self.assertEqual(row1, row2)

        row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
            .add("f2", DataTypes.STRING(nullable=True))
        row2 = DataTypes.ROW(
            [DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))])
        self.assertNotEqual(row1.field_names(), row2.names)
        self.assertNotEqual(row1, row2)

        row1 = (DataTypes.ROW().add(
            DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))).add(
                "f2", DataTypes.STRING(nullable=True)))
        row2 = DataTypes.ROW([
            DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)),
            DataTypes.FIELD("f2", DataTypes.STRING(nullable=True))
        ])
        self.assertEqual(row1.field_names(), row2.names)
        self.assertEqual(row1, row2)

        row1 = (DataTypes.ROW().add(
            DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))).add(
                "f2", DataTypes.STRING(nullable=True)))
        row2 = DataTypes.ROW(
            [DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))])
        self.assertNotEqual(row1.field_names(), row2.names)
        self.assertNotEqual(row1, row2)

        # Catch exception raised during improper construction
        self.assertRaises(ValueError, lambda: DataTypes.ROW().add("name"))

        row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
            .add("f2", DataTypes.STRING(nullable=True))
        for field in row1:
            self.assertIsInstance(field, RowField)

        row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
            .add("f2", DataTypes.STRING(nullable=True))
        self.assertEqual(len(row1), 2)

        row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
            .add("f2", DataTypes.STRING(nullable=True))
        self.assertIs(row1["f1"], row1.fields[0])
        self.assertIs(row1[0], row1.fields[0])
        self.assertEqual(row1[0:1], DataTypes.ROW(row1.fields[0:1]))
        self.assertRaises(KeyError, lambda: row1["f9"])
        self.assertRaises(IndexError, lambda: row1[9])
        self.assertRaises(TypeError, lambda: row1[9.9])

    def test_infer_bigint_type(self):
        longrow = [Row(f1='a', f2=100000000000000)]
        schema = _infer_schema_from_data(longrow)
        self.assertEqual(DataTypes.BIGINT(), schema.fields[1].data_type)
        self.assertEqual(DataTypes.BIGINT(), _infer_type(1))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**10))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**20))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31 - 1))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**61))
        self.assertEqual(DataTypes.BIGINT(), _infer_type(2**71))

    def test_merge_type(self):
        self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()),
                         DataTypes.BIGINT())
        self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()),
                         DataTypes.BIGINT())

        self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()),
                         DataTypes.BIGINT())

        self.assertEqual(
            _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()),
                        DataTypes.ARRAY(DataTypes.BIGINT())),
            DataTypes.ARRAY(DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()),
                        DataTypes.ARRAY(DataTypes.DOUBLE()))

        self.assertEqual(
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
            DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD('f1', DataTypes.BIGINT()),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.DOUBLE()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1',
                    DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())]))
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.STRING())]))
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1', DataTypes.MAP(DataTypes.STRING(),
                                        DataTypes.BIGINT())),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1',
                    DataTypes.ARRAY(
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.DOUBLE(),
                                          DataTypes.BIGINT())))
                ]))

    def test_array_types(self):
        # This test need to make sure that the Scala type selected is at least
        # as large as the python's types. This is necessary because python's
        # array types depend on C implementation on the machine. Therefore there
        # is no machine independent correspondence between python's array types
        # and Scala types.
        # See: https://docs.python.org/2/library/array.html

        def assert_collect_success(typecode, value, element_type):
            self.assertEqual(
                element_type,
                str(_infer_type(array.array(typecode, [value])).element_type))

        # supported string types
        #
        # String types in python's array are "u" for Py_UNICODE and "c" for char.
        # "u" will be removed in python 4, and "c" is not supported in python 3.
        supported_string_types = []
        if sys.version_info[0] < 4:
            supported_string_types += ['u']
            # test unicode
            assert_collect_success('u', u'a', 'CHAR')

        # supported float and double
        #
        # Test max, min, and precision for float and double, assuming IEEE 754
        # floating-point format.
        supported_fractional_types = ['f', 'd']
        assert_collect_success('f', ctypes.c_float(1e+38).value, 'FLOAT')
        assert_collect_success('f', ctypes.c_float(1e-38).value, 'FLOAT')
        assert_collect_success('f', ctypes.c_float(1.123456).value, 'FLOAT')
        assert_collect_success('d', sys.float_info.max, 'DOUBLE')
        assert_collect_success('d', sys.float_info.min, 'DOUBLE')
        assert_collect_success('d', sys.float_info.epsilon, 'DOUBLE')

        def get_int_data_type(size):
            if size <= 8:
                return "TINYINT"
            if size <= 16:
                return "SMALLINT"
            if size <= 32:
                return "INT"
            if size <= 64:
                return "BIGINT"

        # supported signed int types
        #
        # The size of C types changes with implementation, we need to make sure
        # that there is no overflow error on the platform running this test.
        supported_signed_int_types = list(
            set(_array_signed_int_typecode_ctype_mappings.keys()).intersection(
                set(_array_type_mappings.keys())))
        for t in supported_signed_int_types:
            ctype = _array_signed_int_typecode_ctype_mappings[t]
            max_val = 2**(ctypes.sizeof(ctype) * 8 - 1)
            assert_collect_success(t, max_val - 1,
                                   get_int_data_type(ctypes.sizeof(ctype) * 8))
            assert_collect_success(t, -max_val,
                                   get_int_data_type(ctypes.sizeof(ctype) * 8))

        # supported unsigned int types
        #
        # JVM does not have unsigned types. We need to be very careful to make
        # sure that there is no overflow error.
        supported_unsigned_int_types = list(
            set(_array_unsigned_int_typecode_ctype_mappings.keys()).
            intersection(set(_array_type_mappings.keys())))
        for t in supported_unsigned_int_types:
            ctype = _array_unsigned_int_typecode_ctype_mappings[t]
            max_val = 2**(ctypes.sizeof(ctype) * 8 - 1)
            assert_collect_success(
                t, max_val, get_int_data_type(ctypes.sizeof(ctype) * 8 + 1))

        # all supported types
        #
        # Make sure the types tested above:
        # 1. are all supported types
        # 2. cover all supported types
        supported_types = (supported_string_types +
                           supported_fractional_types +
                           supported_signed_int_types +
                           supported_unsigned_int_types)
        self.assertEqual(set(supported_types),
                         set(_array_type_mappings.keys()))

        # all unsupported types
        #
        # Keys in _array_type_mappings is a complete list of all supported types,
        # and types not in _array_type_mappings are considered unsupported.
        all_types = set(array.typecodes)
        unsupported_types = all_types - set(supported_types)
        # test unsupported types
        for t in unsupported_types:
            with self.assertRaises(TypeError):
                _infer_schema_from_data([Row(myarray=array.array(t))])

    def test_data_type_eq(self):
        lt = DataTypes.BIGINT()
        lt2 = pickle.loads(pickle.dumps(DataTypes.BIGINT()))
        self.assertEqual(lt, lt2)

    def test_decimal_type(self):
        t1 = DataTypes.DECIMAL(10, 0)
        t2 = DataTypes.DECIMAL(10, 2)
        self.assertTrue(t2 is not t1)
        self.assertNotEqual(t1, t2)

    def test_datetype_equal_zero(self):
        dt = DataTypes.DATE()
        self.assertEqual(dt.from_sql_type(0), datetime.date(1970, 1, 1))

    @unittest.skipIf(on_windows(),
                     "Windows x64 system only support the datetime not larger "
                     "than time.ctime(32536799999), so this test can't run "
                     "under Windows platform")
    def test_timestamp_microsecond(self):
        tst = DataTypes.TIMESTAMP()
        self.assertEqual(
            tst.to_sql_type(datetime.datetime.max) % 1000000, 999999)

    @unittest.skipIf(on_windows(),
                     "Windows x64 system only support the datetime not larger "
                     "than time.ctime(32536799999), so this test can't run "
                     "under Windows platform")
    def test_local_zoned_timestamp_type(self):
        lztst = DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE()
        last_abbreviation = DataTypes.TIMESTAMP_LTZ()
        self.assertEqual(lztst, last_abbreviation)

        ts = datetime.datetime(1970, 1, 1, 0, 0, 0, 0000)
        self.assertEqual(0, lztst.to_sql_type(ts))

        import pytz
        # suppose the timezone of the data is +9:00
        timezone = pytz.timezone("Asia/Tokyo")
        orig_epoch = LocalZonedTimestampType.EPOCH_ORDINAL
        try:
            # suppose the local timezone is +8:00
            LocalZonedTimestampType.EPOCH_ORDINAL = 28800000000
            ts_tokyo = timezone.localize(ts)
            self.assertEqual(-3600000000, lztst.to_sql_type(ts_tokyo))
        finally:
            LocalZonedTimestampType.EPOCH_ORDINAL = orig_epoch

        if sys.version_info >= (3, 6):
            ts2 = lztst.from_sql_type(0)
            self.assertEqual(ts.astimezone(), ts2.astimezone())

    def test_zoned_timestamp_type(self):
        ztst = ZonedTimestampType()
        ts = datetime.datetime(1970,
                               1,
                               1,
                               0,
                               0,
                               0,
                               0000,
                               tzinfo=UTCOffsetTimezone(1))
        self.assertEqual((0, 3600), ztst.to_sql_type(ts))

        ts2 = ztst.from_sql_type((0, 3600))
        self.assertEqual(ts, ts2)

    def test_day_time_inteval_type(self):
        ymt = DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND())
        td = datetime.timedelta(days=1, seconds=10)
        self.assertEqual(86410000000, ymt.to_sql_type(td))

        td2 = ymt.from_sql_type(86410000000)
        self.assertEqual(td, td2)

    def test_empty_row(self):
        row = Row()
        self.assertEqual(len(row), 0)

    def test_invalid_create_row(self):
        row_class = Row("c1", "c2")
        self.assertRaises(ValueError, lambda: row_class(1, 2, 3))

    def test_nullable(self):
        t = DataType(nullable=False)

        self.assertEqual(t._nullable, False)
        t_nullable = t.nullable()
        self.assertEqual(t_nullable._nullable, True)

    def test_not_null(self):
        t = DataType(nullable=True)

        self.assertEqual(t._nullable, True)
        t_notnull = t.not_null()
        self.assertEqual(t_notnull._nullable, False)
Example #5
0
class StreamDependencyTests(DependencyTests, PyFlinkStreamTableTestCase):
    def setUp(self):
        super(StreamDependencyTests, self).setUp()
        origin_execution_mode = os.environ['_python_worker_execution_mode']
        os.environ['_python_worker_execution_mode'] = "loopback"
        try:
            self.st_env = TableEnvironment.create(
                EnvironmentSettings.in_streaming_mode())
        finally:
            if origin_execution_mode is not None:
                os.environ[
                    '_python_worker_execution_mode'] = origin_execution_mode

    def test_set_requirements_without_cached_directory(self):
        requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("cloudpickle==1.2.2")
        self.st_env.set_python_requirements(requirements_txt_path)

        def check_requirements(i):
            import cloudpickle  # noqa # pylint: disable=unused-import
            assert '_PYTHON_REQUIREMENTS_INSTALL_DIR' in os.environ
            return i

        self.st_env.create_temporary_system_function(
            "check_requirements",
            udf(check_requirements, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.st_env.register_table_sink("Results", table_sink)
        t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select(expr.call('check_requirements', t.a),
                 t.a).execute_insert("Results").wait()

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"])

    def test_set_requirements_with_cached_directory(self):
        tmp_dir = self.tempdir
        requirements_txt_path = os.path.join(
            tmp_dir, "requirements_txt_" + str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("python-package1==0.0.0")

        requirements_dir_path = os.path.join(
            tmp_dir, "requirements_dir_" + str(uuid.uuid4()))
        os.mkdir(requirements_dir_path)
        package_file_name = "python-package1-0.0.0.tar.gz"
        with open(os.path.join(requirements_dir_path, package_file_name),
                  'wb') as f:
            import base64
            # This base64 data is encoded from a python package file which includes a
            # "python_package1" module. The module contains a "plus(a, b)" function.
            # The base64 can be recomputed by following code:
            # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8")
            f.write(
                base64.b64decode(
                    "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI"
                    "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX"
                    "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy"
                    "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ"
                    "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli"
                    "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp"
                    "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls"
                    "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p"
                    "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/"
                    "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP"
                    "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA"
                    "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA=="))
        self.st_env.set_python_requirements(requirements_txt_path,
                                            requirements_dir_path)

        def add_one(i):
            from python_package1 import plus
            return plus(i, 1)

        self.st_env.create_temporary_system_function(
            "add_one", udf(add_one, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.st_env.register_table_sink("Results", table_sink)
        t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select(expr.call('add_one', t.a),
                 t.a).execute_insert("Results").wait()

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["+I[2, 1]", "+I[3, 2]", "+I[4, 3]"])

    def test_add_python_archive(self):
        tmp_dir = self.tempdir
        archive_dir_path = os.path.join(tmp_dir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("2")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.t_env.add_python_archive(archive_file_path, "data")

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i + int(f.read())

        self.t_env.create_temporary_system_function(
            "add_from_file",
            udf(add_from_file, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.t_env.register_table_sink("Results", table_sink)
        t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select(expr.call('add_from_file', t.a),
                 t.a).execute_insert("Results").wait()

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["+I[3, 1]", "+I[4, 2]", "+I[5, 3]"])

    @unittest.skipIf(on_windows(),
                     "Symbolic link is not supported on Windows, skipping.")
    def test_set_environment(self):
        python_exec = sys.executable
        tmp_dir = self.tempdir
        python_exec_link_path = os.path.join(tmp_dir, "py_exec")
        os.symlink(python_exec, python_exec_link_path)
        self.st_env.get_config().set_python_executable(python_exec_link_path)

        def check_python_exec(i):
            import os
            assert os.environ["python"] == python_exec_link_path
            return i

        self.st_env.create_temporary_system_function(
            "check_python_exec",
            udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT()))

        def check_pyflink_gateway_disabled(i):
            from pyflink.java_gateway import get_gateway
            get_gateway()
            return i

        self.st_env.create_temporary_system_function(
            "check_pyflink_gateway_disabled",
            udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(),
                DataTypes.BIGINT()))

        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.st_env.register_table_sink("Results", table_sink)
        t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select(
            expr.call('check_python_exec', t.a),
            expr.call('check_pyflink_gateway_disabled', t.a)) \
            .execute_insert("Results").wait()

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"])
Example #6
0
class StreamExecutionEnvironmentTests(PyFlinkTestCase):
    def setUp(self):
        self.env = StreamExecutionEnvironment.get_execution_environment()
        self.env.set_parallelism(2)
        self.test_sink = DataStreamTestSinkFunction()

    def test_get_config(self):
        execution_config = self.env.get_config()

        self.assertIsInstance(execution_config, ExecutionConfig)

    def test_get_set_parallelism(self):
        self.env.set_parallelism(10)

        parallelism = self.env.get_parallelism()

        self.assertEqual(parallelism, 10)

    def test_get_set_buffer_timeout(self):
        self.env.set_buffer_timeout(12000)

        timeout = self.env.get_buffer_timeout()

        self.assertEqual(timeout, 12000)

    def test_get_set_default_local_parallelism(self):
        self.env.set_default_local_parallelism(8)

        parallelism = self.env.get_default_local_parallelism()

        self.assertEqual(parallelism, 8)

    def test_set_get_restart_strategy(self):
        self.env.set_restart_strategy(RestartStrategies.no_restart())

        restart_strategy = self.env.get_restart_strategy()

        self.assertEqual(restart_strategy, RestartStrategies.no_restart())

    def test_add_default_kryo_serializer(self):
        self.env.add_default_kryo_serializer(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
            "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer"
        )

        class_dict = self.env.get_config().get_default_kryo_serializer_classes(
        )

        self.assertEqual(
            class_dict, {
                'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
                'org.apache.flink.runtime.state'
                '.StateBackendTestBase$CustomKryoTestSerializer'
            })

    def test_register_type_with_kryo_serializer(self):
        self.env.register_type_with_kryo_serializer(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
            "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer"
        )

        class_dict = self.env.get_config(
        ).get_registered_types_with_kryo_serializer_classes()

        self.assertEqual(
            class_dict, {
                'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
                'org.apache.flink.runtime.state'
                '.StateBackendTestBase$CustomKryoTestSerializer'
            })

    def test_register_type(self):
        self.env.register_type(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo")

        type_list = self.env.get_config().get_registered_pojo_types()

        self.assertEqual(
            type_list,
            ['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo'])

    def test_get_set_max_parallelism(self):
        self.env.set_max_parallelism(12)

        parallelism = self.env.get_max_parallelism()

        self.assertEqual(parallelism, 12)

    def test_set_runtime_mode(self):
        self.env.set_runtime_mode(RuntimeExecutionMode.BATCH)

        config = invoke_java_object_method(
            self.env._j_stream_execution_environment, "getConfiguration")
        runtime_mode = config.getValue(get_gateway(
        ).jvm.org.apache.flink.configuration.ExecutionOptions.RUNTIME_MODE)

        self.assertEqual(runtime_mode, "BATCH")

    def test_operation_chaining(self):
        self.assertTrue(self.env.is_chaining_enabled())

        self.env.disable_operator_chaining()

        self.assertFalse(self.env.is_chaining_enabled())

    def test_get_checkpoint_config(self):
        checkpoint_config = self.env.get_checkpoint_config()

        self.assertIsInstance(checkpoint_config, CheckpointConfig)

    def test_get_set_checkpoint_interval(self):
        self.env.enable_checkpointing(30000)

        interval = self.env.get_checkpoint_interval()

        self.assertEqual(interval, 30000)

    def test_get_set_checkpointing_mode(self):
        mode = self.env.get_checkpointing_mode()
        self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE)

        self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE)

        mode = self.env.get_checkpointing_mode()

        self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE)

    def test_get_state_backend(self):
        state_backend = self.env.get_state_backend()

        self.assertIsNone(state_backend)

    def test_set_state_backend(self):
        input_backend = MemoryStateBackend()

        self.env.set_state_backend(input_backend)

        output_backend = self.env.get_state_backend()

        self.assertEqual(output_backend._j_memory_state_backend,
                         input_backend._j_memory_state_backend)

    def test_is_changelog_state_backend_enabled(self):
        self.assertIsNone(self.env.is_changelog_state_backend_enabled())

    def test_enable_changelog_state_backend(self):

        self.env.enable_changelog_state_backend(True)

        self.assertTrue(self.env.is_changelog_state_backend_enabled())

        self.env.enable_changelog_state_backend(False)

        self.assertFalse(self.env.is_changelog_state_backend_enabled())

    def test_get_set_stream_time_characteristic(self):
        default_time_characteristic = self.env.get_stream_time_characteristic()

        self.assertEqual(default_time_characteristic,
                         TimeCharacteristic.EventTime)

        self.env.set_stream_time_characteristic(
            TimeCharacteristic.ProcessingTime)

        time_characteristic = self.env.get_stream_time_characteristic()

        self.assertEqual(time_characteristic,
                         TimeCharacteristic.ProcessingTime)

    @unittest.skip(
        "Python API does not support DataStream now. refactor this test later")
    def test_get_execution_plan(self):
        tmp_dir = tempfile.gettempdir()
        source_path = os.path.join(tmp_dir + '/streaming.csv')
        tmp_csv = os.path.join(tmp_dir + '/streaming2.csv')
        field_names = ["a", "b", "c"]
        field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]

        t_env = StreamTableEnvironment.create(self.env)
        csv_source = CsvTableSource(source_path, field_names, field_types)
        t_env.register_table_source("Orders", csv_source)
        t_env.register_table_sink(
            "Results", CsvTableSink(field_names, field_types, tmp_csv))
        t_env.from_path("Orders").execute_insert("Results").wait()

        plan = self.env.get_execution_plan()

        json.loads(plan)

    def test_execute(self):
        tmp_dir = tempfile.gettempdir()
        field_names = ['a', 'b', 'c']
        field_types = [
            DataTypes.BIGINT(),
            DataTypes.STRING(),
            DataTypes.STRING()
        ]
        t_env = StreamTableEnvironment.create(self.env)
        t_env.register_table_sink(
            'Results',
            CsvTableSink(
                field_names, field_types,
                os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time())))))
        execution_result = exec_insert_table(
            t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']),
            'Results')
        self.assertIsNotNone(execution_result.get_job_id())
        self.assertIsNotNone(execution_result.get_net_runtime())
        self.assertEqual(len(execution_result.get_all_accumulator_results()),
                         0)
        self.assertIsNone(
            execution_result.get_accumulator_result('accumulator'))
        self.assertIsNotNone(str(execution_result))

    def test_from_collection_without_data_types(self):
        ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection")
        results = self.test_sink.get_results(True)
        # user does not specify data types for input data, the collected result should be in
        # in tuple format as inputs.
        expected = ["(1, 'Hi', 'Hello')", "(2, 'Hello', 'Hi')"]
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_from_collection_with_data_types(self):
        # verify from_collection for the collection with single object.
        ds = self.env.from_collection(['Hi', 'Hello'],
                                      type_info=Types.STRING())
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection with single object")
        results = self.test_sink.get_results(False)
        expected = ['Hello', 'Hi']
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

        # verify from_collection for the collection with multiple objects like tuple.
        ds = self.env.from_collection(
            [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
              bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
              datetime.time(hour=12, minute=0, second=0, microsecond=123000),
              datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [1, 2, 3],
              decimal.Decimal('1000000000000000000.05'),
              decimal.Decimal('1000000000000000000.0599999999999'
                              '9999899999999999')),
             (2, None, 2, True, 43878, 9147483648, 9.87, 2.98936,
              bytearray(b'flink'), 'pyflink', datetime.date(2015, 10, 14),
              datetime.time(hour=11, minute=2, second=2, microsecond=234500),
              datetime.datetime(2020, 4, 15, 8, 2, 6, 235000), [2, 4, 6],
              decimal.Decimal('2000000000000000000.74'),
              decimal.Decimal('2000000000000000000.061111111111111'
                              '11111111111111'))],
            type_info=Types.ROW([
                Types.LONG(),
                Types.LONG(),
                Types.SHORT(),
                Types.BOOLEAN(),
                Types.SHORT(),
                Types.INT(),
                Types.FLOAT(),
                Types.DOUBLE(),
                Types.PICKLED_BYTE_ARRAY(),
                Types.STRING(),
                Types.SQL_DATE(),
                Types.SQL_TIME(),
                Types.SQL_TIMESTAMP(),
                Types.BASIC_ARRAY(Types.LONG()),
                Types.BIG_DEC(),
                Types.BIG_DEC()
            ]))
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection with tuple object")
        results = self.test_sink.get_results(False)
        # if user specifies data types of input data, the collected result should be in row format.
        expected = [
            '+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, [102, 108, 105, 110, 107], '
            'pyflink, 2014-09-13, 12:00:00, 2018-03-11 03:00:00.123, [1, 2, 3], '
            '1000000000000000000.05, 1000000000000000000.05999999999999999899999999999]',
            '+I[2, null, 2, true, -21658, 557549056, 9.87, 2.98936, [102, 108, 105, 110, 107], '
            'pyflink, 2015-10-14, 11:02:02, 2020-04-15 08:02:06.235, [2, 4, 6], '
            '2000000000000000000.74, 2000000000000000000.06111111111111111111111111111]'
        ]
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_add_custom_source(self):
        custom_source = SourceFunction(
            "org.apache.flink.python.util.MyCustomSourceFunction")
        ds = self.env.add_source(custom_source,
                                 type_info=Types.ROW(
                                     [Types.INT(), Types.STRING()]))
        ds.add_sink(self.test_sink)
        self.env.execute("test add custom source")
        results = self.test_sink.get_results(False)
        expected = [
            '+I[3, Mike]', '+I[1, Marry]', '+I[4, Ted]', '+I[5, Jack]',
            '+I[0, Bob]', '+I[2, Henry]'
        ]
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_read_text_file(self):
        texts = ["Mike", "Marry", "Ted", "Jack", "Bob", "Henry"]
        text_file_path = self.tempdir + '/text_file'
        with open(text_file_path, 'a') as f:
            for text in texts:
                f.write(text)
                f.write('\n')

        ds = self.env.read_text_file(text_file_path)
        ds.add_sink(self.test_sink)
        self.env.execute("test read text file")
        results = self.test_sink.get_results()
        results.sort()
        texts.sort()
        self.assertEqual(texts, results)

    def test_execute_async(self):
        ds = self.env.from_collection(
            [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
            type_info=Types.ROW([Types.INT(),
                                 Types.STRING(),
                                 Types.STRING()]))
        ds.add_sink(self.test_sink)
        job_client = self.env.execute_async("test execute async")
        job_id = job_client.get_job_id()
        self.assertIsNotNone(job_id)
        execution_result = job_client.get_job_execution_result().result()
        self.assertEqual(str(job_id), str(execution_result.get_job_id()))

    def test_add_python_file(self):
        import uuid
        python_file_dir = os.path.join(self.tempdir,
                                       "python_file_dir_" + str(uuid.uuid4()))
        os.mkdir(python_file_dir)
        python_file_path = os.path.join(python_file_dir, "test_dep1.py")
        with open(python_file_path, 'w') as f:
            f.write("def add_two(a):\n    return a + 2")

        def plus_two_map(value):
            from test_dep1 import add_two
            return add_two(value)

        get_j_env_configuration(self.env._j_stream_execution_environment).\
            setString("taskmanager.numberOfTaskSlots", "10")
        self.env.add_python_file(python_file_path)
        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds = ds.map(plus_two_map, Types.LONG()) \
               .slot_sharing_group("data_stream") \
               .map(lambda i: i, Types.LONG()) \
               .slot_sharing_group("table")

        python_file_path = os.path.join(python_file_dir, "test_dep2.py")
        with open(python_file_path, 'w') as f:
            f.write("def add_three(a):\n    return a + 3")

        def plus_three(value):
            from test_dep2 import add_three
            return add_three(value)

        t_env = StreamTableEnvironment.create(
            stream_execution_environment=self.env,
            environment_settings=EnvironmentSettings.in_streaming_mode())
        self.env.add_python_file(python_file_path)

        from pyflink.table.udf import udf
        from pyflink.table.expressions import col
        add_three = udf(plus_three, result_type=DataTypes.BIGINT())

        tab = t_env.from_data_stream(ds, 'a') \
                   .select(add_three(col('a')))
        t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \
             .map(lambda i: i[0]) \
             .add_sink(self.test_sink)
        self.env.execute("test add_python_file")
        result = self.test_sink.get_results(True)
        expected = ['6', '7', '8', '9', '10']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_set_requirements_without_cached_directory(self):
        import uuid
        requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("cloudpickle==1.2.2")
        self.env.set_python_requirements(requirements_txt_path)

        def check_requirements(i):
            import cloudpickle
            assert os.path.abspath(cloudpickle.__file__).startswith(
                os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR'])
            return i

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(check_requirements).add_sink(self.test_sink)
        self.env.execute("test set requirements without cache dir")
        result = self.test_sink.get_results(True)
        expected = ['1', '2', '3', '4', '5']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_set_requirements_with_cached_directory(self):
        import uuid
        tmp_dir = self.tempdir
        requirements_txt_path = os.path.join(
            tmp_dir, "requirements_txt_" + str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("python-package1==0.0.0")

        requirements_dir_path = os.path.join(
            tmp_dir, "requirements_dir_" + str(uuid.uuid4()))
        os.mkdir(requirements_dir_path)
        package_file_name = "python-package1-0.0.0.tar.gz"
        with open(os.path.join(requirements_dir_path, package_file_name),
                  'wb') as f:
            import base64
            # This base64 data is encoded from a python package file which includes a
            # "python_package1" module. The module contains a "plus(a, b)" function.
            # The base64 can be recomputed by following code:
            # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8")
            f.write(
                base64.b64decode(
                    "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI"
                    "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX"
                    "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy"
                    "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ"
                    "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli"
                    "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp"
                    "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls"
                    "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p"
                    "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/"
                    "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP"
                    "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA"
                    "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA=="))
        self.env.set_python_requirements(requirements_txt_path,
                                         requirements_dir_path)

        def add_one(i):
            from python_package1 import plus
            return plus(i, 1)

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(add_one).add_sink(self.test_sink)
        self.env.execute("test set requirements with cachd dir")
        result = self.test_sink.get_results(True)
        expected = ['2', '3', '4', '5', '6']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_add_python_archive(self):
        import uuid
        import shutil
        tmp_dir = self.tempdir
        archive_dir_path = os.path.join(tmp_dir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("2")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.env.add_python_archive(archive_file_path, "data")

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i + int(f.read())

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(add_from_file).add_sink(self.test_sink)
        self.env.execute("test set python archive")
        result = self.test_sink.get_results(True)
        expected = ['3', '4', '5', '6', '7']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    @unittest.skipIf(on_windows(),
                     "Symbolic link is not supported on Windows, skipping.")
    def test_set_stream_env(self):
        import sys
        python_exec = sys.executable
        tmp_dir = self.tempdir
        python_exec_link_path = os.path.join(tmp_dir, "py_exec")
        os.symlink(python_exec, python_exec_link_path)
        self.env.set_python_executable(python_exec_link_path)

        def check_python_exec(i):
            import os
            assert os.environ["python"] == python_exec_link_path
            return i

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(check_python_exec).add_sink(self.test_sink)
        self.env.execute("test set python executable")
        result = self.test_sink.get_results(True)
        expected = ['1', '2', '3', '4', '5']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_add_jars(self):
        # find kafka connector jars
        flink_source_root = _find_flink_source_root()
        jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka'
        specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar')
        specific_jars = [
            'file://' + specific_jar for specific_jar in specific_jars
        ]

        self.env.add_jars(*specific_jars)
        source_topic = 'test_source_topic'
        props = {
            'bootstrap.servers': 'localhost:9092',
            'group.id': 'test_group'
        }
        type_info = Types.ROW([Types.INT(), Types.STRING()])

        # Test for kafka consumer
        deserialization_schema = JsonRowDeserializationSchema.builder() \
            .type_info(type_info=type_info).build()

        # Will get a ClassNotFoundException if not add the kafka connector into the pipeline jars.
        kafka_consumer = FlinkKafkaConsumer(source_topic,
                                            deserialization_schema, props)
        self.env.add_source(kafka_consumer).print()
        self.env.get_execution_plan()

    def test_add_classpaths(self):
        # find kafka connector jars
        flink_source_root = _find_flink_source_root()
        jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka'
        specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar')
        specific_jars = [
            'file://' + specific_jar for specific_jar in specific_jars
        ]

        self.env.add_classpaths(*specific_jars)
        source_topic = 'test_source_topic'
        props = {
            'bootstrap.servers': 'localhost:9092',
            'group.id': 'test_group'
        }
        type_info = Types.ROW([Types.INT(), Types.STRING()])

        # Test for kafka consumer
        deserialization_schema = JsonRowDeserializationSchema.builder() \
            .type_info(type_info=type_info).build()

        # It Will raise a ClassNotFoundException if the kafka connector is not added into the
        # pipeline classpaths.
        kafka_consumer = FlinkKafkaConsumer(source_topic,
                                            deserialization_schema, props)
        self.env.add_source(kafka_consumer).print()
        self.env.get_execution_plan()

    def test_generate_stream_graph_with_dependencies(self):
        python_file_dir = os.path.join(self.tempdir,
                                       "python_file_dir_" + str(uuid.uuid4()))
        os.mkdir(python_file_dir)
        python_file_path = os.path.join(
            python_file_dir, "test_stream_dependency_manage_lib.py")
        with open(python_file_path, 'w') as f:
            f.write("def add_two(a):\n    return a + 2")
        self.env.add_python_file(python_file_path)

        def plus_two_map(value):
            from test_stream_dependency_manage_lib import add_two
            return value[0], add_two(value[1])

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i[0], i[1] + int(f.read())

        from_collection_source = self.env.from_collection(
            [('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)],
            type_info=Types.ROW([Types.STRING(), Types.INT()]))
        from_collection_source.name("From Collection")
        keyed_stream = from_collection_source.key_by(lambda x: x[1],
                                                     key_type=Types.INT())

        plus_two_map_stream = keyed_stream.map(plus_two_map).name(
            "Plus Two Map").set_parallelism(3)

        add_from_file_map = plus_two_map_stream.map(add_from_file).name(
            "Add From File Map")

        test_stream_sink = add_from_file_map.add_sink(
            self.test_sink).name("Test Sink")
        test_stream_sink.set_parallelism(4)

        archive_dir_path = os.path.join(self.tempdir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("3")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.env.add_python_archive(archive_file_path, "data")

        nodes = eval(self.env.get_execution_plan())['nodes']

        # The StreamGraph should be as bellow:
        # Source: From Collection -> _stream_key_by_map_operator ->
        # Plus Two Map -> Add From File Map -> Sink: Test Sink.

        # Source: From Collection and _stream_key_by_map_operator should have same parallelism.
        self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism'])

        # The parallelism of Plus Two Map should be 3
        self.assertEqual(nodes[2]['parallelism'], 3)

        # The ship_strategy for Source: From Collection and _stream_key_by_map_operator should be
        # FORWARD
        self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'],
                         "FORWARD")

        # The ship_strategy for _keyed_stream_values_operator and Plus Two Map should be
        # HASH
        self.assertEqual(nodes[2]['predecessors'][0]['ship_strategy'], "HASH")

        # The parallelism of Sink: Test Sink should be 4
        self.assertEqual(nodes[4]['parallelism'], 4)

        env_config_with_dependencies = dict(
            get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil.
            getEnvConfigWithDependencies(
                self.env._j_stream_execution_environment).toMap())

        # Make sure that user specified files and archives are correctly added.
        self.assertIsNotNone(env_config_with_dependencies['python.files'])
        self.assertIsNotNone(env_config_with_dependencies['python.archives'])

    def tearDown(self) -> None:
        self.test_sink.clear()
Example #7
0
class BlinkStreamDependencyTests(DependencyTests,
                                 PyFlinkBlinkStreamTableTestCase):
    def test_set_requirements_without_cached_directory(self):
        requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("cloudpickle==1.2.2")
        self.t_env.set_python_requirements(requirements_txt_path)

        def check_requirements(i):
            import cloudpickle
            assert os.path.abspath(cloudpickle.__file__).startswith(
                os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR'])
            return i

        self.t_env.register_function(
            "check_requirements",
            udf(check_requirements, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.t_env.register_table_sink("Results", table_sink)
        t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select("check_requirements(a), a").insert_into("Results")
        self.t_env.execute("test")

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["1,1", "2,2", "3,3"])

    def test_set_requirements_with_cached_directory(self):
        tmp_dir = self.tempdir
        requirements_txt_path = os.path.join(
            tmp_dir, "requirements_txt_" + str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("python-package1==0.0.0")

        requirements_dir_path = os.path.join(
            tmp_dir, "requirements_dir_" + str(uuid.uuid4()))
        os.mkdir(requirements_dir_path)
        package_file_name = "python-package1-0.0.0.tar.gz"
        with open(os.path.join(requirements_dir_path, package_file_name),
                  'wb') as f:
            from pyflink.fn_execution.tests.process_mode_test_data import file_data
            import base64
            f.write(
                base64.b64decode(
                    json.loads(file_data[package_file_name])["data"]))
        self.t_env.set_python_requirements(requirements_txt_path,
                                           requirements_dir_path)

        def add_one(i):
            from python_package1 import plus
            return plus(i, 1)

        self.t_env.register_function(
            "add_one", udf(add_one, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.t_env.register_table_sink("Results", table_sink)
        t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select("add_one(a), a").insert_into("Results")
        self.t_env.execute("test")

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["2,1", "3,2", "4,3"])

    def test_add_python_archive(self):
        tmp_dir = self.tempdir
        archive_dir_path = os.path.join(tmp_dir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("2")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.t_env.add_python_archive(archive_file_path, "data")

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i + int(f.read())

        self.t_env.register_function(
            "add_from_file",
            udf(add_from_file, DataTypes.BIGINT(), DataTypes.BIGINT()))
        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.t_env.register_table_sink("Results", table_sink)
        t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select("add_from_file(a), a").insert_into("Results")
        self.t_env.execute("test")

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["3,1", "4,2", "5,3"])

    @unittest.skipIf(on_windows(),
                     "Symbolic link is not supported on Windows, skipping.")
    def test_set_environment(self):
        python_exec = sys.executable
        tmp_dir = self.tempdir
        python_exec_link_path = os.path.join(tmp_dir, "py_exec")
        os.symlink(python_exec, python_exec_link_path)
        self.t_env.get_config().set_python_executable(python_exec_link_path)

        def check_python_exec(i):
            import os
            assert os.environ["python"] == python_exec_link_path
            return i

        self.t_env.register_function(
            "check_python_exec",
            udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT()))

        def check_pyflink_gateway_disabled(i):
            try:
                from pyflink.java_gateway import get_gateway
                get_gateway()
            except Exception as e:
                assert str(e).startswith(
                    "It's launching the PythonGatewayServer during Python UDF"
                    " execution which is unexpected.")
            else:
                raise Exception("The gateway server is not disabled!")
            return i

        self.t_env.register_function(
            "check_pyflink_gateway_disabled",
            udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(),
                DataTypes.BIGINT()))

        table_sink = source_sink_utils.TestAppendSink(
            ['a', 'b'],
            [DataTypes.BIGINT(), DataTypes.BIGINT()])
        self.t_env.register_table_sink("Results", table_sink)
        t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
        t.select("check_python_exec(a), check_pyflink_gateway_disabled(a)"
                 ).insert_into("Results")
        self.t_env.execute("test")

        actual = source_sink_utils.results()
        self.assert_equals(actual, ["1,1", "2,2", "3,3"])
class StreamExecutionEnvironmentTests(PyFlinkTestCase):
    def setUp(self):
        self.env = StreamExecutionEnvironment.get_execution_environment()
        self.test_sink = DataStreamTestSinkFunction()

    def test_get_config(self):
        execution_config = self.env.get_config()

        self.assertIsInstance(execution_config, ExecutionConfig)

    def test_get_set_parallelism(self):

        self.env.set_parallelism(10)

        parallelism = self.env.get_parallelism()

        self.assertEqual(parallelism, 10)

    def test_get_set_buffer_timeout(self):

        self.env.set_buffer_timeout(12000)

        timeout = self.env.get_buffer_timeout()

        self.assertEqual(timeout, 12000)

    def test_get_set_default_local_parallelism(self):

        self.env.set_default_local_parallelism(8)

        parallelism = self.env.get_default_local_parallelism()

        self.assertEqual(parallelism, 8)

    def test_set_get_restart_strategy(self):

        self.env.set_restart_strategy(RestartStrategies.no_restart())

        restart_strategy = self.env.get_restart_strategy()

        self.assertEqual(restart_strategy, RestartStrategies.no_restart())

    def test_add_default_kryo_serializer(self):

        self.env.add_default_kryo_serializer(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
            "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer"
        )

        class_dict = self.env.get_config().get_default_kryo_serializer_classes(
        )

        self.assertEqual(
            class_dict, {
                'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
                'org.apache.flink.runtime.state'
                '.StateBackendTestBase$CustomKryoTestSerializer'
            })

    def test_register_type_with_kryo_serializer(self):

        self.env.register_type_with_kryo_serializer(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
            "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer"
        )

        class_dict = self.env.get_config(
        ).get_registered_types_with_kryo_serializer_classes()

        self.assertEqual(
            class_dict, {
                'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
                'org.apache.flink.runtime.state'
                '.StateBackendTestBase$CustomKryoTestSerializer'
            })

    def test_register_type(self):

        self.env.register_type(
            "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo")

        type_list = self.env.get_config().get_registered_pojo_types()

        self.assertEqual(
            type_list,
            ['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo'])

    def test_get_set_max_parallelism(self):

        self.env.set_max_parallelism(12)

        parallelism = self.env.get_max_parallelism()

        self.assertEqual(parallelism, 12)

    def test_operation_chaining(self):

        self.assertTrue(self.env.is_chaining_enabled())

        self.env.disable_operator_chaining()

        self.assertFalse(self.env.is_chaining_enabled())

    def test_get_checkpoint_config(self):

        checkpoint_config = self.env.get_checkpoint_config()

        self.assertIsInstance(checkpoint_config, CheckpointConfig)

    def test_get_set_checkpoint_interval(self):

        self.env.enable_checkpointing(30000)

        interval = self.env.get_checkpoint_interval()

        self.assertEqual(interval, 30000)

    def test_get_set_checkpointing_mode(self):
        mode = self.env.get_checkpointing_mode()
        self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE)

        self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE)

        mode = self.env.get_checkpointing_mode()

        self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE)

    def test_get_state_backend(self):

        state_backend = self.env.get_state_backend()

        self.assertIsNone(state_backend)

    def test_set_state_backend(self):

        input_backend = MemoryStateBackend()

        self.env.set_state_backend(input_backend)

        output_backend = self.env.get_state_backend()

        self.assertEqual(output_backend._j_memory_state_backend,
                         input_backend._j_memory_state_backend)

    def test_get_set_stream_time_characteristic(self):

        default_time_characteristic = self.env.get_stream_time_characteristic()

        self.assertEqual(default_time_characteristic,
                         TimeCharacteristic.ProcessingTime)

        self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)

        time_characteristic = self.env.get_stream_time_characteristic()

        self.assertEqual(time_characteristic, TimeCharacteristic.EventTime)

    @unittest.skip(
        "Python API does not support DataStream now. refactor this test later")
    def test_get_execution_plan(self):
        tmp_dir = tempfile.gettempdir()
        source_path = os.path.join(tmp_dir + '/streaming.csv')
        tmp_csv = os.path.join(tmp_dir + '/streaming2.csv')
        field_names = ["a", "b", "c"]
        field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]

        t_env = StreamTableEnvironment.create(self.env)
        csv_source = CsvTableSource(source_path, field_names, field_types)
        t_env.register_table_source("Orders", csv_source)
        t_env.register_table_sink(
            "Results", CsvTableSink(field_names, field_types, tmp_csv))
        exec_insert_table(t_env.from_path("Orders"), "Results")

        plan = self.env.get_execution_plan()

        json.loads(plan)

    def test_execute(self):
        tmp_dir = tempfile.gettempdir()
        field_names = ['a', 'b', 'c']
        field_types = [
            DataTypes.BIGINT(),
            DataTypes.STRING(),
            DataTypes.STRING()
        ]
        t_env = StreamTableEnvironment.create(self.env)
        t_env.register_table_sink(
            'Results',
            CsvTableSink(
                field_names, field_types,
                os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time())))))
        execution_result = exec_insert_table(
            t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']),
            'Results')
        self.assertIsNotNone(execution_result.get_job_id())
        self.assertIsNotNone(execution_result.get_net_runtime())
        self.assertEqual(len(execution_result.get_all_accumulator_results()),
                         0)
        self.assertIsNone(
            execution_result.get_accumulator_result('accumulator'))
        self.assertIsNotNone(str(execution_result))

    def test_from_collection_without_data_types(self):
        ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection")
        results = self.test_sink.get_results(True)
        # user does not specify data types for input data, the collected result should be in
        # in tuple format as inputs.
        expected = ["(1, 'Hi', 'Hello')", "(2, 'Hello', 'Hi')"]
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_from_collection_with_data_types(self):
        ds = self.env.from_collection(
            [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
            type_info=Types.ROW([Types.INT(),
                                 Types.STRING(),
                                 Types.STRING()]))
        ds.add_sink(self.test_sink)
        self.env.execute("test from collection")
        results = self.test_sink.get_results(False)
        # if user specifies data types of input data, the collected result should be in row format.
        expected = ['1,Hi,Hello', '2,Hello,Hi']
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_add_custom_source(self):
        custom_source = SourceFunction(
            "org.apache.flink.python.util.MyCustomSourceFunction")
        ds = self.env.add_source(custom_source,
                                 type_info=Types.ROW(
                                     [Types.INT(), Types.STRING()]))
        ds.add_sink(self.test_sink)
        self.env.execute("test add custom source")
        results = self.test_sink.get_results(False)
        expected = ['3,Mike', '1,Marry', '4,Ted', '5,Jack', '0,Bob', '2,Henry']
        results.sort()
        expected.sort()
        self.assertEqual(expected, results)

    def test_read_text_file(self):
        texts = ["Mike", "Marry", "Ted", "Jack", "Bob", "Henry"]
        text_file_path = self.tempdir + '/text_file'
        with open(text_file_path, 'a') as f:
            for text in texts:
                f.write(text)
                f.write('\n')

        ds = self.env.read_text_file(text_file_path)
        ds.add_sink(self.test_sink)
        self.env.execute("test read text file")
        results = self.test_sink.get_results()
        results.sort()
        texts.sort()
        self.assertEqual(texts, results)

    def test_execute_async(self):
        ds = self.env.from_collection(
            [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
            type_info=Types.ROW([Types.INT(),
                                 Types.STRING(),
                                 Types.STRING()]))
        ds.add_sink(self.test_sink)
        job_client = self.env.execute_async("test execute async")
        job_id = job_client.get_job_id()
        self.assertIsNotNone(job_id)
        execution_result = job_client.get_job_execution_result().result()
        self.assertEqual(str(job_id), str(execution_result.get_job_id()))

    def test_add_python_file(self):
        import uuid
        python_file_dir = os.path.join(self.tempdir,
                                       "python_file_dir_" + str(uuid.uuid4()))
        os.mkdir(python_file_dir)
        python_file_path = os.path.join(
            python_file_dir, "test_stream_dependency_manage_lib.py")
        with open(python_file_path, 'w') as f:
            f.write("def add_two(a):\n    return a + 2")

        def plus_two_map(value):
            from test_stream_dependency_manage_lib import add_two
            return add_two(value)

        self.env.add_python_file(python_file_path)
        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(plus_two_map).add_sink(self.test_sink)
        self.env.execute("test add python file")
        result = self.test_sink.get_results(True)
        expected = ['3', '4', '5', '6', '7']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_set_requirements_without_cached_directory(self):
        import uuid
        requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("cloudpickle==1.2.2")
        self.env.set_python_requirements(requirements_txt_path)

        def check_requirements(i):
            import cloudpickle
            assert os.path.abspath(cloudpickle.__file__).startswith(
                os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR'])
            return i

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(check_requirements).add_sink(self.test_sink)
        self.env.execute("test set requirements without cache dir")
        result = self.test_sink.get_results(True)
        expected = ['1', '2', '3', '4', '5']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_set_requirements_with_cached_directory(self):
        import uuid
        tmp_dir = self.tempdir
        requirements_txt_path = os.path.join(
            tmp_dir, "requirements_txt_" + str(uuid.uuid4()))
        with open(requirements_txt_path, 'w') as f:
            f.write("python-package1==0.0.0")

        requirements_dir_path = os.path.join(
            tmp_dir, "requirements_dir_" + str(uuid.uuid4()))
        os.mkdir(requirements_dir_path)
        package_file_name = "python-package1-0.0.0.tar.gz"
        with open(os.path.join(requirements_dir_path, package_file_name),
                  'wb') as f:
            import base64
            # This base64 data is encoded from a python package file which includes a
            # "python_package1" module. The module contains a "plus(a, b)" function.
            # The base64 can be recomputed by following code:
            # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8")
            f.write(
                base64.b64decode(
                    "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI"
                    "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX"
                    "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy"
                    "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ"
                    "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli"
                    "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp"
                    "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls"
                    "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p"
                    "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/"
                    "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP"
                    "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA"
                    "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA=="))
        self.env.set_python_requirements(requirements_txt_path,
                                         requirements_dir_path)

        def add_one(i):
            from python_package1 import plus
            return plus(i, 1)

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(add_one).add_sink(self.test_sink)
        self.env.execute("test set requirements with cachd dir")
        result = self.test_sink.get_results(True)
        expected = ['2', '3', '4', '5', '6']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_add_python_archive(self):
        import uuid
        import shutil
        tmp_dir = self.tempdir
        archive_dir_path = os.path.join(tmp_dir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("2")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.env.add_python_archive(archive_file_path, "data")

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i + int(f.read())

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(add_from_file).add_sink(self.test_sink)
        self.env.execute("test set python archive")
        result = self.test_sink.get_results(True)
        expected = ['3', '4', '5', '6', '7']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    @unittest.skipIf(on_windows(),
                     "Symbolic link is not supported on Windows, skipping.")
    def test_set_stream_env(self):
        import sys
        python_exec = sys.executable
        tmp_dir = self.tempdir
        python_exec_link_path = os.path.join(tmp_dir, "py_exec")
        os.symlink(python_exec, python_exec_link_path)
        self.env.set_python_executable(python_exec_link_path)

        def check_python_exec(i):
            import os
            assert os.environ["python"] == python_exec_link_path
            return i

        ds = self.env.from_collection([1, 2, 3, 4, 5])
        ds.map(check_python_exec).add_sink(self.test_sink)
        self.env.execute("test set python executable")
        result = self.test_sink.get_results(True)
        expected = ['1', '2', '3', '4', '5']
        result.sort()
        expected.sort()
        self.assertEqual(expected, result)

    def test_generate_stream_graph_with_dependencies(self):

        python_file_dir = os.path.join(self.tempdir,
                                       "python_file_dir_" + str(uuid.uuid4()))
        os.mkdir(python_file_dir)
        python_file_path = os.path.join(
            python_file_dir, "test_stream_dependency_manage_lib.py")
        with open(python_file_path, 'w') as f:
            f.write("def add_two(a):\n    return a + 2")
        self.env.add_python_file(python_file_path)

        def plus_two_map(value):
            from test_stream_dependency_manage_lib import add_two
            return value[0], add_two(value[1])

        def add_from_file(i):
            with open("data/data.txt", 'r') as f:
                return i[0], i[1] + int(f.read())

        from_collection_source = self.env.from_collection(
            [('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)],
            type_info=Types.ROW([Types.STRING(), Types.INT()]))
        from_collection_source.name("From Collection")
        keyed_stream = from_collection_source.key_by(lambda x: x[1],
                                                     key_type_info=Types.INT())

        plus_two_map_stream = keyed_stream.map(plus_two_map).name(
            "Plus Two Map").set_parallelism(3)

        add_from_file_map = plus_two_map_stream.map(add_from_file).name(
            "Add From File Map")

        test_stream_sink = add_from_file_map.add_sink(
            self.test_sink).name("Test Sink")
        test_stream_sink.set_parallelism(4)

        archive_dir_path = os.path.join(self.tempdir,
                                        "archive_" + str(uuid.uuid4()))
        os.mkdir(archive_dir_path)
        with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
            f.write("3")
        archive_file_path = \
            shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
        self.env.add_python_archive(archive_file_path, "data")

        nodes = eval(self.env.get_execution_plan())['nodes']

        # The StreamGraph should be as bellow:
        # Source: From Collection -> _stream_key_by_map_operator -> _keyed_stream_values_operator ->
        # Plus Two Map -> Add From File Map -> Sink: Test Sink.

        # Source: From Collection and _stream_key_by_map_operator should have same parallelism.
        self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism'])

        # _keyed_stream_values_operator and Plus Two Map should have same parallisim.
        self.assertEqual(nodes[3]['parallelism'], 3)
        self.assertEqual(nodes[2]['parallelism'], nodes[3]['parallelism'])

        # The ship_strategy for Source: From Collection and _stream_key_by_map_operator shoule be
        # FORWARD
        self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'],
                         "FORWARD")

        # The ship_strategy for _keyed_stream_values_operator and Plus Two Map shoule be
        # FORWARD
        self.assertEqual(nodes[3]['predecessors'][0]['ship_strategy'],
                         "FORWARD")

        # The parallelism of Sink: Test Sink should be 4
        self.assertEqual(nodes[5]['parallelism'], 4)

        env_config_with_dependencies = dict(
            get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil.
            getEnvConfigWithDependencies(
                self.env._j_stream_execution_environment).toMap())

        # Make sure that user specified files and archives are correctly added.
        self.assertIsNotNone(env_config_with_dependencies['python.files'])
        self.assertIsNotNone(env_config_with_dependencies['python.archives'])

    def tearDown(self) -> None:
        self.test_sink.clear()