예제 #1
0
    def test_error_level(self, logger):
        transpile("x + 1 (", error_level=ErrorLevel.WARN)
        assert (
            "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 7.\n  x + 1 \033[4m(\033[0m"
            in str(logger.error.call_args_list[0][0][0]))

        with self.assertRaises(ParseError):
            transpile("x + 1 (")
예제 #2
0
    def test_types(self):
        self.validate("INT x", "CAST(x AS INT)")
        self.validate("VARCHAR x y", "CAST(x AS VARCHAR) AS y")
        self.validate("STRING x y", "CAST(x AS TEXT) AS y")
        self.validate("x::INT", "CAST(x AS INT)")
        self.validate("x::INTEGER", "CAST(x AS INT)")
        self.validate("x::INT y", "CAST(x AS INT) AS y")
        self.validate("x::INT AS y", "CAST(x AS INT) AS y")

        with self.assertRaises(ParseError):
            transpile("x::z")
예제 #3
0
 def test_partial(self):
     with open(os.path.join(self.fixtures_dir, "partial.sql"),
               encoding="utf-8") as f:
         for sql in f:
             self.assertEqual(
                 transpile(sql, error_level=ErrorLevel.IGNORE)[0],
                 sql.strip())
예제 #4
0
 def test_custom_transform(self):
     self.assertEqual(
         transpile(
             "SELECT CAST(a AS INT) FROM x",
             type_mapping={exp.DataType.Type.INT: "SPECIAL INT"},
         )[0],
         "SELECT CAST(a AS SPECIAL INT) FROM x",
     )
예제 #5
0
    def test_comments(self):
        sql = transpile('SELECT 1 FROM foo -- comment')[0]
        self.assertEqual(sql, 'SELECT 1 FROM foo')

        sql = transpile('SELECT 1 /* inline */ FROM foo -- comment')[0]
        self.assertEqual(sql, 'SELECT 1 FROM foo')

        sql = transpile("""
            SELECT 1 -- comment
            FROM foo -- comment
            """)[0]
        self.assertEqual(sql, 'SELECT 1 FROM foo')

        sql = transpile("""
            SELECT 1 /* big comment
             like this */
            FROM foo -- comment
            """)[0]
        self.assertEqual(sql, 'SELECT 1 FROM foo')
예제 #6
0
    def test_pretty(self):
        with open(os.path.join(self.fixtures_dir, 'pretty.sql')) as f:
            lines = f.read().split(';')
            size = len(lines)

            for i in range(0, size, 2):
                if i + 1 < size:
                    sql = lines[i]
                    pretty = lines[i + 1].strip()
                    generated = transpile(sql, pretty=True)[0]
                    self.assertEqual(generated, pretty)
예제 #7
0
    def test_hive(self):
        sql = transpile('SELECT "a"."b" FROM "foo"', write='hive')[0]
        self.assertEqual(sql, "SELECT `a`.`b` FROM `foo`")

        sql = transpile('SELECT CAST(`a`.`b` AS SMALLINT) FROM foo',
                        read='hive',
                        write='hive')[0]
        self.assertEqual(sql, 'SELECT CAST(`a`.`b` AS SMALLINT) FROM foo')

        sql = transpile('SELECT "a"."b" FROM foo', write='hive',
                        identify=True)[0]
        self.assertEqual(sql, 'SELECT `a`.`b` FROM `foo`')

        sql = transpile('SELECT APPROX_COUNT_DISTINCT(a) FROM foo',
                        read='hive',
                        write='presto')[0]
        self.assertEqual(sql, 'SELECT APPROX_DISTINCT(a) FROM foo')

        sql = transpile('CREATE TABLE test STORED AS PARQUET AS SELECT 1',
                        read='hive',
                        write='presto')[0]
        self.assertEqual(
            sql, "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1")

        sql = transpile("SELECT GET_JSON_OBJECT(x, '$.name')",
                        read='hive',
                        write='presto')[0]
        self.assertEqual(sql, "SELECT JSON_EXTRACT(x, '$.name')")
예제 #8
0
    def test_spark(self):
        sql = transpile('SELECT "a"."b" FROM "foo"', write='spark')[0]
        self.assertEqual(sql, "SELECT `a`.`b` FROM `foo`")

        sql = transpile('SELECT CAST(`a`.`b` AS SMALLINT) FROM foo',
                        read='spark')[0]
        self.assertEqual(sql, 'SELECT CAST(`a`.`b` AS SHORT) FROM foo')

        sql = transpile('SELECT "a"."b" FROM foo',
                        write='spark',
                        identify=True)[0]
        self.assertEqual(sql, 'SELECT `a`.`b` FROM `foo`')

        sql = transpile('SELECT APPROX_COUNT_DISTINCT(a) FROM foo',
                        read='spark',
                        write='presto')[0]
        self.assertEqual(sql, 'SELECT APPROX_DISTINCT(a) FROM foo')

        sql = transpile('CREATE TABLE test STORED AS PARQUET AS SELECT 1',
                        read='spark',
                        write='presto')[0]
        self.assertEqual(
            sql, "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1")

        sql = transpile('SELECT /*+ COALESCE(3) */ * FROM x', read='spark')[0]
        self.assertEqual(sql, 'SELECT /*+ COALESCE(3) */ * FROM x')
예제 #9
0
    def test_pretty(self):
        with open(os.path.join(self.fixtures_dir, "pretty.sql"),
                  encoding="utf-8") as f:
            lines = f.read().split(";")
            size = len(lines)

            for i in range(0, size, 2):
                if i + 1 < size:
                    sql = lines[i]
                    pretty = lines[i + 1].strip()
                    generated = transpile(sql, pretty=True)[0]
                    self.assertEqual(generated, pretty)
                    self.assertEqual(parse_one(sql), parse_one(pretty))
예제 #10
0
 def test_mysql(self):
     sql = transpile('SELECT CAST(`a`.`b` AS INT) FROM foo',
                     read='mysql',
                     write='mysql')[0]
     self.assertEqual(sql, 'SELECT CAST(`a`.`b` AS INT) FROM foo')
예제 #11
0
    def test_spark(self):
        self.validate(
            'SELECT "a"."b" FROM "foo"',
            "SELECT `a`.`b` FROM `foo`",
            write="spark",
        )

        self.validate("CAST(a AS TEXT)", "CAST(a AS STRING)", write="spark")
        self.validate(
            "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo",
            "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
            read="spark",
        )
        self.validate(
            'SELECT "a"."b" FROM foo',
            "SELECT `a`.`b` FROM `foo`",
            write="spark",
            identify=True,
        )
        self.validate(
            "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
            "SELECT APPROX_DISTINCT(a) FROM foo",
            read="spark",
            write="presto",
        )
        self.validate(
            "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
            "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
            read="spark",
            write="presto",
        )

        self.validate("ARRAY(0, 1, 2)", "ARRAY[0, 1, 2]", read="spark", write="presto")
        self.validate(
            "ARRAY(0, 1, 2)", "LIST_VALUE(0, 1, 2)", read="spark", write="duckdb"
        )
        self.validate(
            "SELECT /*+ COALESCE(3) */ * FROM x",
            "SELECT /*+ COALESCE(3) */ * FROM x",
            read="spark",
        )
        self.validate(
            "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
            "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
            read="spark",
        )
        self.validate(
            "x IN ('a', 'a''b')", "x IN ('a', 'a\\'b')", read="presto", write="spark"
        )

        self.validate(
            "STRUCT_EXTRACT(x, 'abc')", "x.`abc`", read="duckdb", write="spark"
        )
        self.validate(
            "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
            "x.`y`.`abc`",
            read="duckdb",
            write="spark",
        )

        self.validate(
            "MONTH('2021-03-01')",
            "MONTH(CAST('2021-03-01' AS DATE))",
            read="spark",
            write="duckdb",
        )
        self.validate("MONTH(x)", "MONTH(x)", read="duckdb", write="spark")

        self.validate("'\u6bdb'", "'毛'", read="spark")

        self.validate(
            "SELECT LEFT(x, 2), RIGHT(x, 2)",
            "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
            read="spark",
            write="presto",
        )

        with self.assertRaises(UnsupportedError):
            transpile(
                "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n+1 FROM t WHERE n < 100 ) SELECT sum(n) FROM t",
                read="presto",
                write="spark",
                unsupported_level=ErrorLevel.RAISE,
            )

        self.validate(
            "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
            "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
            write="spark",
        )
        self.validate(
            "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
            "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
            write="spark",
        )
        self.validate(
            "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t (a)",
            "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
            write="spark",
        )
예제 #12
0
    def test_hive(self):
        sql = transpile('SELECT "a"."b" FROM "foo"', write="hive")[0]
        self.assertEqual(sql, "SELECT `a`.`b` FROM `foo`")
        self.validate("""'["x"]'""", """'["x"]'""", write="hive", identity=True)
        self.validate(
            "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo",
            "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo",
            read="hive",
            write="hive",
        )
        self.validate(
            'SELECT "a"."b" FROM foo',
            "SELECT `a`.`b` FROM `foo`",
            write="hive",
            identify=True,
        )
        self.validate(
            "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
            "SELECT APPROX_DISTINCT(a) FROM foo",
            read="hive",
            write="presto",
        )
        self.validate(
            "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
            "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
            read="hive",
            write="presto",
        )
        self.validate(
            "SELECT GET_JSON_OBJECT(x, '$.name')",
            "SELECT JSON_EXTRACT_SCALAR(x, '$.name')",
            read="hive",
            write="presto",
        )

        self.validate(
            "MAP(a, b, c, d)",
            "MAP(ARRAY[a, c], ARRAY[b, d])",
            read="hive",
            write="presto",
        )
        self.validate("LOG(10)", "LN(10)", read="hive", write="presto")
        self.validate("LOG(2, 10)", "LOG(2, 10)", read="hive", write="presto")
        self.validate("'\"x\"'", "'\"x\"'", read="hive", write="presto")
        self.validate("\"'x'\"", "'''x'''", read="hive", write="presto")
        self.validate('ds = "2020-01-01"', "ds = '2020-01-01'", read="hive")
        self.validate("ds = \"1''2\"", "ds = '1\\'\\'2'", read="hive")
        self.validate("ds = \"1''2\"", "ds = '1''''2'", read="hive", write="presto")
        self.validate("x == 1", "x = 1", read="hive")
        self.validate("x == 1", "x = 1", read="hive", write="presto")
        self.validate("x div y", "CAST(x / y AS INTEGER)", read="hive", write="presto")

        self.validate(
            "STR_TO_TIME('2020-01-01', 'yyyy-MM-dd')",
            "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
            write="hive",
            identity=False,
        )
        self.validate(
            "STR_TO_TIME('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
            "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
            write="hive",
            identity=False,
        )
        self.validate(
            "STR_TO_TIME(x, 'yyyy')",
            "FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy'))",
            write="hive",
            identity=False,
        )
        self.validate(
            "DATE_ADD('2020-01-01', 1)",
            "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')",
            read="hive",
            write=None,
            identity=False,
        )
        self.validate(
            "DATE_ADD('2020-01-01', 1)",
            "DATE_ADD('2020-01-01', 1)",
            read="hive",
        )
        self.validate(
            "DATE_SUB('2020-01-01', 1)",
            "DATE_ADD('2020-01-01', 1 * -1)",
            read="hive",
        )
        self.validate(
            "DATE_SUB('2020-01-01', 1)",
            "DATE_FORMAT(DATE_ADD('DAY', 1 * -1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d')), '%Y-%m-%d')",
            read="hive",
            write="presto",
        )
        self.validate(
            "DATE_ADD('2020-01-01', 1)",
            "DATE_FORMAT(DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d')), '%Y-%m-%d')",
            read="hive",
            write="presto",
        )
        self.validate(
            "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
            "DATE_FORMAT(DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d')), '%Y-%m-%d')",
            write="presto",
            identity=False,
        )
        self.validate(
            "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
            "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
            write="duckdb",
            identity=False,
        )
        self.validate(
            "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
            "STRFTIME(CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY, '%Y-%m-%d')",
            write="duckdb",
            identity=False,
        )
        self.validate(
            "DATE_ADD('2020-01-01', 1)",
            "STRFTIME(CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY, '%Y-%m-%d')",
            read="hive",
            write="duckdb",
        )
        self.validate(
            "DATEDIFF('2020-01-02', '2020-01-02')",
            "DATE_DIFF(DATE_STR_TO_DATE('2020-01-02'), DATE_STR_TO_DATE('2020-01-02'))",
            read="hive",
            write=None,
            identity=False,
        )
        self.validate(
            "DATEDIFF('2020-01-02', '2020-01-01')",
            "DATEDIFF('2020-01-02', '2020-01-01')",
            read="hive",
        )
        self.validate(
            "DATEDIFF(TO_DATE(y), x)",
            "DATE_DIFF('day', DATE_PARSE(x, '%Y-%m-%d'), DATE_PARSE(DATE_FORMAT(DATE_PARSE(SUBSTR(y, 1, 10), '%Y-%m-%d'), '%Y-%m-%d'), '%Y-%m-%d'))",
            read="hive",
            write="presto",
        )
        self.validate(
            "DATEDIFF('2020-01-02', '2020-01-01')",
            "DATE_DIFF('day', DATE_PARSE('2020-01-01', '%Y-%m-%d'), DATE_PARSE('2020-01-02', '%Y-%m-%d'))",
            read="hive",
            write="presto",
        )

        self.validate("COLLECT_LIST(x)", "ARRAY_AGG(x)", read="hive", write="presto")
        self.validate("ARRAY_AGG(x)", "COLLECT_LIST(x)", read="presto", write="hive")
        self.validate("COLLECT_SET(x)", "SET_AGG(x)", read="hive", write="presto")
        self.validate("SET_AGG(x)", "COLLECT_SET(x)", read="presto", write="hive")
        self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
        self.validate(
            "CASE WHEN 1 THEN x ELSE 0 END",
            "CASE WHEN 1 THEN x ELSE 0 END",
            write="hive",
        )

        self.validate(
            "UNIX_TIMESTAMP(x)",
            "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",
            read="hive",
            identity=False,
        )
        self.validate(
            "TIME_STR_TO_UNIX(x)",
            "UNIX_TIMESTAMP(x)",
            write="hive",
        )
        self.validate(
            "TIME_STR_TO_TIME(x)",
            "x",
            write="hive",
        )
        self.validate(
            "TIME_TO_TIME_STR(x)",
            "x",
            write="hive",
        )
        self.validate(
            "UNIX_TO_TIME_STR(x)",
            "FROM_UNIXTIME(x)",
            write="hive",
        )
        self.validate(
            "FROM_UNIXTIME(x)",
            "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d %H:%i:%S')",
            read="hive",
            write="presto",
        )
        self.validate(
            "TS_OR_DS_TO_DATE(x)",
            "TO_DATE(x)",
            write="hive",
            identity=False,
        )
        self.validate(
            "TO_DATE(x)",
            "TS_OR_DS_TO_DATE_STR(x)",
            read="hive",
            identity=False,
        )

        self.validate(
            "STRUCT_EXTRACT(x, 'abc')", "x.`abc`", read="duckdb", write="hive"
        )
        self.validate(
            "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
            "x.`y`.`abc`",
            read="duckdb",
            write="hive",
        )

        self.validate(
            "MONTH('2021-03-01')",
            "MONTH(CAST('2021-03-01' AS DATE))",
            read="hive",
            write="duckdb",
        )
        self.validate("MONTH(x)", "MONTH(x)", read="duckdb", write="hive")

        self.validate(
            "DAY('2021-03-01')",
            "DAY(CAST('2021-03-01' AS DATE))",
            read="hive",
            write="duckdb",
        )
        self.validate("DAY(x)", "DAY(x)", read="duckdb", write="hive")

        self.validate("'\\\\a'", "'\\\\a'", read="hive")
        self.validate("'\\\\a'", "'\\a'", read="hive", write="presto")
        self.validate("'\\a'", "'\\\\a'", read="presto", write="hive")

        self.validate("1s", "CAST(1 AS SMALLINT)", read="hive")
        self.validate("1S", "CAST(1 AS SMALLINT)", read="hive")
        self.validate("1Y", "CAST(1 AS TINYINT)", read="hive")
        self.validate("1L", "CAST(1 AS BIGINT)", read="hive")
        self.validate("1.0bd", "CAST(1.0 AS DECIMAL)", read="hive")

        self.validate("TRY_CAST(1 AS INT)", "CAST(1 AS INT)", write="hive")
        self.validate(
            "CAST(1 AS INT)", "TRY_CAST(1 AS INTEGER)", read="hive", write="presto"
        )
        self.validate(
            "CAST(1 AS INT)", "CAST(1 AS INT)", read="hive", write="starrocks"
        )
예제 #13
0
 def test_identity(self):
     with open(os.path.join(self.fixtures_dir, "identity.sql"),
               encoding="utf-8") as f:
         for sql in f:
             self.assertEqual(transpile(sql)[0], sql.strip())
예제 #14
0
 def test_if(self):
     sql = transpile('SELECT IF(a > 1, 1, 0) FROM foo')[0]
     self.assertEqual(sql,
                      'SELECT CASE WHEN a > 1 THEN 1 ELSE 0 END FROM foo')
     sql = transpile('SELECT IF(a > 1, 1) FROM foo')[0]
     self.assertEqual(sql, 'SELECT CASE WHEN a > 1 THEN 1 END FROM foo')
예제 #15
0
 def test_postgres(self):
     sql = transpile('SELECT CAST(`a`.`b` AS DOUBLE) FROM foo',
                     read='postgres',
                     write='postgres')[0]
     self.assertEqual(sql,
                      'SELECT CAST(`a`.`b` AS DOUBLE PRECISION) FROM foo')
예제 #16
0
 def test_identity(self):
     with open(os.path.join(self.fixtures_dir, 'identity.sql')) as f:
         for sql in f:
             self.assertEqual(transpile(sql)[0], sql.strip())
예제 #17
0
 def test_sqlite(self):
     sql = transpile('SELECT CAST(`a`.`b` AS SMALLINT) FROM foo',
                     read='sqlite',
                     write='sqlite')[0]
     self.assertEqual(sql, 'SELECT CAST(`a`.`b` AS INTEGER) FROM foo')
예제 #18
0
    def test_presto(self):
        self.validate(
            'SELECT "a"."b" FROM foo',
            'SELECT "a"."b" FROM "foo"',
            read="presto",
            write="presto",
            identify=True,
        )
        self.validate(
            "SELECT a.b FROM foo", "SELECT a.b FROM foo", read="presto", write="spark"
        )
        self.validate(
            'SELECT "a"."b" FROM foo',
            "SELECT `a`.`b` FROM `foo`",
            read="presto",
            write="spark",
            identify=True,
        )
        self.validate(
            "SELECT a.b FROM foo",
            "SELECT `a`.`b` FROM `foo`",
            read="presto",
            write="spark",
            identify=True,
        )
        self.validate(
            "SELECT ARRAY[1, 2]", "SELECT ARRAY(1, 2)", read="presto", write="spark"
        )
        self.validate(
            "CAST(a AS ARRAY(INT))",
            "CAST(a AS ARRAY(INTEGER))",
            read="presto",
            write="presto",
        )
        self.validate(
            "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
            "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
            read="presto",
            write="presto",
        )
        self.validate(
            "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(ARRAY(INT(9))))",
            "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(ARRAY(INTEGER(9))))",
            read="presto",
            write="presto",
        )
        self.validate(
            "CAST(ARRAY[1, 2] AS ARRAY<BIGINT>)",
            "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
            read="presto",
            write="presto",
        )
        self.validate(
            "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
            "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
            read="presto",
            write="presto",
        )
        self.validate("CAST(a AS TEXT)", "CAST(a AS VARCHAR)", write="presto")
        self.validate("CAST(a AS STRING)", "CAST(a AS VARCHAR)", write="presto")
        self.validate(
            "CAST(a AS VARCHAR)", "CAST(a AS STRING)", read="presto", write="spark"
        )

        self.validate("x & 1", "BITWISE_AND(x, 1)", read="hive", write="presto")
        self.validate("~x", "BITWISE_NOT(x)", read="hive", write="presto")
        self.validate("x | 1", "BITWISE_OR(x, 1)", read="hive", write="presto")
        self.validate(
            "x << 1", "BITWISE_ARITHMETIC_SHIFT_LEFT(x, 1)", read="hive", write="presto"
        )
        self.validate(
            "x >> 1",
            "BITWISE_ARITHMETIC_SHIFT_RIGHT(x, 1)",
            read="hive",
            write="presto",
        )
        self.validate("x & 1 > 0", "BITWISE_AND(x, 1) > 0", read="hive", write="presto")

        self.validate("REGEXP_LIKE(a, 'x')", "a RLIKE 'x'", read="presto", write="hive")
        self.validate("a RLIKE 'x'", "REGEXP_LIKE(a, 'x')", read="hive", write="presto")
        self.validate(
            "a REGEXP 'x'", "REGEXP_LIKE(a, 'x')", read="hive", write="presto"
        )
        self.validate(
            "CASE WHEN x > 1 THEN 1 WHEN x > 2 THEN 2 END",
            "CASE WHEN x > 1 THEN 1 WHEN x > 2 THEN 2 END",
            write="presto",
        )

        self.validate(
            "ARRAY_CONTAINS(x, 1)", "CONTAINS(x, 1)", read="hive", write="presto"
        )
        self.validate("SIZE(x)", "CARDINALITY(x)", read="hive", write="presto")
        self.validate("CARDINALITY(x)", "SIZE(x)", read="presto", write="hive")
        self.validate("ARRAY_SIZE(x)", "CARDINALITY(x)", write="presto", identity=False)

        self.validate(
            "PERCENTILE(x, 0.5)",
            "APPROX_PERCENTILE(x, 0.5)",
            read="hive",
            write="presto",
            unsupported_level=ErrorLevel.IGNORE,
        )

        self.validate(
            "STR_POSITION(x, 'a')", "STRPOS(x, 'a')", write="presto", identity=False
        )
        self.validate(
            "STR_POSITION(x, 'a')", "LOCATE('a', x)", read="presto", write="hive"
        )
        self.validate("LOCATE('a', x)", "STRPOS(x, 'a')", read="hive", write="presto")
        self.validate(
            "LOCATE('a', x, 3)",
            "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
            read="hive",
            write="presto",
        )

        self.validate(
            "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%s')",
            "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
            read="presto",
            write="hive",
        )
        self.validate(
            "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')",
            "FROM_UNIXTIME(UNIX_TIMESTAMP(x))",
            read="presto",
            write="hive",
        )
        self.validate(
            "DATE_PARSE(x, '%Y-%m-%d')",
            "FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-dd'))",
            read="presto",
            write="hive",
        )
        self.validate(
            "TIME_STR_TO_UNIX(x)",
            "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))",
            write="presto",
        )
        self.validate(
            "TIME_STR_TO_TIME(x)",
            "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')",
            write="presto",
        )
        self.validate(
            "TIME_TO_TIME_STR(x)",
            "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
            write="presto",
        )
        self.validate(
            "UNIX_TO_TIME_STR(x)",
            "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d %H:%i:%S')",
            write="presto",
        )
        self.validate(
            "FROM_UNIXTIME(x)",
            "FROM_UNIXTIME(x)",
            read="presto",
            write="hive",
        )
        self.validate(
            "TO_UNIXTIME(x)",
            "UNIX_TIMESTAMP(x)",
            read="presto",
            write="hive",
        )
        self.validate(
            "DATE_ADD('day', 1, x)",
            "DATE_ADD(x, 1)",
            read="presto",
            write="hive",
        )
        self.validate(
            "DATE_DIFF('day', a, b)",
            "DATEDIFF(b, a)",
            read="presto",
            write="hive",
        )
        self.validate(
            "DATE_DIFF(a, b)",
            "DATE_DIFF('day', b, a)",
            write="presto",
            identity=False,
        )
        self.validate(
            "TS_OR_DS_TO_DATE(x)",
            "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
            write="presto",
            identity=False,
        )
        self.validate(
            "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
            "STR_TO_TIME(SUBSTR(x, 1, 10), '%Y-%m-%d')",
            read="presto",
            identity=False,
        )

        self.validate(
            "SELECT APPROX_DISTINCT(a) FROM foo",
            "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
            read="presto",
            write="spark",
        )

        sql = transpile(
            "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
            read="presto",
            write="spark",
            unsupported_level=ErrorLevel.IGNORE,
        )[0]
        self.assertEqual(sql, "SELECT APPROX_COUNT_DISTINCT(a) FROM foo")

        ctas = "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1"
        self.assertEqual(transpile(ctas, read="presto", write="presto")[0], ctas)

        sql = transpile(ctas, read="presto", write="spark")[0]
        self.assertEqual(sql, "CREATE TABLE test STORED AS PARQUET AS SELECT 1")

        sql = transpile(
            "SELECT JSON_EXTRACT(x, '$.name')", read="presto", write="spark"
        )[0]
        self.assertEqual(sql, "SELECT GET_JSON_OBJECT(x, '$.name')")

        # pylint: disable=anomalous-backslash-in-string
        self.validate(
            "INITCAP('new york')",
            "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
            read="hive",
            write="presto",
        )

        self.validate("''''", "''''", read="presto", write="presto")
        self.validate("''''", "'\\''", read="presto", write="hive")
        self.validate("'x'", "'x'", read="presto", write="presto")
        self.validate("'x'", "'x'", read="presto", write="hive")
        self.validate("'''x'''", "'''x'''", read="presto", write="presto")
        self.validate("'''x'''", "'\\'x\\''", read="presto", write="hive")
        self.validate("'''x'", "'\\'x'", read="presto", write="hive")
        self.validate("'x'''", "'x\\''", read="presto", write="hive")

        self.validate(
            "STRUCT_EXTRACT(x, 'abc')", 'x."abc"', read="duckdb", write="presto"
        )
        self.validate(
            "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
            'x."y"."abc"',
            read="duckdb",
            write="presto",
        )

        self.validate("MONTH(x)", "MONTH(x)", read="presto", write="spark")
        self.validate("MONTH(x)", "MONTH(x)", read="presto", write="hive")
        self.validate(
            "MONTH(x)",
            "MONTH(DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d'))",
            read="hive",
            write="presto",
        )

        self.validate("DAY(x)", "DAY(x)", read="presto", write="hive")
        self.validate(
            "DAY(x)",
            "DAY(DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d'))",
            read="hive",
            write="presto",
        )
        self.validate(
            "CONCAT_WS('-', 'a', 'b')",
            "ARRAY_JOIN(ARRAY['a', 'b'], '-')",
            write="presto",
        )
        self.validate("CONCAT_WS('-', x)", "ARRAY_JOIN(x, '-')", write="presto")
        self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="presto")
        self.validate(
            "CASE WHEN 1 THEN x ELSE 0 END",
            "CASE WHEN 1 THEN x ELSE 0 END",
            write="presto",
        )
        self.validate("x[y]", "x[y]", read="presto", identity=False)
        self.validate("x[y]", "x[y]", write="presto", identity=False)

        with self.assertRaises(UnsupportedError):
            transpile(
                "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
                read="presto",
                write="spark",
                unsupported_level=ErrorLevel.RAISE,
            )

        self.validate(
            "SELECT * FROM x TABLESAMPLE(10)",
            "SELECT * FROM x",
            read="hive",
            write="presto",
            unsupported_level=ErrorLevel.IGNORE,
        )

        self.validate("'\u6bdb'", "'\u6bdb'", read="presto")

        with self.assertRaises(UnsupportedError):
            transpile(
                "SELECT * FROM x TABLESAMPLE(10)",
                read="hive",
                write="presto",
                unsupported_level=ErrorLevel.RAISE,
            )

        self.validate(
            "SELECT NULL as foo FROM baz",
            'SELECT NULL AS "foo" FROM "baz"',
            read="presto",
            write="presto",
            identify=True,
        )
        self.validate(
            "SELECT true as foo FROM baz",
            'SELECT TRUE AS "foo" FROM "baz"',
            read="presto",
            write="presto",
            identify=True,
        )
        self.validate(
            "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz",
            "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
            read="presto",
            write="presto",
            identify=False,
        )
        self.validate(
            "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz",
            'SELECT IF(COALESCE("bar", 0) = 1, TRUE, FALSE) AS "foo" FROM "baz"',
            read="hive",
            write="presto",
            identify=True,
        )
        self.validate(
            "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
            "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t (a) CROSS JOIN UNNEST(z) AS u (b)",
            write="presto",
        )
        self.validate(
            "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
            "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
            write="presto",
        )
        self.validate(
            "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
            "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t (a)",
            write="presto",
        )

        self.validate(
            "SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y))AS t (a)",
            "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
            read="presto",
            write="hive",
        )
        self.validate(
            "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
            "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t (a)",
            read="hive",
            write="presto",
        )
예제 #19
0
 def test_msaccess(self):
     sql = transpile('SELECT [a].[b] FROM [foo]',
                     read='msacess',
                     write='msacess')[0]
     self.assertEqual(sql, 'SELECT [a].[b] FROM [foo]')
예제 #20
0
 def validate(self, sql, target, **kwargs):
     self.assertEqual(transpile(sql, **kwargs)[0], target)
예제 #21
0
    def test_presto(self):
        sql = transpile('SELECT "a"."b" FROM foo',
                        read='presto',
                        write='presto',
                        identify=True)[0]
        self.assertEqual(sql, 'SELECT "a"."b" FROM "foo"')

        sql = transpile('SELECT a.b FROM foo', read='presto', write='spark')[0]
        self.assertEqual(sql, 'SELECT a.b FROM foo')

        sql = transpile('SELECT "a"."b" FROM foo',
                        read='presto',
                        write='spark',
                        identify=True)[0]
        self.assertEqual(sql, 'SELECT `a`.`b` FROM `foo`')

        sql = transpile('SELECT a.b FROM foo',
                        read='presto',
                        write='spark',
                        identify=True)[0]
        self.assertEqual(sql, 'SELECT `a`.`b` FROM `foo`')

        sql = transpile('SELECT APPROX_DISTINCT(a) FROM foo',
                        read='presto',
                        write='spark')[0]
        self.assertEqual(sql, 'SELECT APPROX_COUNT_DISTINCT(a) FROM foo')

        sql = transpile('SELECT APPROX_DISTINCT(a, 0.1) FROM foo',
                        read='presto',
                        write='spark',
                        unsupported_level=ErrorLevel.IGNORE)[0]
        self.assertEqual(sql, 'SELECT APPROX_COUNT_DISTINCT(a) FROM foo')

        ctas = "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1"
        self.assertEqual(
            transpile(ctas, read='presto', write='presto')[0], ctas)

        sql = transpile(ctas, read='presto', write='spark')[0]
        self.assertEqual(sql,
                         "CREATE TABLE test STORED AS PARQUET AS SELECT 1")

        sql = transpile("SELECT JSON_EXTRACT(x, '$.name')",
                        read='presto',
                        write='spark')[0]
        self.assertEqual(sql, "SELECT GET_JSON_OBJECT(x, '$.name')")

        with self.assertRaises(UnsupportedError):
            transpile(
                'SELECT APPROX_DISTINCT(a, 0.1) FROM foo',
                read='presto',
                write='spark',
                unsupported_level=ErrorLevel.RAISE,
            )
예제 #22
0
파일: __main__.py 프로젝트: blthree/sqlglot
import sys

import sqlglot

for sql in sqlglot.transpile(sys.argv[1],
                             read='spark',
                             write='spark',
                             pretty=True):
    print(sql)
예제 #23
0
    "--parse",
    dest="parse",
    action="store_true",
    help="Parse and return the expression tree",
)
parser.add_argument(
    "--error-level",
    dest="error_level",
    type=str,
    default="RAISE",
    help="IGNORE, WARN, RAISE (default)",
)

args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]

if args.parse:
    sqls = sqlglot.parse(args.sql, read=args.read, error_level=error_level)
else:
    sqls = sqlglot.transpile(
        args.sql,
        read=args.read,
        write=args.write,
        identify=args.identify,
        pretty=args.pretty,
        error_level=error_level,
    )

for sql in sqls:
    print(sql)