Beispiel #1
0
    def test_ctas(self):
        expression = parse("SELECT * FROM y")[0]
        generator = Hive().generator()

        self.assertEqual(
            generator.generate(Rewriter(expression).ctas('x').expression),
            'CREATE TABLE x AS SELECT * FROM y'
        )

        self.assertEqual(
            generator.generate(
                Rewriter(expression).ctas('x', db='foo', file_format='parquet').expression
            ),
            'CREATE TABLE foo.x STORED AS parquet AS SELECT * FROM y'
        )

        self.assertEqual(generator.generate(expression), 'SELECT * FROM y')

        rewriter = Rewriter(expression).ctas('x')
        self.assertEqual(generator.generate(rewriter.expression), 'CREATE TABLE x AS SELECT * FROM y')
        self.assertEqual(
            generator.generate(rewriter.ctas('y').expression),
            'CREATE TABLE y AS SELECT * FROM y'
        )

        expression = parse("CREATE TABLE x AS SELECT * FROM y")[0]
        rewriter = Rewriter(expression, copy=False).ctas('x', file_format='ORC')
        self.assertEqual(
            generator.generate(expression),
            'CREATE TABLE x STORED AS ORC AS SELECT * FROM y'
        )
Beispiel #2
0
    def test_findall(self):
        expression = parse(
            """
            SELECT *
            FROM (
                SELECT b.*
                FROM a.b b
            ) x
            JOIN (
              SELECT c.foo
              FROM a.c c
              WHERE foo = 1
            ) y
              ON x.c = y.foo
            CROSS JOIN (
              SELECT *
              FROM (
                SELECT d.bar
                FROM d
              ) nested
            ) z
              ON x.c = y.foo
            """
        )[0]

        self.assertEqual(
            [table.args['this'].text for table, _, _ in expression.findall(exp.Table)],
            ['d', 'c', 'b'],
        )
Beispiel #3
0
    def test_multi(self):
        expressions = parse("""
            SELECT * FROM a; SELECT * FROM b;
        """)

        assert len(expressions) == 2
        assert (expressions[0].args["from"].args["expressions"]
                [0].args["this"].args["this"] == "a")
        assert (expressions[1].args["from"].args["expressions"]
                [0].args["this"].args["this"] == "b")
Beispiel #4
0
def bench_sqlglot():
    return sqlglot.parse(TEST_SQL)
Beispiel #5
0
 def test_find(self):
     expression = parse("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")[0]
     self.assertTrue(expression.find(exp.Create))
     self.assertFalse(expression.find(exp.Group))
     self.assertEqual([table.args['this'].text for table, _, _ in expression.findall(exp.Table)], ['x', 'y'])
Beispiel #6
0
 def test_command(self):
     expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
     self.assertEqual(len(expressions), 3)
     self.assertEqual(expressions[0].sql(), "SET x = 1")
     self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
     self.assertEqual(expressions[2].sql(), "SELECT 1")
Beispiel #7
0
def sqlglot_parse(sql):
    sqlglot.parse(sql, error_level=sqlglot.ErrorLevel.IGNORE)
Beispiel #8
0
    "--parse",
    dest="parse",
    action="store_true",
    help="Parse and return the expression tree",
)
parser.add_argument(
    "--error-level",
    dest="error_level",
    type=str,
    default="RAISE",
    help="IGNORE, WARN, RAISE (default)",
)

args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]

if args.parse:
    sqls = sqlglot.parse(args.sql, read=args.read, error_level=error_level)
else:
    sqls = sqlglot.transpile(
        args.sql,
        read=args.read,
        write=args.write,
        identify=args.identify,
        pretty=args.pretty,
        error_level=error_level,
    )

for sql in sqls:
    print(sql)