Example #1
0
 def configuration(cls,
                   plm=None,
                   method='lgesql',
                   table_path='data/tables.json',
                   tables='data/tables.bin',
                   db_dir='data/database'):
     cls.plm, cls.method = plm, method
     cls.grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
     cls.trans = TransitionSystem.get_class_by_lang('sql')(cls.grammar)
     cls.tables = pickle.load(open(tables,
                                   'rb')) if type(tables) == str else tables
     cls.evaluator = Evaluator(cls.trans, table_path, db_dir)
     if plm is None:
         cls.word2vec = Word2vecUtils()
         cls.tokenizer = lambda x: x
         cls.word_vocab = Vocab(
             padding=True,
             unk=True,
             boundary=True,
             default=UNK,
             filepath='./pretrained_models/glove.42b.300d/vocab.txt',
             specials=SCHEMA_TYPES)  # word vocab for glove.42B.300d
     else:
         cls.tokenizer = AutoTokenizer.from_pretrained(
             os.path.join('./pretrained_models', plm))
         cls.word_vocab = cls.tokenizer.get_vocab()
     cls.relation_vocab = Vocab(padding=False,
                                unk=False,
                                boundary=False,
                                iterable=RELATIONS,
                                default=None)
     cls.graph_factory = GraphFactory(cls.method, cls.relation_vocab)
Example #2
0
def process_dataset(processor, dataset, tables, output_path=None, skip_large=False, verbose=False):
    from utils.constants import GRAMMAR_FILEPATH
    grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
    trans = TransitionSystem.get_class_by_lang('sql')(grammar)
    processed_dataset = []
    for idx, entry in enumerate(dataset):
        if skip_large and len(tables[entry['db_id']]['column_names']) > 100: continue
        if verbose:
            print('*************** Processing %d-th sample **************' % (idx))
        entry = process_example(processor, entry, tables[entry['db_id']], trans, verbose=verbose)
        processed_dataset.append(entry)
    print('In total, process %d samples , skip %d extremely large databases.' % (len(processed_dataset), len(dataset) - len(processed_dataset)))
    if output_path is not None:
        # serialize preprocessed dataset
        pickle.dump(processed_dataset, open(output_path, 'wb'))
    return processed_dataset
Example #3
0
                return action_list
            elif realized_field.value is not None:
                return [SelectTableAction(int(realized_field.value))]
            else:
                return []
        else:
            raise ValueError('unknown primitive field type')


if __name__ == '__main__':

    try:
        from evaluation import evaluate, build_foreign_key_map_from_json
    except Exception:
        print('Cannot find evaluator ...')
    grammar = ASDLGrammar.from_filepath('asdl/sql/grammar/sql_asdl_v2.txt')
    print('Total number of productions:', len(grammar))
    for each in grammar.productions:
        print(each)
    print('Total number of types:', len(grammar.types))
    for each in grammar.types:
        print(each)
    print('Total number of fields:', len(grammar.fields))
    for each in grammar.fields:
        print(each)

    spider_trans = SQLTransitionSystem(grammar)
    kmaps = build_foreign_key_map_from_json('data/tables.json')
    dbs_list = json.load(open('data/tables.json', 'r'))
    dbs = {}
    for each in dbs_list: