def test_sql_with_index_col(self): import pandas as pd # Index psdf = ps.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }, index=pd.Index(["a", "b", "c"], name="index")) psdf_reset_index = psdf.reset_index() actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col="index") expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected) # MultiIndex psdf = ps.DataFrame( { "A": [1, 2, 3], "B": [4, 5, 6] }, index=pd.MultiIndex.from_tuples([("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]), ) psdf_reset_index = psdf.reset_index() actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col=["index1", "index2"]) expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected)
def test_sql_with_python_objects(self): self.assert_eq(ps.sql("SELECT {col} as a FROM range(1)", col="lit"), ps.DataFrame({"a": ["lit"]})) self.assert_eq( ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="lit", pred=(1, 2, 3)), ps.DataFrame({"id": [1, 2, 3]}), )
def test_sql_with_pandas_on_spark_objects(self): psdf = ps.DataFrame({"a": [1, 2, 3, 4]}) self.assert_eq(ps.sql("SELECT {col} FROM {tbl}", col=psdf.a, tbl=psdf), psdf) self.assert_eq(ps.sql("SELECT {tbl.a} FROM {tbl}", tbl=psdf), psdf) psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) self.assert_eq( ps.sql("SELECT {col}, {col2} FROM {tbl}", col=psdf.A, col2=psdf.B, tbl=psdf), psdf) self.assert_eq(ps.sql("SELECT {tbl.A}, {tbl.B} FROM {tbl}", tbl=psdf), psdf)
def test_sql_with_pandas_objects(self): import pandas as pd pdf = pd.DataFrame({"a": [1, 2, 3, 4]}) self.assert_eq( ps.sql("SELECT {col} + 1 as a FROM {tbl}", col=pdf.a, tbl=pdf), pdf + 1)
def test_error_bad_sql(self): with self.assertRaises(ParseException): ps.sql("this is not valid sql")
def test_error_unsupported_type(self): msg = "Unsupported variable type dict: {'a': 1}" with self.assertRaisesRegex(ValueError, msg): some_dict = {"a": 1} ps.sql("select * from {some_dict}")
def test_error_variable_not_exist(self): msg = "The key variable_foo in the SQL statement was not found.*" with self.assertRaisesRegex(ValueError, msg): ps.sql("select * from {variable_foo}")
def test_series_not_referred(self): psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) with self.assertRaisesRegex(ValueError, "The series in {ser}"): ps.sql("SELECT {ser} FROM range(10)", ser=psdf.A)
def test_error_variable_not_exist(self): with self.assertRaisesRegex(KeyError, "variable_foo"): ps.sql("select * from {variable_foo}")
def test_error_unsupported_type(self): with self.assertRaisesRegex(KeyError, "some_dict"): ps.sql("select * from {some_dict}")