コード例 #1
0
ファイル: sqlnet.py プロジェクト: shanelleroman/seq2sql
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 use_ca=True,
                 trainable_emb=False):
        super(SQLNet, self).__init__()
        self.use_ca = use_ca
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]
        self.COND_OPS = ['EQL', 'GT', 'LT']

        #Word embedding
        if trainable_emb:
            self.agg_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=True,
                                                 trainable=trainable_emb)
            self.sel_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=True,
                                                 trainable=trainable_emb)
            self.cond_embed_layer = WordEmbedding(word_emb,
                                                  N_word,
                                                  gpu,
                                                  self.SQL_TOK,
                                                  our_model=True,
                                                  trainable=trainable_emb)
        else:
            self.embed_layer = WordEmbedding(word_emb,
                                             N_word,
                                             gpu,
                                             self.SQL_TOK,
                                             our_model=True,
                                             trainable=trainable_emb)

        #Predict aggregator
        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        #Predict selected column
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=use_ca)

        #Predict number of cond
        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth,
                                             self.max_col_num,
                                             self.max_tok_num, use_ca, gpu)

        self.CE = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax()
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    def generate_gt_where_seq(self, q, col, query):
        ret_seq = []
        for cur_q, cur_col, cur_query in zip(q, col, query):
            cur_values = []
            st = cur_query.index(u'where')+1 if \
                    u'where' in cur_query else len(cur_query)
            all_toks = ['<BEG>'] + cur_q + ['<END>']
            while st < len(cur_query):
                ed = len(cur_query) if 'and' not in cur_query[st:]\
                        else cur_query[st:].index('and') + st
                if '=' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('=') + st
                elif '>' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('>') + st
                elif '<' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('<') + st
                else:
                    raise RuntimeError("No operator in it!")
                this_str = ['<BEG>'] + cur_query[op + 1:ed] + ['<END>']
                cur_seq = [all_toks.index(s) if s in all_toks \
                        else 0 for s in this_str]
                cur_values.append(cur_seq)
                st = ed + 1
            ret_seq.append(cur_values)
        # print 'ret_seq', ret_seq
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                pred_entry,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None):
        B = len(q)
        pred_agg, pred_sel, pred_cond = pred_entry

        agg_score = None
        sel_score = None
        cond_score = None

        #Predict aggregator
        if self.trainable_emb:
            if pred_agg:
                x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.agg_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                agg_score = self.agg_pred(x_emb_var,
                                          x_len,
                                          col_inp_var,
                                          col_name_len,
                                          col_len,
                                          col_num,
                                          gt_sel=gt_sel)

            if pred_sel:
                x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.sel_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.cond_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = \
                    self.embed_layer.gen_col_batch(col)
            max_x_len = max(x_len)
            if pred_agg:
                agg_score = self.agg_pred(x_emb_var,
                                          x_len,
                                          col_inp_var,
                                          col_name_len,
                                          col_len,
                                          col_num,
                                          gt_sel=gt_sel)

            if pred_sel:
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)

        return (agg_score, sel_score, cond_score)

    def loss(self, score, truth_num, pred_entry, gt_where):
        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score

        loss = 0
        if pred_agg:
            agg_truth = map(lambda x: x[0], truth_num)
            data = torch.from_numpy(np.array(agg_truth))
            if self.gpu:
                agg_truth_var = Variable(data.cuda())
            else:
                agg_truth_var = Variable(data)

            loss += self.CE(agg_score, agg_truth_var)

        if pred_sel:
            sel_truth = map(lambda x: x[1], truth_num)
            data = torch.from_numpy(np.array(sel_truth))
            if self.gpu:
                sel_truth_var = Variable(data.cuda())
            else:
                sel_truth_var = Variable(data)

            loss += self.CE(sel_score, sel_truth_var)

        if pred_cond:
            B = len(truth_num)
            cond_num_score, cond_col_score,\
                    cond_op_score, cond_str_score = cond_score
            #Evaluate the number of conditions
            cond_num_truth = map(lambda x: x[2], truth_num)
            data = torch.from_numpy(np.array(cond_num_truth))
            if self.gpu:
                cond_num_truth_var = Variable(data.cuda())
            else:
                cond_num_truth_var = Variable(data)
            loss += self.CE(cond_num_score, cond_num_truth_var)

            #Evaluate the columns of conditions
            T = len(cond_col_score[0])
            truth_prob = np.zeros((B, T), dtype=np.float32)
            for b in range(B):
                if len(truth_num[b][3]) > 0:
                    truth_prob[b][list(truth_num[b][3])] = 1
            data = torch.from_numpy(truth_prob)
            if self.gpu:
                cond_col_truth_var = Variable(data.cuda())
            else:
                cond_col_truth_var = Variable(data)

            sigm = nn.Sigmoid()
            cond_col_prob = sigm(cond_col_score)
            bce_loss = -torch.mean( 3*(cond_col_truth_var * \
                    torch.log(cond_col_prob+1e-10)) + \
                    (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
            loss += bce_loss

            #Evaluate the operator of conditions
            for b in range(len(truth_num)):
                if len(truth_num[b][4]) == 0:
                    continue
                data = torch.from_numpy(np.array(truth_num[b][4]))
                if self.gpu:
                    cond_op_truth_var = Variable(data.cuda())
                else:
                    cond_op_truth_var = Variable(data)
                cond_op_pred = cond_op_score[b, :len(truth_num[b][4])]
                loss += (self.CE(cond_op_pred, cond_op_truth_var) \
                        / len(truth_num))

            #Evaluate the strings of conditions
            for b in range(len(gt_where)):
                for idx in range(len(gt_where[b])):
                    cond_str_truth = gt_where[b][idx]
                    if len(cond_str_truth) == 1:
                        continue
                    data = torch.from_numpy(np.array(cond_str_truth[1:]))
                    if self.gpu:
                        cond_str_truth_var = Variable(data.cuda())
                    else:
                        cond_str_truth_var = Variable(data)
                    str_end = len(cond_str_truth) - 1
                    cond_str_pred = cond_str_score[b, idx, :str_end]
                    loss += (self.CE(cond_str_pred, cond_str_truth_var) \
                            / (len(gt_where) * len(gt_where[b])))

        return loss

    def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry):
        def pretty_print(vis_data):
            print 'question:', vis_data[0]
            print 'headers: (%s)' % (' || '.join(vis_data[1]))
            print 'query:', vis_data[2]

        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        pred_agg, pred_sel, pred_cond = pred_entry

        B = len(gt_queries)

        tot_err = agg_err = sel_err = cond_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = 0.0
        agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            if pred_agg:
                agg_pred = pred_qry['agg']
                agg_gt = gt_qry['agg']
                if agg_pred != agg_gt:
                    agg_err += 1
                    good = False

            if pred_sel:
                sel_pred = pred_qry['sel']
                sel_gt = gt_qry['sel']
                if sel_pred != sel_gt:
                    sel_err += 1
                    good = False

            if pred_cond:
                cond_pred = pred_qry['conds']
                cond_gt = gt_qry['conds']
                flag = True
                if len(cond_pred) != len(cond_gt):
                    flag = False
                    cond_num_err += 1

                if flag and set(x[0] for x in cond_pred) != \
                        set(x[0] for x in cond_gt):
                    flag = False
                    cond_col_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]:
                        flag = False
                        cond_op_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and unicode(cond_gt[gt_idx][2]).lower() != \
                            unicode(cond_pred[idx][2]).lower():
                        flag = False
                        cond_val_err += 1

                if not flag:
                    cond_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array((agg_err, sel_err, cond_err)), tot_err

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  pred_entry,
                  reinforce=False,
                  verbose=False):
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str.lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear

                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                elif tok[0] not in alphabet:
                    pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                        and (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score

        ret_queries = []
        if pred_agg:
            B = len(agg_score)
        elif pred_sel:
            B = len(sel_score)
        elif pred_cond:
            B = len(cond_score[0])
        for b in range(B):
            cur_query = {}
            if pred_agg:
                cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy())
            if pred_sel:
                cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy())
            if pred_cond:
                cur_query['conds'] = []
                cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
                        [x.data.cpu().numpy() for x in cond_score]
                cond_num = np.argmax(cond_num_score[b])
                # print cond_num
                all_toks = ['<BEG>'] + q[b] + ['<END>']
                max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
                for idx in range(cond_num):
                    cur_cond = []
                    cur_cond.append(max_idxes[idx])
                    cur_cond.append(np.argmax(cond_op_score[b][idx]))
                    cur_cond_str_toks = []
                    for str_score in cond_str_score[b][idx]:
                        str_tok = np.argmax(str_score[:len(all_toks)])
                        str_val = all_toks[str_tok]
                        if str_val == '<END>':
                            break
                        cur_cond_str_toks.append(str_val)
                    cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
                    cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)

        return ret_queries
コード例 #2
0
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 use_ca=True,
                 trainable_emb=False):
        super(SQLNet, self).__init__()
        self.use_ca = use_ca
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=',
            '<BEG>'
        ]
        self.COND_OPS = ['>', '<', '==', '!=']

        # Word embedding
        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         our_model=True,
                                         trainable=trainable_emb)

        # Predict the number of selected columns
        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        #Predict which columns are selected
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=use_ca)

        #Predict aggregation functions of corresponding selected columns
        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        #Predict number of conditions, condition columns, condition operations and condition values
        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth,
                                             self.max_col_num,
                                             self.max_tok_num, use_ca, gpu)

        # Predict condition relationship, like 'and', 'or'
        self.where_rela_pred = WhereRelationPredictor(N_word,
                                                      N_h,
                                                      N_depth,
                                                      use_ca=use_ca)

        self.CE = nn.CrossEntropyLoss()
        #self.softmax = nn.Softmax(dim=-1)
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    def generate_gt_where_seq_test(self, q, gt_cond_seq):
        ret_seq = []
        for cur_q, ans in zip(q, gt_cond_seq):
            temp_q = u"".join(cur_q)
            cur_q = [u'<BEG>'] + cur_q + [u'<END>']
            record = []
            record_cond = []
            for cond in ans:
                if cond[2] not in temp_q:
                    record.append((False, cond[2]))
                else:
                    record.append((True, cond[2]))
            for idx, item in enumerate(record):
                temp_ret_seq = []
                if item[0]:
                    temp_ret_seq.append(0)
                    temp_ret_seq.extend(
                        list(
                            range(
                                temp_q.index(item[1]) + 1,
                                temp_q.index(item[1]) + len(item[1]) + 1)))
                    temp_ret_seq.append(len(cur_q) - 1)
                else:
                    temp_ret_seq.append([0, len(cur_q) - 1])
                record_cond.append(temp_ret_seq)
            ret_seq.append(record_cond)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None,
                gt_sel_num=None):
        B = len(q)

        sel_num_score = None
        agg_score = None
        sel_score = None
        cond_score = None
        #Predict aggregator
        if self.trainable_emb:
            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=gt_sel)

            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)

            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)
            where_rela_score = None
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)
            sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var,
                                         col_name_len, col_len, col_num)
            # x_emb_var: embedding of each question
            # x_len: length of each question
            # col_inp_var: embedding of each header
            # col_name_len: length of each header
            # col_len: number of headers in each table, array type
            # col_num: number of headers in each table, list type
            if gt_sel_num:
                pr_sel_num = gt_sel_num
            else:
                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(),
                                       axis=1)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)

            if gt_sel:
                pr_sel = gt_sel
            else:
                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
                sel = sel_score.data.cpu().numpy()
                pr_sel = [
                    list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))
                ]
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=pr_sel,
                                      gt_sel_num=pr_sel_num)

            where_rela_score = self.where_rela_pred(x_emb_var, x_len,
                                                    col_inp_var, col_name_len,
                                                    col_len, col_num)

            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)

        return (sel_num_score, sel_score, agg_score, cond_score,
                where_rela_score)

    def loss(self, score, truth_num, gt_where):
        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score

        B = len(truth_num)
        loss = 0

        # Evaluate select number
        sel_num_truth = list(map(lambda x: x[0], truth_num))
        #sel_num_truth=[torch.LongTensor(x[0]) for x in truth_num]
        sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
        sel_num_truth = sel_num_truth.long()
        if self.gpu:
            sel_num_truth = Variable(sel_num_truth.cuda())
        else:
            sel_num_truth = Variable(sel_num_truth)
        loss += self.CE(sel_num_score, sel_num_truth)  #self.CE(input,target)

        # Evaluate select column
        T = len(sel_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            truth_prob[b][list(truth_num[b][1])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            sel_col_truth_var = Variable(data.cuda())
        else:
            sel_col_truth_var = Variable(data)
        sigm = nn.Sigmoid()
        sel_col_prob = sigm(sel_score)
        bce_loss = -torch.mean(
            3 * (sel_col_truth_var * torch.log(sel_col_prob + 1e-10)) +
            (1 - sel_col_truth_var) * torch.log(1 - sel_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate select aggregation
        for b in range(len(truth_num)):
            data = torch.from_numpy(np.array(truth_num[b][2]))
            if self.gpu:
                sel_agg_truth_var = Variable(data.cuda())
                sel_agg_truth_var = sel_agg_truth_var.long()
            else:
                sel_agg_truth_var = Variable(data)
                sel_agg_truth_var = sel_agg_truth_var.long()
            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)

        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score

        # Evaluate the number of conditions
        cond_num_truth = list(map(lambda x: x[3], truth_num))
        data = torch.from_numpy(np.array(cond_num_truth, dtype=np.float))
        if self.gpu:
            try:
                cond_num_truth_var = Variable(data.cuda())
                cond_num_truth_var = cond_num_truth_var.long()
            except:
                print("cond_num_truth_var error")
                print(data)
                exit(0)
        else:
            cond_num_truth_var = Variable(data)
            cond_num_truth_var = cond_num_truth_var.long()
        loss += self.CE(cond_num_score, cond_num_truth_var)

        # Evaluate the columns of conditions
        T = len(cond_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][4]) > 0:
                truth_prob[b][list(truth_num[b][4])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            cond_col_truth_var = Variable(data.cuda())
        else:
            cond_col_truth_var = Variable(data)

        sigm = nn.Sigmoid()
        cond_col_prob = sigm(cond_col_score)
        bce_loss = -torch.mean(
            3 * (cond_col_truth_var * torch.log(cond_col_prob + 1e-10)) +
            (1 - cond_col_truth_var) * torch.log(1 - cond_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate the operator of conditions
        for b in range(len(truth_num)):
            if len(truth_num[b][5]) == 0:
                continue
            data = torch.from_numpy(np.array(truth_num[b][5]))
            if self.gpu:
                cond_op_truth_var = Variable(data.cuda())
            else:
                cond_op_truth_var = Variable(data)
            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
            try:
                loss += (self.CE(cond_op_pred, cond_op_truth_var) /
                         len(truth_num))
            except:
                print(cond_op_pred)
                print(cond_op_truth_var)
                exit(0)

        #Evaluate the strings of conditions
        for b in range(len(gt_where)):
            for idx in range(len(gt_where[b])):
                cond_str_truth = gt_where[b][idx]
                if len(cond_str_truth) == 1:
                    continue
                data = torch.from_numpy(np.array(cond_str_truth[1:]))
                if self.gpu:
                    cond_str_truth_var = Variable(data.cuda())
                else:
                    cond_str_truth_var = Variable(data)
                str_end = len(cond_str_truth) - 1
                cond_str_pred = cond_str_score[b, idx, :str_end]
                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
                        / (len(gt_where) * len(gt_where[b])))

        # Evaluate condition relationship, and / or
        where_rela_truth = map(lambda x: x[6], truth_num)
        data = torch.from_numpy(np.array(where_rela_truth))
        if self.gpu:
            try:
                where_rela_truth = Variable(data.cuda())
            except:
                print("where_rela_truth error")
                print(data)
                exit(0)
        else:
            where_rela_truth = Variable(data)
        loss += self.CE(where_rela_score, where_rela_truth)
        return loss

    def check_acc(self, vis_info, pred_queries, gt_queries):
        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                np.unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        tot_err = sel_num_err = agg_err = sel_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry[
                'agg'], pred_qry['cond_conn_op']
            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry[
                'agg'], gt_qry['cond_conn_op']

            if where_rela_gt != where_rela_pred:
                good = False
                cond_rela_err += 1

            if len(sel_pred) != len(sel_gt):
                good = False
                sel_num_err += 1

            pred_sel_dict = {
                k: v
                for k, v in zip(list(sel_pred), list(agg_pred))
            }
            gt_sel_dict = {k: v for k, v in zip(sel_gt, agg_gt)}
            if set(sel_pred) != set(sel_gt):
                good = False
                sel_err += 1
            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
            if agg_pred != agg_gt:
                good = False
                agg_err += 1

            cond_pred = pred_qry['conds']
            cond_gt = gt_qry['conds']
            if len(cond_pred) != len(cond_gt):
                good = False
                cond_num_err += 1
            else:
                cond_op_pred, cond_op_gt = {}, {}
                cond_val_pred, cond_val_gt = {}, {}
                for p, g in zip(cond_pred, cond_gt):
                    cond_op_pred[p[0]] = p[1]
                    cond_val_pred[p[0]] = p[2]
                    cond_op_gt[g[0]] = g[1]
                    cond_val_gt[g[0]] = g[2]

                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
                    cond_col_err += 1
                    good = False

                where_op_pred = [
                    cond_op_pred[x] for x in sorted(cond_op_pred.keys())
                ]
                where_op_gt = [
                    cond_op_gt[x] for x in sorted(cond_op_gt.keys())
                ]
                if where_op_pred != where_op_gt:
                    cond_op_err += 1
                    good = False

                where_val_pred = [
                    cond_val_pred[x] for x in sorted(cond_val_pred.keys())
                ]
                where_val_gt = [
                    cond_val_gt[x] for x in sorted(cond_val_gt.keys())
                ]
                if where_val_pred != where_val_gt:
                    cond_val_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array(
            (sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err,
             cond_op_err, cond_val_err, cond_rela_err)), tot_err

    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
        """
        :param score:
        :param q: token-questions
        :param col: token-headers
        :param raw_q: original question sequence
        :return:
        """
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str  # .lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear
                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                # elif tok[0] not in alphabet:
                #     pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                        and (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
        # [64,4,6], [64,14], ..., [64,4]
        sel_num_score = sel_num_score.data.cpu().numpy()
        sel_score = sel_score.data.cpu().numpy()
        agg_score = agg_score.data.cpu().numpy()
        where_rela_score = where_rela_score.data.cpu().numpy()
        ret_queries = []
        B = len(agg_score)
        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
            [x.data.cpu().numpy() for x in cond_score]
        for b in range(B):
            cur_query = {}
            cur_query['sel'] = []
            cur_query['agg'] = []
            sel_num = np.argmax(sel_num_score[b])
            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
            # find the most-probable columns' indexes
            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
            cur_query['sel'].extend([int(i) for i in max_col_idxes])
            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            all_toks = ['<BEG>'] + q[b] + ['<END>']
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])  # where-col
                cur_cond.append(np.argmax(cond_op_score[b][idx]))  # where-op
                cur_cond_str_toks = []
                for str_score in cond_str_score[b][idx]:
                    str_tok = np.argmax(str_score[:len(all_toks)])
                    str_val = all_toks[str_tok]
                    if str_val == '<END>':
                        break
                    cur_cond_str_toks.append(str_val)
                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)
        return ret_queries
コード例 #3
0
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=120,
                 N_depth=2,
                 gpu=False,
                 trainable_emb=False):
        super(SQLNet, self).__init__()
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]
        self.COND_OPS = ['EQL', 'GT', 'LT']

        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         trainable=trainable_emb)

        #Predict select clause
        self.sel_pred = SelPredictor(N_word, N_h, N_depth, gpu)
        #Predict where condition
        self.cond_pred = CondPredictor(N_word, N_h, N_depth, gpu)
        #Predict group by
        self.group_pred = GroupPredictor(N_word, N_h, N_depth, gpu)
        #Predict order by
        self.order_pred = OrderPredictor(N_word, N_h, N_depth, gpu)

        self.CE = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax()
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        self.sigm = nn.Sigmoid()
        if gpu:
            self.cuda()

    def forward(self,
                q,
                col,
                col_num,
                pred_entry,
                gt_where=None,
                gt_cond=None,
                gt_sel=None):
        B = len(q)
        pred_agg, pred_sel, pred_cond = pred_entry

        sel_score = None
        cond_score = None
        group_score = None
        order_score = None

        x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col, is_q=True)
        col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
            col)
        max_x_len = max(x_len)

        sel_score = self.sel_pred(x_emb_var,
                                  x_len,
                                  col_inp_var,
                                  col_len,
                                  col_num,
                                  col_name_len,
                                  gt_sel=gt_sel)
        cond_score = self.cond_pred(x_emb_var,
                                    x_len,
                                    col_inp_var,
                                    col_len,
                                    col_num,
                                    col_name_len,
                                    gt_cond=gt_cond)
        group_score = self.group_pred(x_emb_var, x_len, col_inp_var, col_len,
                                      col_num, col_name_len)
        order_score = self.order_pred(x_emb_var, x_len, col_inp_var, col_len,
                                      col_num, col_name_len)

        return (sel_score, cond_score, group_score, order_score)

    def loss(self, score, truth_num, pred_entry):
        pred_agg, pred_sel, pred_cond = pred_entry

        sel_score, cond_score, group_score, order_score = score

        sel_num_score, sel_col_score, agg_num_score, agg_op_score = sel_score
        cond_num_score, cond_col_score, cond_op_score = cond_score
        gby_num_score, gby_score, hv_score, hv_col_score, hv_agg_score, hv_op_score = group_score
        ody_num_score, ody_col_score, ody_agg_score, ody_par_score = order_score

        B = len(truth_num)
        loss = 0

        #----------loss for sel_pred -------------#

        # loss for sel agg # and sel agg
        for b in range(len(truth_num)):
            curr_col = truth_num[b][1][0]
            curr_col_num_aggs = 0
            gt_aggs_num = []
            for i, col in enumerate(truth_num[b][1]):
                if col != curr_col:
                    gt_aggs_num.append(curr_col_num_aggs)
                    curr_col = col
                    curr_col_num_aggs = 0
                if truth_num[b][0][i] != 0:
                    curr_col_num_aggs += 1
            gt_aggs_num.append(curr_col_num_aggs)
            # print gt_aggs_num
            data = torch.from_numpy(
                np.array(gt_aggs_num))  #supposed to be gt # of aggs
            if self.gpu:
                agg_num_truth_var = Variable(data.cuda())
            else:
                agg_num_truth_var = Variable(data)
            agg_num_pred = agg_num_score[
                b, :truth_num[b][5]]  # supposed to be gt # of select columns
            loss += (self.CE(agg_num_pred, agg_num_truth_var) \
                    / len(truth_num))
            # loss for sel agg prediction
            T = 6  #num agg ops
            truth_prob = np.zeros((truth_num[b][5], T), dtype=np.float32)
            gt_agg_by_sel = []
            curr_sel_aggs = []
            curr_col = truth_num[b][1][0]
            col_counter = 0
            for i, col in enumerate(truth_num[b][1]):
                if col != curr_col:
                    gt_agg_by_sel.append(curr_sel_aggs)
                    curr_col = col
                    col_counter += 1
                    curr_sel_aggs = [truth_num[b][0][i]]
                    truth_prob[col_counter][curr_sel_aggs] = 1
                else:
                    curr_sel_aggs.append(truth_num[b][0][i])
                    truth_prob[col_counter][curr_sel_aggs] = 1
            data = torch.from_numpy(truth_prob)
            if self.gpu:
                agg_op_truth_var = Variable(data.cuda())
            else:
                agg_op_truth_var = Variable(data)
            agg_op_prob = self.sigm(agg_op_score[b, :truth_num[b][5]])
            agg_bce_loss = -torch.mean( 3*(agg_op_truth_var * \
                    torch.log(agg_op_prob+1e-10)) + \
                    (1-agg_op_truth_var) * torch.log(1-agg_op_prob+1e-10) )
            loss += agg_bce_loss / len(truth_num)

        #Evaluate the number of select columns
        sel_num_truth = map(
            lambda x: x[5] - 1,
            truth_num)  #might need to be the length of the set of columms
        data = torch.from_numpy(np.array(sel_num_truth))
        if self.gpu:
            sel_num_truth_var = Variable(data.cuda())
        else:
            sel_num_truth_var = Variable(data)
        loss += self.CE(sel_num_score, sel_num_truth_var)
        # Evaluate the select columns
        T = len(sel_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            truth_prob[b][truth_num[b][1]] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            sel_col_truth_var = Variable(data.cuda())
        else:
            sel_col_truth_var = Variable(data)
        sel_col_prob = self.sigm(sel_col_score)
        sel_bce_loss = -torch.mean( 3*(sel_col_truth_var * \
                torch.log(sel_col_prob+1e-10)) + \
                (1-sel_col_truth_var) * torch.log(1-sel_col_prob+1e-10) )
        loss += sel_bce_loss
        #----------------loss for cond_pred--------------------#
        #cond_num_score, cond_col_score, cond_op_score = cond_score

        #Evaluate the number of conditions
        cond_num_truth = map(lambda x: x[2], truth_num)
        data = torch.from_numpy(np.array(cond_num_truth))
        if self.gpu:
            cond_num_truth_var = Variable(data.cuda())
        else:
            cond_num_truth_var = Variable(data)
        loss += self.CE(cond_num_score, cond_num_truth_var)
        #Evaluate the columns of conditions
        T = len(cond_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][3]) > 0:
                truth_prob[b][list(truth_num[b][3])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            cond_col_truth_var = Variable(data.cuda())
        else:
            cond_col_truth_var = Variable(data)

        cond_col_prob = self.sigm(cond_col_score)
        bce_loss = -torch.mean( 3*(cond_col_truth_var * \
                torch.log(cond_col_prob+1e-10)) + \
                (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
        loss += bce_loss
        #Evaluate the operator of conditions
        for b in range(len(truth_num)):
            if len(truth_num[b][4]) == 0:
                continue
            data = torch.from_numpy(np.array(truth_num[b][4]))
            if self.gpu:
                cond_op_truth_var = Variable(data.cuda())
            else:
                cond_op_truth_var = Variable(data)
            cond_op_pred = cond_op_score[b, :len(truth_num[b][4])]
            # print 'cond_op_truth_var', list(cond_op_truth_var.size())
            # print 'cond_op_pred', list(cond_op_pred.size())
            loss += (self.CE(cond_op_pred, cond_op_truth_var) \
                    / len(truth_num))
        # -----------loss for group_pred -------------- #
        #gby_num_score, gby_score, hv_score, hv_col_score, hv_agg_score, hv_op_score = group_score
        # Evaluate the number of group by columns
        gby_num_truth = map(lambda x: x[7], truth_num)
        data = torch.from_numpy(np.array(gby_num_truth))
        if self.gpu:
            gby_num_truth_var = Variable(data.cuda())
        else:
            gby_num_truth_var = Variable(data)
        loss += self.CE(gby_num_score, gby_num_truth_var)
        # Evaluate the group by columns
        T = len(gby_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][6]) > 0:
                truth_prob[b][list(truth_num[b][6])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            gby_col_truth_var = Variable(data.cuda())
        else:
            gby_col_truth_var = Variable(data)
        gby_col_prob = self.sigm(gby_score)
        gby_bce_loss = -torch.mean( 3*(gby_col_truth_var * \
                torch.log(gby_col_prob+1e-10)) + \
                (1-gby_col_truth_var) * torch.log(1-gby_col_prob+1e-10) )
        loss += gby_bce_loss
        # Evaluate having
        having_truth = [1 if len(x[13]) == 1 else 0 for x in truth_num]
        data = torch.from_numpy(np.array(having_truth))
        if self.gpu:
            having_truth_var = Variable(data.cuda())
        else:
            having_truth_var = Variable(data)
        loss += self.CE(hv_score, having_truth_var)
        # Evaluate having col
        T = len(hv_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][13]) > 0:
                truth_prob[b][truth_num[b][13]] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            hv_col_truth_var = Variable(data.cuda())
        else:
            hv_col_truth_var = Variable(data)
        hv_col_prob = self.sigm(hv_col_score)
        hv_col_bce_loss = -torch.mean( 3*(hv_col_truth_var * \
                torch.log(hv_col_prob+1e-10)) + \
                (1-hv_col_truth_var) * torch.log(1-hv_col_prob+1e-10) )
        loss += hv_col_bce_loss
        # Evaluate having agg
        T = len(hv_agg_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][12]) > 0:
                truth_prob[b][truth_num[b][12]] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            hv_agg_truth_var = Variable(data.cuda())
        else:
            hv_agg_truth_var = Variable(data)
        hv_agg_prob = self.sigm(hv_agg_truth_var)
        hv_agg_bce_loss = -torch.mean( 3*(hv_agg_truth_var * \
                torch.log(hv_agg_prob+1e-10)) + \
                (1-hv_agg_truth_var) * torch.log(1-hv_agg_prob+1e-10) )
        loss += hv_agg_bce_loss
        # Evaluate having op
        T = len(hv_op_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][14]) > 0:
                truth_prob[b][truth_num[b][14]] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            hv_op_truth_var = Variable(data.cuda())
        else:
            hv_op_truth_var = Variable(data)
        hv_op_prob = self.sigm(hv_op_truth_var)
        hv_op_bce_loss = -torch.mean( 3*(hv_op_truth_var * \
                torch.log(hv_op_prob+1e-10)) + \
                (1-hv_op_truth_var) * torch.log(1-hv_op_prob+1e-10) )
        loss += hv_op_bce_loss

        # -----------loss for order_pred -------------- #
        #ody_col_score, ody_agg_score, ody_par_score = order_score

        # Evaluate the number of order by columns
        ody_num_truth = map(lambda x: x[10], truth_num)
        data = torch.from_numpy(np.array(ody_num_truth))
        if self.gpu:
            ody_num_truth_var = Variable(data.cuda())
        else:
            ody_num_truth_var = Variable(data)
        loss += self.CE(ody_num_score, ody_num_truth_var)
        # Evaluate the order by columns
        T = len(ody_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][9]) > 0:
                truth_prob[b][list(truth_num[b][9])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            ody_col_truth_var = Variable(data.cuda())
        else:
            ody_col_truth_var = Variable(data)
        ody_col_prob = self.sigm(ody_col_score)
        ody_bce_loss = -torch.mean( 3*(ody_col_truth_var * \
                torch.log(ody_col_prob+1e-10)) + \
                (1-ody_col_truth_var) * torch.log(1-ody_col_prob+1e-10) )
        loss += ody_bce_loss
        # Evaluate order agg assume only one
        T = 6
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][9]) > 0:
                truth_prob[b][list(truth_num[b][8])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            ody_agg_truth_var = Variable(data.cuda())
        else:
            ody_agg_truth_var = Variable(data)
        ody_agg_prob = self.sigm(ody_agg_score)
        ody_agg_bce_loss = -torch.mean( 3*(ody_agg_truth_var * \
                torch.log(ody_agg_prob+1e-10)) + \
                (1-ody_agg_truth_var) * torch.log(1-ody_agg_prob+1e-10) )
        loss += ody_agg_bce_loss
        # Evaluate parity
        ody_par_truth = map(lambda x: x[11], truth_num)
        data = torch.from_numpy(np.array(ody_par_truth))
        if self.gpu:
            ody_par_truth_var = Variable(data.cuda())
        else:
            ody_par_truth_var = Variable(data)
        loss += self.CE(ody_par_score, ody_par_truth_var)
        return loss

    def check_acc(self,
                  vis_info,
                  pred_queries,
                  gt_queries,
                  pred_entry,
                  error_print=False):
        def pretty_print(vis_data, pred_query, gt_query):
            print "\n----------detailed error prints-----------"
            try:
                print 'question: ', vis_data[0]
                print 'question_tok: ', vis_data[3]
                print 'headers: (%s)' % (' || '.join(vis_data[1]))
                print 'query:', vis_data[2]
                print "target query: ", gt_query
                print "pred query: ", pred_query
            except:
                print "\n------skipping print: decoding problem ----------------------"

        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        pred_agg, pred_sel, pred_cond = pred_entry

        B = len(gt_queries)

        tot_err = 0.0
        sel_err = agg_num_err = agg_op_err = sel_num_err = sel_col_err = 0.0
        cond_err = cond_num_err = cond_col_err = cond_op_err = 0.0
        gby_err = gby_num_err = gby_col_err = hv_err = hv_col_err = hv_agg_err = hv_op_err = 0.0
        ody_err = ody_num_err = ody_col_err = ody_agg_err = ody_par_err = 0.0
        agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        for b, (pred_qry, gt_qry,
                vis_data) in enumerate(zip(pred_queries, gt_queries,
                                           vis_info)):

            good = True
            tot_flag = True
            sel_flag = True
            cond_flag = True
            gby_flag = True
            ody_flag = True
            # sel
            sel_gt = gt_qry['sel']
            sel_num_gt = len(set(sel_gt))
            sel_pred = pred_qry['sel']
            sel_num_pred = pred_qry['sel_num']
            if sel_num_pred != sel_num_gt:
                sel_num_err += 1
                sel_flag = False
            if sorted(set(sel_pred)) != sorted(set(sel_gt)):
                sel_col_err += 1
                sel_flag = False

            agg_gt = gt_qry['agg']
            curr_col = gt_qry['sel'][0]
            curr_col_num_aggs = 0
            gt_aggs_num = []
            gt_sel_order = [curr_col]
            for i, col in enumerate(gt_qry['sel']):
                if col != curr_col:
                    gt_sel_order.append(col)
                    gt_aggs_num.append(curr_col_num_aggs)
                    curr_col = col
                    curr_col_num_aggs = 0
                if agg_gt[i] != 0:
                    curr_col_num_aggs += 1
            gt_aggs_num.append(curr_col_num_aggs)

            if pred_qry['agg_num'] != gt_aggs_num:
                agg_num_err += 1
                sel_flag = False

            if sorted(pred_qry['agg']) != sorted(gt_qry['agg']):  # naive
                agg_op_err += 1
                sel_flag = False

            if not sel_flag:
                sel_err += 1
                good = False

            # group
            gby_gt = gt_qry['group'][:-1]
            gby_pred = pred_qry['group']
            gby_num_pred = pred_qry['gby_num']
            gby_num_gt = len(gby_gt)
            if gby_num_pred != gby_num_gt:
                gby_num_err += 1
                gby_flag = False
            if sorted(gby_pred) != sorted(gby_gt):
                gby_col_err += 1
                gby_flag = False
            gt_gby_agg = gt_qry['group'][-1][0]
            gt_gby_col = gt_qry['group'][-1][1]
            gt_gby_op = gt_qry['group'][-1][2]
            if gby_num_pred != 0 and len(gt_gby_col) != 0:
                if pred_qry['hv'] != 1:
                    hv_err += 1
                    gby_flag = False
                if pred_qry['hv_agg'] != gt_gby_agg[0]:
                    hv_agg_err += 1
                    gby_flag = False
                if pred_qry['hv_col'] != gt_gby_col[0]:
                    hv_col_err += 1
                    gby_flag = False
                if pred_qry['hv_op'] != gt_gby_op[0]:
                    hv_op_err += 1
                    gby_flag = False

            if not gby_flag:
                gby_err += 1
                good = False

            # order
            ody_gt_aggs = gt_qry['order'][0]
            ody_gt_cols = gt_qry['order'][1]
            ody_gt_par = gt_qry['order'][2]
            ody_num_cols_pred = pred_qry['ody_num']
            ody_cols_pred = pred_qry['order']
            ody_aggs_pred = pred_qry['ody_agg']
            ody_par_pred = pred_qry['parity']

            if ody_num_cols_pred != len(ody_gt_cols):
                ody_num_err += 1
                ody_flag = False
            if len(ody_gt_cols) > 0:
                if ody_cols_pred != ody_gt_cols:
                    ody_col_err += 1
                    ody_flag = False
                if ody_aggs_pred != ody_gt_aggs:
                    ody_agg_err += 1
                    ody_flag = False
                if ody_par_pred != ody_gt_par:
                    ody_par_err += 1
                    ody_flag = False

            if not ody_flag:
                ody_err += 1
                good = False

            # conds
            cond_pred = pred_qry['conds']
            cond_gt = gt_qry['cond']
            flag = True
            if len(cond_pred) != len(cond_gt):
                flag = False
                cond_num_err += 1
                cond_flag = False
            if flag and set(x[0] for x in cond_pred) != set(x[0]
                                                            for x in cond_gt):
                flag = False
                cond_col_err += 1
                cond_flag = False
            for idx in range(len(cond_pred)):
                if not flag:
                    break
                gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0])
                if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]:
                    flag = False
                    cond_op_err += 1
                    cond_flag = False

            if not cond_flag:
                cond_err += 1
                good = False

            if not good:
                if error_print:
                    pretty_print(vis_data, pred_qry, gt_qry)
                tot_err += 1

        return np.array((sel_err, cond_err, gby_err, ody_err)), tot_err

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  pred_entry,
                  verbose=False):
        pred_agg, pred_sel, pred_cond = pred_entry

        sel_score, cond_score, group_score, order_score = score

        sel_num_score, sel_col_score, agg_num_score, agg_op_score = [
            x.data.cpu().numpy() if x is not None else None for x in sel_score
        ]
        cond_num_score, cond_col_score, cond_op_score = [
            x.data.cpu().numpy() if x is not None else None for x in cond_score
        ]
        gby_num_score, gby_score, hv_score, hv_col_score, hv_agg_score, hv_op_score = [
            x.data.cpu().numpy() if x is not None else None
            for x in group_score
        ]
        ody_num_score, ody_col_score, ody_agg_score, ody_par_score = [
            x.data.cpu().numpy() if x is not None else None
            for x in order_score
        ]

        ret_queries = []
        B = len(sel_num_score)
        for b in range(B):
            cur_query = {}
            # ------------get sel predict
            sel_num_cols = np.argmax(sel_num_score[b]) + 1
            cur_query['sel_num'] = sel_num_cols
            cur_query['sel'] = np.argsort(-sel_col_score[b])[:sel_num_cols]

            agg_nums = []
            agg_preds = []
            for idx in range(sel_num_cols):
                curr_num_aggs = np.argmax(agg_num_score[b][idx])
                agg_nums.append(curr_num_aggs)
                if curr_num_aggs == 0:
                    curr_agg_ops = [0]
                else:
                    curr_agg_ops = [
                        x for x in list(np.argsort(-agg_op_score[b][idx]))
                        if x != 0
                    ][:curr_num_aggs]
                agg_preds += curr_agg_ops
            cur_query['agg_num'] = agg_nums
            cur_query['agg'] = agg_preds
            #----------get group by predict
            gby_num_cols = np.argmax(gby_num_score[b])
            cur_query['gby_num'] = gby_num_cols
            cur_query['group'] = np.argsort(-gby_score[b])[:gby_num_cols]
            cur_query['hv'] = np.argmax(hv_score[b])
            if gby_num_cols != 0 and cur_query['hv'] != 0:
                cur_query['hv_agg'] = np.argmax(hv_agg_score[b])
                cur_query['hv_col'] = np.argmax(hv_col_score[b])
                cur_query['hv_op'] = np.argmax(hv_op_score[b])
            else:
                cur_query['hv'] = 0
                cur_query['hv_agg'] = 0
                cur_query['hv_col'] = -1
                cur_query['hv_op'] = -1
            # --------get order by
            ody_num_cols = np.argmax(ody_num_score[b])
            cur_query['ody_num'] = ody_num_cols
            cur_query['order'] = np.argsort(-ody_col_score[b])[:ody_num_cols]
            if ody_num_cols != 0:
                cur_query['ody_agg'] = np.argmax(ody_agg_score[b])
                cur_query['parity'] = np.argmax(ody_par_score[b])
            else:
                cur_query['ody_agg'] = 0
                cur_query['parity'] = -1

            # ody_agg_preds = []
            # for idx in range(len(gt_ody[b])):           # eventually dont use gold (look at agg query generation)
            #     curr_ody_agg = np.argmax(ody_agg_score[b][idx])
            #     ody_agg_preds += curr_ody_agg
            #
            # cur_query['ody_agg'] = ody_agg_preds
            # cur_query['parity'] = np.argmax(ody_par_score[b]) - 1
            #---------get cond predict
            #cond_num_score, cond_col_score, cond_op_score = [x.data.cpu().numpy() if x is not None else None for x in cond_score]
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])
                cur_cond.append(np.argmax(cond_op_score[b][idx]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)

        return ret_queries

    def find_shortest_path(self, start, end, graph):
        stack = [[start, []]]
        visited = set()
        while len(stack) > 0:
            ele, history = stack.pop()
            if ele == end:
                return history
            for node in graph[ele]:
                if node[0] not in visited:
                    stack.append((node[0], history + [(node[0], node[1])]))
                    visited.add(node[0])
        #print("Could not find a path between table {} and table {}".format(start, end))

    def gen_from(self, candidate_tables, schema):
        if len(candidate_tables) <= 1:
            if len(candidate_tables) == 1:
                ret = "from {}".format(
                    schema["table_names_original"][list(candidate_tables)[0]])
            else:
                ret = "from {}".format(schema["table_names_original"][0])
            # TODO: temporarily settings for select count(*)
            return {}, ret
        # print("candidate:{}".format(candidate_tables))
        table_alias_dict = {}
        uf_dict = {}
        for t in candidate_tables:
            uf_dict[t] = -1
        idx = 1
        graph = defaultdict(list)
        for acol, bcol in schema["foreign_keys"]:
            t1 = schema["column_names"][acol][0]
            t2 = schema["column_names"][bcol][0]
            graph[t1].append((t2, (acol, bcol)))
            graph[t2].append((t1, (bcol, acol)))
        candidate_tables = list(candidate_tables)
        start = candidate_tables[0]
        table_alias_dict[start] = idx
        idx += 1
        ret = "from {} as T1".format(schema["table_names_original"][start])
        try:
            for end in candidate_tables[1:]:
                if end in table_alias_dict:
                    continue
                path = self.find_shortest_path(start, end, graph)
                prev_table = start
                if not path:
                    table_alias_dict[end] = idx
                    idx += 1
                    ret = "{} join {} as T{}".format(
                        ret,
                        schema["table_names_original"][end],
                        table_alias_dict[end],
                    )
                    continue
                for node, (acol, bcol) in path:
                    if node in table_alias_dict:
                        prev_table = node
                        continue
                    table_alias_dict[node] = idx
                    idx += 1
                    ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(
                        ret, schema["table_names_original"][node],
                        table_alias_dict[node], table_alias_dict[prev_table],
                        schema["column_names_original"][acol][1],
                        table_alias_dict[node],
                        schema["column_names_original"][bcol][1])
                    prev_table = node
        except:
            traceback.print_exc()
            print("db:{}".format(schema["db_id"]))
            # print(table["db_id"])
            return table_alias_dict, ret
        return table_alias_dict, ret

    def gen_sql(self, score, col_org, schema_seq):

        sel_score, cond_score, group_score, order_score = score

        sel_num_score, sel_col_score, agg_num_score, agg_op_score = [
            x.data.cpu().numpy() if x is not None else None for x in sel_score
        ]
        cond_num_score, cond_col_score, cond_op_score = [
            x.data.cpu().numpy() if x is not None else None for x in cond_score
        ]
        gby_num_score, gby_score, hv_score, hv_col_score, hv_agg_score, hv_op_score = [
            x.data.cpu().numpy() if x is not None else None
            for x in group_score
        ]
        ody_num_score, ody_col_score, ody_agg_score, ody_par_score = [
            x.data.cpu().numpy() if x is not None else None
            for x in order_score
        ]

        ret_queries = []
        ret_sqls = []
        B = len(sel_num_score)

        for b in range(B):
            cur_cols = col_org[b]
            cur_query = {}
            schema = schema_seq[b]
            #for generate sql
            cur_sql = []
            cur_sel = []
            cur_conds = []
            cur_group = []
            cur_order = []
            cur_tables = defaultdict(list)

            # ------------get sel predict
            sel_num_cols = np.argmax(sel_num_score[b]) + 1
            cur_query['sel_num'] = sel_num_cols
            cur_query['sel'] = np.argsort(-sel_col_score[b])[:sel_num_cols]

            agg_nums = []
            agg_preds = []
            agg_preds_gen = []
            for idx in range(sel_num_cols):
                curr_num_aggs = np.argmax(agg_num_score[b][idx])
                agg_nums.append(curr_num_aggs)
                if curr_num_aggs == 0:
                    curr_agg_ops = [0]
                else:
                    curr_agg_ops = [
                        x for x in list(np.argsort(-agg_op_score[b][idx]))
                        if x != 0
                    ][:curr_num_aggs]
                agg_preds += curr_agg_ops
                agg_preds_gen.append(curr_agg_ops)
            cur_query['agg_num'] = agg_nums
            cur_query['agg'] = agg_preds
            # for gen sel

            cur_sel.append("select")
            for i, cid in enumerate(cur_query['sel']):
                aggs = agg_preds_gen[i]
                agg_num = len(aggs)
                for j, gix in enumerate(aggs):
                    if gix == 0:
                        cur_sel.append([cid, cur_cols[cid][1]])
                        cur_tables[cur_cols[cid][0]].append(
                            [cid, cur_cols[cid][1]])
                    else:
                        cur_sel.append(AGG_OPS[gix])
                        cur_sel.append("(")
                        cur_sel.append([cid, cur_cols[cid][1]])
                        cur_tables[cur_cols[cid][0]].append(
                            [cid, cur_cols[cid][1]])
                        cur_sel.append(")")
                    if j < agg_num - 1:
                        cur_sel.append(",")

                if i < sel_num_cols - 1:
                    cur_sel.append(",")

            #----------get group by predict
            gby_num_cols = np.argmax(gby_num_score[b])
            cur_query['gby_num'] = gby_num_cols
            cur_query['group'] = np.argsort(-gby_score[b])[:gby_num_cols]
            cur_query['hv'] = np.argmax(hv_score[b])
            if gby_num_cols != 0 and cur_query['hv'] != 0:
                cur_query['hv_agg'] = np.argmax(hv_agg_score[b])
                cur_query['hv_col'] = np.argmax(hv_col_score[b])
                cur_query['hv_op'] = np.argmax(hv_op_score[b])
            else:
                cur_query['hv'] = 0
                cur_query['hv_agg'] = 0
                cur_query['hv_col'] = -1
                cur_query['hv_op'] = -1

            # for gen group
            if gby_num_cols > 0:
                cur_group.append("group by")
                for i, gid in enumerate(cur_query['group']):
                    cur_group.append([gid, cur_cols[gid][1]])
                    cur_tables[cur_cols[gid][0]].append(
                        [gid, cur_cols[gid][1]])
                    if i < gby_num_cols - 1:
                        cur_group.append(",")
                if cur_query['hv'] != 0:
                    cur_group.append("having")
                    if cur_query['hv_agg'] != 0:
                        cur_group.append(AGG_OPS[cur_query['hv_agg']])
                        cur_group.append("(")
                        cur_group.append([
                            cur_query['hv_col'],
                            cur_cols[cur_query['hv_col']][1]
                        ])
                        cur_group.append(")")
                    else:
                        cur_group.append([
                            cur_query['hv_col'],
                            cur_cols[cur_query['hv_col']][1]
                        ])
                    cur_tables[cur_cols[cur_query['hv_col']][0]].append([
                        cur_query['hv_col'], cur_cols[cur_query['hv_col']][1]
                    ])
                    cur_group.append(WHERE_OPS[cur_query['hv_op']])
                    cur_group.append(VALUE)

            # --------get order by
            ody_num_cols = np.argmax(ody_num_score[b])
            cur_query['ody_num'] = ody_num_cols
            cur_query['order'] = np.argsort(-ody_col_score[b])[:ody_num_cols]
            if ody_num_cols != 0:
                cur_query['ody_agg'] = np.argmax(ody_agg_score[b])
                cur_query['parity'] = np.argmax(ody_par_score[b])
            else:
                cur_query['ody_agg'] = 0
                cur_query['parity'] = -1

            # for gen order
            if ody_num_cols > 0:
                cur_order.append("order by")
                for oid in cur_query['order']:
                    if cur_query['ody_agg'] != 0:
                        cur_order.append(AGG_OPS[cur_query['ody_agg']])
                        cur_order.append("(")
                        cur_order.append([oid, cur_cols[oid][1]])
                        cur_order.append(")")
                    else:
                        cur_order.append([oid, cur_cols[oid][1]])
                    cur_tables[cur_cols[oid][0]].append(
                        [oid, cur_cols[oid][1]])

                datid = cur_query['parity']
                if datid == 0:
                    cur_order.append(DESC_ASC_LIMIT[0])
                elif datid == 1:
                    cur_order.append(DESC_ASC_LIMIT[1])
                elif datid == 2:
                    cur_order.append(DESC_ASC_LIMIT[2])
                elif datid == 3:
                    cur_order.append(DESC_ASC_LIMIT[3])

            #---------get cond predict
            #cond_num_score, cond_col_score, cond_op_score = [x.data.cpu().numpy() if x is not None else None for x in cond_score]
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])
                cur_cond.append(np.argmax(cond_op_score[b][idx]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)

            # for gen conds
            if len(cur_query['conds']) > 0:
                cur_conds.append("where")
                for i, cond in enumerate(cur_query['conds']):
                    cid, oid = cond
                    cur_conds.append([cid, cur_cols[cid][1]])
                    cur_tables[cur_cols[cid][0]].append(
                        [cid, cur_cols[cid][1]])
                    cur_conds.append(WHERE_OPS[oid])
                    cur_conds.append(VALUE)
                    if i < cond_num - 1:
                        cur_conds.append("and")

            if -1 in cur_tables.keys():
                del cur_tables[-1]

            table_alias_dict, ret = self.gen_from(cur_tables.keys(), schema)
            if len(table_alias_dict) > 0:
                col_map = {}
                for tid, aid in table_alias_dict.items():
                    for cid, col in cur_tables[tid]:
                        col_map[cid] = "t" + str(aid) + "." + col

                new_sel = []
                for s in cur_sel:
                    if isinstance(s, list):
                        if s[0] == 0:
                            new_sel.append("*")
                        elif s[0] in col_map:
                            new_sel.append(col_map[s[0]])
                    else:
                        new_sel.append(s)

                new_conds = []
                for s in cur_conds:
                    if isinstance(s, list):
                        if s[0] == 0:
                            new_conds.append("*")
                        else:
                            new_conds.append(col_map[s[0]])
                    else:
                        new_conds.append(s)

                new_group = []
                for s in cur_group:
                    if isinstance(s, list):
                        if s[0] == 0:
                            new_group.append("*")
                        else:
                            new_group.append(col_map[s[0]])
                    else:
                        new_group.append(s)

                new_order = []
                for s in cur_order:
                    if isinstance(s, list):
                        if s[0] == 0:
                            new_order.append("*")
                        else:
                            new_order.append(col_map[s[0]])
                    else:
                        new_order.append(s)

                        # for gen all sql
                cur_sql = new_sel + [ret] + new_conds + new_group + new_order
            else:
                cur_sql = []
                #try:
                cur_sql.extend(
                    [s[1] if isinstance(s, list) else s for s in cur_sel])
                if len(cur_tables.keys()) == 0:
                    cur_tables[0] = []
                cur_sql.extend([
                    "from",
                    schema["table_names_original"][cur_tables.keys()[0]]
                ])
                if len(cur_conds) > 0:
                    cur_sql.extend([
                        s[1] if isinstance(s, list) else s for s in cur_conds
                    ])
                if len(cur_group) > 0:
                    cur_sql.extend([
                        s[1] if isinstance(s, list) else s for s in cur_group
                    ])
                if len(cur_order) > 0:
                    cur_sql.extend([
                        s[1] if isinstance(s, list) else s for s in cur_order
                    ])

            sql_str = " ".join(cur_sql)
            ret_sqls.append(sql_str)

        return ret_sqls
コード例 #4
0
class Seq2SQL(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 trainable_emb=False):
        super(Seq2SQL, self).__init__()
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]
        self.COND_OPS = ['EQL', 'GT', 'LT']

        #Word embedding
        if trainable_emb:
            self.agg_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=False,
                                                 trainable=trainable_emb)
            self.sel_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=False,
                                                 trainable=trainable_emb)
            self.cond_embed_layer = WordEmbedding(word_emb,
                                                  N_word,
                                                  gpu,
                                                  self.SQL_TOK,
                                                  our_model=False,
                                                  trainable=trainable_emb)
        else:
            self.embed_layer = WordEmbedding(word_emb,
                                             N_word,
                                             gpu,
                                             self.SQL_TOK,
                                             our_model=False,
                                             trainable=trainable_emb)

        #Predict aggregator
        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=False)

        #Predict selected column
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=False)

        #Predict number of cond
        self.cond_pred = Seq2SQLCondPredictor(N_word, N_h, N_depth,
                                              self.max_col_num,
                                              self.max_tok_num, gpu)

        self.CE = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax()
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    def generate_gt_where_seq(self, q, col, query):
        # data format
        # <BEG> WHERE cond1_col cond1_op cond1
        #         AND cond2_col cond2_op cond2
        #         AND ... <END>

        ret_seq = []
        for cur_q, cur_col, cur_query in zip(q, col, query):
            connect_col = [
                tok for col_tok in cur_col for tok in col_tok + [',']
            ]
            all_toks = self.SQL_TOK + connect_col + [None] + cur_q + [None]
            cur_seq = [all_toks.index('<BEG>')]
            if 'WHERE' in cur_query:
                cur_where_query = cur_query[cur_query.index('WHERE'):]
                cur_seq = cur_seq + map(
                    lambda tok: all_toks.index(tok)
                    if tok in all_toks else 0, cur_where_query)
            cur_seq.append(all_toks.index('<END>'))
            ret_seq.append(cur_seq)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                pred_entry,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None):
        B = len(q)
        pred_agg, pred_sel, pred_cond = pred_entry

        agg_score = None
        sel_score = None
        cond_score = None

        if self.trainable_emb:
            if pred_agg:
                x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
                batch = self.agg_embed_layer.gen_col_batch(col)
                col_inp_var, col_name_len, col_len = batch
                max_x_len = max(x_len)
                agg_score = self.agg_pred(x_emb_var, x_len)

            if pred_sel:
                x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
                batch = self.sel_embed_layer.gen_col_batch(col)
                col_inp_var, col_name_len, col_len = batch
                max_x_len = max(x_len)
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
                batch = self.cond_embed_layer.gen_col_batch(col)
                col_inp_var, col_name_len, col_len = batch
                max_x_len = max(x_len)
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
            batch = self.embed_layer.gen_col_batch(col)
            col_inp_var, col_name_len, col_len = batch
            max_x_len = max(x_len)
            if pred_agg:
                agg_score = self.agg_pred(x_emb_var, x_len)

            if pred_sel:
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)

        return (agg_score, sel_score, cond_score)

    def loss(self, score, truth_num, pred_entry, gt_where):
        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score
        loss = 0
        if pred_agg:
            agg_truth = map(lambda x: x[0], truth_num)
            data = torch.from_numpy(np.array(agg_truth))
            if self.gpu:
                agg_truth_var = Variable(data.cuda())
            else:
                agg_truth_var = Variable(data)

            loss += self.CE(agg_score, agg_truth_var)

        if pred_sel:
            sel_truth = map(lambda x: x[1], truth_num)
            data = torch.from_numpy(np.array(sel_truth))
            if self.gpu:
                sel_truth_var = Variable(data).cuda()
            else:
                sel_truth_var = Variable(data)

            loss += self.CE(sel_score, sel_truth_var)

        if pred_cond:
            for b in range(len(gt_where)):
                if self.gpu:
                    cond_truth_var = Variable(
                        torch.from_numpy(np.array(gt_where[b][1:])).cuda())
                else:
                    cond_truth_var = Variable(
                        torch.from_numpy(np.array(gt_where[b][1:])))
                cond_pred_score = cond_score[b, :len(gt_where[b]) - 1]

                loss += (self.CE(cond_pred_score, cond_truth_var) /
                         len(gt_where))

        return loss

    def reinforce_backward(self, score, rewards):
        agg_score, sel_score, cond_score = score

        cur_reward = rewards[:]
        eof = self.SQL_TOK.index('<END>')

        loss = 0
        for t in range(len(cond_score[1])):
            reward_inp = torch.FloatTensor(cur_reward).unsqueeze(1)
            if self.gpu:
                reward_inp = reward_inp.cuda()
            # cond_score[1][t].reinforce(reward_inp)
            loss -= cond_score[1][t].log_prob * reward_inp

            for b in range(len(rewards)):
                if cond_score[1][t][b].data.cpu().numpy()[0] == eof:
                    cur_reward[b] = 0
        # torch.autograd.backward(cond_score[1], [None for _ in cond_score[1]])
        loss.backward()
        return

    def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry):
        def pretty_print(vis_data):
            print 'question:', vis_data[0]
            print 'headers: (%s)' % (' || '.join(vis_data[1]))
            print 'query:', vis_data[2]

        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(
                    header[cond[0]] + ' ' + self.COND_OPS[cond[1]] + \
                    ' ' + unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        pred_agg, pred_sel, pred_cond = pred_entry

        B = len(gt_queries)

        tot_err = agg_err = sel_err = cond_err = cond_num_err = \
                  cond_col_err = cond_op_err = cond_val_err = 0.0
        agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            if pred_agg:
                agg_pred = pred_qry['agg']
                agg_gt = gt_qry['agg']
                if agg_pred != agg_gt:
                    agg_err += 1
                    good = False

            if pred_sel:
                sel_pred = pred_qry['sel']
                sel_gt = gt_qry['sel']
                if sel_pred != sel_gt:
                    sel_err += 1
                    good = False

            if pred_cond:
                cond_pred = pred_qry['conds']
                cond_gt = gt_qry['conds']
                flag = True
                if len(cond_pred) != len(cond_gt):
                    flag = False
                    cond_num_err += 1

                if flag and set(x[0]
                                for x in cond_pred) != set(x[0]
                                                           for x in cond_gt):
                    flag = False
                    cond_col_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]:
                        flag = False
                        cond_op_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and unicode(cond_gt[gt_idx][2]).lower() != \
                       unicode(cond_pred[idx][2]).lower():
                        flag = False
                        cond_val_err += 1

                if not flag:
                    cond_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array((agg_err, sel_err, cond_err)), tot_err

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  pred_entry,
                  reinforce=False,
                  verbose=False):
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str.lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear

                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                elif tok[0] not in alphabet:
                    pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) and \
                     (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score

        ret_queries = []
        if pred_agg:
            B = len(agg_score)
        elif pred_sel:
            B = len(sel_score)
        elif pred_cond:
            B = len(cond_score[0]) if reinforce else len(cond_score)
        for b in range(B):
            cur_query = {}
            if pred_agg:
                cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy())
            if pred_sel:
                cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy())
            if pred_cond:
                cur_query['conds'] = []
                all_toks = self.SQL_TOK + \
                           [x for toks in col[b] for x in
                            toks+[',']] + [''] + q[b] + ['']
                cond_toks = []
                if reinforce:
                    for choices in cond_score[1]:
                        if choices[b].data.cpu().numpy()[0] < len(all_toks):
                            cond_val = all_toks[choices[b].data.cpu().numpy()
                                                [0]]
                        else:
                            cond_val = '<UNK>'
                        if cond_val == '<END>':
                            break
                        cond_toks.append(cond_val)
                else:
                    for where_score in cond_score[b].data.cpu().numpy():
                        cond_tok = np.argmax(where_score)
                        cond_val = all_toks[cond_tok]
                        if cond_val == '<END>':
                            break
                        cond_toks.append(cond_val)

                if verbose:
                    print cond_toks
                if len(cond_toks) > 0:
                    cond_toks = cond_toks[1:]
                st = 0
                while st < len(cond_toks):
                    cur_cond = [None, None, None]
                    ed = len(cond_toks) if 'AND' not in cond_toks[st:] \
                         else cond_toks[st:].index('AND') + st
                    if 'EQL' in cond_toks[st:ed]:
                        op = cond_toks[st:ed].index('EQL') + st
                        cur_cond[1] = 0
                    elif 'GT' in cond_toks[st:ed]:
                        op = cond_toks[st:ed].index('GT') + st
                        cur_cond[1] = 1
                    elif 'LT' in cond_toks[st:ed]:
                        op = cond_toks[st:ed].index('LT') + st
                        cur_cond[1] = 2
                    else:
                        op = st
                        cur_cond[1] = 0
                    sel_col = cond_toks[st:op]
                    to_idx = [x.lower() for x in raw_col[b]]
                    pred_col = merge_tokens(sel_col, raw_q[b] + ' || ' + \
                                            ' || '.join(raw_col[b]))
                    if pred_col in to_idx:
                        cur_cond[0] = to_idx.index(pred_col)
                    else:
                        cur_cond[0] = 0
                    cur_cond[2] = merge_tokens(cond_toks[op + 1:ed], raw_q[b])
                    cur_query['conds'].append(cur_cond)
                    st = ed + 1
            ret_queries.append(cur_query)

        return ret_queries
コード例 #5
0
ファイル: sqlnet.py プロジェクト: qianwenyuan/typesql_ch
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=120,
                 N_depth=2,
                 use_ca=True,
                 gpu=True,
                 trainable_emb=False,
                 db_content=0):
        super(SQLNet, self).__init__()
        self.trainable_emb = trainable_emb
        self.db_content = db_content

        self.use_ca = use_ca
        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=',
            '<BEG>'
        ]
        self.COND_OPS = ['>', '<', '==', '!=']

        #the model actually doesn't use type embedding when db_content == 1
        if db_content == 0:
            is_train = True
        else:
            is_train = False

        # self.sel_num_type_embed_layer = WordEmbedding(word_emb, N_word, gpu,
        #         self.SQL_TOK, trainable=is_train)
        self.agg_type_embed_layer = WordEmbedding(word_emb,
                                                  N_word,
                                                  gpu,
                                                  self.SQL_TOK,
                                                  trainable=is_train)
        self.sel_type_embed_layer = WordEmbedding(word_emb,
                                                  N_word,
                                                  gpu,
                                                  self.SQL_TOK,
                                                  trainable=is_train)
        self.cond_type_embed_layer = WordEmbedding(word_emb,
                                                   N_word,
                                                   gpu,
                                                   self.SQL_TOK,
                                                   trainable=is_train)
        self.where_rela_type_embed_layer = WordEmbedding(word_emb,
                                                         N_word,
                                                         gpu,
                                                         self.SQL_TOK,
                                                         trainable=is_train)

        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         trainable=trainable_emb)

        # # Predict selected column number
        # self.sel_num = SelNumPredictor(N_word, N_h, N_depth)
        #
        # # Predict which columns are selected
        # self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca)

        #Predict aggregator
        self.agg_pred = AggPredictor(N_word, N_h, N_depth)

        # # Predict number of conditions, condition columns, condition operations and condition values
        # self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num,use_ca, gpu, db_content)

        #Predict selected column number + select column + condition number and columns
        self.selcond_pred = SelCondPredictor(N_word, N_h, N_depth, gpu,
                                             db_content)

        #Predict condition operators and string values
        self.op_str_pred = CondOpStrPredictor(N_word, N_h, N_depth,
                                              self.max_col_num,
                                              self.max_tok_num, gpu,
                                              db_content)

        # Predict conditions' relation
        self.where_rela_pred = WhereRelationPredictor(N_word,
                                                      N_h,
                                                      N_depth,
                                                      use_ca=use_ca)

        self.CE = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax()
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    def get_str_index(self, all_toks, this_str):
        cur_seq = []
        tok_gt_1 = [t for t in all_toks if len(t) > 1]
        if this_str in all_toks:
            all_str = [['<BEG>'], this_str, ['<END>']]
            cur_seq = [
                all_toks.index(s) if s in all_toks else 0 for s in all_str
            ]
        elif len(tok_gt_1) > 0:
            flag = False
            for tgt in tok_gt_1:
                if set(tgt).issubset(this_str):
                    not_tgt = [x for x in this_str if x not in tgt]
                    if len(not_tgt) > 0:
                        not_tgt = [[x] for x in not_tgt]
                        all_str = [tgt] + not_tgt
                    else:
                        all_str = [tgt]
                    beg_ind = all_toks.index(
                        ['<BEG>']) if ['<BEG>'] in all_toks else 0
                    end_ind = all_toks.index(
                        ['<END>']) if ['<END>'] in all_toks else 0
                    cur_seq = sorted([
                        all_toks.index(s) if s in all_toks else 0
                        for s in all_str
                    ])
                    cur_seq = [beg_ind] + cur_seq + [end_ind]
                elif set(this_str).issubset(tgt):
                    all_str = [['<BEG>'], tgt, ['<END>']]
                    cur_seq = [
                        all_toks.index(s) if s in all_toks else 0
                        for s in all_str
                    ]

                if len(cur_seq) > 0:
                    flag = True
                    break

            if not flag:
                all_str = [['<BEG>']] + [[x] for x in this_str] + [['<END>']]
                cur_seq = [
                    all_toks.index(s) if s in all_toks else 0 for s in all_str
                ]
        else:
            all_str = [['<BEG>']] + [[x] for x in this_str] + [['<END>']]
            cur_seq = [
                all_toks.index(s) if s in all_toks else 0 for s in all_str
            ]

        return cur_seq

    def generate_gt_where_seq_test(self, q, gt_cond_seq):
        ret_seq = []
        for cur_q, ans in zip(q, gt_cond_seq):
            q_toks = []
            q_toks_cnt = []
            cur_q_join = []
            cnt = 0
            for toks in cur_q:
                cur_q_join.append(u"".join(toks))
                cnt1 = 0
                for tok in toks:
                    q_toks.append(tok)
                    cnt1 = cnt1 + len(tok)
                cnt = cnt + cnt1
                q_toks_cnt.append(cnt)
            #for tok in q_toks:
            #   print("{}".format(tok.encode('utf-8')))
            #print("q_toks:{}".format(len(q_toks)))
            temp_q = u"".join(cur_q_join)
            #print("temp_q:{}".format(temp_q.encode('utf-8')))
            #cur_q = [u'<BEG>'] + cur_q + [u'<END>']
            #cur_q = [u'<BEG>'] + cur_q_join + [u'<END>']
            #print("cur_q:{}".format(cur_q.encode('utf-8')))
            record = []
            record_cond = []
            for cond in ans:
                #print("cond[2]:{}".format(cond[2].encode('utf-8')))
                if cond[2] not in temp_q:
                    record.append((False, cond[2]))
                else:
                    record.append((True, cond[2]))
            for idx, item in enumerate(record):
                temp_ret_seq = []
                if item[0]:
                    temp_ret_seq.append(0)
                    start_idx = -1
                    end_idx = -1
                    start_idx_org = temp_q.index(item[1]) + 1
                    end_idx_org = start_idx_org + len(item[1]) - 1
                    for idx, cnt in enumerate(q_toks_cnt):
                        if start_idx_org <= cnt:
                            start_idx = idx
                            break
                    for idx, cnt in enumerate(q_toks_cnt):
                        if end_idx_org <= cnt:
                            end_idx = idx + 1
                            break
                    if end_idx == -1:
                        end_idx = len(q_toks_cnt) + 1
#temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
#print("start_idx:{} end_idx:{}".format(start_idx, end_idx))
                    temp_ret_seq.extend(list(range(start_idx + 1,
                                                   end_idx + 1)))
                    temp_ret_seq.append(len(q_toks_cnt) + 1)
                else:
                    temp_ret_seq.extend([0, len(q_toks_cnt) + 1])
                #print("temp_ret_sql:{}".format(temp_ret_seq))
                record_cond.append(temp_ret_seq)
            ret_seq.append(record_cond)
        return ret_seq

    def generate_gt_where_seq(self, q, col):
        ret_seq = []
        for cur_q, cur_query in zip(q, col):
            cur_values = []
            st = cur_query.index(u'WHERE') + 1 if \
                u'WHERE' in cur_query else len(cur_query)
            all_toks = [['<BEG>']] + cur_q + [['<END>']]
            while st < len(cur_query):
                ed = len(cur_query) if 'AND' not in cur_query[st:] \
                    else cur_query[st:].index('AND') + st
                if '==' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('==') + st
                elif '>' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('>') + st
                elif '<' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('<') + st
                elif '>=' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('>=') + st
                elif '<=' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('<=') + st
                elif '!=' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('!=') + st
                else:
                    raise RuntimeError("No operator in it!")

                this_str = cur_query[op + 1:ed]
                cur_seq = self.get_str_index(all_toks, this_str)
                cur_values.append(cur_seq)
                st = ed + 1
            ret_seq.append(cur_values)
        # ret_seq = []
        # for cur_q, ans in zip(q, gt_cond_seq):
        #     temp_q = u"".join(cur_q)
        #     cur_q = [u'<BEG>'] + cur_q + [u'<END>']
        #     record = []
        #     record_cond = []
        #     for cond in ans:
        #         if cond[2] not in temp_q:
        #             record.append((False, cond[2]))
        #         else:
        #             record.append((True, cond[2]))
        #     for idx, item in enumerate(record):
        #         temp_ret_seq = []
        #         if item[0]:
        #             temp_ret_seq.append(0)
        #             temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
        #             temp_ret_seq.append(len(cur_q)-1)
        #         else:
        #             temp_ret_seq.append([0,len(cur_q)-1])
        #         record_cond.append(temp_ret_seq)
        #     ret_seq.append(record_cond)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                q_type,
                col_type,
                gt_where=None,
                gt_cond=None,
                gt_sel=None,
                gt_sel_num=None):
        B = len(q)
        # pred_sel_num, pred_agg, pred_sel, pred_cond, pred_where_rela = pred_entry
        pred_agg = True
        pred_sel = True
        pred_cond = True
        pred_sel_num = True
        pred_where_rela = True

        agg_score = None
        sel_cond_score = None
        cond_op_str_score = None

        if self.trainable_emb:
            if pred_agg:
                x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.agg_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                agg_score = self.agg_pred(x_emb_var,
                                          x_len,
                                          col_inp_var,
                                          col_name_len,
                                          col_len,
                                          col_num,
                                          gt_sel=gt_sel)

            if pred_sel:
                x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.sel_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                sel_score = self.selcond_pred(x_emb_var, x_len, col_inp_var,
                                              col_name_len, col_len, col_num)

            if pred_cond:
                x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.cond_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var,
                                            col_name_len, col_len, col_num,
                                            gt_where, gt_cond)
        elif self.db_content == 0:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q,
                                                            col,
                                                            is_list=True,
                                                            is_q=True)
            #col_inp_var, col_len = self.embed_layer.gen_x_batch(col, col, is_list=True)
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)
            agg_emb_var = self.embed_layer.gen_agg_batch(q)
            max_x_len = max(x_len)
            if pred_sel_num and pred_agg and pred_sel:
                sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var,
                                             col_name_len, col_len, col_num)
            if gt_sel_num:
                pre_sel_num = gt_sel_num
            else:
                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(),
                                       axis=1)
            x_type_sel_emb_var, _ = self.sel_type_embed_layer.gen_xc_type_batch(
                q_type, is_list=True)
            sel_cond_score = self.selcond_pred(x_emb_var, x_len, col_inp_var,
                                               col_len, x_type_sel_emb_var,
                                               gt_sel)

            if gt_sel:
                pr_sel = gt_sel
            else:
                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
                sel = sel_cond_score.data.cpu().numpy()
                pr_sel = [
                    list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))
                ]
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      agg_emb_var,
                                      col_inp_var,
                                      col_len,
                                      gt_sel=pr_sel,
                                      gt_sel_num=pr_sel_num)
            # if pred_agg:
            #     #x_type_agg_emb_var, _ = self.agg_type_embed_layer.gen_xc_type_batch(q_type, is_list=True)
            #     agg_score = self.agg_pred(x_emb_var, x_len, agg_emb_var, col_inp_var, col_len)
            #
            # if pred_sel:
            #     x_type_sel_emb_var, _ = self.sel_type_embed_layer.gen_xc_type_batch(q_type, is_list=True)
            #     sel_cond_score = self.selcond_pred(x_emb_var, x_len, col_inp_var, col_len, x_type_sel_emb_var,
            #                                    gt_sel)

            if pred_cond:
                x_type_cond_emb_var, _ = self.cond_type_embed_layer.gen_xc_type_batch(
                    q_type, is_list=True)
                cond_op_str_score = self.op_str_pred(x_emb_var, x_len,
                                                     col_inp_var, col_len,
                                                     x_type_cond_emb_var,
                                                     gt_where, gt_cond,
                                                     sel_cond_score)

            if pred_where_rela:
                where_rela_score = self.where_rela_pred(
                    x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                    col_num)
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q,
                                                            col,
                                                            is_list=True,
                                                            is_q=True)
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)
            x_type_emb_var, x_type_len = self.embed_layer.gen_x_batch(
                q_type, col, is_list=True, is_q=True)
            #x_type_cond_emb_var, _ = self.cond_type_embed_layer.gen_xc_type_batch(q_type, is_list=True)
            #col_type_inp_var, col_type_len = self.embed_layer.gen_x_batch(col_type, col_type, is_list=True)
            #print("x_var_shape:{}, x_type_var_shape:{}".format(x_emb_var, x_type_emb_var))
            sel_cond_score = self.selcond_pred(x_emb_var, x_len, col_inp_var,
                                               col_name_len, col_len,
                                               x_type_emb_var, gt_sel)
            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_len,
                                      col_name_len, x_type_emb_var, gt_sel,
                                      sel_cond_score)
            cond_op_str_score = self.op_str_pred(x_emb_var, x_len, col_inp_var,
                                                 col_len, col_name_len,
                                                 x_type_emb_var, gt_where,
                                                 gt_cond, sel_cond_score)
            where_rela_score = self.where_rela_pred(x_emb_var, x_len,
                                                    col_inp_var, col_name_len,
                                                    col_len, col_num,
                                                    x_type_emb_var)

            # sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, x_type_emb_var)
            #
            # if gt_sel_num:
            #     pr_sel_num = gt_sel_num
            # else:
            #     pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
# sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, x_type_emb_var)
#
# if gt_sel:
#     pr_sel = gt_sel
# else:
#     num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) +1
# sel = sel_score.data.cpu().numpy()
# pr_sel = [list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))]
# agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, x_type_emb_var, gt_sel=pr_sel,
#                       gt_sel_num=pr_sel_num)
# cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, x_type_emb_var, gt_where, gt_cond)
# where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, x_type_emb_var)
        return (sel_cond_score, agg_score, cond_op_str_score, where_rela_score)

    def loss(self, score, truth_num, gt_where):
        sel_cond_score, agg_score, cond_op_str_score, where_rela_score = score

        sel_num_score, cond_num_score, sel_score, cond_col_score = sel_cond_score
        cond_op_score, cond_str_score = cond_op_str_score

        B = len(truth_num)
        loss = 0

        # Evaluate select number
        sel_num_truth = map(lambda x: x[0], truth_num)
        sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
        if self.gpu:
            sel_num_truth = Variable(sel_num_truth.cuda())
        else:
            sel_num_truth = Variable(sel_num_truth)
        loss += self.CE(sel_num_score, sel_num_truth)

        # Evaluate select column
        T = len(sel_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            truth_prob[b][list(truth_num[b][1])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            sel_col_truth_var = Variable(data.cuda())
        else:
            sel_col_truth_var = Variable(data)
        sigm = nn.Sigmoid()
        sel_col_prob = sigm(sel_score)
        bce_loss = -torch.mean(
            3 * (sel_col_truth_var * torch.log(sel_col_prob + 1e-10)) +
            (1 - sel_col_truth_var) * torch.log(1 - sel_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate select aggregation
        for b in range(len(truth_num)):
            data = torch.from_numpy(np.array(truth_num[b][2]))
            if self.gpu:
                sel_agg_truth_var = Variable(data.cuda())
            else:
                sel_agg_truth_var = Variable(data)
            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)

        # Evaluate the number of conditions
        cond_num_truth = map(lambda x: x[3], truth_num)
        data = torch.from_numpy(np.array(cond_num_truth))
        if self.gpu:
            try:
                cond_num_truth_var = Variable(data.cuda())
            except:
                print "cond_num_truth_var error"
                print data
                exit(0)
        else:
            cond_num_truth_var = Variable(data)
        loss += self.CE(cond_num_score, cond_num_truth_var)

        # Evaluate the columns of conditions
        T = len(cond_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][4]) > 0:
                truth_prob[b][list(truth_num[b][4])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            cond_col_truth_var = Variable(data.cuda())
        else:
            cond_col_truth_var = Variable(data)

        sigm = nn.Sigmoid()
        cond_col_prob = sigm(cond_col_score)
        bce_loss = -torch.mean(
            3 * (cond_col_truth_var * torch.log(cond_col_prob + 1e-10)) +
            (1 - cond_col_truth_var) * torch.log(1 - cond_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate the operator of conditions
        for b in range(len(truth_num)):
            if len(truth_num[b][5]) == 0:
                continue
            data = torch.from_numpy(np.array(truth_num[b][5]))
            if self.gpu:
                cond_op_truth_var = Variable(data.cuda())
            else:
                cond_op_truth_var = Variable(data)
            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
            try:
                loss += (self.CE(cond_op_pred, cond_op_truth_var) /
                         len(truth_num))
            except:
                print cond_op_pred
                print cond_op_truth_var
                exit(0)

        # Evaluate the strings of conditions
        for b in range(len(gt_where)):
            #print(gt_where[b])
            for idx in range(len(gt_where[b])):
                cond_str_truth = gt_where[b][idx]
                #print("{}{}".format(cond_str_truth, len(cond_str_truth)))
                if len(cond_str_truth) == 2:
                    continue
                #print("cond_str_truth:{}".format(cond_str_truth[1:]))
#for tok in cond_str_truth[1:]:
#print("cond_str_tr_tok:{}".format(tok))
#print(' ')
                data = torch.from_numpy(np.array(cond_str_truth[1:]))
                #print("data:{}{}".format(data, data.shape))
                if self.gpu:
                    cond_str_truth_var = Variable(data.cuda())
                else:
                    cond_str_truth_var = Variable(data)
                str_end = len(cond_str_truth) - 1
                #print("cond_str_score:{} str_end:{}".format(cond_str_score.shape, str_end))
                cond_str_pred = cond_str_score[b, idx, :str_end]
                #print ("cond_str_score:{}cond_str_pred:{}".format(cond_str_score.shape, cond_str_pred.shape))
                #print("cond_str_pred:{}".format(cond_str_pred))
                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
                                       / (len(gt_where) * len(gt_where[b])))

        # Evaluate condition relationship, and / or
        where_rela_truth = map(lambda x: x[6], truth_num)
        data = torch.from_numpy(np.array(where_rela_truth))
        if self.gpu:
            try:
                where_rela_truth = Variable(data.cuda())
            except:
                print "where_rela_truth error"
                print data
                exit(0)
        else:
            where_rela_truth = Variable(data)
        loss += self.CE(where_rela_score, where_rela_truth)
        return loss

    # def loss(self, score, truth_num, pred_entry, gt_where): #edited by qwy
    #     pred_agg, pred_sel, pred_cond = pred_entry
    #     sel_num_score, agg_score, sel_cond_score, cond_op_str_score, where_rela_score = score
    #
    #     cond_num_score, sel_score, cond_col_score = sel_cond_score
    #     cond_op_score, cond_str_score = cond_op_str_score
    #
    #     loss = 0
    #
    #     sel_num_truth = map(lambda x:x[0], truth_num)
    #     sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
    #     if self.gpu:
    #         sel_num_truth = Variable(sel_num_truth.cuda())
    #     else:
    #         sel_num_truth = Variable(sel_num_truth)
    #     loss += self.CE(sel_num_score, sel_num_truth)
    #
    #     if pred_sel:
    #         sel_truth = map(lambda x:x[1], truth_num)
    #         data = torch.from_numpy(np.array(sel_truth))
    #         if self.gpu:
    #             sel_truth_var = Variable(data.cuda())
    #         else:
    #             sel_truth_var = Variable(data)
    #
    #         loss += self.CE(sel_score, sel_truth_var)
    #
    #     if pred_agg:
    #         agg_truth = map(lambda x:x[2], truth_num)
    #         data = torch.from_numpy(np.array(agg_truth))
    #         if self.gpu:
    #             agg_truth_var = Variable(data.cuda())
    #         else:
    #             agg_truth_var = Variable(data)
    #
    #         loss += self.CE(agg_score, agg_truth_var)
    #
    #     if pred_cond:
    #         B = len(truth_num)
    #         #Evaluate the number of conditions
    #         cond_num_truth = map(lambda x:x[3], truth_num)
    #         data = torch.from_numpy(np.array(cond_num_truth))
    #         if self.gpu:
    #             cond_num_truth_var = Variable(data.cuda())
    #         else:
    #             cond_num_truth_var = Variable(data)
    #         loss += self.CE(cond_num_score, cond_num_truth_var)
    #
    #         #Evaluate the columns of conditions
    #         T = len(cond_col_score[0])
    #         truth_prob = np.zeros((B, T), dtype=np.float32)
    #         for b in range(B):
    #             if len(truth_num[b][4]) > 0:
    #                 truth_prob[b][list(truth_num[b][4])] = 1
    #         data = torch.from_numpy(truth_prob)
    #         if self.gpu:
    #             cond_col_truth_var = Variable(data.cuda())
    #         else:
    #             cond_col_truth_var = Variable(data)
    #
    #         sigm = nn.Sigmoid()
    #         cond_col_prob = sigm(cond_col_score)
    #         bce_loss = -torch.mean( 3*(cond_col_truth_var * \
    #                 torch.log(cond_col_prob+1e-10)) + \
    #                 (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
    #         loss += bce_loss
    #
    #         #Evaluate the operator of conditions
    #         for b in range(len(truth_num)):
    #             if len(truth_num[b][5]) == 0:
    #                 continue
    #             data = torch.from_numpy(np.array(truth_num[b][5]))
    #             if self.gpu:
    #                 cond_op_truth_var = Variable(data.cuda())
    #             else:
    #                 cond_op_truth_var = Variable(data)
    #             cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
    #             loss += (self.CE(cond_op_pred, cond_op_truth_var) \
    #                     / len(truth_num))
    #
    #         #Evaluate the strings of conditions
    #         for b in range(len(gt_where)):
    #             for idx in range(len(gt_where[b])):
    #                 cond_str_truth = gt_where[b][idx]
    #                 if len(cond_str_truth) == 1:
    #                     continue
    #                 data = torch.from_numpy(np.array(cond_str_truth[1:]))
    #                 if self.gpu:
    #                     cond_str_truth_var = Variable(data.cuda())
    #                 else:
    #                     cond_str_truth_var = Variable(data)
    #                 str_end = len(cond_str_truth)-1
    #                 cond_str_pred = cond_str_score[b, idx, :str_end]
    #                 loss += (self.CE(cond_str_pred, cond_str_truth_var) \
    #                         / (len(gt_where) * len(gt_where[b])))
    #
    #         where_rela_truth = map(lambda x: x[6], truth_num)
    #         data = torch.from_numpy(np.array(where_rela_truth))
    #         if self.gpu:
    #             try:
    #                 where_rela_truth = Variable(data.cuda())
    #             except:
    #                 print "where_rela_truth error"
    #                 print data
    #                 exit(0)
    #         else:
    #             where_rela_truth = Variable(data)
    #         loss += self.CE(where_rela_score, where_rela_truth)
    #
    #     return loss

    def check_acc(self, vis_info, pred_queries, gt_queries):
        def pretty_print(vis_data, pred_query, gt_query):
            print "\n----------detailed error prints-----------"
            try:
                print 'question: ', vis_data[0]
                print 'question_tok: ', vis_data[3]
                print 'headers: (%s)' % (' || '.join(vis_data[1]))
                print 'query:', vis_data[2]
                print "target query: ", gt_query
                print "pred query: ", pred_query
            except:
                print "\n------skipping print: decoding problem ----------------------"

        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        """
	B = len(gt_queries)

        tot_err = agg_err = sel_err = cond_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = 0.0
        agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        for b, (pred_qry, gt_qry, vis_data) in enumerate(zip(pred_queries, gt_queries, vis_info)):
            good = True
            if pred_agg:
                agg_pred = pred_qry['agg']
                agg_gt = gt_qry['agg']
                if agg_pred != agg_gt:
                    agg_err += 1
                    good = False

            if pred_sel:
                sel_pred = pred_qry['sel']
                sel_gt = gt_qry['sel']
                if sel_pred != sel_gt:
                    sel_err += 1
                    good = False

            if pred_cond:
                cond_pred = pred_qry['conds']
                cond_gt = gt_qry['conds']
                flag = True
                if len(cond_pred) != len(cond_gt):
                    flag = False
                    cond_num_err += 1

                if flag and set(x[0] for x in cond_pred) != \
                        set(x[0] for x in cond_gt):
                    flag = False
                    cond_col_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0] for x in cond_gt).index(cond_pred[idx][0])
                    if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]:
                        flag = False
                        cond_op_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(
                            x[0] for x in cond_gt).index(cond_pred[idx][0])
                    if flag and unicode(cond_gt[gt_idx][2]).lower() != \
                            unicode(cond_pred[idx][2]).lower():
                        flag = False
                        cond_val_err += 1

                if not flag:
                    cond_err += 1
                    good = False

            if not good:
                if error_print:
                    pretty_print(vis_data, pred_qry, gt_qry)
                tot_err += 1

        return np.array((agg_err, sel_err, cond_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err)), tot_err
"""
        tot_err = sel_num_err = agg_err = sel_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry[
                'agg'], pred_qry['cond_conn_op']
            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry[
                'agg'], gt_qry['cond_conn_op']

            if where_rela_gt != where_rela_pred:
                good = False
                cond_rela_err += 1

            if len(sel_pred) != len(sel_gt):
                good = False
                sel_num_err += 1

            pred_sel_dict = {
                k: v
                for k, v in zip(list(sel_pred), list(agg_pred))
            }
            gt_sel_dict = {k: v for k, v in zip(sel_gt, agg_gt)}
            if set(sel_pred) != set(sel_gt):
                good = False
                sel_err += 1
            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
            if agg_pred != agg_gt:
                good = False
                agg_err += 1

            cond_pred = pred_qry['conds']
            cond_gt = gt_qry['conds']
            if len(cond_pred) != len(cond_gt):
                good = False
                cond_num_err += 1
            else:
                cond_op_pred, cond_op_gt = {}, {}
                cond_val_pred, cond_val_gt = {}, {}
                for p, g in zip(cond_pred, cond_gt):
                    cond_op_pred[p[0]] = p[1]
                    cond_val_pred[p[0]] = p[2]
                    cond_op_gt[g[0]] = g[1]
                    cond_val_gt[g[0]] = g[2]

                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
                    cond_col_err += 1
                    good = False

                where_op_pred = [
                    cond_op_pred[x] for x in sorted(cond_op_pred.keys())
                ]
                where_op_gt = [
                    cond_op_gt[x] for x in sorted(cond_op_gt.keys())
                ]
                if where_op_pred != where_op_gt:
                    cond_op_err += 1
                    good = False

                where_val_pred = [
                    cond_val_pred[x] for x in sorted(cond_val_pred.keys())
                ]
                where_val_gt = [
                    cond_val_gt[x] for x in sorted(cond_val_gt.keys())
                ]
                if where_val_pred != where_val_gt:
                    cond_val_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array(
            (sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err,
             cond_op_err, cond_val_err, cond_rela_err)), tot_err

    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
        def merge_tokens(tok_list, raw_tok_str):
            """
            tok_list: list of string words in current cond
            raw_tok_str: list of words in question
            """
            tok_str = raw_tok_str.lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            tok_list = [x for gx in tok_list for x in gx]
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear

                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                elif tok[0] not in alphabet:
                    pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                        and (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        sel_cond_score, agg_score, cond_op_str_score, where_rela_score = score

        sel_num_score, cond_num_score, sel_score, cond_col_score \
            = [x.data.cpu().numpy() for x in sel_cond_score]
        cond_op_score, cond_str_score = [
            x.data.cpu().numpy() for x in cond_op_str_score
        ]

        # [64,4,6], [64,14], ..., [64,4]
        #sel_num_score = sel_num_score.data.cpu().numpy()
        #sel_score = sel_score.data.cpu().numpy()
        agg_score = agg_score.data.cpu().numpy()
        where_rela_score = where_rela_score.data.cpu().numpy()
        ret_queries = []
        B = len(agg_score)

        for b in range(B):
            cur_query = {}
            cur_query['sel'] = []
            cur_query['agg'] = []
            sel_num = np.argmax(sel_num_score[b])
            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
            # find the most-probable columns' indexes
            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
            cur_query['sel'].extend([int(i) for i in max_col_idxes])
            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            all_toks = ['<BEG>'] + q[b] + ['<END>']
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])  # where-col
                cur_cond.append(np.argmax(cond_op_score[b][idx]))  # where-op
                cur_cond_str_toks = []
                for str_score in cond_str_score[b][idx]:
                    str_tok = np.argmax(str_score[:len(all_toks)])
                    str_val = all_toks[str_tok]
                    if str_val == '<END>':
                        break
                    cur_cond_str_toks.append(str_val)
                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)
        return ret_queries
コード例 #6
0
ファイル: sqlnet.py プロジェクト: Mars-Wei/MISP
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 use_ca=True,
                 trainable_emb=False,
                 dr=0.3,
                 temperature=False):
        super(SQLNet, self).__init__()
        self.use_ca = use_ca
        self.trainable_emb = trainable_emb
        self.temperature = temperature

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]
        self.COND_OPS = ['EQL', 'GT', 'LT']

        #Word embedding
        if trainable_emb:
            self.agg_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=True,
                                                 trainable=trainable_emb)
            self.sel_embed_layer = WordEmbedding(word_emb,
                                                 N_word,
                                                 gpu,
                                                 self.SQL_TOK,
                                                 our_model=True,
                                                 trainable=trainable_emb)
            self.cond_embed_layer = WordEmbedding(word_emb,
                                                  N_word,
                                                  gpu,
                                                  self.SQL_TOK,
                                                  our_model=True,
                                                  trainable=trainable_emb)
        else:
            self.embed_layer = WordEmbedding(word_emb,
                                             N_word,
                                             gpu,
                                             self.SQL_TOK,
                                             our_model=True,
                                             trainable=trainable_emb)

        #Predict aggregator
        self.agg_pred = AggPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     use_ca=use_ca,
                                     dr=dr,
                                     temperature=temperature)

        #Predict selected column
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=use_ca,
                                     dr=dr,
                                     temperature=temperature)

        #Predict number of cond
        self.cond_pred = SQLNetCondPredictor(N_word,
                                             N_h,
                                             N_depth,
                                             self.max_col_num,
                                             self.max_tok_num,
                                             use_ca,
                                             gpu,
                                             dr=dr,
                                             temperature=temperature)

        self.CE = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    def generate_gt_where_seq(self, q, col, query):
        ret_seq = []
        for cur_q, cur_col, cur_query in zip(q, col, query):
            cur_values = []
            st = cur_query.index(u'WHERE')+1 if \
                    u'WHERE' in cur_query else len(cur_query)
            all_toks = ['<BEG>'] + cur_q + ['<END>']
            while st < len(cur_query):
                ed = len(cur_query) if 'AND' not in cur_query[st:]\
                        else cur_query[st:].index('AND') + st
                if 'EQL' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('EQL') + st
                elif 'GT' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('GT') + st
                elif 'LT' in cur_query[st:ed]:
                    op = cur_query[st:ed].index('LT') + st
                else:
                    raise RuntimeError("No operator in it!")
                this_str = ['<BEG>'] + cur_query[op + 1:ed] + ['<END>']
                cur_seq = [all_toks.index(s) if s in all_toks \
                        else 0 for s in this_str]
                cur_values.append(cur_seq)
                st = ed + 1
            ret_seq.append(cur_values)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                pred_entry,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None):
        B = len(q)
        pred_agg, pred_sel, pred_cond = pred_entry

        agg_score = None
        sel_score = None
        cond_score = None

        #Predict aggregator
        if self.trainable_emb:
            if pred_agg:
                x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.agg_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                agg_score = self.agg_pred(x_emb_var,
                                          x_len,
                                          col_inp_var,
                                          col_name_len,
                                          col_len,
                                          col_num,
                                          gt_sel=gt_sel)

            if pred_sel:
                x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.sel_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
                col_inp_var, col_name_len, col_len = \
                        self.cond_embed_layer.gen_col_batch(col)
                max_x_len = max(x_len)
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = \
                    self.embed_layer.gen_col_batch(col)
            max_x_len = max(x_len)
            if pred_agg:
                agg_score = self.agg_pred(x_emb_var,
                                          x_len,
                                          col_inp_var,
                                          col_name_len,
                                          col_len,
                                          col_num,
                                          gt_sel=gt_sel)

            if pred_sel:
                sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                          col_name_len, col_len, col_num)

            if pred_cond:
                cond_score = self.cond_pred(x_emb_var,
                                            x_len,
                                            col_inp_var,
                                            col_name_len,
                                            col_len,
                                            col_num,
                                            gt_where,
                                            gt_cond,
                                            reinforce=reinforce)

        return (agg_score, sel_score, cond_score)

    def interaction_beam_forward(self,
                                 q,
                                 col,
                                 raw_q,
                                 raw_col,
                                 col_num,
                                 beam_size,
                                 dec_prefix,
                                 stop_step=None,
                                 avoid_items=None,
                                 confirmed_items=None,
                                 dropout_rate=0.0,
                                 bool_collect_choices=False,
                                 bool_verbal=False):
        """
        @author: Ziyu Yao
        Beam search decoding for interactive sql generation.
        Only support batch size=1 and self.trainable_emb=True.
        """
        assert self.trainable_emb, "Support trainable_emb=True only."
        assert len(q) == 1

        dec_prefix = dec_prefix[::-1]
        hypotheses = [Hypothesis(dec_prefix)]
        completed_hypotheses = []
        table_name = None

        while True:
            new_hypotheses = []

            for hyp in hypotheses:
                if hyp.stack.isEmpty():
                    # sort conds by its col idx
                    conds = hyp.sql_i['conds']
                    sorted_conds = sorted(conds, key=lambda x: x[0])
                    hyp.sql_i['conds'] = sorted_conds
                    hyp.sql = generate_sql_q1(hyp.sql_i, raw_q[0], raw_col[0])
                    if bool_verbal:
                        print("Completed %d-th hypotheses: " %
                              len(completed_hypotheses))
                        print("tag_seq:{}".format(hyp.tag_seq))
                        print("dec_seq: {}".format(hyp.dec_seq))
                        print("sql_i: {}".format(hyp.sql_i))
                        print("sql: {}".format(hyp.sql))
                    completed_hypotheses.append(hyp)  # add to completion
                else:
                    vet = hyp.stack.pop()
                    if vet[0] == "sc":
                        x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(
                            q, col, dropout_rate=dropout_rate)
                        col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(
                            col, dropout_rate=dropout_rate)
                        sel_score = self.sel_pred(
                            x_emb_var,
                            x_len,
                            col_inp_var,
                            col_name_len,
                            col_len,
                            col_num,
                            dropout_rate=dropout_rate).view(1, -1)
                        prob_sc = self.softmax(sel_score).data.cpu().numpy()[0]
                        hyp.tag_seq.append((OUTSIDE, 'select', 1.0, None))

                        if len(hyp.dec_prefix):
                            partial_vet, sc_idx = hyp.dec_prefix.pop()
                            assert partial_vet == vet
                            sc_candidates = [sc_idx]
                        else:
                            sc_candidates = np.argsort(-prob_sc)
                            # rm avoid candidates
                            if avoid_items is not None and hyp.dec_seq_idx in avoid_items:
                                sc_candidates = [
                                    sc_idx for sc_idx in sc_candidates if
                                    sc_idx not in avoid_items[hyp.dec_seq_idx]
                                ]
                            sc_candidates = sc_candidates[:beam_size]

                        for sc_idx in sc_candidates:
                            if len(sc_candidates) == 1:
                                step_hyp = hyp
                            else:
                                step_hyp = hyp.copy()
                            sc_name = raw_col[0][sc_idx]

                            step_hyp.sql_i['sel'] = sc_idx
                            step_hyp.dec_seq.append((vet, sc_idx))
                            step_hyp.tag_seq.append(
                                (SELECT_COL, (table_name, sc_name, sc_idx),
                                 prob_sc[sc_idx], step_hyp.dec_seq_idx))
                            step_hyp.add_logprob(np.log(prob_sc[sc_idx]))
                            step_hyp.stack.push(("sa", (sc_idx, sc_name)))
                            step_hyp.dec_seq_idx += 1

                            new_hypotheses.append(step_hyp)

                    elif vet[0] == "sa":
                        sc_idx, sc_name = vet[1]
                        x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(
                            q, col, dropout_rate=dropout_rate)
                        col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(
                            col, dropout_rate=dropout_rate)
                        agg_score = self.agg_pred(
                            x_emb_var,
                            x_len,
                            col_inp_var,
                            col_name_len,
                            col_len,
                            col_num,
                            gt_sel=[sc_idx],
                            dropout_rate=dropout_rate).view(1, -1)
                        prob_sa = self.softmax(agg_score).data.cpu().numpy()[0]

                        if len(hyp.dec_prefix):
                            partial_vet, sa_idx = hyp.dec_prefix.pop()
                            assert partial_vet == vet
                            sa_candidates = [sa_idx]
                        else:
                            sa_candidates = np.argsort(-prob_sa)

                            if avoid_items is not None and hyp.dec_seq_idx in avoid_items:
                                sa_candidates = [
                                    sa_idx for sa_idx in sa_candidates if
                                    sa_idx not in avoid_items[hyp.dec_seq_idx]
                                ]
                            sa_candidates = sa_candidates[:beam_size]

                        for sa_idx in sa_candidates:
                            if len(sa_candidates) == 1:
                                step_hyp = hyp
                            else:
                                step_hyp = hyp.copy()
                            sa_name = AGG_OPS[sa_idx]
                            if sa_name == 'None':
                                sa_name = 'none_agg'  # for q gen usage

                            step_hyp.sql_i['agg'] = sa_idx
                            step_hyp.dec_seq.append((vet, sa_idx))
                            step_hyp.tag_seq.append(
                                (SELECT_AGG, (table_name, sc_name,
                                              sc_idx), (sa_name, sa_idx),
                                 prob_sa[sa_idx], step_hyp.dec_seq_idx))
                            step_hyp.add_logprob(np.log(prob_sa[sa_idx]))
                            step_hyp.stack.push(("wc", None))
                            step_hyp.dec_seq_idx += 1

                            new_hypotheses.append(step_hyp)

                    elif vet[0] == "wc":
                        hyp.tag_seq.append((OUTSIDE, 'where', 1.0, None))
                        hyp.sql_i['conds'] = []

                        step_hypotheses = []

                        x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(
                            q, col, dropout_rate=dropout_rate)
                        col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                            col, dropout_rate=dropout_rate)

                        # wn, wc
                        cond_num_score, cond_col_score = self.cond_pred.cols_forward(
                            x_emb_var,
                            x_len,
                            col_inp_var,
                            col_name_len,
                            col_len,
                            dropout_rate=dropout_rate)
                        prob_wn = self.softmax(cond_num_score.view(
                            1, -1)).data.cpu().numpy()[0]
                        prob_wc = self.sigmoid(cond_col_score.view(
                            1, -1)).data.cpu().numpy()[0]

                        if len(hyp.dec_prefix):
                            partial_vet, wn, wc_list = hyp.dec_prefix.pop()
                            assert partial_vet == vet
                            col_num_cols_pair = [(wn, wc_list)]
                        else:
                            col_num_cols_pair = []
                            sorted_col_num = np.argsort(-prob_wn)
                            sorted_cols = np.argsort(-prob_wc)

                            # filter avoid_items
                            if avoid_items is not None and hyp.dec_seq_idx in avoid_items:
                                sorted_cols = [
                                    col_idx for col_idx in sorted_cols if
                                    col_idx not in avoid_items[hyp.dec_seq_idx]
                                ]
                                sorted_col_num = [
                                    col_num for col_num in sorted_col_num
                                    if col_num <= len(sorted_cols)
                                ]

                            # fix confirmed items
                            if confirmed_items is not None and hyp.dec_seq_idx in confirmed_items:
                                fixed_cols = list(
                                    confirmed_items[hyp.dec_seq_idx])
                                sorted_col_num = [
                                    col_num - len(fixed_cols)
                                    for col_num in sorted_col_num
                                    if col_num >= len(fixed_cols)
                                ]
                                sorted_cols = [
                                    col_idx for col_idx in sorted_cols
                                    if col_idx not in fixed_cols
                                ]
                            else:
                                fixed_cols = []

                            if bool_collect_choices:  # fake searching to collect some choices
                                col_num_cols_pair.extend([
                                    (1, [col_idx])
                                    for col_idx in sorted_cols[:beam_size]
                                ])
                            else:
                                for col_num in sorted_col_num:  #[:beam_size]
                                    if col_num == 0:
                                        col_num_cols_pair.append(
                                            (len(fixed_cols), fixed_cols))
                                    elif col_num == 1:
                                        col_num_cols_pair.extend([
                                            (len(fixed_cols) + 1,
                                             fixed_cols + [col_idx]) for
                                            col_idx in sorted_cols[:beam_size]
                                        ])
                                    elif beam_size == 1:
                                        top_cols = list(sorted_cols[:col_num])
                                        # top_cols.sort()
                                        col_num_cols_pair.append(
                                            (len(fixed_cols) + col_num,
                                             fixed_cols + top_cols))
                                    else:
                                        combs = combinations(
                                            sorted_cols[:10], col_num
                                        )  # to reduce beam search time
                                        comb_score = []
                                        for comb in combs:
                                            score = sum([
                                                np.log(prob_wc[c_idx])
                                                for c_idx in comb
                                            ])
                                            comb_score.append((comb, score))
                                        sorted_comb_score = sorted(
                                            comb_score,
                                            key=lambda x: x[1],
                                            reverse=True)[:beam_size]
                                        for comb, _ in sorted_comb_score:
                                            comb_cols = list(comb)
                                            # comb_cols.sort()
                                            col_num_cols_pair.append(
                                                (len(fixed_cols) + col_num,
                                                 fixed_cols + comb_cols))

                        for col_num, cols in col_num_cols_pair:
                            if len(col_num_cols_pair) == 1:
                                step_hyp = hyp
                            else:
                                step_hyp = hyp.copy()

                            step_hyp.dec_seq.append((vet, col_num, cols))
                            step_hyp.add_logprob(np.log(prob_wn[col_num]))

                            for wc_idx in cols:
                                wc_name = raw_col[0][wc_idx]
                                step_hyp.tag_seq.append(
                                    (WHERE_COL, (table_name, wc_name, wc_idx),
                                     prob_wc[wc_idx], step_hyp.dec_seq_idx))
                                step_hyp.add_logprob(np.log(prob_wc[wc_idx]))
                                step_hyp.stack.push(("wo", (wc_idx, wc_name)))

                            step_hyp.dec_seq_idx += 1
                            step_hypotheses.append(step_hyp)

                        step_hypotheses = Hypothesis.sort_hypotheses(
                            step_hypotheses, beam_size, 0.0)
                        new_hypotheses.extend(step_hypotheses)

                    elif vet[0] == "wo":
                        wc_idx, wc_name = vet[1]
                        chosen_col_gt = [[wc_idx]]

                        x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(
                            q, col, dropout_rate=dropout_rate)
                        col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                            col, dropout_rate=dropout_rate)
                        cond_op_score = self.cond_pred.op_forward(
                            x_emb_var,
                            x_len,
                            col_inp_var,
                            col_name_len,
                            col_len,
                            chosen_col_gt,
                            dropout_rate=dropout_rate).view(
                                1, 4, -1)  #[B=1, 4, |OPS|]
                        prob_wo = self.softmax(
                            cond_op_score[:, 0, :]).data.cpu().numpy()[0]

                        if len(hyp.dec_prefix):
                            partial_vet, wo_idx = hyp.dec_prefix.pop()
                            assert partial_vet == vet
                            wo_candidates = [wo_idx]
                        else:
                            wo_candidates = np.argsort(-prob_wo)

                            if avoid_items is not None and hyp.dec_seq_idx in avoid_items:
                                wo_candidates = [
                                    wo_idx for wo_idx in wo_candidates if
                                    wo_idx not in avoid_items[hyp.dec_seq_idx]
                                ]
                            wo_candidates = wo_candidates[:beam_size]

                        for wo_idx in wo_candidates:
                            if len(wo_candidates) == 1:
                                step_hyp = hyp
                            else:
                                step_hyp = hyp.copy()
                            wo_name = COND_OPS[wo_idx]

                            step_hyp.dec_seq.append((vet, wo_idx))
                            step_hyp.tag_seq.append(
                                (WHERE_OP, ((table_name, wc_name,
                                             wc_idx), ), (wo_name, wo_idx),
                                 prob_wo[wo_idx], step_hyp.dec_seq_idx))
                            step_hyp.add_logprob(np.log(prob_wo[wo_idx]))
                            step_hyp.stack.push(
                                ("wv", (wc_idx, wc_name, wo_idx, wo_name)))
                            step_hyp.dec_seq_idx += 1

                            new_hypotheses.append(step_hyp)

                    elif vet[0] == "wv":
                        wc_idx, wc_name, wo_idx, wo_name = vet[1]
                        x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(
                            q, col, dropout_rate=dropout_rate)
                        col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                            col, dropout_rate=dropout_rate)

                        given_idxes, avoid_idxes_list = None, None
                        if len(hyp.dec_prefix):
                            partial_vet, given_idxes = hyp.dec_prefix.pop()
                            assert partial_vet == vet
                        elif avoid_items is not None and hyp.dec_seq_idx in avoid_items:
                            avoid_idxes_list = list(
                                avoid_items[hyp.dec_seq_idx])

                        str_idxes_prob_pairs = self.cond_pred.val_beam_search(
                            x_emb_var,
                            x_len,
                            col_inp_var,
                            col_name_len,
                            col_len, [[wc_idx]],
                            beam_size,
                            avoid_idxes_list=avoid_idxes_list,
                            given_idxes=given_idxes,
                            dropout_rate=dropout_rate)

                        all_toks = ['<BEG>'] + q[0] + ['<END>']
                        for str_idxes, logprob in str_idxes_prob_pairs:
                            if len(str_idxes_prob_pairs) == 1:
                                step_hyp = hyp
                            else:
                                step_hyp = hyp.copy()

                            # get val_str
                            cur_cond_str_toks = []
                            for wd_idx in str_idxes[1:]:
                                str_val = all_toks[wd_idx]
                                if str_val == '<END>':
                                    break
                                cur_cond_str_toks.append(str_val)
                            val_str = SQLNet.merge_tokens(
                                cur_cond_str_toks, raw_q[0])

                            step_hyp.sql_i['conds'].append(
                                [wc_idx, wo_idx, val_str])
                            step_hyp.dec_seq.append((vet, str_idxes))
                            step_hyp.tag_seq.append(
                                (WHERE_VAL, ((table_name, wc_name, wc_idx), ),
                                 (wo_name, wo_idx), (str_idxes, val_str),
                                 np.exp(logprob), hyp.dec_seq_idx))
                            step_hyp.add_logprob(logprob)
                            step_hyp.dec_seq_idx += 1

                            new_hypotheses.append(step_hyp)

            if len(new_hypotheses) == 0:
                # sort completed hypotheses
                sorted_completed_hypotheses = Hypothesis.sort_hypotheses(
                    completed_hypotheses, beam_size, 0.0)
                return sorted_completed_hypotheses

            # if bool_verbal:
            #     print("Before sorting...")
            #     Hypothesis.print_hypotheses(new_hypotheses)
            hypotheses = Hypothesis.sort_hypotheses(new_hypotheses, beam_size,
                                                    0.0)
            if bool_verbal:
                print("\nAfter sorting...")
                Hypothesis.print_hypotheses(hypotheses)

            if stop_step is not None:  # for one-step beam search; the partial_seq lengths must be the same for all hyps
                dec_seq_length = len(hypotheses[0].dec_seq)
                if dec_seq_length == stop_step + 1:
                    for hyp in hypotheses:
                        assert len(hyp.dec_seq) == dec_seq_length
                    return hypotheses

    def loss(self, score, truth_num, pred_entry, gt_where):
        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score

        loss = 0
        loss_agg, loss_sel, loss_cond = 0., 0., 0.
        if pred_agg:
            agg_truth = map(lambda x: x[0], truth_num)
            data = torch.from_numpy(np.array(agg_truth))
            if self.gpu:
                agg_truth_var = Variable(data.cuda())
            else:
                agg_truth_var = Variable(data)

            loss_agg = self.CE(agg_score, agg_truth_var)
            loss += loss_agg

        if pred_sel:
            sel_truth = map(lambda x: x[1], truth_num)
            data = torch.from_numpy(np.array(sel_truth))
            if self.gpu:
                sel_truth_var = Variable(data.cuda())
            else:
                sel_truth_var = Variable(data)

            loss_sel = self.CE(sel_score, sel_truth_var)
            loss += loss_sel

        if pred_cond:
            B = len(truth_num)
            cond_num_score, cond_col_score,\
                    cond_op_score, cond_str_score = cond_score
            #Evaluate the number of conditions
            cond_num_truth = map(lambda x: x[2], truth_num)
            data = torch.from_numpy(np.array(cond_num_truth))
            if self.gpu:
                cond_num_truth_var = Variable(data.cuda())
            else:
                cond_num_truth_var = Variable(data)

            cond_num_loss = self.CE(cond_num_score, cond_num_truth_var)
            loss_cond += cond_num_loss
            loss += cond_num_loss

            #Evaluate the columns of conditions
            T = len(cond_col_score[0])
            truth_prob = np.zeros((B, T), dtype=np.float32)
            for b in range(B):
                if len(truth_num[b][3]) > 0:
                    truth_prob[b][list(truth_num[b][3])] = 1
            data = torch.from_numpy(truth_prob)
            if self.gpu:
                cond_col_truth_var = Variable(data.cuda())
            else:
                cond_col_truth_var = Variable(data)

            sigm = nn.Sigmoid()
            cond_col_prob = sigm(cond_col_score)
            bce_loss = -torch.mean( 3*(cond_col_truth_var * \
                    torch.log(cond_col_prob+1e-10)) + \
                    (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
            loss_cond += bce_loss
            loss += bce_loss

            #Evaluate the operator of conditions
            for b in range(len(truth_num)):
                if len(truth_num[b][4]) == 0:
                    continue
                data = torch.from_numpy(np.array(truth_num[b][4]))
                if self.gpu:
                    cond_op_truth_var = Variable(data.cuda())
                else:
                    cond_op_truth_var = Variable(data)
                cond_op_pred = cond_op_score[b, :len(truth_num[b][4])]
                cond_op_loss_b = (self.CE(cond_op_pred, cond_op_truth_var) /
                                  len(truth_num))
                loss_cond += cond_op_loss_b
                loss += cond_op_loss_b

            #Evaluate the strings of conditions
            for b in range(len(gt_where)):
                for idx in range(len(gt_where[b])):
                    cond_str_truth = gt_where[b][idx]
                    if len(cond_str_truth) == 1:
                        continue
                    data = torch.from_numpy(np.array(cond_str_truth[1:]))
                    if self.gpu:
                        cond_str_truth_var = Variable(data.cuda())
                    else:
                        cond_str_truth_var = Variable(data)
                    str_end = len(cond_str_truth) - 1
                    cond_str_pred = cond_str_score[b, idx, :str_end]
                    cond_str_loss_b = (
                        self.CE(cond_str_pred, cond_str_truth_var) /
                        (len(gt_where) * len(gt_where[b])))
                    loss_cond += cond_str_loss_b
                    loss += cond_str_loss_b

        if self.temperature:
            return [loss, loss_sel, loss_agg, loss_cond]

        return [loss]

    def check_acc(self, vis_info, pred_queries, gt_queries, pred_entry):
        def pretty_print(vis_data):
            print 'question:', vis_data[0]
            print 'headers: (%s)' % (' || '.join(vis_data[1]))
            print 'query:', vis_data[2]

        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        pred_agg, pred_sel, pred_cond = pred_entry

        B = len(gt_queries)

        tot_err = agg_err = sel_err = cond_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = 0.0
        agg_ops = ['None', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            if pred_agg:
                agg_pred = pred_qry['agg']
                agg_gt = gt_qry['agg']
                if agg_pred != agg_gt:
                    agg_err += 1
                    good = False

            if pred_sel:
                sel_pred = pred_qry['sel']
                sel_gt = gt_qry['sel']
                if sel_pred != sel_gt:
                    sel_err += 1
                    good = False

            if pred_cond:
                cond_pred = pred_qry['conds']
                cond_gt = gt_qry['conds']
                flag = True
                if len(cond_pred) != len(cond_gt):
                    flag = False
                    cond_num_err += 1

                if flag and set(x[0] for x in cond_pred) != \
                        set(x[0] for x in cond_gt):
                    flag = False
                    cond_col_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and cond_gt[gt_idx][1] != cond_pred[idx][1]:
                        flag = False
                        cond_op_err += 1

                for idx in range(len(cond_pred)):
                    if not flag:
                        break
                    gt_idx = tuple(x[0]
                                   for x in cond_gt).index(cond_pred[idx][0])
                    if flag and unicode(cond_gt[gt_idx][2]).lower() != \
                            unicode(cond_pred[idx][2]).lower():
                        flag = False
                        cond_val_err += 1

                if not flag:
                    cond_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array((agg_err, sel_err, cond_err)), tot_err

    @staticmethod
    def merge_tokens(tok_list, raw_tok_str):
        tok_str = raw_tok_str.lower()
        alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
        special = {
            '-LRB-': '(',
            '-RRB-': ')',
            '-LSB-': '[',
            '-RSB-': ']',
            '``': '"',
            '\'\'': '"',
            '--': u'\u2013'
        }
        ret = ''
        double_quote_appear = 0
        for raw_tok in tok_list:
            if not raw_tok:
                continue
            tok = special.get(raw_tok, raw_tok)
            if tok == '"':
                double_quote_appear = 1 - double_quote_appear

            if len(ret) == 0:
                pass
            elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                ret = ret + ' '
            elif len(ret) > 0 and ret + tok in tok_str:
                pass
            elif tok == '"':
                if double_quote_appear:
                    ret = ret + ' '
            elif tok[0] not in alphabet:
                pass
            elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                    and (ret[-1] != '"' or not double_quote_appear):
                ret = ret + ' '
            ret = ret + tok
        return ret.strip()

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  pred_entry,
                  reinforce=False,
                  verbose=False):
        pred_agg, pred_sel, pred_cond = pred_entry
        agg_score, sel_score, cond_score = score

        ret_queries = []
        if pred_agg:
            B = len(agg_score)
        elif pred_sel:
            B = len(sel_score)
        elif pred_cond:
            B = len(cond_score[0])
        for b in range(B):
            cur_query = {}
            if pred_agg:
                cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy())
            if pred_sel:
                cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy())
            if pred_cond:
                cur_query['conds'] = []
                cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
                        [x.data.cpu().numpy() for x in cond_score]
                cond_num = np.argmax(cond_num_score[b])
                all_toks = ['<BEG>'] + q[b] + ['<END>']
                max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
                for idx in range(cond_num):
                    cur_cond = []
                    cur_cond.append(max_idxes[idx])
                    cur_cond.append(np.argmax(cond_op_score[b][idx]))
                    cur_cond_str_toks = []
                    for str_score in cond_str_score[b][idx]:
                        str_tok = np.argmax(str_score[:len(all_toks)])
                        str_val = all_toks[str_tok]
                        if str_val == '<END>':
                            break
                        cur_cond_str_toks.append(str_val)
                    cur_cond.append(
                        SQLNet.merge_tokens(cur_cond_str_toks, raw_q[b]))
                    cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)

        return ret_queries