def test_get_set_db_storage_paths(self): if on_windows(): checkpoints_path = "file:/C:/var/checkpoints/" storage_path = [ "file:/C:/var/db_storage_dir1/", "file:/C:/var/db_storage_dir2/", "file:/C:/var/db_storage_dir3/" ] expected = [ "C:\\var\\db_storage_dir1", "C:\\var\\db_storage_dir2", "C:\\var\\db_storage_dir3" ] else: checkpoints_path = "file://var/checkpoints/" storage_path = [ "file://var/db_storage_dir1/", "file://var/db_storage_dir2/", "file://var/db_storage_dir3/" ] expected = [ "/db_storage_dir1", "/db_storage_dir2", "/db_storage_dir3" ] state_backend = RocksDBStateBackend(checkpoints_path) state_backend.set_db_storage_paths(*storage_path) self.assertEqual(state_backend.get_db_storage_paths(), expected)
def setUp(self): provision_info = json_format.Parse('{"retrievalToken": "test_token"}', ProvisionInfo()) response = GetProvisionInfoResponse(info=provision_info) def get_unused_port(): sock = socket.socket() sock.bind(('', 0)) port = sock.getsockname()[1] sock.close() return port class ProvisionService(ProvisionServiceServicer): def GetProvisionInfo(self, request, context): return response def start_test_provision_server(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) add_ProvisionServiceServicer_to_server(ProvisionService(), server) port = get_unused_port() server.add_insecure_port('[::]:' + str(port)) server.start() return server, port self.provision_server, self.provision_port = start_test_provision_server( ) self.env = dict(os.environ) self.env["python"] = sys.executable self.env["FLINK_BOOT_TESTING"] = "1" self.env["BOOT_LOG_DIR"] = os.path.join(self.env["FLINK_HOME"], "log") self.tmp_dir = tempfile.mkdtemp(str(time.time()), dir=self.tempdir) # assume that this file is in flink-python source code directory. flink_python_source_root = os.path.dirname( os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) runner_script = "pyflink-udf-runner.bat" if on_windows() else \ "pyflink-udf-runner.sh" self.runner_path = os.path.join(flink_python_source_root, "bin", runner_script)
class PythonBootTests(PyFlinkTestCase): def setUp(self): provision_info = json_format.Parse('{"retrievalToken": "test_token"}', ProvisionInfo()) response = GetProvisionInfoResponse(info=provision_info) def get_unused_port(): sock = socket.socket() sock.bind(('', 0)) port = sock.getsockname()[1] sock.close() return port class ProvisionService(ProvisionServiceServicer): def GetProvisionInfo(self, request, context): return response def start_test_provision_server(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) add_ProvisionServiceServicer_to_server(ProvisionService(), server) port = get_unused_port() server.add_insecure_port('[::]:' + str(port)) server.start() return server, port self.provision_server, self.provision_port = start_test_provision_server() self.env = dict(os.environ) self.env["python"] = sys.executable self.env["FLINK_BOOT_TESTING"] = "1" self.env["BOOT_LOG_DIR"] = os.path.join(self.env["FLINK_HOME"], "log") self.tmp_dir = tempfile.mkdtemp(str(time.time()), dir=self.tempdir) # assume that this file is in flink-python source code directory. flink_python_source_root = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) runner_script = "pyflink-udf-runner.bat" if on_windows() else \ "pyflink-udf-runner.sh" self.runner_path = os.path.join( flink_python_source_root, "bin", runner_script) def run_boot_py(self): args = [self.runner_path, "--id", "1", "--logging_endpoint", "localhost:0000", "--artifact_endpoint", "whatever", "--provision_endpoint", "localhost:%d" % self.provision_port, "--control_endpoint", "localhost:0000", "--semi_persist_dir", self.tmp_dir] return subprocess.call(args, env=self.env) def test_python_boot(self): exit_code = self.run_boot_py() self.assertTrue(exit_code == 0, "the boot.py exited with non-zero code.") @unittest.skipIf(on_windows(), "'subprocess.check_output' in Windows always return empty " "string, skip this test.") def test_param_validation(self): args = [self.runner_path] exit_message = subprocess.check_output(args, env=self.env).decode("utf-8") self.assertIn("No id provided.", exit_message) args = [self.runner_path, "--id", "1"] exit_message = subprocess.check_output(args, env=self.env).decode("utf-8") self.assertIn("No logging endpoint provided.", exit_message) args = [self.runner_path, "--id", "1", "--logging_endpoint", "localhost:0000"] exit_message = subprocess.check_output(args, env=self.env).decode("utf-8") self.assertIn("No provision endpoint provided.", exit_message) args = [self.runner_path, "--id", "1", "--logging_endpoint", "localhost:0000", "--provision_endpoint", "localhost:%d" % self.provision_port] exit_message = subprocess.check_output(args, env=self.env).decode("utf-8") self.assertIn("No control endpoint provided.", exit_message) def test_set_working_directory(self): JProcessPythonEnvironmentManager = \ get_gateway().jvm.org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager output_file = os.path.join(self.tmp_dir, "output.txt") pyflink_dir = os.path.join(self.tmp_dir, "pyflink") os.mkdir(pyflink_dir) # just create an empty file open(os.path.join(pyflink_dir, "__init__.py"), 'a').close() fn_execution_dir = os.path.join(pyflink_dir, "fn_execution") os.mkdir(fn_execution_dir) open(os.path.join(fn_execution_dir, "__init__.py"), 'a').close() with open(os.path.join(fn_execution_dir, "boot.py"), "w") as f: f.write("import os\nwith open(r'%s', 'w') as f:\n f.write(os.getcwd())" % output_file) # test if the name of working directory variable of udf runner is consist with # ProcessPythonEnvironmentManager. self.env[JProcessPythonEnvironmentManager.PYTHON_WORKING_DIR] = self.tmp_dir self.env["python"] = sys.executable args = [self.runner_path] subprocess.check_output(args, env=self.env) process_cwd = None if os.path.exists(output_file): with open(output_file, 'r') as f: process_cwd = f.read() self.assertEqual(os.path.realpath(self.tmp_dir), process_cwd, "setting working directory variable is not work!") def tearDown(self): self.provision_server.stop(0) try: if self.tmp_dir is not None: shutil.rmtree(self.tmp_dir) except: pass
class TypesTests(PyFlinkTestCase): def test_infer_schema(self): from decimal import Decimal class A(object): def __init__(self): self.a = 1 from collections import namedtuple Point = namedtuple('Point', 'x y') data = [ True, 1, "a", u"a", datetime.date(1970, 1, 1), datetime.time(0, 0, 0), datetime.datetime(1970, 1, 1, 0, 0), 1.0, array.array("d", [1]), [1], (1, ), Point(1.0, 5.0), { "a": 1 }, bytearray(1), Decimal(1), Row(a=1), Row("a")(1), A(), ] expected = [ 'BooleanType(true)', 'BigIntType(true)', 'VarCharType(2147483647, true)', 'VarCharType(2147483647, true)', 'DateType(true)', 'TimeType(0, true)', 'LocalZonedTimestampType(6, true)', 'DoubleType(true)', "ArrayType(DoubleType(false), true)", "ArrayType(BigIntType(true), true)", 'RowType(RowField(_1, BigIntType(true), ...))', 'RowType(RowField(x, DoubleType(true), ...),RowField(y, DoubleType(true), ...))', 'MapType(VarCharType(2147483647, false), BigIntType(true), true)', 'VarBinaryType(2147483647, true)', 'DecimalType(38, 18, true)', 'RowType(RowField(a, BigIntType(true), ...))', 'RowType(RowField(a, BigIntType(true), ...))', 'RowType(RowField(a, BigIntType(true), ...))', ] schema = _infer_schema_from_data([data]) self.assertEqual(expected, [repr(f.data_type) for f in schema.fields]) def test_infer_schema_nulltype(self): elements = [ Row(c1=[], c2={}, c3=None), Row(c1=[Row(a=1, b='s')], c2={"key": Row(c=1.0, d="2")}, c3="") ] schema = _infer_schema_from_data(elements) self.assertTrue(isinstance(schema, RowType)) self.assertEqual(3, len(schema.fields)) # first column is array self.assertTrue(isinstance(schema.fields[0].data_type, ArrayType)) # element type of first column is struct self.assertTrue( isinstance(schema.fields[0].data_type.element_type, RowType)) self.assertTrue( isinstance( schema.fields[0].data_type.element_type.fields[0].data_type, BigIntType)) self.assertTrue( isinstance( schema.fields[0].data_type.element_type.fields[1].data_type, VarCharType)) # second column is map self.assertTrue(isinstance(schema.fields[1].data_type, MapType)) self.assertTrue( isinstance(schema.fields[1].data_type.key_type, VarCharType)) self.assertTrue( isinstance(schema.fields[1].data_type.value_type, RowType)) # third column is varchar self.assertTrue(isinstance(schema.fields[2].data_type, VarCharType)) def test_infer_schema_not_enough_names(self): schema = _infer_schema_from_data([["a", "b"]], ["col1"]) self.assertTrue(schema.names, ['col1', '_2']) def test_infer_schema_fails(self): with self.assertRaises(TypeError): _infer_schema_from_data([[1, 1], ["x", 1]], names=["a", "b"]) def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") data1 = [ NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0}) ] schema1 = _infer_schema_from_data(data1) expected1 = [ 'ArrayType(BigIntType(true), true)', 'MapType(VarCharType(2147483647, false), DoubleType(true), true)' ] self.assertEqual(expected1, [repr(f.data_type) for f in schema1.fields]) data2 = [ NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3]) ] schema2 = _infer_schema_from_data(data2) expected2 = [ 'ArrayType(ArrayType(BigIntType(true), true), true)', 'ArrayType(BigIntType(true), true)' ] self.assertEqual(expected2, [repr(f.data_type) for f in schema2.fields]) def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.as_dict()['l'][0].a) self.assertEqual(1.0, row.as_dict()['d']['key'].c) def test_udt(self): p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) _create_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _create_type_verifier(ExamplePointUDT()) ([1.0, 2.0])) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) _create_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _create_type_verifier(PythonOnlyUDT()) ([1.0, 2.0])) def test_nested_udt_in_df(self): expected_schema = DataTypes.ROW() \ .add("_1", DataTypes.BIGINT()).add("_2", DataTypes.ARRAY(PythonOnlyUDT())) data = (1, [PythonOnlyPoint(float(1), float(2))]) self.assertEqual(expected_schema, _infer_type(data)) expected_schema = DataTypes.ROW().add("_1", DataTypes.BIGINT()).add( "_2", DataTypes.MAP(DataTypes.BIGINT(False), PythonOnlyUDT())) p = (1, {1: PythonOnlyPoint(1, float(2))}) self.assertEqual(expected_schema, _infer_type(p)) def test_struct_type(self): row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \ .add("f2", DataTypes.STRING(nullable=True)) row2 = DataTypes.ROW([ DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)), DataTypes.FIELD("f2", DataTypes.STRING(nullable=True), None) ]) self.assertEqual(row1.field_names(), row2.names) self.assertEqual(row1, row2) row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \ .add("f2", DataTypes.STRING(nullable=True)) row2 = DataTypes.ROW( [DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))]) self.assertNotEqual(row1.field_names(), row2.names) self.assertNotEqual(row1, row2) row1 = (DataTypes.ROW().add( DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))).add( "f2", DataTypes.STRING(nullable=True))) row2 = DataTypes.ROW([ DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)), DataTypes.FIELD("f2", DataTypes.STRING(nullable=True)) ]) self.assertEqual(row1.field_names(), row2.names) self.assertEqual(row1, row2) row1 = (DataTypes.ROW().add( DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))).add( "f2", DataTypes.STRING(nullable=True))) row2 = DataTypes.ROW( [DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))]) self.assertNotEqual(row1.field_names(), row2.names) self.assertNotEqual(row1, row2) # Catch exception raised during improper construction self.assertRaises(ValueError, lambda: DataTypes.ROW().add("name")) row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \ .add("f2", DataTypes.STRING(nullable=True)) for field in row1: self.assertIsInstance(field, RowField) row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \ .add("f2", DataTypes.STRING(nullable=True)) self.assertEqual(len(row1), 2) row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \ .add("f2", DataTypes.STRING(nullable=True)) self.assertIs(row1["f1"], row1.fields[0]) self.assertIs(row1[0], row1.fields[0]) self.assertEqual(row1[0:1], DataTypes.ROW(row1.fields[0:1])) self.assertRaises(KeyError, lambda: row1["f9"]) self.assertRaises(IndexError, lambda: row1[9]) self.assertRaises(TypeError, lambda: row1[9.9]) def test_infer_bigint_type(self): longrow = [Row(f1='a', f2=100000000000000)] schema = _infer_schema_from_data(longrow) self.assertEqual(DataTypes.BIGINT(), schema.fields[1].data_type) self.assertEqual(DataTypes.BIGINT(), _infer_type(1)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**10)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**20)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31 - 1)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**61)) self.assertEqual(DataTypes.BIGINT(), _infer_type(2**71)) def test_merge_type(self): self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()), DataTypes.BIGINT()) self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()), DataTypes.BIGINT()) self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()), DataTypes.BIGINT()) self.assertEqual( _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.ARRAY(DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.ARRAY(DataTypes.DOUBLE())) self.assertEqual( _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.DOUBLE()), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())])) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.STRING())])) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))) ])) def test_array_types(self): # This test need to make sure that the Scala type selected is at least # as large as the python's types. This is necessary because python's # array types depend on C implementation on the machine. Therefore there # is no machine independent correspondence between python's array types # and Scala types. # See: https://docs.python.org/2/library/array.html def assert_collect_success(typecode, value, element_type): self.assertEqual( element_type, str(_infer_type(array.array(typecode, [value])).element_type)) # supported string types # # String types in python's array are "u" for Py_UNICODE and "c" for char. # "u" will be removed in python 4, and "c" is not supported in python 3. supported_string_types = [] if sys.version_info[0] < 4: supported_string_types += ['u'] # test unicode assert_collect_success('u', u'a', 'CHAR') # supported float and double # # Test max, min, and precision for float and double, assuming IEEE 754 # floating-point format. supported_fractional_types = ['f', 'd'] assert_collect_success('f', ctypes.c_float(1e+38).value, 'FLOAT') assert_collect_success('f', ctypes.c_float(1e-38).value, 'FLOAT') assert_collect_success('f', ctypes.c_float(1.123456).value, 'FLOAT') assert_collect_success('d', sys.float_info.max, 'DOUBLE') assert_collect_success('d', sys.float_info.min, 'DOUBLE') assert_collect_success('d', sys.float_info.epsilon, 'DOUBLE') def get_int_data_type(size): if size <= 8: return "TINYINT" if size <= 16: return "SMALLINT" if size <= 32: return "INT" if size <= 64: return "BIGINT" # supported signed int types # # The size of C types changes with implementation, we need to make sure # that there is no overflow error on the platform running this test. supported_signed_int_types = list( set(_array_signed_int_typecode_ctype_mappings.keys()).intersection( set(_array_type_mappings.keys()))) for t in supported_signed_int_types: ctype = _array_signed_int_typecode_ctype_mappings[t] max_val = 2**(ctypes.sizeof(ctype) * 8 - 1) assert_collect_success(t, max_val - 1, get_int_data_type(ctypes.sizeof(ctype) * 8)) assert_collect_success(t, -max_val, get_int_data_type(ctypes.sizeof(ctype) * 8)) # supported unsigned int types # # JVM does not have unsigned types. We need to be very careful to make # sure that there is no overflow error. supported_unsigned_int_types = list( set(_array_unsigned_int_typecode_ctype_mappings.keys()). intersection(set(_array_type_mappings.keys()))) for t in supported_unsigned_int_types: ctype = _array_unsigned_int_typecode_ctype_mappings[t] max_val = 2**(ctypes.sizeof(ctype) * 8 - 1) assert_collect_success( t, max_val, get_int_data_type(ctypes.sizeof(ctype) * 8 + 1)) # all supported types # # Make sure the types tested above: # 1. are all supported types # 2. cover all supported types supported_types = (supported_string_types + supported_fractional_types + supported_signed_int_types + supported_unsigned_int_types) self.assertEqual(set(supported_types), set(_array_type_mappings.keys())) # all unsupported types # # Keys in _array_type_mappings is a complete list of all supported types, # and types not in _array_type_mappings are considered unsupported. all_types = set(array.typecodes) unsupported_types = all_types - set(supported_types) # test unsupported types for t in unsupported_types: with self.assertRaises(TypeError): _infer_schema_from_data([Row(myarray=array.array(t))]) def test_data_type_eq(self): lt = DataTypes.BIGINT() lt2 = pickle.loads(pickle.dumps(DataTypes.BIGINT())) self.assertEqual(lt, lt2) def test_decimal_type(self): t1 = DataTypes.DECIMAL(10, 0) t2 = DataTypes.DECIMAL(10, 2) self.assertTrue(t2 is not t1) self.assertNotEqual(t1, t2) def test_datetype_equal_zero(self): dt = DataTypes.DATE() self.assertEqual(dt.from_sql_type(0), datetime.date(1970, 1, 1)) @unittest.skipIf(on_windows(), "Windows x64 system only support the datetime not larger " "than time.ctime(32536799999), so this test can't run " "under Windows platform") def test_timestamp_microsecond(self): tst = DataTypes.TIMESTAMP() self.assertEqual( tst.to_sql_type(datetime.datetime.max) % 1000000, 999999) @unittest.skipIf(on_windows(), "Windows x64 system only support the datetime not larger " "than time.ctime(32536799999), so this test can't run " "under Windows platform") def test_local_zoned_timestamp_type(self): lztst = DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE() last_abbreviation = DataTypes.TIMESTAMP_LTZ() self.assertEqual(lztst, last_abbreviation) ts = datetime.datetime(1970, 1, 1, 0, 0, 0, 0000) self.assertEqual(0, lztst.to_sql_type(ts)) import pytz # suppose the timezone of the data is +9:00 timezone = pytz.timezone("Asia/Tokyo") orig_epoch = LocalZonedTimestampType.EPOCH_ORDINAL try: # suppose the local timezone is +8:00 LocalZonedTimestampType.EPOCH_ORDINAL = 28800000000 ts_tokyo = timezone.localize(ts) self.assertEqual(-3600000000, lztst.to_sql_type(ts_tokyo)) finally: LocalZonedTimestampType.EPOCH_ORDINAL = orig_epoch if sys.version_info >= (3, 6): ts2 = lztst.from_sql_type(0) self.assertEqual(ts.astimezone(), ts2.astimezone()) def test_zoned_timestamp_type(self): ztst = ZonedTimestampType() ts = datetime.datetime(1970, 1, 1, 0, 0, 0, 0000, tzinfo=UTCOffsetTimezone(1)) self.assertEqual((0, 3600), ztst.to_sql_type(ts)) ts2 = ztst.from_sql_type((0, 3600)) self.assertEqual(ts, ts2) def test_day_time_inteval_type(self): ymt = DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND()) td = datetime.timedelta(days=1, seconds=10) self.assertEqual(86410000000, ymt.to_sql_type(td)) td2 = ymt.from_sql_type(86410000000) self.assertEqual(td, td2) def test_empty_row(self): row = Row() self.assertEqual(len(row), 0) def test_invalid_create_row(self): row_class = Row("c1", "c2") self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) def test_nullable(self): t = DataType(nullable=False) self.assertEqual(t._nullable, False) t_nullable = t.nullable() self.assertEqual(t_nullable._nullable, True) def test_not_null(self): t = DataType(nullable=True) self.assertEqual(t._nullable, True) t_notnull = t.not_null() self.assertEqual(t_notnull._nullable, False)
class StreamDependencyTests(DependencyTests, PyFlinkStreamTableTestCase): def setUp(self): super(StreamDependencyTests, self).setUp() origin_execution_mode = os.environ['_python_worker_execution_mode'] os.environ['_python_worker_execution_mode'] = "loopback" try: self.st_env = TableEnvironment.create( EnvironmentSettings.in_streaming_mode()) finally: if origin_execution_mode is not None: os.environ[ '_python_worker_execution_mode'] = origin_execution_mode def test_set_requirements_without_cached_directory(self): requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("cloudpickle==1.2.2") self.st_env.set_python_requirements(requirements_txt_path) def check_requirements(i): import cloudpickle # noqa # pylint: disable=unused-import assert '_PYTHON_REQUIREMENTS_INSTALL_DIR' in os.environ return i self.st_env.create_temporary_system_function( "check_requirements", udf(check_requirements, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.st_env.register_table_sink("Results", table_sink) t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select(expr.call('check_requirements', t.a), t.a).execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"]) def test_set_requirements_with_cached_directory(self): tmp_dir = self.tempdir requirements_txt_path = os.path.join( tmp_dir, "requirements_txt_" + str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("python-package1==0.0.0") requirements_dir_path = os.path.join( tmp_dir, "requirements_dir_" + str(uuid.uuid4())) os.mkdir(requirements_dir_path) package_file_name = "python-package1-0.0.0.tar.gz" with open(os.path.join(requirements_dir_path, package_file_name), 'wb') as f: import base64 # This base64 data is encoded from a python package file which includes a # "python_package1" module. The module contains a "plus(a, b)" function. # The base64 can be recomputed by following code: # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8") f.write( base64.b64decode( "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI" "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX" "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy" "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ" "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli" "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp" "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls" "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p" "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/" "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP" "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA==")) self.st_env.set_python_requirements(requirements_txt_path, requirements_dir_path) def add_one(i): from python_package1 import plus return plus(i, 1) self.st_env.create_temporary_system_function( "add_one", udf(add_one, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.st_env.register_table_sink("Results", table_sink) t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select(expr.call('add_one', t.a), t.a).execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[2, 1]", "+I[3, 2]", "+I[4, 3]"]) def test_add_python_archive(self): tmp_dir = self.tempdir archive_dir_path = os.path.join(tmp_dir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("2") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.t_env.add_python_archive(archive_file_path, "data") def add_from_file(i): with open("data/data.txt", 'r') as f: return i + int(f.read()) self.t_env.create_temporary_system_function( "add_from_file", udf(add_from_file, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select(expr.call('add_from_file', t.a), t.a).execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[3, 1]", "+I[4, 2]", "+I[5, 3]"]) @unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.") def test_set_environment(self): python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.st_env.get_config().set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i self.st_env.create_temporary_system_function( "check_python_exec", udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): from pyflink.java_gateway import get_gateway get_gateway() return i self.st_env.create_temporary_system_function( "check_pyflink_gateway_disabled", udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.st_env.register_table_sink("Results", table_sink) t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select( expr.call('check_python_exec', t.a), expr.call('check_pyflink_gateway_disabled', t.a)) \ .execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"])
class StreamExecutionEnvironmentTests(PyFlinkTestCase): def setUp(self): self.env = StreamExecutionEnvironment.get_execution_environment() self.env.set_parallelism(2) self.test_sink = DataStreamTestSinkFunction() def test_get_config(self): execution_config = self.env.get_config() self.assertIsInstance(execution_config, ExecutionConfig) def test_get_set_parallelism(self): self.env.set_parallelism(10) parallelism = self.env.get_parallelism() self.assertEqual(parallelism, 10) def test_get_set_buffer_timeout(self): self.env.set_buffer_timeout(12000) timeout = self.env.get_buffer_timeout() self.assertEqual(timeout, 12000) def test_get_set_default_local_parallelism(self): self.env.set_default_local_parallelism(8) parallelism = self.env.get_default_local_parallelism() self.assertEqual(parallelism, 8) def test_set_get_restart_strategy(self): self.env.set_restart_strategy(RestartStrategies.no_restart()) restart_strategy = self.env.get_restart_strategy() self.assertEqual(restart_strategy, RestartStrategies.no_restart()) def test_add_default_kryo_serializer(self): self.env.add_default_kryo_serializer( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo", "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer" ) class_dict = self.env.get_config().get_default_kryo_serializer_classes( ) self.assertEqual( class_dict, { 'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo': 'org.apache.flink.runtime.state' '.StateBackendTestBase$CustomKryoTestSerializer' }) def test_register_type_with_kryo_serializer(self): self.env.register_type_with_kryo_serializer( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo", "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer" ) class_dict = self.env.get_config( ).get_registered_types_with_kryo_serializer_classes() self.assertEqual( class_dict, { 'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo': 'org.apache.flink.runtime.state' '.StateBackendTestBase$CustomKryoTestSerializer' }) def test_register_type(self): self.env.register_type( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo") type_list = self.env.get_config().get_registered_pojo_types() self.assertEqual( type_list, ['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo']) def test_get_set_max_parallelism(self): self.env.set_max_parallelism(12) parallelism = self.env.get_max_parallelism() self.assertEqual(parallelism, 12) def test_set_runtime_mode(self): self.env.set_runtime_mode(RuntimeExecutionMode.BATCH) config = invoke_java_object_method( self.env._j_stream_execution_environment, "getConfiguration") runtime_mode = config.getValue(get_gateway( ).jvm.org.apache.flink.configuration.ExecutionOptions.RUNTIME_MODE) self.assertEqual(runtime_mode, "BATCH") def test_operation_chaining(self): self.assertTrue(self.env.is_chaining_enabled()) self.env.disable_operator_chaining() self.assertFalse(self.env.is_chaining_enabled()) def test_get_checkpoint_config(self): checkpoint_config = self.env.get_checkpoint_config() self.assertIsInstance(checkpoint_config, CheckpointConfig) def test_get_set_checkpoint_interval(self): self.env.enable_checkpointing(30000) interval = self.env.get_checkpoint_interval() self.assertEqual(interval, 30000) def test_get_set_checkpointing_mode(self): mode = self.env.get_checkpointing_mode() self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE) self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE) mode = self.env.get_checkpointing_mode() self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE) def test_get_state_backend(self): state_backend = self.env.get_state_backend() self.assertIsNone(state_backend) def test_set_state_backend(self): input_backend = MemoryStateBackend() self.env.set_state_backend(input_backend) output_backend = self.env.get_state_backend() self.assertEqual(output_backend._j_memory_state_backend, input_backend._j_memory_state_backend) def test_is_changelog_state_backend_enabled(self): self.assertIsNone(self.env.is_changelog_state_backend_enabled()) def test_enable_changelog_state_backend(self): self.env.enable_changelog_state_backend(True) self.assertTrue(self.env.is_changelog_state_backend_enabled()) self.env.enable_changelog_state_backend(False) self.assertFalse(self.env.is_changelog_state_backend_enabled()) def test_get_set_stream_time_characteristic(self): default_time_characteristic = self.env.get_stream_time_characteristic() self.assertEqual(default_time_characteristic, TimeCharacteristic.EventTime) self.env.set_stream_time_characteristic( TimeCharacteristic.ProcessingTime) time_characteristic = self.env.get_stream_time_characteristic() self.assertEqual(time_characteristic, TimeCharacteristic.ProcessingTime) @unittest.skip( "Python API does not support DataStream now. refactor this test later") def test_get_execution_plan(self): tmp_dir = tempfile.gettempdir() source_path = os.path.join(tmp_dir + '/streaming.csv') tmp_csv = os.path.join(tmp_dir + '/streaming2.csv') field_names = ["a", "b", "c"] field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] t_env = StreamTableEnvironment.create(self.env) csv_source = CsvTableSource(source_path, field_names, field_types) t_env.register_table_source("Orders", csv_source) t_env.register_table_sink( "Results", CsvTableSink(field_names, field_types, tmp_csv)) t_env.from_path("Orders").execute_insert("Results").wait() plan = self.env.get_execution_plan() json.loads(plan) def test_execute(self): tmp_dir = tempfile.gettempdir() field_names = ['a', 'b', 'c'] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env = StreamTableEnvironment.create(self.env) t_env.register_table_sink( 'Results', CsvTableSink( field_names, field_types, os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time()))))) execution_result = exec_insert_table( t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']), 'Results') self.assertIsNotNone(execution_result.get_job_id()) self.assertIsNotNone(execution_result.get_net_runtime()) self.assertEqual(len(execution_result.get_all_accumulator_results()), 0) self.assertIsNone( execution_result.get_accumulator_result('accumulator')) self.assertIsNotNone(str(execution_result)) def test_from_collection_without_data_types(self): ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) ds.add_sink(self.test_sink) self.env.execute("test from collection") results = self.test_sink.get_results(True) # user does not specify data types for input data, the collected result should be in # in tuple format as inputs. expected = ["(1, 'Hi', 'Hello')", "(2, 'Hello', 'Hi')"] results.sort() expected.sort() self.assertEqual(expected, results) def test_from_collection_with_data_types(self): # verify from_collection for the collection with single object. ds = self.env.from_collection(['Hi', 'Hello'], type_info=Types.STRING()) ds.add_sink(self.test_sink) self.env.execute("test from collection with single object") results = self.test_sink.get_results(False) expected = ['Hello', 'Hi'] results.sort() expected.sort() self.assertEqual(expected, results) # verify from_collection for the collection with multiple objects like tuple. ds = self.env.from_collection( [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13), datetime.time(hour=12, minute=0, second=0, microsecond=123000), datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [1, 2, 3], decimal.Decimal('1000000000000000000.05'), decimal.Decimal('1000000000000000000.0599999999999' '9999899999999999')), (2, None, 2, True, 43878, 9147483648, 9.87, 2.98936, bytearray(b'flink'), 'pyflink', datetime.date(2015, 10, 14), datetime.time(hour=11, minute=2, second=2, microsecond=234500), datetime.datetime(2020, 4, 15, 8, 2, 6, 235000), [2, 4, 6], decimal.Decimal('2000000000000000000.74'), decimal.Decimal('2000000000000000000.061111111111111' '11111111111111'))], type_info=Types.ROW([ Types.LONG(), Types.LONG(), Types.SHORT(), Types.BOOLEAN(), Types.SHORT(), Types.INT(), Types.FLOAT(), Types.DOUBLE(), Types.PICKLED_BYTE_ARRAY(), Types.STRING(), Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP(), Types.BASIC_ARRAY(Types.LONG()), Types.BIG_DEC(), Types.BIG_DEC() ])) ds.add_sink(self.test_sink) self.env.execute("test from collection with tuple object") results = self.test_sink.get_results(False) # if user specifies data types of input data, the collected result should be in row format. expected = [ '+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, [102, 108, 105, 110, 107], ' 'pyflink, 2014-09-13, 12:00:00, 2018-03-11 03:00:00.123, [1, 2, 3], ' '1000000000000000000.05, 1000000000000000000.05999999999999999899999999999]', '+I[2, null, 2, true, -21658, 557549056, 9.87, 2.98936, [102, 108, 105, 110, 107], ' 'pyflink, 2015-10-14, 11:02:02, 2020-04-15 08:02:06.235, [2, 4, 6], ' '2000000000000000000.74, 2000000000000000000.06111111111111111111111111111]' ] results.sort() expected.sort() self.assertEqual(expected, results) def test_add_custom_source(self): custom_source = SourceFunction( "org.apache.flink.python.util.MyCustomSourceFunction") ds = self.env.add_source(custom_source, type_info=Types.ROW( [Types.INT(), Types.STRING()])) ds.add_sink(self.test_sink) self.env.execute("test add custom source") results = self.test_sink.get_results(False) expected = [ '+I[3, Mike]', '+I[1, Marry]', '+I[4, Ted]', '+I[5, Jack]', '+I[0, Bob]', '+I[2, Henry]' ] results.sort() expected.sort() self.assertEqual(expected, results) def test_read_text_file(self): texts = ["Mike", "Marry", "Ted", "Jack", "Bob", "Henry"] text_file_path = self.tempdir + '/text_file' with open(text_file_path, 'a') as f: for text in texts: f.write(text) f.write('\n') ds = self.env.read_text_file(text_file_path) ds.add_sink(self.test_sink) self.env.execute("test read text file") results = self.test_sink.get_results() results.sort() texts.sort() self.assertEqual(texts, results) def test_execute_async(self): ds = self.env.from_collection( [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')], type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])) ds.add_sink(self.test_sink) job_client = self.env.execute_async("test execute async") job_id = job_client.get_job_id() self.assertIsNotNone(job_id) execution_result = job_client.get_job_execution_result().result() self.assertEqual(str(job_id), str(execution_result.get_job_id())) def test_add_python_file(self): import uuid python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4())) os.mkdir(python_file_dir) python_file_path = os.path.join(python_file_dir, "test_dep1.py") with open(python_file_path, 'w') as f: f.write("def add_two(a):\n return a + 2") def plus_two_map(value): from test_dep1 import add_two return add_two(value) get_j_env_configuration(self.env._j_stream_execution_environment).\ setString("taskmanager.numberOfTaskSlots", "10") self.env.add_python_file(python_file_path) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds = ds.map(plus_two_map, Types.LONG()) \ .slot_sharing_group("data_stream") \ .map(lambda i: i, Types.LONG()) \ .slot_sharing_group("table") python_file_path = os.path.join(python_file_dir, "test_dep2.py") with open(python_file_path, 'w') as f: f.write("def add_three(a):\n return a + 3") def plus_three(value): from test_dep2 import add_three return add_three(value) t_env = StreamTableEnvironment.create( stream_execution_environment=self.env, environment_settings=EnvironmentSettings.in_streaming_mode()) self.env.add_python_file(python_file_path) from pyflink.table.udf import udf from pyflink.table.expressions import col add_three = udf(plus_three, result_type=DataTypes.BIGINT()) tab = t_env.from_data_stream(ds, 'a') \ .select(add_three(col('a'))) t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \ .map(lambda i: i[0]) \ .add_sink(self.test_sink) self.env.execute("test add_python_file") result = self.test_sink.get_results(True) expected = ['6', '7', '8', '9', '10'] result.sort() expected.sort() self.assertEqual(expected, result) def test_set_requirements_without_cached_directory(self): import uuid requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("cloudpickle==1.2.2") self.env.set_python_requirements(requirements_txt_path) def check_requirements(i): import cloudpickle assert os.path.abspath(cloudpickle.__file__).startswith( os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR']) return i ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(check_requirements).add_sink(self.test_sink) self.env.execute("test set requirements without cache dir") result = self.test_sink.get_results(True) expected = ['1', '2', '3', '4', '5'] result.sort() expected.sort() self.assertEqual(expected, result) def test_set_requirements_with_cached_directory(self): import uuid tmp_dir = self.tempdir requirements_txt_path = os.path.join( tmp_dir, "requirements_txt_" + str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("python-package1==0.0.0") requirements_dir_path = os.path.join( tmp_dir, "requirements_dir_" + str(uuid.uuid4())) os.mkdir(requirements_dir_path) package_file_name = "python-package1-0.0.0.tar.gz" with open(os.path.join(requirements_dir_path, package_file_name), 'wb') as f: import base64 # This base64 data is encoded from a python package file which includes a # "python_package1" module. The module contains a "plus(a, b)" function. # The base64 can be recomputed by following code: # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8") f.write( base64.b64decode( "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI" "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX" "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy" "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ" "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli" "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp" "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls" "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p" "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/" "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP" "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA==")) self.env.set_python_requirements(requirements_txt_path, requirements_dir_path) def add_one(i): from python_package1 import plus return plus(i, 1) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(add_one).add_sink(self.test_sink) self.env.execute("test set requirements with cachd dir") result = self.test_sink.get_results(True) expected = ['2', '3', '4', '5', '6'] result.sort() expected.sort() self.assertEqual(expected, result) def test_add_python_archive(self): import uuid import shutil tmp_dir = self.tempdir archive_dir_path = os.path.join(tmp_dir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("2") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.env.add_python_archive(archive_file_path, "data") def add_from_file(i): with open("data/data.txt", 'r') as f: return i + int(f.read()) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(add_from_file).add_sink(self.test_sink) self.env.execute("test set python archive") result = self.test_sink.get_results(True) expected = ['3', '4', '5', '6', '7'] result.sort() expected.sort() self.assertEqual(expected, result) @unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.") def test_set_stream_env(self): import sys python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.env.set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(check_python_exec).add_sink(self.test_sink) self.env.execute("test set python executable") result = self.test_sink.get_results(True) expected = ['1', '2', '3', '4', '5'] result.sort() expected.sort() self.assertEqual(expected, result) def test_add_jars(self): # find kafka connector jars flink_source_root = _find_flink_source_root() jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka' specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar') specific_jars = [ 'file://' + specific_jar for specific_jar in specific_jars ] self.env.add_jars(*specific_jars) source_topic = 'test_source_topic' props = { 'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group' } type_info = Types.ROW([Types.INT(), Types.STRING()]) # Test for kafka consumer deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(type_info=type_info).build() # Will get a ClassNotFoundException if not add the kafka connector into the pipeline jars. kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props) self.env.add_source(kafka_consumer).print() self.env.get_execution_plan() def test_add_classpaths(self): # find kafka connector jars flink_source_root = _find_flink_source_root() jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka' specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar') specific_jars = [ 'file://' + specific_jar for specific_jar in specific_jars ] self.env.add_classpaths(*specific_jars) source_topic = 'test_source_topic' props = { 'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group' } type_info = Types.ROW([Types.INT(), Types.STRING()]) # Test for kafka consumer deserialization_schema = JsonRowDeserializationSchema.builder() \ .type_info(type_info=type_info).build() # It Will raise a ClassNotFoundException if the kafka connector is not added into the # pipeline classpaths. kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props) self.env.add_source(kafka_consumer).print() self.env.get_execution_plan() def test_generate_stream_graph_with_dependencies(self): python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4())) os.mkdir(python_file_dir) python_file_path = os.path.join( python_file_dir, "test_stream_dependency_manage_lib.py") with open(python_file_path, 'w') as f: f.write("def add_two(a):\n return a + 2") self.env.add_python_file(python_file_path) def plus_two_map(value): from test_stream_dependency_manage_lib import add_two return value[0], add_two(value[1]) def add_from_file(i): with open("data/data.txt", 'r') as f: return i[0], i[1] + int(f.read()) from_collection_source = self.env.from_collection( [('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], type_info=Types.ROW([Types.STRING(), Types.INT()])) from_collection_source.name("From Collection") keyed_stream = from_collection_source.key_by(lambda x: x[1], key_type=Types.INT()) plus_two_map_stream = keyed_stream.map(plus_two_map).name( "Plus Two Map").set_parallelism(3) add_from_file_map = plus_two_map_stream.map(add_from_file).name( "Add From File Map") test_stream_sink = add_from_file_map.add_sink( self.test_sink).name("Test Sink") test_stream_sink.set_parallelism(4) archive_dir_path = os.path.join(self.tempdir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("3") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.env.add_python_archive(archive_file_path, "data") nodes = eval(self.env.get_execution_plan())['nodes'] # The StreamGraph should be as bellow: # Source: From Collection -> _stream_key_by_map_operator -> # Plus Two Map -> Add From File Map -> Sink: Test Sink. # Source: From Collection and _stream_key_by_map_operator should have same parallelism. self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism']) # The parallelism of Plus Two Map should be 3 self.assertEqual(nodes[2]['parallelism'], 3) # The ship_strategy for Source: From Collection and _stream_key_by_map_operator should be # FORWARD self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'], "FORWARD") # The ship_strategy for _keyed_stream_values_operator and Plus Two Map should be # HASH self.assertEqual(nodes[2]['predecessors'][0]['ship_strategy'], "HASH") # The parallelism of Sink: Test Sink should be 4 self.assertEqual(nodes[4]['parallelism'], 4) env_config_with_dependencies = dict( get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil. getEnvConfigWithDependencies( self.env._j_stream_execution_environment).toMap()) # Make sure that user specified files and archives are correctly added. self.assertIsNotNone(env_config_with_dependencies['python.files']) self.assertIsNotNone(env_config_with_dependencies['python.archives']) def tearDown(self) -> None: self.test_sink.clear()
class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCase): def test_set_requirements_without_cached_directory(self): requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("cloudpickle==1.2.2") self.t_env.set_python_requirements(requirements_txt_path) def check_requirements(i): import cloudpickle assert os.path.abspath(cloudpickle.__file__).startswith( os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR']) return i self.t_env.register_function( "check_requirements", udf(check_requirements, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select("check_requirements(a), a").insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["1,1", "2,2", "3,3"]) def test_set_requirements_with_cached_directory(self): tmp_dir = self.tempdir requirements_txt_path = os.path.join( tmp_dir, "requirements_txt_" + str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("python-package1==0.0.0") requirements_dir_path = os.path.join( tmp_dir, "requirements_dir_" + str(uuid.uuid4())) os.mkdir(requirements_dir_path) package_file_name = "python-package1-0.0.0.tar.gz" with open(os.path.join(requirements_dir_path, package_file_name), 'wb') as f: from pyflink.fn_execution.tests.process_mode_test_data import file_data import base64 f.write( base64.b64decode( json.loads(file_data[package_file_name])["data"])) self.t_env.set_python_requirements(requirements_txt_path, requirements_dir_path) def add_one(i): from python_package1 import plus return plus(i, 1) self.t_env.register_function( "add_one", udf(add_one, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select("add_one(a), a").insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["2,1", "3,2", "4,3"]) def test_add_python_archive(self): tmp_dir = self.tempdir archive_dir_path = os.path.join(tmp_dir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("2") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.t_env.add_python_archive(archive_file_path, "data") def add_from_file(i): with open("data/data.txt", 'r') as f: return i + int(f.read()) self.t_env.register_function( "add_from_file", udf(add_from_file, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select("add_from_file(a), a").insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["3,1", "4,2", "5,3"]) @unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.") def test_set_environment(self): python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.t_env.get_config().set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i self.t_env.register_function( "check_python_exec", udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): try: from pyflink.java_gateway import get_gateway get_gateway() except Exception as e: assert str(e).startswith( "It's launching the PythonGatewayServer during Python UDF" " execution which is unexpected.") else: raise Exception("The gateway server is not disabled!") return i self.t_env.register_function( "check_pyflink_gateway_disabled", udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select("check_python_exec(a), check_pyflink_gateway_disabled(a)" ).insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["1,1", "2,2", "3,3"])
class StreamExecutionEnvironmentTests(PyFlinkTestCase): def setUp(self): self.env = StreamExecutionEnvironment.get_execution_environment() self.test_sink = DataStreamTestSinkFunction() def test_get_config(self): execution_config = self.env.get_config() self.assertIsInstance(execution_config, ExecutionConfig) def test_get_set_parallelism(self): self.env.set_parallelism(10) parallelism = self.env.get_parallelism() self.assertEqual(parallelism, 10) def test_get_set_buffer_timeout(self): self.env.set_buffer_timeout(12000) timeout = self.env.get_buffer_timeout() self.assertEqual(timeout, 12000) def test_get_set_default_local_parallelism(self): self.env.set_default_local_parallelism(8) parallelism = self.env.get_default_local_parallelism() self.assertEqual(parallelism, 8) def test_set_get_restart_strategy(self): self.env.set_restart_strategy(RestartStrategies.no_restart()) restart_strategy = self.env.get_restart_strategy() self.assertEqual(restart_strategy, RestartStrategies.no_restart()) def test_add_default_kryo_serializer(self): self.env.add_default_kryo_serializer( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo", "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer" ) class_dict = self.env.get_config().get_default_kryo_serializer_classes( ) self.assertEqual( class_dict, { 'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo': 'org.apache.flink.runtime.state' '.StateBackendTestBase$CustomKryoTestSerializer' }) def test_register_type_with_kryo_serializer(self): self.env.register_type_with_kryo_serializer( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo", "org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer" ) class_dict = self.env.get_config( ).get_registered_types_with_kryo_serializer_classes() self.assertEqual( class_dict, { 'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo': 'org.apache.flink.runtime.state' '.StateBackendTestBase$CustomKryoTestSerializer' }) def test_register_type(self): self.env.register_type( "org.apache.flink.runtime.state.StateBackendTestBase$TestPojo") type_list = self.env.get_config().get_registered_pojo_types() self.assertEqual( type_list, ['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo']) def test_get_set_max_parallelism(self): self.env.set_max_parallelism(12) parallelism = self.env.get_max_parallelism() self.assertEqual(parallelism, 12) def test_operation_chaining(self): self.assertTrue(self.env.is_chaining_enabled()) self.env.disable_operator_chaining() self.assertFalse(self.env.is_chaining_enabled()) def test_get_checkpoint_config(self): checkpoint_config = self.env.get_checkpoint_config() self.assertIsInstance(checkpoint_config, CheckpointConfig) def test_get_set_checkpoint_interval(self): self.env.enable_checkpointing(30000) interval = self.env.get_checkpoint_interval() self.assertEqual(interval, 30000) def test_get_set_checkpointing_mode(self): mode = self.env.get_checkpointing_mode() self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE) self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE) mode = self.env.get_checkpointing_mode() self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE) def test_get_state_backend(self): state_backend = self.env.get_state_backend() self.assertIsNone(state_backend) def test_set_state_backend(self): input_backend = MemoryStateBackend() self.env.set_state_backend(input_backend) output_backend = self.env.get_state_backend() self.assertEqual(output_backend._j_memory_state_backend, input_backend._j_memory_state_backend) def test_get_set_stream_time_characteristic(self): default_time_characteristic = self.env.get_stream_time_characteristic() self.assertEqual(default_time_characteristic, TimeCharacteristic.ProcessingTime) self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) time_characteristic = self.env.get_stream_time_characteristic() self.assertEqual(time_characteristic, TimeCharacteristic.EventTime) @unittest.skip( "Python API does not support DataStream now. refactor this test later") def test_get_execution_plan(self): tmp_dir = tempfile.gettempdir() source_path = os.path.join(tmp_dir + '/streaming.csv') tmp_csv = os.path.join(tmp_dir + '/streaming2.csv') field_names = ["a", "b", "c"] field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] t_env = StreamTableEnvironment.create(self.env) csv_source = CsvTableSource(source_path, field_names, field_types) t_env.register_table_source("Orders", csv_source) t_env.register_table_sink( "Results", CsvTableSink(field_names, field_types, tmp_csv)) exec_insert_table(t_env.from_path("Orders"), "Results") plan = self.env.get_execution_plan() json.loads(plan) def test_execute(self): tmp_dir = tempfile.gettempdir() field_names = ['a', 'b', 'c'] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env = StreamTableEnvironment.create(self.env) t_env.register_table_sink( 'Results', CsvTableSink( field_names, field_types, os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time()))))) execution_result = exec_insert_table( t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']), 'Results') self.assertIsNotNone(execution_result.get_job_id()) self.assertIsNotNone(execution_result.get_net_runtime()) self.assertEqual(len(execution_result.get_all_accumulator_results()), 0) self.assertIsNone( execution_result.get_accumulator_result('accumulator')) self.assertIsNotNone(str(execution_result)) def test_from_collection_without_data_types(self): ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) ds.add_sink(self.test_sink) self.env.execute("test from collection") results = self.test_sink.get_results(True) # user does not specify data types for input data, the collected result should be in # in tuple format as inputs. expected = ["(1, 'Hi', 'Hello')", "(2, 'Hello', 'Hi')"] results.sort() expected.sort() self.assertEqual(expected, results) def test_from_collection_with_data_types(self): ds = self.env.from_collection( [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')], type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])) ds.add_sink(self.test_sink) self.env.execute("test from collection") results = self.test_sink.get_results(False) # if user specifies data types of input data, the collected result should be in row format. expected = ['1,Hi,Hello', '2,Hello,Hi'] results.sort() expected.sort() self.assertEqual(expected, results) def test_add_custom_source(self): custom_source = SourceFunction( "org.apache.flink.python.util.MyCustomSourceFunction") ds = self.env.add_source(custom_source, type_info=Types.ROW( [Types.INT(), Types.STRING()])) ds.add_sink(self.test_sink) self.env.execute("test add custom source") results = self.test_sink.get_results(False) expected = ['3,Mike', '1,Marry', '4,Ted', '5,Jack', '0,Bob', '2,Henry'] results.sort() expected.sort() self.assertEqual(expected, results) def test_read_text_file(self): texts = ["Mike", "Marry", "Ted", "Jack", "Bob", "Henry"] text_file_path = self.tempdir + '/text_file' with open(text_file_path, 'a') as f: for text in texts: f.write(text) f.write('\n') ds = self.env.read_text_file(text_file_path) ds.add_sink(self.test_sink) self.env.execute("test read text file") results = self.test_sink.get_results() results.sort() texts.sort() self.assertEqual(texts, results) def test_execute_async(self): ds = self.env.from_collection( [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')], type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])) ds.add_sink(self.test_sink) job_client = self.env.execute_async("test execute async") job_id = job_client.get_job_id() self.assertIsNotNone(job_id) execution_result = job_client.get_job_execution_result().result() self.assertEqual(str(job_id), str(execution_result.get_job_id())) def test_add_python_file(self): import uuid python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4())) os.mkdir(python_file_dir) python_file_path = os.path.join( python_file_dir, "test_stream_dependency_manage_lib.py") with open(python_file_path, 'w') as f: f.write("def add_two(a):\n return a + 2") def plus_two_map(value): from test_stream_dependency_manage_lib import add_two return add_two(value) self.env.add_python_file(python_file_path) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(plus_two_map).add_sink(self.test_sink) self.env.execute("test add python file") result = self.test_sink.get_results(True) expected = ['3', '4', '5', '6', '7'] result.sort() expected.sort() self.assertEqual(expected, result) def test_set_requirements_without_cached_directory(self): import uuid requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("cloudpickle==1.2.2") self.env.set_python_requirements(requirements_txt_path) def check_requirements(i): import cloudpickle assert os.path.abspath(cloudpickle.__file__).startswith( os.environ['_PYTHON_REQUIREMENTS_INSTALL_DIR']) return i ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(check_requirements).add_sink(self.test_sink) self.env.execute("test set requirements without cache dir") result = self.test_sink.get_results(True) expected = ['1', '2', '3', '4', '5'] result.sort() expected.sort() self.assertEqual(expected, result) def test_set_requirements_with_cached_directory(self): import uuid tmp_dir = self.tempdir requirements_txt_path = os.path.join( tmp_dir, "requirements_txt_" + str(uuid.uuid4())) with open(requirements_txt_path, 'w') as f: f.write("python-package1==0.0.0") requirements_dir_path = os.path.join( tmp_dir, "requirements_dir_" + str(uuid.uuid4())) os.mkdir(requirements_dir_path) package_file_name = "python-package1-0.0.0.tar.gz" with open(os.path.join(requirements_dir_path, package_file_name), 'wb') as f: import base64 # This base64 data is encoded from a python package file which includes a # "python_package1" module. The module contains a "plus(a, b)" function. # The base64 can be recomputed by following code: # base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8") f.write( base64.b64decode( "H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI" "XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX" "FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy" "P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ" "wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli" "wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp" "sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls" "W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p" "P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/" "7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP" "86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA==")) self.env.set_python_requirements(requirements_txt_path, requirements_dir_path) def add_one(i): from python_package1 import plus return plus(i, 1) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(add_one).add_sink(self.test_sink) self.env.execute("test set requirements with cachd dir") result = self.test_sink.get_results(True) expected = ['2', '3', '4', '5', '6'] result.sort() expected.sort() self.assertEqual(expected, result) def test_add_python_archive(self): import uuid import shutil tmp_dir = self.tempdir archive_dir_path = os.path.join(tmp_dir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("2") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.env.add_python_archive(archive_file_path, "data") def add_from_file(i): with open("data/data.txt", 'r') as f: return i + int(f.read()) ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(add_from_file).add_sink(self.test_sink) self.env.execute("test set python archive") result = self.test_sink.get_results(True) expected = ['3', '4', '5', '6', '7'] result.sort() expected.sort() self.assertEqual(expected, result) @unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.") def test_set_stream_env(self): import sys python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.env.set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i ds = self.env.from_collection([1, 2, 3, 4, 5]) ds.map(check_python_exec).add_sink(self.test_sink) self.env.execute("test set python executable") result = self.test_sink.get_results(True) expected = ['1', '2', '3', '4', '5'] result.sort() expected.sort() self.assertEqual(expected, result) def test_generate_stream_graph_with_dependencies(self): python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4())) os.mkdir(python_file_dir) python_file_path = os.path.join( python_file_dir, "test_stream_dependency_manage_lib.py") with open(python_file_path, 'w') as f: f.write("def add_two(a):\n return a + 2") self.env.add_python_file(python_file_path) def plus_two_map(value): from test_stream_dependency_manage_lib import add_two return value[0], add_two(value[1]) def add_from_file(i): with open("data/data.txt", 'r') as f: return i[0], i[1] + int(f.read()) from_collection_source = self.env.from_collection( [('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], type_info=Types.ROW([Types.STRING(), Types.INT()])) from_collection_source.name("From Collection") keyed_stream = from_collection_source.key_by(lambda x: x[1], key_type_info=Types.INT()) plus_two_map_stream = keyed_stream.map(plus_two_map).name( "Plus Two Map").set_parallelism(3) add_from_file_map = plus_two_map_stream.map(add_from_file).name( "Add From File Map") test_stream_sink = add_from_file_map.add_sink( self.test_sink).name("Test Sink") test_stream_sink.set_parallelism(4) archive_dir_path = os.path.join(self.tempdir, "archive_" + str(uuid.uuid4())) os.mkdir(archive_dir_path) with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f: f.write("3") archive_file_path = \ shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path) self.env.add_python_archive(archive_file_path, "data") nodes = eval(self.env.get_execution_plan())['nodes'] # The StreamGraph should be as bellow: # Source: From Collection -> _stream_key_by_map_operator -> _keyed_stream_values_operator -> # Plus Two Map -> Add From File Map -> Sink: Test Sink. # Source: From Collection and _stream_key_by_map_operator should have same parallelism. self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism']) # _keyed_stream_values_operator and Plus Two Map should have same parallisim. self.assertEqual(nodes[3]['parallelism'], 3) self.assertEqual(nodes[2]['parallelism'], nodes[3]['parallelism']) # The ship_strategy for Source: From Collection and _stream_key_by_map_operator shoule be # FORWARD self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'], "FORWARD") # The ship_strategy for _keyed_stream_values_operator and Plus Two Map shoule be # FORWARD self.assertEqual(nodes[3]['predecessors'][0]['ship_strategy'], "FORWARD") # The parallelism of Sink: Test Sink should be 4 self.assertEqual(nodes[5]['parallelism'], 4) env_config_with_dependencies = dict( get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil. getEnvConfigWithDependencies( self.env._j_stream_execution_environment).toMap()) # Make sure that user specified files and archives are correctly added. self.assertIsNotNone(env_config_with_dependencies['python.files']) self.assertIsNotNone(env_config_with_dependencies['python.archives']) def tearDown(self) -> None: self.test_sink.clear()