def __init__(self, word_emb, char_emb, N_word, N_h=100, N_depth=2, gpu=False, use_ca=True, trainable_emb=False): super(SQLNet, self).__init__() self.use_ca = use_ca self.trainable_emb = trainable_emb self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.max_col_num = 45 self.max_tok_num = 200 self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'] self.COND_OPS = ['EQL', 'GT', 'LT'] #Word embedding if trainable_emb: self.agg_embed_layer = CharacterEmbedding(word_emb, char_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=False) self.sel_embed_layer = CharacterEmbedding(word_emb, char_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=False) self.cond_embed_layer = CharacterEmbedding(word_emb,char_emb,N_word, gpu, self.SQL_TOK, our_model=True, trainable=False) else: self.embed_layer = CharacterEmbedding(word_emb, char_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) #Predict aggregator self.agg_pred = AggPredictor(N_word*2, N_h, N_depth, use_ca=use_ca) #Predict selected column self.sel_pred = SelPredictor(N_word*2, N_h, N_depth, self.max_tok_num, use_ca=use_ca) #Predict number of cond self.cond_pred = SQLNetCondPredictor(N_word*2, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu) self.CE = nn.CrossEntropyLoss() self.softmax = nn.Softmax() self.log_softmax = nn.LogSoftmax() self.bce_logit = nn.BCEWithLogitsLoss() if gpu: self.cuda()
def __init__(self, word_emb, N_word, N_h=100, N_depth=2, gpu=False, use_ca=True, trainable_emb=False): super(SQLNet, self).__init__() self.use_ca = use_ca self.trainable_emb = trainable_emb self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.max_col_num = 45 self.max_tok_num = 200 self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>'] self.COND_OPS = ['>', '<', '==', '!='] # Word embedding self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) # Predict the number of selected columns self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca) #Predict which columns are selected self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca) #Predict aggregation functions of corresponding selected columns self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca) #Predict number of conditions, condition columns, condition operations and condition values self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu) # Predict condition relationship, like 'and', 'or' self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca) self.CE = nn.CrossEntropyLoss() self.softmax = nn.Softmax(dim=-1) self.log_softmax = nn.LogSoftmax() self.bce_logit = nn.BCEWithLogitsLoss() if gpu: self.cuda()
class SQLNet(nn.Module): def __init__(self, word_emb, N_word, N_h=100, N_depth=2, gpu=False, use_ca=True, trainable_emb=False, dr=0.3, temperature=False): super(SQLNet, self).__init__() self.use_ca = use_ca self.trainable_emb = trainable_emb self.temperature = temperature self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.max_col_num = 45 self.max_tok_num = 200 self.SQL_TOK = [ '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>' ] self.COND_OPS = ['EQL', 'GT', 'LT'] #Word embedding if trainable_emb: self.agg_embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) self.sel_embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) self.cond_embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) else: self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) #Predict aggregator self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca, dr=dr, temperature=temperature) #Predict selected column self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca, dr=dr, temperature=temperature) #Predict number of cond self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu, dr=dr, temperature=temperature) self.CE = nn.CrossEntropyLoss() self.softmax = nn.Softmax() self.sigmoid = nn.Sigmoid() self.log_softmax = nn.LogSoftmax() self.bce_logit = nn.BCEWithLogitsLoss() if gpu: self.cuda() def generate_gt_where_seq(self, q, col, query): ret_seq = [] for cur_q, cur_col, cur_query in zip(q, col, query): cur_values = [] st = cur_query.index(u'WHERE')+1 if \ u'WHERE' in cur_query else len(cur_query) all_toks = ['<BEG>'] + cur_q + ['<END>'] while st < len(cur_query): ed = len(cur_query) if 'AND' not in cur_query[st:]\ else cur_query[st:].index('AND') + st if 'EQL' in cur_query[st:ed]: op = cur_query[st:ed].index('EQL') + st elif 'GT' in cur_query[st:ed]: op = cur_query[st:ed].index('GT') + st elif 'LT' in cur_query[st:ed]: op = cur_query[st:ed].index('LT') + st else: raise RuntimeError("No operator in it!") this_str = ['<BEG>'] + cur_query[op + 1:ed] + ['<END>'] cur_seq = [all_toks.index(s) if s in all_toks \ else 0 for s in this_str] cur_values.append(cur_seq) st = ed + 1 ret_seq.append(cur_values) return ret_seq def forward(self, q, col, col_num, pred_entry, gt_where=None, gt_cond=None, reinforce=False, gt_sel=None): B = len(q) pred_agg, pred_sel, pred_cond = pred_entry agg_score = None sel_score = None cond_score = None #Predict aggregator if self.trainable_emb: if pred_agg: x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = \ self.agg_embed_layer.gen_col_batch(col) max_x_len = max(x_len) agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=gt_sel) if pred_sel: x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = \ self.sel_embed_layer.gen_col_batch(col) max_x_len = max(x_len) sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) if pred_cond: x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = \ self.cond_embed_layer.gen_col_batch(col) max_x_len = max(x_len) cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce) else: x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col) col_inp_var, col_name_len, col_len = \ self.embed_layer.gen_col_batch(col) max_x_len = max(x_len) if pred_agg: agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=gt_sel) if pred_sel: sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num) if pred_cond: cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce) return (agg_score, sel_score, cond_score) def interaction_beam_forward(self, q, col, raw_q, raw_col, col_num, beam_size, dec_prefix, stop_step=None, avoid_items=None, confirmed_items=None, dropout_rate=0.0, bool_collect_choices=False, bool_verbal=False): """ @author: Ziyu Yao Beam search decoding for interactive sql generation. Only support batch size=1 and self.trainable_emb=True. """ assert self.trainable_emb, "Support trainable_emb=True only." assert len(q) == 1 dec_prefix = dec_prefix[::-1] hypotheses = [Hypothesis(dec_prefix)] completed_hypotheses = [] table_name = None while True: new_hypotheses = [] for hyp in hypotheses: if hyp.stack.isEmpty(): # sort conds by its col idx conds = hyp.sql_i['conds'] sorted_conds = sorted(conds, key=lambda x: x[0]) hyp.sql_i['conds'] = sorted_conds hyp.sql = generate_sql_q1(hyp.sql_i, raw_q[0], raw_col[0]) if bool_verbal: print("Completed %d-th hypotheses: " % len(completed_hypotheses)) print("tag_seq:{}".format(hyp.tag_seq)) print("dec_seq: {}".format(hyp.dec_seq)) print("sql_i: {}".format(hyp.sql_i)) print("sql: {}".format(hyp.sql)) completed_hypotheses.append(hyp) # add to completion else: vet = hyp.stack.pop() if vet[0] == "sc": x_emb_var, x_len = self.sel_embed_layer.gen_x_batch( q, col, dropout_rate=dropout_rate) col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch( col, dropout_rate=dropout_rate) sel_score = self.sel_pred( x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, dropout_rate=dropout_rate).view(1, -1) prob_sc = self.softmax(sel_score).data.cpu().numpy()[0] hyp.tag_seq.append((OUTSIDE, 'select', 1.0, None)) if len(hyp.dec_prefix): partial_vet, sc_idx = hyp.dec_prefix.pop() assert partial_vet == vet sc_candidates = [sc_idx] else: sc_candidates = np.argsort(-prob_sc) # rm avoid candidates if avoid_items is not None and hyp.dec_seq_idx in avoid_items: sc_candidates = [ sc_idx for sc_idx in sc_candidates if sc_idx not in avoid_items[hyp.dec_seq_idx] ] sc_candidates = sc_candidates[:beam_size] for sc_idx in sc_candidates: if len(sc_candidates) == 1: step_hyp = hyp else: step_hyp = hyp.copy() sc_name = raw_col[0][sc_idx] step_hyp.sql_i['sel'] = sc_idx step_hyp.dec_seq.append((vet, sc_idx)) step_hyp.tag_seq.append( (SELECT_COL, (table_name, sc_name, sc_idx), prob_sc[sc_idx], step_hyp.dec_seq_idx)) step_hyp.add_logprob(np.log(prob_sc[sc_idx])) step_hyp.stack.push(("sa", (sc_idx, sc_name))) step_hyp.dec_seq_idx += 1 new_hypotheses.append(step_hyp) elif vet[0] == "sa": sc_idx, sc_name = vet[1] x_emb_var, x_len = self.agg_embed_layer.gen_x_batch( q, col, dropout_rate=dropout_rate) col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch( col, dropout_rate=dropout_rate) agg_score = self.agg_pred( x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=[sc_idx], dropout_rate=dropout_rate).view(1, -1) prob_sa = self.softmax(agg_score).data.cpu().numpy()[0] if len(hyp.dec_prefix): partial_vet, sa_idx = hyp.dec_prefix.pop() assert partial_vet == vet sa_candidates = [sa_idx] else: sa_candidates = np.argsort(-prob_sa) if avoid_items is not None and hyp.dec_seq_idx in avoid_items: sa_candidates = [ sa_idx for sa_idx in sa_candidates if sa_idx not in avoid_items[hyp.dec_seq_idx] ] sa_candidates = sa_candidates[:beam_size] for sa_idx in sa_candidates: if len(sa_candidates) == 1: step_hyp = hyp else: step_hyp = hyp.copy() sa_name = AGG_OPS[sa_idx] if sa_name == 'None': sa_name = 'none_agg' # for q gen usage step_hyp.sql_i['agg'] = sa_idx step_hyp.dec_seq.append((vet, sa_idx)) step_hyp.tag_seq.append( (SELECT_AGG, (table_name, sc_name, sc_idx), (sa_name, sa_idx), prob_sa[sa_idx], step_hyp.dec_seq_idx)) step_hyp.add_logprob(np.log(prob_sa[sa_idx])) step_hyp.stack.push(("wc", None)) step_hyp.dec_seq_idx += 1 new_hypotheses.append(step_hyp) elif vet[0] == "wc": hyp.tag_seq.append((OUTSIDE, 'where', 1.0, None)) hyp.sql_i['conds'] = [] step_hypotheses = [] x_emb_var, x_len = self.cond_embed_layer.gen_x_batch( q, col, dropout_rate=dropout_rate) col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch( col, dropout_rate=dropout_rate) # wn, wc cond_num_score, cond_col_score = self.cond_pred.cols_forward( x_emb_var, x_len, col_inp_var, col_name_len, col_len, dropout_rate=dropout_rate) prob_wn = self.softmax(cond_num_score.view( 1, -1)).data.cpu().numpy()[0] prob_wc = self.sigmoid(cond_col_score.view( 1, -1)).data.cpu().numpy()[0] if len(hyp.dec_prefix): partial_vet, wn, wc_list = hyp.dec_prefix.pop() assert partial_vet == vet col_num_cols_pair = [(wn, wc_list)] else: col_num_cols_pair = [] sorted_col_num = np.argsort(-prob_wn) sorted_cols = np.argsort(-prob_wc) # filter avoid_items if avoid_items is not None and hyp.dec_seq_idx in avoid_items: sorted_cols = [ col_idx for col_idx in sorted_cols if col_idx not in avoid_items[hyp.dec_seq_idx] ] sorted_col_num = [ col_num for col_num in sorted_col_num if col_num <= len(sorted_cols) ] # fix confirmed items if confirmed_items is not None and hyp.dec_seq_idx in confirmed_items: fixed_cols = list( confirmed_items[hyp.dec_seq_idx]) sorted_col_num = [ col_num - len(fixed_cols) for col_num in sorted_col_num if col_num >= len(fixed_cols) ] sorted_cols = [ col_idx for col_idx in sorted_cols if col_idx not in fixed_cols ] else: fixed_cols = [] if bool_collect_choices: # fake searching to collect some choices col_num_cols_pair.extend([ (1, [col_idx]) for col_idx in sorted_cols[:beam_size] ]) else: for col_num in sorted_col_num: #[:beam_size] if col_num == 0: col_num_cols_pair.append( (len(fixed_cols), fixed_cols)) elif col_num == 1: col_num_cols_pair.extend([ (len(fixed_cols) + 1, fixed_cols + [col_idx]) for col_idx in sorted_cols[:beam_size] ]) elif beam_size == 1: top_cols = list(sorted_cols[:col_num]) # top_cols.sort() col_num_cols_pair.append( (len(fixed_cols) + col_num, fixed_cols + top_cols)) else: combs = combinations( sorted_cols[:10], col_num ) # to reduce beam search time comb_score = [] for comb in combs: score = sum([ np.log(prob_wc[c_idx]) for c_idx in comb ]) comb_score.append((comb, score)) sorted_comb_score = sorted( comb_score, key=lambda x: x[1], reverse=True)[:beam_size] for comb, _ in sorted_comb_score: comb_cols = list(comb) # comb_cols.sort() col_num_cols_pair.append( (len(fixed_cols) + col_num, fixed_cols + comb_cols)) for col_num, cols in col_num_cols_pair: if len(col_num_cols_pair) == 1: step_hyp = hyp else: step_hyp = hyp.copy() step_hyp.dec_seq.append((vet, col_num, cols)) step_hyp.add_logprob(np.log(prob_wn[col_num])) for wc_idx in cols: wc_name = raw_col[0][wc_idx] step_hyp.tag_seq.append( (WHERE_COL, (table_name, wc_name, wc_idx), prob_wc[wc_idx], step_hyp.dec_seq_idx)) step_hyp.add_logprob(np.log(prob_wc[wc_idx])) step_hyp.stack.push(("wo", (wc_idx, wc_name))) step_hyp.dec_seq_idx += 1 step_hypotheses.append(step_hyp) step_hypotheses = Hypothesis.sort_hypotheses( step_hypotheses, beam_size, 0.0) new_hypotheses.extend(step_hypotheses) elif vet[0] == "wo": wc_idx, wc_name = vet[1] chosen_col_gt = [[wc_idx]] x_emb_var, x_len = self.cond_embed_layer.gen_x_batch( q, col, dropout_rate=dropout_rate) col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch( col, dropout_rate=dropout_rate) cond_op_score = self.cond_pred.op_forward( x_emb_var, x_len, col_inp_var, col_name_len, col_len, chosen_col_gt, dropout_rate=dropout_rate).view( 1, 4, -1) #[B=1, 4, |OPS|] prob_wo = self.softmax( cond_op_score[:, 0, :]).data.cpu().numpy()[0] if len(hyp.dec_prefix): partial_vet, wo_idx = hyp.dec_prefix.pop() assert partial_vet == vet wo_candidates = [wo_idx] else: wo_candidates = np.argsort(-prob_wo) if avoid_items is not None and hyp.dec_seq_idx in avoid_items: wo_candidates = [ wo_idx for wo_idx in wo_candidates if wo_idx not in avoid_items[hyp.dec_seq_idx] ] wo_candidates = wo_candidates[:beam_size] for wo_idx in wo_candidates: if len(wo_candidates) == 1: step_hyp = hyp else: step_hyp = hyp.copy() wo_name = COND_OPS[wo_idx] step_hyp.dec_seq.append((vet, wo_idx)) step_hyp.tag_seq.append( (WHERE_OP, ((table_name, wc_name, wc_idx), ), (wo_name, wo_idx), prob_wo[wo_idx], step_hyp.dec_seq_idx)) step_hyp.add_logprob(np.log(prob_wo[wo_idx])) step_hyp.stack.push( ("wv", (wc_idx, wc_name, wo_idx, wo_name))) step_hyp.dec_seq_idx += 1 new_hypotheses.append(step_hyp) elif vet[0] == "wv": wc_idx, wc_name, wo_idx, wo_name = vet[1] x_emb_var, x_len = self.cond_embed_layer.gen_x_batch( q, col, dropout_rate=dropout_rate) col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch( col, dropout_rate=dropout_rate) given_idxes, avoid_idxes_list = None, None if len(hyp.dec_prefix): partial_vet, given_idxes = hyp.dec_prefix.pop() assert partial_vet == vet elif avoid_items is not None and hyp.dec_seq_idx in avoid_items: avoid_idxes_list = list( avoid_items[hyp.dec_seq_idx]) str_idxes_prob_pairs = self.cond_pred.val_beam_search( x_emb_var, x_len, col_inp_var, col_name_len, col_len, [[wc_idx]], beam_size, avoid_idxes_list=avoid_idxes_list, given_idxes=given_idxes, dropout_rate=dropout_rate) all_toks = ['<BEG>'] + q[0] + ['<END>'] for str_idxes, logprob in str_idxes_prob_pairs: if len(str_idxes_prob_pairs) == 1: step_hyp = hyp else: step_hyp = hyp.copy() # get val_str cur_cond_str_toks = [] for wd_idx in str_idxes[1:]: str_val = all_toks[wd_idx] if str_val == '<END>': break cur_cond_str_toks.append(str_val) val_str = SQLNet.merge_tokens( cur_cond_str_toks, raw_q[0]) step_hyp.sql_i['conds'].append( [wc_idx, wo_idx, val_str]) step_hyp.dec_seq.append((vet, str_idxes)) step_hyp.tag_seq.append( (WHERE_VAL, ((table_name, wc_name, wc_idx), ), (wo_name, wo_idx), (str_idxes, val_str), np.exp(logprob), hyp.dec_seq_idx)) step_hyp.add_logprob(logprob) step_hyp.dec_seq_idx += 1 new_hypotheses.append(step_hyp) if len(new_hypotheses) == 0: # sort completed hypotheses sorted_completed_hypotheses = Hypothesis.sort_hypotheses( completed_hypotheses, beam_size, 0.0) return sorted_completed_hypotheses # if bool_verbal: # print("Before sorting...") # Hypothesis.print_hypotheses(new_hypotheses) hypotheses = Hypothesis.sort_hypotheses(new_hypotheses, beam_size, 0.0) if bool_verbal: print("\nAfter sorting...") Hypothesis.print_hypotheses(hypotheses) if stop_step is not None: # for one-step beam search; the partial_seq lengths must be the same for all hyps dec_seq_length = len(hypotheses[0].dec_seq) if dec_seq_length == stop_step + 1: for hyp in hypotheses: assert len(hyp.dec_seq) == dec_seq_length return hypotheses def loss(self, score, truth_num, pred_entry, gt_where): pred_agg, pred_sel, pred_cond = pred_entry agg_score, sel_score, cond_score = score loss = 0 loss_agg, loss_sel, loss_cond = 0., 0., 0. if pred_agg: agg_truth = map(lambda x: x[0], truth_num) data = torch.from_numpy(np.array(agg_truth)) if self.gpu: agg_truth_var = Variable(data.cuda()) else: agg_truth_var = Variable(data) loss_agg = self.CE(agg_score, agg_truth_var) loss += loss_agg if pred_sel: sel_truth = map(lambda x: x[1], truth_num) data = torch.from_numpy(np.array(sel_truth)) if self.gpu: sel_truth_var = Variable(data.cuda()) else: sel_truth_var = Variable(data) loss_sel = self.CE(sel_score, sel_truth_var) loss += loss_sel if pred_cond: B = len(truth_num) cond_num_score, cond_col_score,\ cond_op_score, cond_str_score = cond_score #Evaluate the number of conditions cond_num_truth = map(lambda x: x[2], truth_num) data = torch.from_numpy(np.array(cond_num_truth)) if self.gpu: cond_num_truth_var = Variable(data.cuda()) else: cond_num_truth_var = Variable(data) cond_num_loss = self.CE(cond_num_score, cond_num_truth_var) loss_cond += cond_num_loss loss += cond_num_loss #Evaluate the columns of conditions T = len(cond_col_score[0]) truth_prob = np.zeros((B, T), dtype=np.float32) for b in range(B): if len(truth_num[b][3]) > 0: truth_prob[b][list(truth_num[b][3])] = 1 data = torch.from_numpy(truth_prob) if self.gpu: cond_col_truth_var = Variable(data.cuda()) else: cond_col_truth_var = Variable(data) sigm = nn.Sigmoid() cond_col_prob = sigm(cond_col_score) bce_loss = -torch.mean( 3*(cond_col_truth_var * \ torch.log(cond_col_prob+1e-10)) + \ (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) ) loss_cond += bce_loss loss += bce_loss #Evaluate the operator of conditions for b in range(len(truth_num)): if len(truth_num[b][4]) == 0: continue data = torch.from_numpy(np.array(truth_num[b][4])) if self.gpu: cond_op_truth_var = Variable(data.cuda()) else: cond_op_truth_var = Variable(data) cond_op_pred = cond_op_score[b, :len(truth_num[b][4])] cond_op_loss_b = (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num)) loss_cond += cond_op_loss_b loss += cond_op_loss_b #Evaluate the strings of conditions for b in range(len(gt_where)): for idx in range(len(gt_where[b])): cond_str_truth = gt_where[b][idx] if len(cond_str_truth) == 1: continue data = torch.from_numpy(np.array(cond_str_truth[1:])) if self.gpu: cond_str_truth_var = Variable(data.cuda()) else: cond_str_truth_var = Variable(data) str_end = len(cond_str_truth) - 1 cond_str_pred = cond_str_score[b, idx, :str_end] cond_str_loss_b = ( self.CE(cond_str_pred, cond_str_truth_var) / (len(gt_where) * len(gt_where[b]))) loss_cond += cond_str_loss_b loss += cond_str_loss_b if self.temperature: return [loss, loss_sel, loss_agg, loss_cond] return [loss] def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry): def pretty_print(vis_data): print 'question:', vis_data[0] print 'headers: (%s)' % (' || '.join(vis_data[1])) print 'query:', vis_data[2] def gen_cond_str(conds, header): if len(conds) == 0: return 'None' cond_str = [] for cond in conds: cond_str.append(header[cond[0]] + ' ' + self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower()) return 'WHERE ' + ' AND '.join(cond_str) pred_agg, pred_sel, pred_cond = pred_entry B = len(gt_queries) tot_err = agg_err = sel_err = cond_err = 0.0 cond_num_err = cond_col_err = cond_op_err = cond_val_err = 0.0 agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)): good = True if pred_agg: agg_pred = pred_qry['agg'] agg_gt = gt_qry['agg'] if agg_pred != agg_gt: agg_err += 1 good = False if pred_sel: sel_pred = pred_qry['sel'] sel_gt = gt_qry['sel'] if sel_pred != sel_gt: sel_err += 1 good = False if pred_cond: cond_pred = pred_qry['conds'] cond_gt = gt_qry['conds'] flag = True if len(cond_pred) != len(cond_gt): flag = False cond_num_err += 1 if flag and set(x[0] for x in cond_pred) != \ set(x[0] for x in cond_gt): flag = False cond_col_err += 1 for idx in range(len(cond_pred)): if not flag: break gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0]) if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]: flag = False cond_op_err += 1 for idx in range(len(cond_pred)): if not flag: break gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0]) if flag and unicode(cond_gt[gt_idx][2]).lower() != \ unicode(cond_pred[idx][2]).lower(): flag = False cond_val_err += 1 if not flag: cond_err += 1 good = False if not good: tot_err += 1 return np.array((agg_err, sel_err, cond_err)), tot_err @staticmethod def merge_tokens(tok_list, raw_tok_str): tok_str = raw_tok_str.lower() alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$(' special = { '-LRB-': '(', '-RRB-': ')', '-LSB-': '[', '-RSB-': ']', '``': '"', '\'\'': '"', '--': u'\u2013' } ret = '' double_quote_appear = 0 for raw_tok in tok_list: if not raw_tok: continue tok = special.get(raw_tok, raw_tok) if tok == '"': double_quote_appear = 1 - double_quote_appear if len(ret) == 0: pass elif len(ret) > 0 and ret + ' ' + tok in tok_str: ret = ret + ' ' elif len(ret) > 0 and ret + tok in tok_str: pass elif tok == '"': if double_quote_appear: ret = ret + ' ' elif tok[0] not in alphabet: pass elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \ and (ret[-1] != '"' or not double_quote_appear): ret = ret + ' ' ret = ret + tok return ret.strip() def gen_query(self, score, q, col, raw_q, raw_col, pred_entry, reinforce=False, verbose=False): pred_agg, pred_sel, pred_cond = pred_entry agg_score, sel_score, cond_score = score ret_queries = [] if pred_agg: B = len(agg_score) elif pred_sel: B = len(sel_score) elif pred_cond: B = len(cond_score[0]) for b in range(B): cur_query = {} if pred_agg: cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy()) if pred_sel: cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy()) if pred_cond: cur_query['conds'] = [] cond_num_score,cond_col_score,cond_op_score,cond_str_score =\ [x.data.cpu().numpy() for x in cond_score] cond_num = np.argmax(cond_num_score[b]) all_toks = ['<BEG>'] + q[b] + ['<END>'] max_idxes = np.argsort(-cond_col_score[b])[:cond_num] for idx in range(cond_num): cur_cond = [] cur_cond.append(max_idxes[idx]) cur_cond.append(np.argmax(cond_op_score[b][idx])) cur_cond_str_toks = [] for str_score in cond_str_score[b][idx]: str_tok = np.argmax(str_score[:len(all_toks)]) str_val = all_toks[str_tok] if str_val == '<END>': break cur_cond_str_toks.append(str_val) cur_cond.append( SQLNet.merge_tokens(cur_cond_str_toks, raw_q[b])) cur_query['conds'].append(cur_cond) ret_queries.append(cur_query) return ret_queries