def save_2index(): grammar = Grammar('sql_simple_transition_2.bnf') #this shouldn't be necessary but it is grammar2 = Grammar('sql_simple_transition_2.bnf') sellist_gram = grammar.get_subgrammar('<sellist>') assert '<modifyop>' in grammar.gram_keys modifyopgram = grammar2.get_subgrammar('<modifyop>') dontuse = ['<start>'] db_dict = [] _, gram2ind = get_grammar2index(sellist_gram, db_dict, dontuse=dontuse) with open('grammar2index_grammar_2_sellist.json', 'w') as f: json.dump(gram2ind, f) _, gram2ind = get_grammar2index(modifyopgram, db_dict, dontuse=dontuse) with open('grammar2index_grammar_2_modifyop.json', 'w') as f: json.dump(gram2ind, f) _, w2ind = get_word2index(grammar, db_dict, dontuse=dontuse) with open('terminals2index_grammar_2.json', 'w') as f: json.dump(w2ind, f) with open('spider_tables_lowercase.json', 'r') as f: spider_db = json.loads(f.read()) tab2ind = {} col2ind = {} allcols_db2ind = {} for db in spider_db.keys(): dbp = db.lower() tab2ind[dbp] = { tab.lower(): i for i, tab in enumerate(spider_db[db].keys()) } all_cols = list( set([ c.lower() for t in spider_db[db].keys() for c in spider_db[db][t] ])) allcols_db2ind[dbp] = {c.lower(): i for i, c in enumerate(all_cols)} col2ind[dbp] = {} for tab in spider_db[db].keys(): tabp = tab.lower() col2ind[dbp][tabp] = { col.lower(): i for i, col in enumerate(spider_db[db][tab]) } with open('spider_tab2index.json', 'w') as f: json.dump(tab2ind, f) with open('spider_col2index.json', 'w') as f: json.dump(col2ind, f) with open('spider_db_cols2ind.json', 'w') as f: json.dump(allcols_db2ind, f)
def resave_data(gfile='sql_simple_transition_2.bnf'): cols = [0, 2] grm = Grammar(gfile) assert '<modifyop>' in grm.gr grm_terms = grm.terminal_toks # grm_terms1 = [g.upper() for g in grm_terms if '[' not in g] # grm_terms = [g for g in grm_terms if '[' in g] # grm_terms.extend(grm_terms1) td = pd.read_pickle(trdf) te = pd.read_pickle(tedf) td.to_pickle(trdf + '.bak') te.to_pickle(tedf + '.bak') for i, c in enumerate(cols): if i == 0: continue grm1 = Grammar(gfile) assert '<modifyop>' in grm1.gr max_ = TrainableRepresentation(grm1, x_dim=1, start=dnames[c], return_log=False).max_length numtr, numte = 0, 0 td['drop_{}'.format(i)] = td[dnames[c]].apply( lambda x: check_(x, max_, grm_terms)) te['drop_{}'.format(i)] = te[dnames[c]].apply( lambda x: check_(x, max_, grm_terms)) print('Before shapes {}, {}'.format(td.shape, te.shape)) td = td[~td['drop_0']] td = td[~td['drop_1']] te = te[~te['drop_0']] te = te[~te['drop_1']] del td['drop_0'] del td['drop_1'] del te['drop_0'] del te['drop_1'] print('After shapes {}, {}'.format(td.shape, te.shape)) td.to_pickle(trdf) te.to_pickle(tedf)
def __init__(self, fname, cuda=True): super().__init__() gr = Grammar(fname) self.gr = gr self.grammar = gr.graph self.terminals = gr.terminals self.terminal_toks = gr.terminal_toks self.ors = gr.ors self.or_loc = {l: i for i, l in enumerate(self.ors)} self.ands = gr.ands self.learners = self.create_learners(cuda=cuda)
#heng = torch.zeros(self.rnn_hid, 1, self.word_dim) #hsql = torch.zeros(self.rnn_hid, 1, self.embed_size) heng = torch.zeros(1, 1, self.rnn_hid) hsql = torch.zeros(1, 1, self.rnn_hid) #rets, _, inds, ccount = self.grammar_propagation(self.possibles, x, heng, hsql) return self.grammar_propagation(self.possibles, x, heng, hsql) # def get_tokens(self, probs): if __name__ == '__main__': gram = Grammar('sql_simple_transition_2.bnf') st = '<modifyop>' rsl = RecursiveDecoder(gram, spider_db, hidden_size=256, word_dim=1024, rnn_hid=256, start=st) x = torch.randn(1024, 1, 11) _, _, _, toks, ccount = rsl.forward(x) print(ccount) print(toks) print(rsl.get_sql(toks)) # if os.path.exists('model/RecNN.pkl'): # model = torch.load('model/RecNN.pkl').to(device)