class Seq2SQL(nn.Module): def __init__(self, word_emb, num_words, num_hidden=100, num_layers=2, use_gpu=True): super(Seq2SQL, self).__init__() self.word_emb = word_emb self.num_words = num_words self.num_hidden = num_hidden self.num_layers = num_layers self.use_gpu = use_gpu self.max_col_num = 45 self.max_tok_num = 200 self.COND_OPS = ['EQL', 'GT', 'LT'] self.SQL_TOK = ['<UNK>', '<BEG>', '<END>', 'WHERE', 'AND' ] + self.COND_OPS # GloVe Word Embedding self.embed_layer = WordEmbedding(word_emb, num_words, self.SQL_TOK, use_gpu) # Aggregation Classifier self.agg_classifier = AggregationClassifier(num_words, num_hidden, num_layers) # SELECT Column(s) self.sel_classifier = SelectClassifier(num_words, num_hidden, num_layers, self.max_tok_num) # WHERE Clause self.whr_classifier = WhereClassifier(num_words, num_hidden, num_layers, self.max_col_num, self.max_tok_num, use_gpu) # run on GPU if use_gpu: self.cuda() def generate_g_s(self, q, col, query): # data format # <BEG> WHERE cond1_col cond1_op cond1 # AND cond2_col cond2_op cond2 # AND ... <END> ret_seq = [] for cur_q, cur_col, cur_query in zip(q, col, query): connect_col = [ tok for col_tok in cur_col for tok in col_tok + [','] ] all_toks = self.SQL_TOK + connect_col + [None] + cur_q + [None] cur_seq = [all_toks.index('<BEG>')] if 'WHERE' in cur_query: cur_where_query = cur_query[cur_query.index('WHERE'):] cur_seq = cur_seq + map( lambda tok: all_toks.index(tok) if tok in all_toks else 0, cur_where_query) cur_seq.append(all_toks.index('<END>')) ret_seq.append(cur_seq) return ret_seq def forward(self, q, col, col_num, classif_flag, g_s=None, reinforce=False): agg_classif, sel_classif, whr_classif = classif_flag agg_score, sel_score, whr_score = None, None, None x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col) if agg_classif: agg_score = self.agg_classifier(x_emb_var, x_len) if sel_classif: col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch( col) sel_score = self.sel_classifier(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) if whr_classif: whr_score = self.whr_classifier(x_emb_var, x_len, g_s, reinforce=reinforce) return (agg_score, sel_score, whr_score) def loss(self, score, ref_score, classif_flag, g_s): agg_classif, sel_classif, whr_classif = classif_flag agg_score, sel_score, whr_score = score loss = 0 if agg_classif: agg_ref = torch.from_numpy(np.array(map(lambda x: x[0], ref_score))) agg_ref_var = Variable(agg_ref) if self.use_gpu: agg_ref_var = agg_ref_var.cuda() loss += nn.CrossEntropyLoss()(agg_score, agg_ref_var) if sel_classif: sel_ref = torch.from_numpy(np.array(map(lambda x: x[1], ref_score))) sel_ref_var = Variable(sel_ref) if self.use_gpu: sel_ref_var = sel_ref_var.cuda() loss += nn.CrossEntropyLoss()(sel_score, sel_ref_var) if whr_classif: g_s_len = len(g_s) for s, g_s_i in enumerate(g_s): whr_ref_var = Variable(torch.from_numpy(np.array(g_s_i[1:]))) if self.use_gpu: whr_ref_var = whr_ref_var.cuda() loss += (nn.CrossEntropyLoss()(whr_score[s, :len(g_s_i) - 1], whr_ref_var) / g_s_len) return loss def reinforce_backward(self, score, rewards): agg_score, sel_score, whr_score = score cur_reward = rewards[:] eof = self.SQL_TOK.index('<END>') for whr_score_t in whr_score[1]: reward_inp = torch.FloatTensor(cur_reward).unsqueeze(1) if self.use_gpu: reward_inp = reward_inp.cuda() whr_score_t.reinforce(reward_inp) for b, _ in enumerate(rewards): if whr_score_t[b].data.cpu().numpy()[0] == eof: cur_reward[b] = 0 torch.autograd.backward(whr_score[1], [None for _ in whr_score[1]]) return def check_acc(self, classif_queries, g_s_queries, classif_flag): agg_classif, sel_classif, whr_classif = classif_flag tot_err = agg_err = sel_err = whr_err = whr_num_err = whr_col_err = whr_op_err = whr_val_err = 0.0 for classif_qry, g_s_qry in zip(classif_queries, g_s_queries): agg_err_inc = 1 if agg_classif and classif_qry['agg'] != g_s_qry[ 'agg'] else 0 agg_err += agg_err_inc sel_err_inc = 1 if sel_classif and classif_qry['sel'] != g_s_qry[ 'sel'] else 0 sel_err += sel_err_inc if whr_classif: flag = True whr_classifier = classif_qry['conds'] whr_g_s = g_s_qry['conds'] if len(whr_classifier) != len(whr_g_s): flag = False whr_num_err += 1 elif set(x[0] for x in whr_classifier) != set(x[0] for x in whr_g_s): flag = False whr_col_err += 1 if flag: for whr_class_i in whr_classifier: g_s_idx = tuple(x[0] for x in whr_g_s).index(whr_class_i[0]) if flag and whr_g_s[g_s_idx][1] != whr_class_i[1]: flag = False whr_op_err += 1 break if flag: for whr_class_i in whr_classifier: g_s_idx = tuple(x[0] for x in whr_g_s).index(whr_class_i[0]) if flag and unicode(whr_g_s[g_s_idx][2]).lower() != \ unicode(whr_class_i[2]).lower(): flag = False whr_val_err += 1 break if not flag: whr_err += 1 if agg_err_inc > 0 or sel_err_inc > 0 or not flag: tot_err += 1 return np.array((agg_err, sel_err, whr_err)), tot_err def gen_query(self, score, q, col, raw_q, raw_col, classif_flag, reinforce=False, verbose=False): def merge_tokens(tok_list, raw_tok_str): tok_str = raw_tok_str.lower() special = { '-LRB-': '(', '-RRB-': ')', '-LSB-': '[', '-RSB-': ']', '``': '"', '\'\'': '"', '--': u'\u2013' } ret = '' double_quote_pair_track = 0 for raw_tok in tok_list: if not raw_tok: continue tok = special.get(raw_tok, raw_tok) if tok == '"': double_quote_pair_track = 1 - double_quote_pair_track if double_quote_pair_track: ret = ret + ' ' if len(ret) == 0: pass elif len(ret) > 0 and ret + ' ' + tok in tok_str: ret = ret + ' ' elif len(ret) > 0 and ret + tok in tok_str: pass elif (tok[0] not in string.ascii_lowercase) and ( tok[0] not in string.digits) and (tok[0] not in '$('): pass elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) and \ (ret[-1] != '"' or not double_quote_pair_track): ret = ret + ' ' ret = ret + tok return ret.strip() agg_classif, sel_classif, whr_classif = classif_flag agg_score, sel_score, whr_score = score ret_queries = [] batch_len = len(agg_score) if agg_classif else len( sel_score) if sel_classif else len( whr_score[0]) if reinforce else len(whr_score) for b in range(batch_len): cur_query = {} if agg_classif: cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy()) if sel_classif: cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy()) if whr_classif: cur_query['conds'] = [] all_toks = self.SQL_TOK + [ x for toks in col[b] for x in toks + [','] ] + [''] + q[b] + [''] whr_toks = [] if reinforce: for choices in whr_score[1]: if choices[b].data.cpu().numpy()[0] < len(all_toks): whr_val = all_toks[choices[b].data.cpu().numpy() [0]] else: whr_val = '<UNK>' if whr_val == '<END>': break whr_toks.append(whr_val) else: for where_score in whr_score[b].data.cpu().numpy(): whr_tok = np.argmax(where_score) whr_val = all_toks[whr_tok] if whr_val == '<END>': break whr_toks.append(whr_val) if verbose: print whr_toks if len(whr_toks) > 0: whr_toks = whr_toks[1:] st = 0 while st < len(whr_toks): cur_cond = [None, None, None] ed = len(whr_toks) if 'AND' not in whr_toks[st:] \ else whr_toks[st:].index('AND') + st if 'EQL' in whr_toks[st:ed]: op = whr_toks[st:ed].index('EQL') + st cur_cond[1] = 0 elif 'GT' in whr_toks[st:ed]: op = whr_toks[st:ed].index('GT') + st cur_cond[1] = 1 elif 'LT' in whr_toks[st:ed]: op = whr_toks[st:ed].index('LT') + st cur_cond[1] = 2 else: op = st cur_cond[1] = 0 sel_col = whr_toks[st:op] to_idx = [x.lower() for x in raw_col[b]] classif_col = merge_tokens(sel_col, raw_q[b] + ' || ' + \ ' || '.join(raw_col[b])) if classif_col in to_idx: cur_cond[0] = to_idx.index(classif_col) else: cur_cond[0] = 0 cur_cond[2] = merge_tokens(whr_toks[op + 1:ed], raw_q[b]) cur_query['conds'].append(cur_cond) st = ed + 1 ret_queries.append(cur_query) return ret_queries
class SuperModel(nn.Module): def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): super(SuperModel, self).__init__() self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.trainable_emb = trainable_emb self.table_type = table_type self.use_hs = use_hs self.SQL_TOK = [ '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>' ] # word embedding layer self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, trainable=trainable_emb) # initial all modules self.multi_sql = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.multi_sql.eval() self.key_word = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.key_word.eval() self.col = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.col.eval() self.op = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.op.eval() self.agg = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.agg.eval() self.root_teminal = RootTeminalPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.root_teminal.eval() self.des_asc = DesAscLimitPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.des_asc.eval() self.having = HavingPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.having.eval() self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.andor.eval() self.softmax = nn.Softmax() #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss() self.bce_logit = nn.BCEWithLogitsLoss() self.sigm = nn.Sigmoid() if gpu: self.cuda() self.path_not_found = 0 def forward(self, q_seq, history, tables): # if self.part: # return self.part_forward(q_seq,history,tables) # else: return self.full_forward(q_seq, history, tables) def full_forward(self, q_seq, history, tables): B = len(q_seq) # print("q_seq:{}".format(q_seq)) # print("Batch size:{}".format(B)) q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq) col_seq = to_batch_tables(tables, B, self.table_type) col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch( col_seq) mkw_emb_var = self.embed_layer.gen_word_list_embedding( ["none", "except", "intersect", "union"], (B)) mkw_len = np.full(q_len.shape, 4, dtype=np.int64) kw_emb_var = self.embed_layer.gen_word_list_embedding( ["where", "group by", "order by"], (B)) kw_len = np.full(q_len.shape, 3, dtype=np.int64) stack = Stack() stack.push(("root", None)) history = [["root"]] * B andor_cond = "" has_limit = False # sql = {} current_sql = {} sql_stack = [] idx_stack = [] kw_stack = [] kw = "" nested_label = "" has_having = False timeout = time.time( ) + 2 # set timer to prevent infinite recursion in SQL generation failed = False while not stack.isEmpty(): if time.time() > timeout: failed = True break vet = stack.pop() # print(vet) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) if len(idx_stack) > 0 and stack.size() < idx_stack[-1]: # print("pop!!!!!!!!!!!!!!!!!!!!!!") idx_stack.pop() current_sql = sql_stack.pop() kw = kw_stack.pop() # current_sql = current_sql["sql"] # history.append(vet) # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size())) if isinstance(vet, tuple) and vet[0] == "root": if history[0][-1] != "root": history[0].append("root") hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) if vet[1] != "original": idx_stack.append(stack.size()) sql_stack.append(current_sql) kw_stack.append(kw) else: idx_stack.append(stack.size()) sql_stack.append(sql_stack[-1]) kw_stack.append(kw) if "sql" in current_sql: current_sql["nested_sql"] = {} current_sql["nested_label"] = nested_label current_sql = current_sql["nested_sql"] elif isinstance(vet[1], dict): vet[1]["sql"] = {} current_sql = vet[1]["sql"] elif vet[1] != "original": current_sql["sql"] = {} current_sql = current_sql["sql"] # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size())) if vet[1] == "nested" or vet[1] == "original": stack.push("none") history[0].append("none") else: score = self.multi_sql.forward(q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len) label = np.argmax(score[0].data.cpu().numpy()) label = SQL_OPS[label] history[0].append(label) stack.push(label) if label != "none": nested_label = label elif vet in ('intersect', 'except', 'union'): stack.push(("root", "nested")) stack.push(("root", "original")) # history[0].append("root") elif vet == "none": score = self.key_word.forward(q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len) kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] # print("kw_num_score:{}".format(kw_num_score)) # print("kw_score:{}".format(kw_score)) num_kw = np.argmax(kw_num_score[0]) kw_score = list(np.argsort(-kw_score[0])[:num_kw]) kw_score.sort(reverse=True) # print("num_kw:{}".format(num_kw)) for kw in kw_score: stack.push(KW_OPS[kw]) stack.push("select") elif vet in ("select", "orderBy", "where", "groupBy", "having"): kw = vet current_sql[kw] = [] history[0].append(vet) stack.push(("col", vet)) # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) # label = score[0].data.cpu().numpy() # andor_cond = COND_OPS[label] # history.append("") # elif vet == "groupBy": # score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,) elif isinstance(vet, tuple) and vet[0] == "col": # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size())) score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) col_num_score, col_score = [ x.data.cpu().numpy() for x in score ] col_num = np.argmax(col_num_score[0]) + 1 # double check cols = np.argsort(-col_score[0])[:col_num] # print(col_num) # print("col_num_score:{}".format(col_num_score)) # print("col_score:{}".format(col_score)) for col in cols: if vet[1] == "where": stack.push(("op", "where", col)) elif vet[1] != "groupBy": stack.push(("agg", vet[1], col)) elif vet[1] == "groupBy": history[0].append(index_to_column_name(col, tables)) current_sql[kw].append( index_to_column_name(col, tables)) #predict and or or when there is multi col in where condition if col_num > 1 and vet[1] == "where": score = self.andor.forward(q_emb_var, q_len, hs_emb_var, hs_len) label = np.argmax(score[0].data.cpu().numpy()) andor_cond = COND_OPS[label] current_sql[kw].append(andor_cond) if vet[1] == "groupBy" and col_num > 0: score = self.having.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, cols[0], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) if label == 1: has_having = (label == 1) # stack.insert(-col_num,"having") stack.push("having") # history.append(index_to_column_name(cols[-1], tables[0])) elif isinstance(vet, tuple) and vet[0] == "agg": history[0].append(index_to_column_name(vet[2], tables)) if vet[1] not in ("having", "orderBy"): #DEBUG-ed 20180817 try: current_sql[kw].append( index_to_column_name(vet[2], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) agg_num_score, agg_score = [ x.data.cpu().numpy() for x in score ] agg_num = np.argmax(agg_num_score[0]) # double check agg_idxs = np.argsort(-agg_score[0])[:agg_num] # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs])) if len(agg_idxs) > 0: history[0].append(AGG_OPS[agg_idxs[0]]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append(AGG_OPS[agg_idxs[0]]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg_idxs[0]])) #DEBUG-ed 20180817 else: stack.push( ("op", "having", vet[2], AGG_OPS[agg_idxs[0]])) for agg in agg_idxs[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(AGG_OPS[agg]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append( index_to_column_name(vet[2], tables)) current_sql[kw].append(AGG_OPS[agg]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg])) else: stack.push(("op", "having", vet[2], agg_idxs)) if len(agg_idxs) == 0: if vet[1] not in ("having", "orderBy"): current_sql[kw].append("none_agg") elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], "none_agg")) else: stack.push(("op", "having", vet[2], "none_agg")) # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs]) # if vet[1] == "having": # stack.push(("op","having",vet[2],agg_idxs)) # if vet[1] == "orderBy": # stack.push(("des_asc",vet[2],agg_idxs)) # if vet[1] == "groupBy" and has_having: # stack.push("having") elif isinstance(vet, tuple) and vet[0] == "op": if vet[1] == "where": # current_sql[kw].append(index_to_column_name(vet[2], tables)) history[0].append(index_to_column_name(vet[2], tables)) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) op_num_score, op_score = [x.data.cpu().numpy() for x in score] op_num = np.argmax( op_num_score[0] ) + 1 # num_score 0 maps to 1 in truth, must have at least one op ops = np.argsort(-op_score[0])[:op_num] # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops]) if op_num > 0: history[0].append(NEW_WHERE_OPS[ops[0]]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], ops[0])) else: stack.push(("root_teminal", vet[2], ops[0])) # current_sql[kw].append(NEW_WHERE_OPS[ops[0]]) for op in ops[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(NEW_WHERE_OPS[op]) # current_sql[kw].append(index_to_column_name(vet[2], tables)) # current_sql[kw].append(NEW_WHERE_OPS[op]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], op)) else: stack.push(("root_teminal", vet[2], op)) # stack.push(("root_teminal",vet[2])) elif isinstance(vet, tuple) and vet[0] == "root_teminal": score = self.root_teminal.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) label = ROOT_TERM_OPS[label] if len(vet) == 4: current_sql[kw].append(index_to_column_name( vet[1], tables)) current_sql[kw].append(vet[2]) current_sql[kw].append(NEW_WHERE_OPS[vet[3]]) else: # print("kw:{}".format(kw)) try: current_sql[kw].append( index_to_column_name(vet[1], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) current_sql[kw].append(NEW_WHERE_OPS[vet[2]]) if label == "root": history[0].append("root") current_sql[kw].append({}) # current_sql = current_sql[kw][-1] stack.push(("root", current_sql[kw][-1])) else: current_sql[kw].append("terminal") elif isinstance(vet, tuple) and vet[0] == "des_asc": current_sql[kw].append(index_to_column_name(vet[1], tables)) current_sql[kw].append(vet[2]) score = self.des_asc.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) dec_asc, has_limit = DEC_ASC_OPS[label] history[0].append(dec_asc) current_sql[kw].append(dec_asc) current_sql[kw].append(has_limit) # print("{}".format(current_sql)) if failed: return None print("history:{}".format(history[0])) if len(sql_stack) > 0: current_sql = sql_stack[0] # print("{}".format(current_sql)) return current_sql def gen_col(self, col, table, table_alias_dict): colname = table["column_names_original"][col[2]][1] table_idx = table["column_names_original"][col[2]][0] if table_idx not in table_alias_dict: return colname return "T{}.{}".format(table_alias_dict[table_idx], colname) def gen_group_by(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql)): # if len(sql[i+1]) == 0: # if sql[i+1] == "none_agg": ret.append(self.gen_col(sql[i], table, table_alias_dict)) # else: # ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_select(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql), 2): # if len(sql[i+1]) == 0: if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append(self.gen_col(sql[i], table, table_alias_dict)) else: ret.append("{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_where(self, sql, table, table_alias_dict): if len(sql) == 0: return "" start_idx = 0 andor = "and" if isinstance(sql[0], basestring): start_idx += 1 andor = sql[0] ret = [] for i in range(start_idx, len(sql), 3): col = self.gen_col(sql[i], table, table_alias_dict) op = sql[i + 1] val = sql[i + 2] where_item = "" if val == "terminal": where_item = "{} {} '{}'".format(col, op, val) else: val = self.gen_sql(val, table) where_item = "{} {} ({})".format(col, op, val) if op == "between": #TODO temprarily fixed where_item += " and 'terminal'" ret.append(where_item) return "where {}".format(" {} ".format(andor).join(ret)) def gen_orderby(self, sql, table, table_alias_dict): ret = [] limit = "" if sql[-1] == True: limit = "limit 1" for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append("{} {}".format( self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) else: ret.append("{}({}) {}".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) return "order by {} {}".format(",".join(ret), limit) def gen_having(self, sql, table, table_alias_dict): ret = [] for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg": col = self.gen_col(sql[i], table, table_alias_dict) else: col = "{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)) op = sql[i + 2] val = sql[i + 3] if val == "terminal": ret.append("{} {} '{}'".format(col, op, val)) else: val = self.gen_sql(val, table) ret.append("{} {} ({})".format(col, op, val)) return "having {}".format(",".join(ret)) def find_shortest_path(self, start, end, graph): stack = [[start, []]] visited = set() while len(stack) > 0: ele, history = stack.pop() if ele == end: return history for node in graph[ele]: if node[0] not in visited: stack.append((node[0], history + [(node[0], node[1])])) visited.add(node[0]) print("table {} table {}".format(start, end)) # print("could not find path!!!!!{}".format(self.path_not_found)) self.path_not_found += 1 # return [] def gen_from(self, candidate_tables, table): def find(d, col): if d[col] == -1: return col return find(d, d[col]) def union(d, c1, c2): r1 = find(d, c1) r2 = find(d, c2) if r1 == r2: return d[r1] = r2 ret = "" if len(candidate_tables) <= 1: if len(candidate_tables) == 1: ret = "from {}".format( table["table_names_original"][list(candidate_tables)[0]]) else: ret = "from {}".format(table["table_names_original"][0]) #TODO: temporarily settings return {}, ret # print("candidate:{}".format(candidate_tables)) table_alias_dict = {} uf_dict = {} for t in candidate_tables: uf_dict[t] = -1 idx = 1 graph = defaultdict(list) for acol, bcol in table["foreign_keys"]: t1 = table["column_names"][acol][0] t2 = table["column_names"][bcol][0] graph[t1].append((t2, (acol, bcol))) graph[t2].append((t1, (bcol, acol))) # if t1 in candidate_tables and t2 in candidate_tables: # r1 = find(uf_dict,t1) # r2 = find(uf_dict,t2) # if r1 == r2: # continue # union(uf_dict,t1,t2) # if len(ret) == 0: # ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2], # idx+1,idx,table["column_names_original"][acol][1],idx+1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx+1 # idx += 2 # else: # if t1 in table_alias_dict: # old_t = t1 # new_t = t2 # acol,bcol = bcol,acol # elif t2 in table_alias_dict: # old_t = t2 # new_t = t1 # else: # ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx, # table["table_names"][t2], # idx + 1, idx, # table["column_names_original"][acol][1], # idx + 1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx + 1 # idx += 2 # continue # ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1], # table_alias_dict[old_t],table["column_names_original"][bcol][1]) # table_alias_dict[new_t] = idx # idx += 1 # visited = set() candidate_tables = list(candidate_tables) start = candidate_tables[0] table_alias_dict[start] = idx idx += 1 ret = "from {} as T1".format(table["table_names_original"][start]) try: for end in candidate_tables[1:]: if end in table_alias_dict: continue path = self.find_shortest_path(start, end, graph) prev_table = start if not path: table_alias_dict[end] = idx idx += 1 ret = "{} join {} as T{}".format( ret, table["table_names_original"][end], table_alias_dict[end], ) continue for node, (acol, bcol) in path: if node in table_alias_dict: prev_table = node continue table_alias_dict[node] = idx idx += 1 ret = "{} join {} as T{} on T{}.{} = T{}.{}".format( ret, table["table_names_original"][node], table_alias_dict[node], table_alias_dict[prev_table], table["column_names_original"][acol][1], table_alias_dict[node], table["column_names_original"][bcol][1]) prev_table = node except: traceback.print_exc() print("db:{}".format(table["db_id"])) # print(table["db_id"]) return table_alias_dict, ret # if len(candidate_tables) != len(table_alias_dict): # print("error in generate from clause!!!!!") return table_alias_dict, ret def gen_sql(self, sql, table): select_clause = "" from_clause = "" groupby_clause = "" orderby_clause = "" having_clause = "" where_clause = "" nested_clause = "" cols = {} candidate_tables = set() nested_sql = {} nested_label = "" parent_sql = sql # if "sql" in sql: # sql = sql["sql"] if "nested_label" in sql: nested_label = sql["nested_label"] nested_sql = sql["nested_sql"] sql = sql["sql"] elif "sql" in sql: sql = sql["sql"] for key in sql: if key not in KW_WITH_COL: continue for item in sql[key]: if isinstance(item, tuple) and len(item) == 3: if table["column_names"][item[2]][0] != -1: candidate_tables.add(table["column_names"][item[2]][0]) table_alias_dict, from_clause = self.gen_from(candidate_tables, table) ret = [] if "select" in sql: select_clause = self.gen_select(sql["select"], "select", table, table_alias_dict) if len(select_clause) > 0: ret.append(select_clause) else: print("select not found:{}".format(parent_sql)) else: print("select not found:{}".format(parent_sql)) if len(from_clause) > 0: ret.append(from_clause) if "where" in sql: where_clause = self.gen_where(sql["where"], table, table_alias_dict) if len(where_clause) > 0: ret.append(where_clause) if "groupBy" in sql: ## DEBUG-ed order groupby_clause = self.gen_group_by(sql["groupBy"], "group by", table, table_alias_dict) if len(groupby_clause) > 0: ret.append(groupby_clause) if "orderBy" in sql: orderby_clause = self.gen_orderby(sql["orderBy"], table, table_alias_dict) if len(orderby_clause) > 0: ret.append(orderby_clause) if "having" in sql: having_clause = self.gen_having(sql["having"], table, table_alias_dict) if len(having_clause) > 0: ret.append(having_clause) if len(nested_label) > 0: nested_clause = "{} {}".format(nested_label, self.gen_sql(nested_sql, table)) if len(nested_clause) > 0: ret.append(nested_clause) return " ".join(ret) def check_acc(self, pred_sql, gt_sql): pass
class SQLNet(nn.Module): def __init__(self, word_emb, N_word, N_h=100, N_depth=2, gpu=False, use_ca=True, trainable_emb=False): super(SQLNet, self).__init__() self.use_ca = use_ca self.trainable_emb = trainable_emb self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.max_col_num = 45 self.max_tok_num = 200 self.SQL_TOK = [ '<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>' ] self.COND_OPS = ['>', '<', '==', '!='] # 词向量,可选择自己训练或者使用训练好的词向量,这里选用加载好的词向量 self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) # 预测列数目 self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca) # 预测那个列被选中了 self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca) # 预测相应选定列的聚合函数 self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca) # 预测条件数、条件列、条件操作和条件值 self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu) # 预测条件关系,如“and”、“or” self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca) self.CE = nn.CrossEntropyLoss() #交叉熵损失函数 self.softmax = nn.Softmax(dim=-1) self.log_softmax = nn.LogSoftmax() self.bce_logit = nn.BCEWithLogitsLoss() if gpu: self.cuda() # q:问题,gt_cond_seq:三元组 目的:要选择那一列 def generate_gt_where_seq_test(self, q, gt_cond_seq): ret_seq = [] for cur_q, ans in zip(q, gt_cond_seq): temp_q = u"".join(cur_q) cur_q = [u'<BEG>'] + cur_q + [u'<END>'] # 在每个问题前加<BEG>和结尾加<END> record = [] #如果条件值在问题中,标记(TRUE,条件值) record_cond = [] for cond in ans: if cond[2] not in temp_q: record.append((False, cond[2])) else: record.append((True, cond[2])) for idx, item in enumerate(record): temp_ret_seq = [] if item[0]: temp_ret_seq.append(0) temp_ret_seq.extend( list( range( temp_q.index(item[1]) + 1, temp_q.index(item[1]) + len(item[1]) + 1))) #获取条件值的索引 temp_ret_seq.append(len(cur_q) - 1) else: temp_ret_seq.append([0, len(cur_q) - 1]) record_cond.append(temp_ret_seq) ret_seq.append(record_cond) return ret_seq #q:问题,col:表头名字,col_num:有几个表头列,gt_where:conds中条件值不出现在问题中,gt_conds:conds,gt_sel:选择那列,gt_sel_num:选择几列 def forward(self, q, col, col_num, gt_where=None, gt_cond=None, reinforce=False, gt_sel=None, gt_sel_num=None): B = len(q) #batch_size的大小 sel_num_score = None agg_score = None sel_score = None cond_score = None #预测聚合函数 if self.trainable_emb: x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch( col) max_x_len = max(x_len) agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=gt_sel) x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch( col) max_x_len = max(x_len) sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch( col) max_x_len = max(x_len) cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce) where_rela_score = None else: x_emb_var, x_len = self.embed_layer.gen_x_batch( q, col ) #x_len:batch中每个问题的长度,[x_emb_var:batch_size,max_seq_len,word_embedding_size] col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch( col) #列名向量化,长度,几个列 sel_num_score = self.sel_num( x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) #[16,4]对问题的编码经过lstm,linear,softmax之后乘以编码 # x_emb_var: embedding of each question # x_len: length of each question # col_inp_var: embedding of each header # col_name_len: length of each header # col_len: number of headers in each table, array type # col_num: number of headers in each table, list type if gt_sel_num: pr_sel_num = gt_sel_num else: pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) #【16,19】 if gt_sel: pr_sel = gt_sel else: num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) sel = sel_score.data.cpu().numpy() pr_sel = [ list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num)) ] agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=pr_sel, gt_sel_num=pr_sel_num) #【16,4,6】 where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) #【16,3】 cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce) #4=>[16,5] return (sel_num_score, sel_score, agg_score, cond_score, where_rela_score) def loss(self, score, truth_num, gt_where): sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score B = len(truth_num) loss = 0 # Evaluate select number sel_num_truth = list(map(lambda x: x[0], truth_num)) #聚合函数个数 sel_num_truth = torch.from_numpy( np.array(sel_num_truth)).long() #.astype(float)) if self.gpu: sel_num_truth = Variable(sel_num_truth.cuda()) else: sel_num_truth = Variable(sel_num_truth) #选择几个列的损失 loss += self.CE(sel_num_score, sel_num_truth) # Evaluate select column选择哪个列的损失 T = len(sel_score[0]) truth_prob = np.zeros((B, T), dtype=np.float32) for b in range(B): truth_prob[b][list(truth_num[b][1])] = 1 data = torch.from_numpy(truth_prob) if self.gpu: sel_col_truth_var = Variable(data.cuda()) else: sel_col_truth_var = Variable(data) sigm = nn.Sigmoid() sel_col_prob = sigm(sel_score) bce_loss = -torch.mean( 3 * (sel_col_truth_var * torch.log(sel_col_prob + 1e-10)) + (1 - sel_col_truth_var) * torch.log(1 - sel_col_prob + 1e-10) ) #这儿采用bceloss:-w*[y*log(x)+(1-y)*log(1-x)] loss += bce_loss # Evaluate select aggregation选择聚合函数的损失交叉熵 for b in range(len(truth_num)): data = torch.from_numpy(np.array(truth_num[b][2])) #真实的聚合函数 if self.gpu: sel_agg_truth_var = Variable(data.cuda()) else: sel_agg_truth_var = Variable(data.long()) sel_agg_pred = agg_score[b, :len(truth_num[b][1])] #聚合函数共六种 loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num) cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score # Evaluate the number of conditions预测多少个conds的损失交叉熵 cond_num_truth = list(map(lambda x: x[3], truth_num)) data = torch.from_numpy(np.array(cond_num_truth).astype(float)).long() if self.gpu: try: cond_num_truth_var = Variable(data.cuda()) except: print("cond_num_truth_var error") print(data) exit(0) else: cond_num_truth_var = Variable(data) loss += self.CE(cond_num_score, cond_num_truth_var) # Evaluate the columns of conditions评估条件列 T = len(cond_col_score[0]) truth_prob = np.zeros((B, T), dtype=np.float32) for b in range(B): if len(truth_num[b][4]) > 0: truth_prob[b][list(truth_num[b][4])] = 1 #条件列 data = torch.from_numpy(truth_prob) if self.gpu: cond_col_truth_var = Variable(data.cuda()) else: cond_col_truth_var = Variable(data) sigm = nn.Sigmoid() cond_col_prob = sigm(cond_col_score) bce_loss = -torch.mean( 3 * (cond_col_truth_var * torch.log(cond_col_prob + 1e-10)) + (1 - cond_col_truth_var) * torch.log(1 - cond_col_prob + 1e-10)) loss += bce_loss # Evaluate the operator of conditions评估操作条件 for b in range(len(truth_num)): if len(truth_num[b][5]) == 0: #条件类型 continue data = torch.from_numpy(np.array(truth_num[b][5])).long() if self.gpu: cond_op_truth_var = Variable(data.cuda()) else: cond_op_truth_var = Variable(data) cond_op_pred = cond_op_score[b, :len(truth_num[b][5])] # try: loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num)) # except: # print(cond_op_pred) # print(cond_op_truth_var) # exit(0) #Evaluate the strings of conditions评估条件串 for b in range(len(gt_where)): for idx in range(len(gt_where[b])): cond_str_truth = gt_where[b][idx] if len(cond_str_truth) == 1: continue data = torch.from_numpy(np.array(cond_str_truth[1:])).long() if self.gpu: cond_str_truth_var = Variable(data.cuda()) else: cond_str_truth_var = Variable(data) str_end = len(cond_str_truth) - 1 cond_str_pred = cond_str_score[b, idx, :str_end] loss += (self.CE(cond_str_pred, cond_str_truth_var) \ / (len(gt_where) * len(gt_where[b]))) # Evaluate condition relationship, and / or评估条件关系 where_rela_truth = list(map(lambda x: x[6], truth_num)) data = torch.from_numpy(np.array(where_rela_truth)).long() if self.gpu: try: where_rela_truth = Variable(data.cuda()) except: print("where_rela_truth error") print(data) exit(0) else: where_rela_truth = Variable(data) loss += self.CE(where_rela_score, where_rela_truth) return loss def check_acc(self, vis_info, pred_queries, gt_queries): def gen_cond_str(conds, header): if len(conds) == 0: return 'None' cond_str = [] for cond in conds: cond_str.append(header[cond[0]] + ' ' + self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower()) return 'WHERE ' + ' AND '.join(cond_str) tot_err = sel_num_err = agg_err = sel_err = 0.0 cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0 for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)): good = True sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry[ 'agg'], pred_qry['cond_conn_op'] sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry[ 'agg'], gt_qry['cond_conn_op'] if where_rela_gt != where_rela_pred: good = False cond_rela_err += 1 if len(sel_pred) != len(sel_gt): good = False sel_num_err += 1 pred_sel_dict = { k: v for k, v in zip(list(sel_pred), list(agg_pred)) } gt_sel_dict = {k: v for k, v in zip(sel_gt, agg_gt)} if set(sel_pred) != set(sel_gt): good = False sel_err += 1 agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())] agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())] if agg_pred != agg_gt: good = False agg_err += 1 cond_pred = pred_qry['conds'] cond_gt = gt_qry['conds'] if len(cond_pred) != len(cond_gt): good = False cond_num_err += 1 else: cond_op_pred, cond_op_gt = {}, {} cond_val_pred, cond_val_gt = {}, {} for p, g in zip(cond_pred, cond_gt): cond_op_pred[p[0]] = p[1] cond_val_pred[p[0]] = p[2] cond_op_gt[g[0]] = g[1] cond_val_gt[g[0]] = g[2] if set(cond_op_pred.keys()) != set(cond_op_gt.keys()): cond_col_err += 1 good = False where_op_pred = [ cond_op_pred[x] for x in sorted(cond_op_pred.keys()) ] where_op_gt = [ cond_op_gt[x] for x in sorted(cond_op_gt.keys()) ] if where_op_pred != where_op_gt: cond_op_err += 1 good = False where_val_pred = [ cond_val_pred[x] for x in sorted(cond_val_pred.keys()) ] where_val_gt = [ cond_val_gt[x] for x in sorted(cond_val_gt.keys()) ] if where_val_pred != where_val_gt: cond_val_err += 1 good = False if not good: tot_err += 1 return np.array( (sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err, cond_rela_err)), tot_err def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False): """ :param score: :param q: token-questions :param col: token-headers :param raw_q: original question sequence :return: """ def merge_tokens(tok_list, raw_tok_str): tok_str = raw_tok_str # .lower() alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$(' special = { '-LRB-': '(', '-RRB-': ')', '-LSB-': '[', '-RSB-': ']', '``': '"', '\'\'': '"', '--': u'\u2013' } ret = '' double_quote_appear = 0 for raw_tok in tok_list: if not raw_tok: continue tok = special.get(raw_tok, raw_tok) if tok == '"': double_quote_appear = 1 - double_quote_appear if len(ret) == 0: pass elif len(ret) > 0 and ret + ' ' + tok in tok_str: ret = ret + ' ' elif len(ret) > 0 and ret + tok in tok_str: pass elif tok == '"': if double_quote_appear: ret = ret + ' ' # elif tok[0] not in alphabet: # pass elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \ and (ret[-1] != '"' or not double_quote_appear): ret = ret + ' ' ret = ret + tok return ret.strip() sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score # [64,4,6], [64,14], ..., [64,4] sel_num_score = sel_num_score.data.cpu().numpy() sel_score = sel_score.data.cpu().numpy() agg_score = agg_score.data.cpu().numpy() where_rela_score = where_rela_score.data.cpu().numpy() ret_queries = [] B = len(agg_score) cond_num_score,cond_col_score,cond_op_score,cond_str_score =\ [x.data.cpu().numpy() for x in cond_score] for b in range(B): cur_query = {} cur_query['sel'] = [] cur_query['agg'] = [] sel_num = np.argmax(sel_num_score[b]) max_col_idxes = np.argsort(-sel_score[b])[:sel_num] # find the most-probable columns' indexes max_agg_idxes = np.argsort(-agg_score[b])[:sel_num] cur_query['sel'].extend([int(i) for i in max_col_idxes]) cur_query['agg'].extend([i[0] for i in max_agg_idxes]) cur_query['cond_conn_op'] = np.argmax(where_rela_score[b]) cur_query['conds'] = [] cond_num = np.argmax(cond_num_score[b]) all_toks = ['<BEG>'] + q[b] + ['<END>'] max_idxes = np.argsort(-cond_col_score[b])[:cond_num] for idx in range(cond_num): cur_cond = [] cur_cond.append(max_idxes[idx]) # where-col cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op cur_cond_str_toks = [] for str_score in cond_str_score[b][idx]: str_tok = np.argmax(str_score[:len(all_toks)]) str_val = all_toks[str_tok] if str_val == '<END>': break cur_cond_str_toks.append(str_val) cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b])) cur_query['conds'].append(cur_cond) ret_queries.append(cur_query) return ret_queries