def test_sql_delimiter():
    """Copy of test_sql_with_commands but with an unconventional sql_delimiter.

    This test should not only verify that a random delimiter splits SQL commands correctly, but
    also that semi colon gets added for the split statements instead of the reserved custom keywords.

    Since this is a generator function the sql_delimiter cannot be passed in as a string as it might change
    during execution, so the SnowSQL's cli class is passed in by SnowSQL. This function makes sure that this
    behaviour is not broken by mistake.
    """
    delimiter = SQLDelimiter('imi')
    with StringIO(("create or replace view aaa\n"
                   "        as select * from\n"
                   "        LINEITEM limit 1000 {delimiter}\n"
                   "!spool $outfile\n"
                   "show views like 'AAA'{delimiter}\n"
                   "!spool off\n"
                   "drop view if exists aaa {delimiter}\n"
                   "show tables").format(delimiter=delimiter.sql_delimiter)) as f:
        itr = split_statements(f, delimiter=delimiter)
        assert next(itr) == ("""create or replace view aaa
        as select * from
        LINEITEM limit 1000 ;""", False)

        assert next(itr) == ("""!spool $outfile""", False)
        assert next(itr) == ("show views like 'AAA';", False)
        assert next(itr) == ("!spool off", False)
        assert next(itr) == ("drop view if exists aaa ;", False)
        assert next(itr) == ("show tables", False)
    with pytest.raises(StopIteration):
        next(itr)
Exemple #2
0
def test_sql_splitting_various(sql, delimiter, split_stmnts):
    """This tests various smaller sql splitting pitfalls."""
    with StringIO(sql) as sqlio:
        statements = list(
            [s[0] for s in split_statements(sqlio, delimiter=SQLDelimiter(delimiter))]
        )
    assert statements == split_stmnts
def test_sql_splitting_tokenization():
    """This tests that sql_delimiter is token sensitive."""
    raw_sql = "select 123 as asd"
    for c in set(raw_sql.replace(' ', '')):
        sql = raw_sql + ' ' + c + ' ' + raw_sql
        with StringIO(sql) as sqlio:
            s = split_statements(sqlio, delimiter=SQLDelimiter(c))
            assert next(s)[0] == raw_sql + ' ;'
            assert next(s)[0] == raw_sql