def test_set_environment(self): python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.t_env.get_config().set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i self.t_env.create_temporary_system_function( "check_python_exec", udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): try: from pyflink.java_gateway import get_gateway get_gateway() except Exception as e: assert str(e).startswith( "It's launching the PythonGatewayServer during Python UDF" " execution which is unexpected.") else: raise Exception("The gateway server is not disabled!") return i self.t_env.create_temporary_system_function( "check_pyflink_gateway_disabled", udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select( expr.call('check_python_exec', t.a), expr.call('check_pyflink_gateway_disabled', t.a)) \ .execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"])
def test_basic(self): schema = Schema.new_builder() \ .column("f0", DataTypes.STRING()) \ .column("f1", DataTypes.BIGINT()) \ .primary_key("f0") \ .build() descriptor = TableDescriptor.for_connector("test-connector") \ .schema(schema) \ .partitioned_by("f0") \ .comment("Test Comment") \ .build() self.assertIsNotNone(descriptor.get_schema()) self.assertEqual(1, len(descriptor.get_partition_keys())) self.assertEqual("f0", descriptor.get_partition_keys()[0]) self.assertEqual(1, len(descriptor.get_options())) self.assertEqual("test-connector", descriptor.get_options().get("connector")) self.assertEqual("Test Comment", descriptor.get_comment())
def test_chaining_scalar_function(self): self.t_env.register_function( "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) self.t_env.register_function( "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) self.t_env.register_function("add", add) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c'], [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.INT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c']) t.select("add(add_one(a), subtract_one(b)), c, 1") \ .insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["3,1,1", "7,2,1", "4,3,1"])
def test_table_function(self): self._register_table_sink( ['a', 'b', 'c'], [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) multi_emit = udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) multi_num = udf(MultiNum(), result_type=DataTypes.BIGINT()) t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) t = t.join_lateral(multi_emit(t.a, multi_num(t.b)).alias('x', 'y')) t = t.left_outer_join_lateral(condition_multi_emit(t.x, t.y).alias('m')) \ .select("x, y, m") t = t.left_outer_join_lateral(identity(t.m).alias('n')) \ .select("x, y, n") actual = self._get_output(t) self.assert_equals(actual, ["1,0,null", "1,1,null", "2,0,null", "2,1,null", "3,0,0", "3,0,1", "3,0,2", "3,1,1", "3,1,2", "3,2,2", "3,3,null"])
def intersect_batch(): b_env = ExecutionEnvironment.get_execution_environment() b_env.set_parallelism(1) bt_env = BatchTableEnvironment.create(b_env) result_file = "/tmp/table_intersect_batch.csv" if os.path.exists(result_file): os.remove(result_file) left = bt_env.from_elements([(1, "ra", "raa"), (2, "lb", "lbb"), (3, "", "lcc"), (1, "ra", "raa")], ["a", "b", "c"]).select("a, b, c") right = bt_env.from_elements([(1, "ra", "raa"), (2, "", "rbb"), (3, "rc", "rcc"), (1, "ra", "raa")], ["a", "b", "c"]).select("a, b, c") bt_env.register_table_sink( "result", CsvTableSink( ["a", "b", "c"], [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()], result_file)) result = left.intersect(right) result.insert_into("result") bt_env.execute("intersect batch")
def aggregate_func_python_sql_api(): b_env = ExecutionEnvironment.get_execution_environment() b_env.set_parallelism(1) bt_env = BatchTableEnvironment.create(b_env) source_table = bt_env.from_elements([("a", 1, 1), ("a", 2, 2), ("b", 3, 2), ("a", 5, 2)], ["user", "points", "level"]) result_file = "/tmp/aggregate_func_python_sql_api.csv" if os.path.exists(result_file): os.remove(result_file) bt_env.register_table_sink( "result", CsvTableSink( ["a", "b"], [DataTypes.STRING(), DataTypes.BIGINT()], result_file)) bt_env.register_java_function("wAvg", "com.pyflink.table.WeightedAvg") bt_env.register_table("userScores", source_table) result = bt_env.sql_query( "SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user" ) result.insert_into("result") bt_env.execute("aggregate func python sql api")
def test_explain_with_multi_sinks(self): t_env = self.t_env source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"]) field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env.register_table_sink( "sink1", CsvTableSink(field_names, field_types, "path1")) t_env.register_table_sink( "sink2", CsvTableSink(field_names, field_types, "path2")) stmt_set = t_env.create_statement_set() stmt_set.add_insert_sql( "insert into sink1 select * from %s where a > 100" % source) stmt_set.add_insert_sql( "insert into sink2 select * from %s where a < 100" % source) actual = stmt_set.explain(ExplainDetail.ESTIMATED_COST, ExplainDetail.CHANGELOG_MODE) self.assertIsInstance(actual, str)
def test_where(self): source_path = os.path.join(self.tempdir + '/streaming.csv') field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env t_env.register_table_source("Source", csv_source) source = t_env.scan("Source") t_env.register_table_sink("Results", field_names, field_types, source_sink_utils.TestAppendSink()) result = source.where("a > 1 && b = 'Hello'") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() expected = ['2,Hello,Hello'] self.assert_equals(actual, expected)
def test_explain_with_multi_sinks_with_blink_planner(self): t_env = BatchTableEnvironment.create( environment_settings=EnvironmentSettings.new_instance( ).in_batch_mode().use_blink_planner().build()) source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"]) field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env.register_table_sink( "sink1", CsvTableSink(field_names, field_types, "path1")) t_env.register_table_sink( "sink2", CsvTableSink(field_names, field_types, "path2")) t_env.sql_update("insert into sink1 select * from %s where a > 100" % source) t_env.sql_update("insert into sink2 select * from %s where a < 100" % source) actual = t_env.explain(extended=True) self.assertIsInstance(actual, (str, unicode))
def test_execute(self): tmp_dir = tempfile.gettempdir() field_names = ['a', 'b', 'c'] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env = StreamTableEnvironment.create(self.env) t_env.register_table_sink( 'Results', CsvTableSink( field_names, field_types, os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time()))))) execution_result = exec_insert_table( t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']), 'Results') self.assertIsNotNone(execution_result.get_job_id()) self.assertIsNotNone(execution_result.get_net_runtime()) self.assertEqual(len(execution_result.get_all_accumulator_results()), 0) self.assertIsNone( execution_result.get_accumulator_result('accumulator')) self.assertIsNotNone(str(execution_result))
def custom_test_sink_demo(): s_env = StreamExecutionEnvironment.get_execution_environment() s_env.set_parallelism(1) st_env = StreamTableEnvironment.create(s_env) left = st_env.from_elements([(1, "1a", "1laa"), (2, "2a", "2aa"), (3, None, "3aa"), (2, "4b", "4bb"), (5, "5a", "5aa")], ["a", "b", "c"]).select("a, b, c") right = st_env.from_elements([(1, "1b", "1bb"), (2, None, "2bb"), (1, "3b", "3bb"), (4, "4b", "4bb")], ["d", "e", "f"]).select("d, e, f") result = left.left_outer_join(right, "a = d").select("a, b, e") # use custom retract sink connector custom_connector = CustomConnectorDescriptor('pyflink-test', 1, False) st_env.connect(custom_connector) \ .with_schema( Schema() .field("a", DataTypes.BIGINT()) .field("b", DataTypes.STRING()) .field("c", DataTypes.STRING()) ).register_table_sink("sink") result.insert_into("sink") st_env.execute("custom test sink demo")
def test_explain_with_multi_sinks(self): t_env = self.t_env source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"]) field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env.register_table_sink( "sink1", source_sink_utils.TestAppendSink(field_names, field_types)) t_env.register_table_sink( "sink2", source_sink_utils.TestAppendSink(field_names, field_types)) t_env.sql_update("insert into sink1 select * from %s where a > 100" % source) t_env.sql_update("insert into sink2 select * from %s where a < 100" % source) actual = t_env.explain(extended=True) assert isinstance(actual, str)
def test_statement_set(self): t_env = self.t_env source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"]) field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING() ] t_env.register_table_sink( "sink1", source_sink_utils.TestAppendSink(field_names, field_types)) t_env.register_table_sink( "sink2", source_sink_utils.TestAppendSink(field_names, field_types)) stmt_set = t_env.create_statement_set() stmt_set.add_insert_sql("insert into sink1 select * from %s where a > 100" % source)\ .add_insert("sink2", source.filter("a < 100"), False) actual = stmt_set.explain(ExplainDetail.CHANGELOG_MODE) assert isinstance(actual, str)
def test_from_element_expression(self): t_env = self.t_env field_names = ["a", "b", "c"] field_types = [ DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.FLOAT() ] schema = DataTypes.ROW( list( map( lambda field_name, field_type: DataTypes.FIELD( field_name, field_type), field_names, field_types))) table_sink = source_sink_utils.TestAppendSink(field_names, field_types) t_env.register_table_sink("Results", table_sink) t = t_env.from_elements( [row(1, 'abc', 2.0), row(2, 'def', 3.0)], schema) t.execute_insert("Results").wait() actual = source_sink_utils.results() expected = ['1,abc,2.0', '2,def,3.0'] self.assert_equals(actual, expected)
def test_set_environment(self): if getattr(os, "symlink", None) is None: self.skipTest("Symbolic link is not supported, skip testing 'test_set_python_exec'...") python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.t_env.get_config().set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i self.t_env.register_function("check_python_exec", udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): try: from pyflink.java_gateway import get_gateway get_gateway() except Exception as e: assert str(e).startswith("It's launching the PythonGatewayServer during Python UDF" " execution which is unexpected.") else: raise Exception("The gateway server is not disabled!") return i self.t_env.register_function("check_pyflink_gateway_disabled", udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select("check_python_exec(a), check_pyflink_gateway_disabled(a)").insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["1,1", "2,2", "3,3"])
def test_table_function_with_sql_query(self): table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c'], [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) self.t_env.register_function( "multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()], [DataTypes.BIGINT(), DataTypes.BIGINT()])) t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) self.t_env.register_table("MyTable", t) self.t_env.sql_query( "SELECT a, x, y FROM MyTable LEFT JOIN LATERAL TABLE(multi_emit(a, b)) as T(x, y)" " ON TRUE") \ .insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["1,1,0", "2,2,0", "3,3,0", "3,3,1"])
def test_basic_functionality(self): # pandas UDF self.t_env.create_temporary_system_function( "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="pandas")) self.t_env.create_temporary_system_function("add", add) # general Python UDF self.t_env.create_temporary_system_function( "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd'], [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) exec_insert_table( t.where(E.call('add_one', t.b) <= 3) .select("a, b + 1, add(a + 1, subtract_one(c)) + 2, add(add_one(a), 1L)"), "Results") actual = source_sink_utils.results() self.assert_equals(actual, ["1,3,6,3", "3,2,14,5"])
def test_table_function(self): self._register_table_sink( ['a', 'b', 'c'], [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_function( "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])) self.t_env.register_function("condition_multi_emit", condition_multi_emit) self.t_env.register_function( "multi_num", udf(MultiNum(), result_type=DataTypes.BIGINT())) t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) t = t.join_lateral("multi_emit(a, multi_num(b)) as (x, y)") \ .left_outer_join_lateral("condition_multi_emit(x, y) as m") \ .select("x, y, m") actual = self._get_output(t) self.assert_equals(actual, ["1,0,null", "1,1,null", "2,0,null", "2,1,null", "3,0,0", "3,0,1", "3,0,2", "3,1,1", "3,1,2", "3,2,2", "3,3,null"])
def test_set_environment(self): python_exec = sys.executable tmp_dir = self.tempdir python_exec_link_path = os.path.join(tmp_dir, "py_exec") os.symlink(python_exec, python_exec_link_path) self.st_env.get_config().set_python_executable(python_exec_link_path) def check_python_exec(i): import os assert os.environ["python"] == python_exec_link_path return i self.st_env.create_temporary_system_function( "check_python_exec", udf(check_python_exec, DataTypes.BIGINT(), DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): from pyflink.java_gateway import get_gateway get_gateway() return i self.st_env.create_temporary_system_function( "check_pyflink_gateway_disabled", udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.st_env.register_table_sink("Results", table_sink) t = self.st_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) t.select( expr.call('check_python_exec', t.a), expr.call('check_pyflink_gateway_disabled', t.a)) \ .execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, ["+I[1, 1]", "+I[2, 2]", "+I[3, 3]"])
def test_all_data_types(self): import pandas as pd import numpy as np @udf(result_type=DataTypes.TINYINT(), func_type="pandas") def tinyint_func(tinyint_param): assert isinstance(tinyint_param, pd.Series) assert isinstance(tinyint_param[0], np.int8), \ 'tinyint_param of wrong type %s !' % type(tinyint_param[0]) return tinyint_param @udf(result_type=DataTypes.SMALLINT(), func_type="pandas") def smallint_func(smallint_param): assert isinstance(smallint_param, pd.Series) assert isinstance(smallint_param[0], np.int16), \ 'smallint_param of wrong type %s !' % type(smallint_param[0]) assert smallint_param[ 0] == 32767, 'smallint_param of wrong value %s' % smallint_param return smallint_param @udf(result_type=DataTypes.INT(), func_type="pandas") def int_func(int_param): assert isinstance(int_param, pd.Series) assert isinstance(int_param[0], np.int32), \ 'int_param of wrong type %s !' % type(int_param[0]) assert int_param[ 0] == -2147483648, 'int_param of wrong value %s' % int_param return int_param @udf(result_type=DataTypes.BIGINT(), func_type="pandas") def bigint_func(bigint_param): assert isinstance(bigint_param, pd.Series) assert isinstance(bigint_param[0], np.int64), \ 'bigint_param of wrong type %s !' % type(bigint_param[0]) return bigint_param @udf(result_type=DataTypes.BOOLEAN(), func_type="pandas") def boolean_func(boolean_param): assert isinstance(boolean_param, pd.Series) assert isinstance(boolean_param[0], np.bool_), \ 'boolean_param of wrong type %s !' % type(boolean_param[0]) return boolean_param @udf(result_type=DataTypes.FLOAT(), func_type="pandas") def float_func(float_param): assert isinstance(float_param, pd.Series) assert isinstance(float_param[0], np.float32), \ 'float_param of wrong type %s !' % type(float_param[0]) return float_param @udf(result_type=DataTypes.DOUBLE(), func_type="pandas") def double_func(double_param): assert isinstance(double_param, pd.Series) assert isinstance(double_param[0], np.float64), \ 'double_param of wrong type %s !' % type(double_param[0]) return double_param @udf(result_type=DataTypes.STRING(), func_type="pandas") def varchar_func(varchar_param): assert isinstance(varchar_param, pd.Series) assert isinstance(varchar_param[0], str), \ 'varchar_param of wrong type %s !' % type(varchar_param[0]) return varchar_param @udf(result_type=DataTypes.BYTES(), func_type="pandas") def varbinary_func(varbinary_param): assert isinstance(varbinary_param, pd.Series) assert isinstance(varbinary_param[0], bytes), \ 'varbinary_param of wrong type %s !' % type(varbinary_param[0]) return varbinary_param @udf(result_type=DataTypes.DECIMAL(38, 18), func_type="pandas") def decimal_func(decimal_param): assert isinstance(decimal_param, pd.Series) assert isinstance(decimal_param[0], decimal.Decimal), \ 'decimal_param of wrong type %s !' % type(decimal_param[0]) return decimal_param @udf(result_type=DataTypes.DATE(), func_type="pandas") def date_func(date_param): assert isinstance(date_param, pd.Series) assert isinstance(date_param[0], datetime.date), \ 'date_param of wrong type %s !' % type(date_param[0]) return date_param @udf(result_type=DataTypes.TIME(), func_type="pandas") def time_func(time_param): assert isinstance(time_param, pd.Series) assert isinstance(time_param[0], datetime.time), \ 'time_param of wrong type %s !' % type(time_param[0]) return time_param timestamp_value = datetime.datetime(1970, 1, 2, 0, 0, 0, 123000) @udf(result_type=DataTypes.TIMESTAMP(3), func_type="pandas") def timestamp_func(timestamp_param): assert isinstance(timestamp_param, pd.Series) assert isinstance(timestamp_param[0], datetime.datetime), \ 'timestamp_param of wrong type %s !' % type(timestamp_param[0]) assert timestamp_param[0] == timestamp_value, \ 'timestamp_param is wrong value %s, should be %s!' % (timestamp_param[0], timestamp_value) return timestamp_param def array_func(array_param): assert isinstance(array_param, pd.Series) assert isinstance(array_param[0], np.ndarray), \ 'array_param of wrong type %s !' % type(array_param[0]) return array_param array_str_func = udf(array_func, result_type=DataTypes.ARRAY(DataTypes.STRING()), func_type="pandas") array_timestamp_func = udf(array_func, result_type=DataTypes.ARRAY( DataTypes.TIMESTAMP(3)), func_type="pandas") array_int_func = udf(array_func, result_type=DataTypes.ARRAY(DataTypes.INT()), func_type="pandas") @udf(result_type=DataTypes.ARRAY(DataTypes.STRING()), func_type="pandas") def nested_array_func(nested_array_param): assert isinstance(nested_array_param, pd.Series) assert isinstance(nested_array_param[0], np.ndarray), \ 'nested_array_param of wrong type %s !' % type(nested_array_param[0]) return pd.Series(nested_array_param[0]) row_type = DataTypes.ROW([ DataTypes.FIELD("f1", DataTypes.INT()), DataTypes.FIELD("f2", DataTypes.STRING()), DataTypes.FIELD("f3", DataTypes.TIMESTAMP(3)), DataTypes.FIELD("f4", DataTypes.ARRAY(DataTypes.INT())) ]) @udf(result_type=row_type, func_type="pandas") def row_func(row_param): assert isinstance(row_param, pd.DataFrame) assert isinstance(row_param.f1, pd.Series) assert isinstance(row_param.f1[0], np.int32), \ 'row_param.f1 of wrong type %s !' % type(row_param.f1[0]) assert isinstance(row_param.f2, pd.Series) assert isinstance(row_param.f2[0], str), \ 'row_param.f2 of wrong type %s !' % type(row_param.f2[0]) assert isinstance(row_param.f3, pd.Series) assert isinstance(row_param.f3[0], datetime.datetime), \ 'row_param.f3 of wrong type %s !' % type(row_param.f3[0]) assert isinstance(row_param.f4, pd.Series) assert isinstance(row_param.f4[0], np.ndarray), \ 'row_param.f4 of wrong type %s !' % type(row_param.f4[0]) return row_param table_sink = source_sink_utils.TestAppendSink([ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u' ], [ DataTypes.TINYINT(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.BIGINT(), DataTypes.BOOLEAN(), DataTypes.BOOLEAN(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING(), DataTypes.STRING(), DataTypes.BYTES(), DataTypes.DECIMAL(38, 18), DataTypes.DECIMAL(38, 18), DataTypes.DATE(), DataTypes.TIME(), DataTypes.TIMESTAMP(3), DataTypes.ARRAY(DataTypes.STRING()), DataTypes.ARRAY(DataTypes.TIMESTAMP(3)), DataTypes.ARRAY(DataTypes.INT()), DataTypes.ARRAY(DataTypes.STRING()), row_type ]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements( [(1, 32767, -2147483648, 1, True, False, 1.0, 1.0, 'hello', '中文', bytearray(b'flink'), decimal.Decimal('1000000000000000000.05'), decimal.Decimal( '1000000000000000000.05999999999999999899999999999'), datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1), timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2], [[ 'hello', '中文', None ]], Row(1, 'hello', timestamp_value, [1, 2]))], DataTypes.ROW([ DataTypes.FIELD("a", DataTypes.TINYINT()), DataTypes.FIELD("b", DataTypes.SMALLINT()), DataTypes.FIELD("c", DataTypes.INT()), DataTypes.FIELD("d", DataTypes.BIGINT()), DataTypes.FIELD("e", DataTypes.BOOLEAN()), DataTypes.FIELD("f", DataTypes.BOOLEAN()), DataTypes.FIELD("g", DataTypes.FLOAT()), DataTypes.FIELD("h", DataTypes.DOUBLE()), DataTypes.FIELD("i", DataTypes.STRING()), DataTypes.FIELD("j", DataTypes.STRING()), DataTypes.FIELD("k", DataTypes.BYTES()), DataTypes.FIELD("l", DataTypes.DECIMAL(38, 18)), DataTypes.FIELD("m", DataTypes.DECIMAL(38, 18)), DataTypes.FIELD("n", DataTypes.DATE()), DataTypes.FIELD("o", DataTypes.TIME()), DataTypes.FIELD("p", DataTypes.TIMESTAMP(3)), DataTypes.FIELD("q", DataTypes.ARRAY(DataTypes.STRING())), DataTypes.FIELD("r", DataTypes.ARRAY(DataTypes.TIMESTAMP(3))), DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())), DataTypes.FIELD( "t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))), DataTypes.FIELD("u", row_type) ])) t.select( tinyint_func(t.a), smallint_func(t.b), int_func(t.c), bigint_func(t.d), boolean_func(t.e), boolean_func(t.f), float_func(t.g), double_func(t.h), varchar_func(t.i), varchar_func(t.j), varbinary_func(t.k), decimal_func(t.l), decimal_func(t.m), date_func(t.n), time_func(t.o), timestamp_func(t.p), array_str_func(t.q), array_timestamp_func(t.r), array_int_func(t.s), nested_array_func(t.t), row_func(t.u)) \ .execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals(actual, [ "+I[1, 32767, -2147483648, 1, true, false, 1.0, 1.0, hello, 中文, " "[102, 108, 105, 110, 107], 1000000000000000000.050000000000000000, " "1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, " "1970-01-02 00:00:00.123, [hello, 中文, null], [1970-01-02 00:00:00.123], " "[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02 00:00:00.123, [1, 2]]]" ])
t_env = StreamTableEnvironment.create( env, environment_settings = EnvironmentSettings.new_instance() .use_blink_planner() .build(), ) result_path = '/notebooks/output.csv' print('Results directory:', result_path) t_env.connect(FileSystem().path(result_path)).with_format( OldCsv() .field_delimiter(',') .field('word', DataTypes.STRING()) .field('count', DataTypes.BIGINT()) ).with_schema( Schema() .field('word', DataTypes.STRING()) .field('count', DataTypes.BIGINT()) ).register_table_sink( 'Results' ) elements = [(word, 1) for word in content.split(' ')] t_env.from_elements(elements, ['word', 'count']).group_by('word').select( 'word, count(1) as count' ).insert_into('Results') t_env.execute('word_count')
class MultiEmit(TableFunction, unittest.TestCase): def open(self, function_context): mg = function_context.get_metric_group() self.counter = mg.add_group("key", "value").counter("my_counter") self.counter_sum = 0 def eval(self, x, y): self.counter.inc(y) self.counter_sum += y self.assertEqual(self.counter_sum, self.counter.get_count()) for i in range(y): yield x, i @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_types=DataTypes.BIGINT()) def condition_multi_emit(x, y): if x == 3: return range(y, x) class MultiNum(ScalarFunction): def eval(self, x): return x * 2 if __name__ == '__main__': import unittest try:
.select(t.a, t.b + 1, add(t.a + 1, subtract_one(t.c)) + 2, add(add_one(t.a), 1)) result = self.collect(t) self.assert_equals(result, ["+I[1, 3, 6, 3]", "+I[3, 2, 14, 5]"]) class BlinkBatchPandasUDFITTests(PandasUDFITTests, BlinkPandasUDFITTests, PyFlinkBlinkBatchTableTestCase): pass class BlinkStreamPandasUDFITTests(PandasUDFITTests, BlinkPandasUDFITTests, PyFlinkBlinkStreamTableTestCase): pass @udf(result_type=DataTypes.BIGINT(), func_type='pandas') def add(i, j): return i + j if __name__ == '__main__': import unittest try: import xmlrunner testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2)
from pyflink.datastream import StreamExecutionEnvironment from pyflink.table import StreamTableEnvironment, DataTypes from pyflink.table.descriptors import Schema, OldCsv, FileSystem from pyflink.table.udf import udf env = StreamExecutionEnvironment.get_execution_environment() env.set_parallelism(1) t_env = StreamTableEnvironment.create(env) add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) t_env.register_function("add", add) t_env.connect(FileSystem().path('/opt/examples/data/udf_add_input')) \ .with_format(OldCsv() .field('a', DataTypes.BIGINT()) .field('b', DataTypes.BIGINT())) \ .with_schema(Schema() .field('a', DataTypes.BIGINT()) .field('b', DataTypes.BIGINT())) \ .create_temporary_table('mySource') t_env.connect(FileSystem().path('/opt/examples/data/udf_add_output')) \ .with_format(OldCsv() .field('sum', DataTypes.BIGINT())) \ .with_schema(Schema() .field('sum', DataTypes.BIGINT())) \ .create_temporary_table('mySink') t_env.from_path('mySource')\ .select("add(a, b)") \ .insert_into('mySink')
def test_collect_for_all_data_types(self): expected_result = [ Row(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'pyflink'), 'pyflink', datetime.date(2014, 9, 13), datetime.time(12, 0, 0, 123000), datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [Row(['[pyflink]']), Row(['[pyflink]']), Row(['[pyflink]'])], { 1: Row(['[flink]']), 2: Row(['[pyflink]']) }, decimal.Decimal('1000000000000000000.050000000000000000'), decimal.Decimal('1000000000000000000.059999999999999999')) ] source = self.t_env.from_elements( [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'pyflink'), 'pyflink', datetime.date(2014, 9, 13), datetime.time(hour=12, minute=0, second=0, microsecond=123000), datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [Row(['pyflink']), Row(['pyflink']), Row(['pyflink'])], { 1: Row(['flink']), 2: Row(['pyflink']) }, decimal.Decimal('1000000000000000000.05'), decimal.Decimal( '1000000000000000000.05999999999999999899999999999'))], DataTypes.ROW([ DataTypes.FIELD("a", DataTypes.BIGINT()), DataTypes.FIELD("b", DataTypes.BIGINT()), DataTypes.FIELD("c", DataTypes.TINYINT()), DataTypes.FIELD("d", DataTypes.BOOLEAN()), DataTypes.FIELD("e", DataTypes.SMALLINT()), DataTypes.FIELD("f", DataTypes.INT()), DataTypes.FIELD("g", DataTypes.FLOAT()), DataTypes.FIELD("h", DataTypes.DOUBLE()), DataTypes.FIELD("i", DataTypes.BYTES()), DataTypes.FIELD("j", DataTypes.STRING()), DataTypes.FIELD("k", DataTypes.DATE()), DataTypes.FIELD("l", DataTypes.TIME()), DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)), DataTypes.FIELD( "n", DataTypes.ARRAY( DataTypes.ROW( [DataTypes.FIELD('ss2', DataTypes.STRING())]))), DataTypes.FIELD( "o", DataTypes.MAP( DataTypes.BIGINT(), DataTypes.ROW( [DataTypes.FIELD('ss', DataTypes.STRING())]))), DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)), DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18)) ])) table_result = source.execute() with table_result.collect() as result: collected_result = [] for i in result: collected_result.append(i) self.assertEqual(expected_result, collected_result)
def test_all_data_types(self): def boolean_func(bool_param): assert isinstance(bool_param, bool), 'bool_param of wrong type %s !' \ % type(bool_param) return bool_param def tinyint_func(tinyint_param): assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \ % type(tinyint_param) return tinyint_param def smallint_func(smallint_param): assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \ % type(smallint_param) assert smallint_param == 32767, 'smallint_param of wrong value %s' % smallint_param return smallint_param def int_func(int_param): assert isinstance(int_param, int), 'int_param of wrong type %s !' \ % type(int_param) assert int_param == -2147483648, 'int_param of wrong value %s' % int_param return int_param def bigint_func(bigint_param): assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \ % type(bigint_param) return bigint_param def bigint_func_none(bigint_param): assert bigint_param is None, 'bigint_param %s should be None!' % bigint_param return bigint_param def float_func(float_param): assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-6), \ 'float_param is wrong value %s !' % float_param return float_param def double_func(double_param): assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-7), \ 'double_param is wrong value %s !' % double_param return double_param def bytes_func(bytes_param): assert bytes_param == b'flink', \ 'bytes_param is wrong value %s !' % bytes_param return bytes_param def str_func(str_param): assert str_param == 'pyflink', \ 'str_param is wrong value %s !' % str_param return str_param def date_func(date_param): from datetime import date assert date_param == date(year=2014, month=9, day=13), \ 'date_param is wrong value %s !' % date_param return date_param def time_func(time_param): from datetime import time assert time_param == time(hour=12, minute=0, second=0, microsecond=123000), \ 'time_param is wrong value %s !' % time_param return time_param def timestamp_func(timestamp_param): from datetime import datetime assert timestamp_param == datetime(2018, 3, 11, 3, 0, 0, 123000), \ 'timestamp_param is wrong value %s !' % timestamp_param return timestamp_param def array_func(array_param): assert array_param == [[1, 2, 3]], \ 'array_param is wrong value %s !' % array_param return array_param[0] def map_func(map_param): assert map_param == {1: 'flink', 2: 'pyflink'}, \ 'map_param is wrong value %s !' % map_param return map_param def decimal_func(decimal_param): from decimal import Decimal assert decimal_param == Decimal('1000000000000000000.050000000000000000'), \ 'decimal_param is wrong value %s !' % decimal_param return decimal_param def decimal_cut_func(decimal_param): from decimal import Decimal assert decimal_param == Decimal('1000000000000000000.059999999999999999'), \ 'decimal_param is wrong value %s !' % decimal_param return decimal_param self.t_env.create_temporary_system_function( "boolean_func", udf(boolean_func, result_type=DataTypes.BOOLEAN())) self.t_env.create_temporary_system_function( "tinyint_func", udf(tinyint_func, result_type=DataTypes.TINYINT())) self.t_env.create_temporary_system_function( "smallint_func", udf(smallint_func, result_type=DataTypes.SMALLINT())) self.t_env.create_temporary_system_function( "int_func", udf(int_func, result_type=DataTypes.INT())) self.t_env.create_temporary_system_function( "bigint_func", udf(bigint_func, result_type=DataTypes.BIGINT())) self.t_env.create_temporary_system_function( "bigint_func_none", udf(bigint_func_none, result_type=DataTypes.BIGINT())) self.t_env.create_temporary_system_function( "float_func", udf(float_func, result_type=DataTypes.FLOAT())) self.t_env.create_temporary_system_function( "double_func", udf(double_func, result_type=DataTypes.DOUBLE())) self.t_env.create_temporary_system_function( "bytes_func", udf(bytes_func, result_type=DataTypes.BYTES())) self.t_env.create_temporary_system_function( "str_func", udf(str_func, result_type=DataTypes.STRING())) self.t_env.create_temporary_system_function( "date_func", udf(date_func, result_type=DataTypes.DATE())) self.t_env.create_temporary_system_function( "time_func", udf(time_func, result_type=DataTypes.TIME())) self.t_env.create_temporary_system_function( "timestamp_func", udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3))) self.t_env.create_temporary_system_function( "array_func", udf(array_func, result_type=DataTypes.ARRAY(DataTypes.BIGINT()))) self.t_env.create_temporary_system_function( "map_func", udf(map_func, result_type=DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()))) self.t_env.register_function( "decimal_func", udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18))) self.t_env.register_function( "decimal_cut_func", udf(decimal_cut_func, result_type=DataTypes.DECIMAL(38, 18))) table_sink = source_sink_utils.TestAppendSink([ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q' ], [ DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.TINYINT(), DataTypes.BOOLEAN(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.BYTES(), DataTypes.STRING(), DataTypes.DATE(), DataTypes.TIME(), DataTypes.TIMESTAMP(3), DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()), DataTypes.DECIMAL(38, 18), DataTypes.DECIMAL(38, 18) ]) self.t_env.register_table_sink("Results", table_sink) import datetime import decimal t = self.t_env.from_elements( [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932, bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13), datetime.time(hour=12, minute=0, second=0, microsecond=123000), datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [[1, 2, 3]], { 1: 'flink', 2: 'pyflink' }, decimal.Decimal('1000000000000000000.05'), decimal.Decimal( '1000000000000000000.05999999999999999899999999999'))], DataTypes.ROW([ DataTypes.FIELD("a", DataTypes.BIGINT()), DataTypes.FIELD("b", DataTypes.BIGINT()), DataTypes.FIELD("c", DataTypes.TINYINT()), DataTypes.FIELD("d", DataTypes.BOOLEAN()), DataTypes.FIELD("e", DataTypes.SMALLINT()), DataTypes.FIELD("f", DataTypes.INT()), DataTypes.FIELD("g", DataTypes.FLOAT()), DataTypes.FIELD("h", DataTypes.DOUBLE()), DataTypes.FIELD("i", DataTypes.BYTES()), DataTypes.FIELD("j", DataTypes.STRING()), DataTypes.FIELD("k", DataTypes.DATE()), DataTypes.FIELD("l", DataTypes.TIME()), DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)), DataTypes.FIELD( "n", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))), DataTypes.FIELD( "o", DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())), DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)), DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18)) ])) exec_insert_table( t.select("bigint_func(a), bigint_func_none(b)," "tinyint_func(c), boolean_func(d)," "smallint_func(e),int_func(f)," "float_func(g),double_func(h)," "bytes_func(i),str_func(j)," "date_func(k),time_func(l)," "timestamp_func(m),array_func(n)," "map_func(o),decimal_func(p)," "decimal_cut_func(q)"), "Results") actual = source_sink_utils.results() # Currently the sink result precision of DataTypes.TIME(precision) only supports 0. self.assert_equals(actual, [ "1,null,1,true,32767,-2147483648,1.23,1.98932," "[102, 108, 105, 110, 107],pyflink,2014-09-13," "12:00:00,2018-03-11 03:00:00.123,[1, 2, 3]," "{1=flink, 2=pyflink},1000000000000000000.050000000000000000," "1000000000000000000.059999999999999999" ])
def test_udf_with_constant_params(self): def udf_with_constant_params(p, null_param, tinyint_param, smallint_param, int_param, bigint_param, decimal_param, float_param, double_param, boolean_param, str_param, date_param, time_param, timestamp_param): from decimal import Decimal import datetime assert null_param is None, 'null_param is wrong value %s' % null_param assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \ % type(tinyint_param) p += tinyint_param assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \ % type(smallint_param) p += smallint_param assert isinstance(int_param, int), 'int_param of wrong type %s !' \ % type(int_param) p += int_param assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \ % type(bigint_param) p += bigint_param assert decimal_param == Decimal('1.05'), \ 'decimal_param is wrong value %s ' % decimal_param p += int(decimal_param) assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-06), \ 'float_param is wrong value %s ' % float_param p += int(float_param) assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-07), \ 'double_param is wrong value %s ' % double_param p += int(double_param) assert boolean_param is True, 'boolean_param is wrong value %s' % boolean_param assert str_param == 'flink', 'str_param is wrong value %s' % str_param assert date_param == datetime.date(year=2014, month=9, day=13), \ 'date_param is wrong value %s' % date_param assert time_param == datetime.time(hour=12, minute=0, second=0), \ 'time_param is wrong value %s' % time_param assert timestamp_param == datetime.datetime(1999, 9, 10, 5, 20, 10), \ 'timestamp_param is wrong value %s' % timestamp_param return p self.t_env.create_temporary_system_function( "udf_with_constant_params", udf(udf_with_constant_params, result_type=DataTypes.BIGINT())) self.t_env.create_temporary_system_function( "udf_with_all_constant_params", udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) self.t_env.register_table("test_table", t) self.t_env.sql_query("select udf_with_all_constant_params(" "cast (1 as BIGINT)," "cast (2 as BIGINT)), " "udf_with_constant_params(a, " "cast (null as BIGINT)," "cast (1 as TINYINT)," "cast (1 as SMALLINT)," "cast (1 as INT)," "cast (1 as BIGINT)," "cast (1.05 as DECIMAL)," "cast (1.23 as FLOAT)," "cast (1.98932 as DOUBLE)," "true," "'flink'," "cast ('2014-09-13' as DATE)," "cast ('12:00:00' as TIME)," "cast ('1999-9-10 05:20:10' as TIMESTAMP))" " from test_table").insert_into("Results") self.t_env.execute("test") actual = source_sink_utils.results() self.assert_equals(actual, ["3,8", "3,9", "3,10"])
def test_scalar_function(self): # test metric disabled. self.t_env.get_config().get_configuration().set_string( 'python.metric.enabled', 'false') # test lambda function add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT()) # test Python ScalarFunction subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT()) # test callable function add_one_callable = udf(CallablePlus(), result_type=DataTypes.BIGINT()) def partial_func(col, param): return col + param # test partial function import functools add_one_partial = udf(functools.partial(partial_func, param=1), result_type=DataTypes.BIGINT()) # check memory limit is set @udf(result_type=DataTypes.BIGINT()) def check_memory_limit(): assert os.environ['_PYTHON_WORKER_MEMORY_LIMIT'] is not None return 1 table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd', 'e', 'f', 'g'], [ DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT() ]) self.t_env.register_table_sink("Results", table_sink) t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) t.where(add_one(t.b) <= 3).select( add_one(t.a), subtract_one(t.b), add(t.a, t.c), add_one_callable(t.a), add_one_partial(t.a), check_memory_limit(), t.a) \ .execute_insert("Results").wait() actual = source_sink_utils.results() self.assert_equals( actual, ["+I[2, 1, 4, 2, 2, 1, 1]", "+I[4, 0, 12, 4, 4, 1, 3]"])
def test_expressions(self): expr1 = col('a') expr2 = col('b') expr3 = col('c') self.assertEqual('10', str(lit(10, DataTypes.INT(False)))) self.assertEqual('rangeTo(1, 2)', str(range_(1, 2))) self.assertEqual('and(a, b, c)', str(and_(expr1, expr2, expr3))) self.assertEqual('or(a, b, c)', str(or_(expr1, expr2, expr3))) from pyflink.table.expressions import UNBOUNDED_ROW, UNBOUNDED_RANGE, CURRENT_ROW, \ CURRENT_RANGE self.assertEqual('unboundedRow()', str(UNBOUNDED_ROW)) self.assertEqual('unboundedRange()', str(UNBOUNDED_RANGE)) self.assertEqual('currentRow()', str(CURRENT_ROW)) self.assertEqual('currentRange()', str(CURRENT_RANGE)) self.assertEqual('currentDate()', str(current_date())) self.assertEqual('currentTime()', str(current_time())) self.assertEqual('currentTimestamp()', str(current_timestamp())) self.assertEqual('localTime()', str(local_time())) self.assertEqual('localTimestamp()', str(local_timestamp())) self.assertEqual('toTimestampLtz(123, 0)', str(to_timestamp_ltz(123, 0))) self.assertEqual( "temporalOverlaps(cast('2:55:00', TIME(0)), 3600000, " "cast('3:30:00', TIME(0)), 7200000)", str( temporal_overlaps( lit("2:55:00").to_time, lit(1).hours, lit("3:30:00").to_time, lit(2).hours))) self.assertEqual("dateFormat(time, '%Y, %d %M')", str(date_format(col("time"), "%Y, %d %M"))) self.assertEqual( "timestampDiff(DAY, cast('2016-06-15', DATE), cast('2016-06-18', DATE))", str( timestamp_diff(TimePointUnit.DAY, lit("2016-06-15").to_date, lit("2016-06-18").to_date))) self.assertEqual('array(1, 2, 3)', str(array(1, 2, 3))) self.assertEqual("row('key1', 1)", str(row("key1", 1))) self.assertEqual("map('key1', 1, 'key2', 2, 'key3', 3)", str(map_("key1", 1, "key2", 2, "key3", 3))) self.assertEqual('4', str(row_interval(4))) self.assertEqual('pi()', str(pi())) self.assertEqual('e()', str(e())) self.assertEqual('rand(4)', str(rand(4))) self.assertEqual('randInteger(4)', str(rand_integer(4))) self.assertEqual('atan2(1, 2)', str(atan2(1, 2))) self.assertEqual('minusPrefix(a)', str(negative(expr1))) self.assertEqual('concat(a, b, c)', str(concat(expr1, expr2, expr3))) self.assertEqual("concat_ws(', ', b, c)", str(concat_ws(', ', expr2, expr3))) self.assertEqual('uuid()', str(uuid())) self.assertEqual('null', str(null_of(DataTypes.BIGINT()))) self.assertEqual('log(a)', str(log(expr1))) self.assertEqual('ifThenElse(a, b, c)', str(if_then_else(expr1, expr2, expr3))) self.assertEqual('withColumns(a, b, c)', str(with_columns(expr1, expr2, expr3))) self.assertEqual('a.b.c(a)', str(call('a.b.c', expr1)))
def test_non_exist_udf_type(self): with self.assertRaisesRegex( ValueError, 'The udf_type must be one of \'general, pandas\''): udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="non-exist")