def configuration(cls, plm=None, method='lgesql', table_path='data/tables.json', tables='data/tables.bin', db_dir='data/database'): cls.plm, cls.method = plm, method cls.grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH) cls.trans = TransitionSystem.get_class_by_lang('sql')(cls.grammar) cls.tables = pickle.load(open(tables, 'rb')) if type(tables) == str else tables cls.evaluator = Evaluator(cls.trans, table_path, db_dir) if plm is None: cls.word2vec = Word2vecUtils() cls.tokenizer = lambda x: x cls.word_vocab = Vocab( padding=True, unk=True, boundary=True, default=UNK, filepath='./pretrained_models/glove.42b.300d/vocab.txt', specials=SCHEMA_TYPES) # word vocab for glove.42B.300d else: cls.tokenizer = AutoTokenizer.from_pretrained( os.path.join('./pretrained_models', plm)) cls.word_vocab = cls.tokenizer.get_vocab() cls.relation_vocab = Vocab(padding=False, unk=False, boundary=False, iterable=RELATIONS, default=None) cls.graph_factory = GraphFactory(cls.method, cls.relation_vocab)
def process_dataset(processor, dataset, tables, output_path=None, skip_large=False, verbose=False): from utils.constants import GRAMMAR_FILEPATH grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH) trans = TransitionSystem.get_class_by_lang('sql')(grammar) processed_dataset = [] for idx, entry in enumerate(dataset): if skip_large and len(tables[entry['db_id']]['column_names']) > 100: continue if verbose: print('*************** Processing %d-th sample **************' % (idx)) entry = process_example(processor, entry, tables[entry['db_id']], trans, verbose=verbose) processed_dataset.append(entry) print('In total, process %d samples , skip %d extremely large databases.' % (len(processed_dataset), len(dataset) - len(processed_dataset))) if output_path is not None: # serialize preprocessed dataset pickle.dump(processed_dataset, open(output_path, 'wb')) return processed_dataset
return action_list elif realized_field.value is not None: return [SelectTableAction(int(realized_field.value))] else: return [] else: raise ValueError('unknown primitive field type') if __name__ == '__main__': try: from evaluation import evaluate, build_foreign_key_map_from_json except Exception: print('Cannot find evaluator ...') grammar = ASDLGrammar.from_filepath('asdl/sql/grammar/sql_asdl_v2.txt') print('Total number of productions:', len(grammar)) for each in grammar.productions: print(each) print('Total number of types:', len(grammar.types)) for each in grammar.types: print(each) print('Total number of fields:', len(grammar.fields)) for each in grammar.fields: print(each) spider_trans = SQLTransitionSystem(grammar) kmaps = build_foreign_key_map_from_json('data/tables.json') dbs_list = json.load(open('data/tables.json', 'r')) dbs = {} for each in dbs_list: