def batch_denormalize():
    data_dir = sys.argv[1]
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    train_data = load_data_split_spider(data_dir, 'train', schema_graphs)
    # random.shuffle(train_data)
    for i, example in enumerate(train_data):
        if not example.program in complex_queries:
            continue
        schema_graph = schema_graphs.get_schema(example.db_id)
        if DEBUG:
            denormalizer_unit_test(example.program, schema_graph)
        else:
            try:
                denormalizer_unit_test(example.program, schema_graph)
            except KeyError as e:
                example.pretty_print(schema_graph)
                print(str(e))
                import pdb
                pdb.set_trace()
            except AssertionError as e:
                example.pretty_print(schema_graph)
                print(str(e))
                import pdb
                pdb.set_trace()
            except pyparsing.ParseException as e:
                example.pretty_print(schema_graph)
                print(str(e))
                import pdb
                pdb.set_trace()
        if i > 0 and i % 500 == 0:
            print('{} examples processed'.format(i))
def test_restore_clause_order():
    in_json = sys.argv[1]
    data_dir = sys.argv[2]
    schema_graphs = load_schema_graphs_spider(data_dir)

    with open(in_json) as f:
        content = json.load(f)
        num_errors = 0
        for i, example in enumerate(content):
            sql = example['query']
            print('Orig SQL: {}'.format(sql))
            db_name = example['db_id']
            schema = schema_graphs[db_name]
            # sn_sql = shallow_normalize(sql, schema)
            dn_sql, _ = denormalize(sql, schema)
            print('DN SQL:\t\t{}'.format(dn_sql))
            eo_sql = convert_to_execution_order(dn_sql, schema)
            restored_sql = restore_clause_order(eo_sql, schema)
            print('Restored SQL:\t{}'.format(restored_sql))
            # print('EO Pred SQL: {}'.format(eo_sql))
            # print('SN SQL:\t\t{}'.format(sn_sql))
            print()
            if dn_sql != restored_sql:
                num_errors += 1
                import pdb
                pdb.set_trace()
        print('{}/{} errors detected'.format(num_errors, len(content)))
def check_foreign_keys_in_queries():
    data_dir = sys.argv[1]
    dataset = sys.argv[2]
    schema_graphs = load_schema_graphs_spider(data_dir)
    train_data = load_data_split_spider(data_dir, 'train', schema_graphs)
    dev_data = load_data_split_spider(data_dir, 'dev', schema_graphs)
    tables = load_spider_tables(os.path.join(data_dir, 'tables.json'))
    table_dict = dict()
    for table in tables:
        table_dict[table['db_id']] = table
    in_json = os.path.join(data_dir, '{}.parsed.json'.format(dataset))
    with open(in_json) as f:
        parsed_sqls = json.load(f)

    for i, example in enumerate(train_data + dev_data):
        schema_graph = schema_graphs.get_schema(example.db_id)
        ast, _ = get_ast(example.program,
                         parsed_sqls,
                         denormalize_sql=True,
                         schema_graph=schema_graph)
        foreign_keys_readable, foreign_keys = extract_foreign_keys(
            ast, schema_graph)
        for f_key in foreign_keys:
            if not tuple(sorted(f_key)) in schema_graph.foreign_keys:
                print(example.program)
                print(json.dumps(ast, indent=4))
                print('Missing foreign key detected:')
                print('- {}'.format(schema_graph.get_field_signature(
                    f_key[0])))
                print('- {}'.format(schema_graph.get_field_signature(
                    f_key[1])))
                import pdb
                pdb.set_trace()
def batch_tokenize(data_dir, db_name, denormalization=False):
    print('Tokenizing {}'.format(db_name))
    num_tokens_list = []
    if db_name == 'spider':
        schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
        in_json = os.path.join(data_dir, 'dev.json')
        with open(in_json) as f:
            content = json.load(f)
            for i, example in enumerate(content):
                sql = example['query']
                if sql.endswith(';'):
                    sql = sql[:-1]
                db_name = example['db_id']
                schema = schema_graphs[db_name]
                ast = parse(sql)
                denormalized_ast, _ = denormalize(ast,
                                                  schema,
                                                  return_parse_tree=True)
                tokens, token_types, constants = \
                    sql_tokenize(denormalized_ast, bu.tokenizer.tokenize,
                                 return_token_types=True, schema=schema,
                                 keep_singleton_fields=True, atomic_value=True,
                                 num_token=' <NUM> ', str_token=' <STRING> ')
                num_tokens_list.append(len(tokens))
    else:
        return NotImplementedError
    print('{}: avg # tokens = {}'.format(db_name, np.mean(num_tokens_list)))
def test_schema_consistency():
    data_dir = sys.argv[1]
    db_name = 'flight_2'
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    schema = schema_graphs[db_name]
    schema.pretty_print()

    in_sql = 'SELECT singer.Name FROM concert JOIN singer_in_concert ON singer_in_concert.Singer_ID = singer.Singer_ID WHERE concert.Year = 2014'
    in_sql = 'SELECT singer.concert FROM singer WHERE singer.age  >  (SELECT avg(singer.age) FROM singer)'
    in_sql = 'SELECT singer.Name, singer.Country FROM singer INTERSECT SELECT singer.Name, singer.Country, singer.Age FROM singer WHERE singer.Age = "?" ORDER BY singer.Age DESC'
    in_sql = 'SELECT COUNT(*) FROM singer'
    in_sql = 'SELECT concert.concert_Name, concert.Theme, COUNT(*) FROM concert GROUP BY concert.Theme, concert.Theme'
    in_sql = 'SELECT T2.name FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id  =  T2.singer_id JOIN concert AS T3 ON T1.concert_id  =  T3.concert_id WHERE T3.year  =  2014'
    in_sql = 'SELECT AIRPORTS.AirportCode FROM AIRPORTS JOIN FLIGHTS ON AIRPORTS.AirportCode = FLIGHTS.DestAirport OR AIRPORTS.AirportCode = FLIGHTS.SourceAirport GROUP BY AIRPORTS.AirportCode ORDER BY COUNT(*) DESC LIMIT 1'

    # in_sql = 'from singer select singer.Name , singer.Country union from singer where singer.Age = "age" select singer.Name , singer.Country , singer.Age order by singer.Age desc'
    # in_sql = 'from stadium join concert on stadium.Stadium_ID = concert.Stadium_ID where stadium.Capacity = (from stadium select max (stadium.Capacity)) select count (*)'
    # in_sql = 'from singer join singer_in_concert on singer.Singer_ID = singer_in_concert.Singer_ID join concert on singer_in_concert.Singer_ID = concert.concert_ID where concert.Year = 2014 select singer.Name'
    # in_sql = 'from Students select Students.other_student_details order by Students.other_student_details desc limit  <UNK>'
    # in_sql = 'from Sections join Sections on Addresses.address_id = * where Sections.section_name = "h" select Sections.section_name'
    # in_sql = 'from stadium join concert on stadium.Stadium_ID = concert.Stadium_ID where concert.Year = 2014 select stadium.Name , stadium.Location intersect from concert join stadium on stadium.Stadium_ID = stadium.Stadium_ID where concert.Year = 2015 select stadium.Name , stadium.Location'

    # in_sql, _ = denormalize(in_sql, schema)
    ast = parse(in_sql)
    check_schema_consistency(ast, schema, verbose=True)
def test_execution_order():
    in_sql = sys.argv[1]
    in_sql = "SELECT song_name FROM singer WHERE age  >  (SELECT avg(age) FROM singer)"
    data_dir = sys.argv[2]
    db_name = sys.argv[3]
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    in_sqls = complex_queries[:4]
    db_names = ['flight_4', 'academic', 'baseball_1', 'voter_2']
    for db_name, in_sql in zip(db_names, in_sqls):
        schema = schema_graphs[db_name]
        ast = parse(in_sql)
        # print(json.dumps(ast, indent=4))
        ast_c = copy.deepcopy(ast)
        eo_sql = format(ast, schema, in_execution_order=True)
        eo_tokens = tokenize(in_sql,
                             bu.tokenizer.tokenize,
                             schema=schema,
                             in_execution_order=True)
        print('in_sql: {}'.format(in_sql))
        print('eo_sql: {}'.format(eo_sql))
        # print('eo_tokens: {}'.format(eo_tokens))
        eo_ast = eo_parse(eo_sql)
        assert (json.dumps(ast_c,
                           sort_keys=True) == json.dumps(eo_ast,
                                                         sort_keys=True))
        # print(json.dumps(eo_ast, indent=4))
        restored_sql = format(eo_ast, schema)
        # print('restored_sql: {}'.format(restored_sql))
        print()
Beispiel #7
0
def load_data_spider(args):
    """
    Load the Spider dataset released by Yu et. al. 2018.
    """
    in_dir = args.data_dir
    dataset = dict()
    schema_graphs = load_schema_graphs_spider(
        in_dir,
        'spider',
        augment_with_wikisql=args.augment_with_wikisql,
        db_dir=args.db_dir)
    dataset['train'] = load_data_split_spider(
        in_dir,
        'train',
        schema_graphs,
        get_data_augmentation_tag(args),
        augment_with_wikisql=args.augment_with_wikisql)
    dataset['dev'] = load_data_split_spider(
        in_dir,
        'dev',
        schema_graphs,
        augment_with_wikisql=args.augment_with_wikisql)
    dataset['schema'] = schema_graphs

    fine_tune_set = load_data_split_spider(
        in_dir,
        'fine-tune',
        schema_graphs,
        augment_with_wikisql=args.augment_with_wikisql)
    if fine_tune_set:
        dataset['fine-tune'] = fine_tune_set
    return dataset
def test_value_extractor():
    in_sql = 'SELECT singer.Name, singer.Country FROM singer INTERSECT SELECT singer.Name, singer.Country, singer.Age FROM singer WHERE singer.Age = "?" ORDER BY singer.Age DESC'
    in_sql = 'SELECT avg(age) FROM Student WHERE StuID IN ( SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "animal" INTERSECT SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "animal")'
    in_sql = 'SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id  =  T2.head_id WHERE T2.temporary_acting  =  \'Yes\''
    in_sql = 'SELECT t3.title FROM authors AS t1 JOIN authorship AS t2 ON t1.authid  =  t2.authid JOIN papers AS t3 ON t2.paperid  =  t3.paperid JOIN inst AS t4 ON t2.instid  =  t4.instid WHERE t4.country  =  \"USA\" AND t2.authorder  =  2 AND t1.lname  =  \"Turon\"'
    data_dir = sys.argv[1]
    db_name = sys.argv[2]
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    schema = schema_graphs[db_name]
    value_extractor_unit_test(in_sql, schema)
def test_restore_clause_order():
    in_sql = 'from (from countrylanguage where countrylanguage.Language = "spanish" select max (countrylanguage.Percentage)) as T0 JOIN countrylanguage ON T0. = countrylanguage.Percentage select count (*)'
    in_sql = 'from poker_player where poker_player.Earnings = (from poker_player select sum (poker_player.Earnings)) select poker_player.Money_Rank order by poker_player.Earnings desc limit 1'
    in_sql = 'from Student join Has_Pet on Student.StuID = Has_Pet.StuID join Pets on Has_Pet.PetID = Pets.PetID where Pets.PetType = "cat" select Student.Major , Student.Major'
    data_dir = sys.argv[1]
    db_name = sys.argv[2]
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    schema = schema_graphs[db_name]
    print('eo_sql: {}'.format(in_sql))
    restored_sql = restore_clause_order(in_sql, schema)
    print('restored_sql: {}'.format(restored_sql))
def ensemble():
    text_tokenize, program_tokenize, post_process, table_utils = tok.get_tokenizers(
        args)
    schema_graphs = schema_loader.load_schema_graphs_spider(
        args.codalab_data_dir, 'spider', db_dir=args.codalab_db_dir)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize,
                                    normalized=(args.model_id
                                                in [utils.BRIDGE]))
    text_vocab = Vocabulary('text',
                            func_token_index=functional_token_index,
                            tu=table_utils)
    for v in table_utils.tokenizer.vocab:
        text_vocab.index_token(v, True,
                               table_utils.tokenizer.convert_tokens_to_ids(v))
    program_vocab = sql_reserved_tokens if args.pretrained_transformer else sql_reserved_tokens_revtok
    vocabs = {'text': text_vocab, 'program': program_vocab}
    examples = data_loader.load_data_split_spider(args.codalab_data_dir, 'dev',
                                                  schema_graphs)
    print('{} {} examples loaded'.format(len(examples), 'dev'))
    for i, example in enumerate(examples):
        schema_graph = schema_graphs.get_schema(example.db_id)
        preprocess_example('dev', example, args, None, text_tokenize,
                           program_tokenize, post_process, table_utils,
                           schema_graph, vocabs)
    print('{} {} examples processed'.format(len(examples), 'dev'))

    checkpoint_paths = [
        'ensemble_models/model1.tar', 'ensemble_models/model2.tar',
        'ensemble_models/model3.tar'
    ]

    sps = [EncoderDecoderLFramework(args) for _ in checkpoint_paths]
    for i, checkpoint_path in enumerate(checkpoint_paths):
        sps[i].schema_graphs = schema_graphs
        sps[i].load_checkpoint(checkpoint_path)
        sps[i].cuda()
        sps[i].eval()

    out_dict = sps[0].inference(
        examples,
        restore_clause_order=args.process_sql_in_execution_order,
        check_schema_consistency_=args.sql_consistency_check,
        inline_eval=False,
        model_ensemble=[sp.mdl for sp in sps],
        verbose=False)

    assert (sps[0].args.prediction_path is not None)
    out_txt = sps[0].args.prediction_path
    with open(out_txt, 'w') as o_f:
        for pred_sql in out_dict['pred_decoded']:
            o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))
def test_tokenizer():
    # for sql in complex_queries:
    if True:
        sql = 'SELECT avg(age) FROM Student WHERE StuID IN ( SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "food" INTERSECT SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "animal")'
        sql = 'SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Vincent" INTERSECT SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Marcelle"'
        sql = "SELECT Perpetrator_ID FROM perpetrator WHERE Year IN ('1995.0', '1994.0', '1982.0')"
        print(sql)
        data_dir = sys.argv[1]
        db_name = sys.argv[2]
        schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
        schema = schema_graphs[db_name]
        tokens = tokenize(sql,
                          bu.tokenizer.tokenize,
                          in_execution_order=True,
                          schema=schema)[0]
        print(tokens)
        print()