Example #1
0
class Seq2SQL(nn.Module):
    def __init__(self,
                 word_emb,
                 num_words,
                 num_hidden=100,
                 num_layers=2,
                 use_gpu=True):
        super(Seq2SQL, self).__init__()

        self.word_emb = word_emb
        self.num_words = num_words
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.use_gpu = use_gpu

        self.max_col_num = 45
        self.max_tok_num = 200
        self.COND_OPS = ['EQL', 'GT', 'LT']
        self.SQL_TOK = ['<UNK>', '<BEG>', '<END>', 'WHERE', 'AND'
                        ] + self.COND_OPS

        # GloVe Word Embedding
        self.embed_layer = WordEmbedding(word_emb, num_words, self.SQL_TOK,
                                         use_gpu)

        # Aggregation Classifier
        self.agg_classifier = AggregationClassifier(num_words, num_hidden,
                                                    num_layers)

        # SELECT Column(s)
        self.sel_classifier = SelectClassifier(num_words, num_hidden,
                                               num_layers, self.max_tok_num)

        # WHERE Clause
        self.whr_classifier = WhereClassifier(num_words, num_hidden,
                                              num_layers, self.max_col_num,
                                              self.max_tok_num, use_gpu)

        # run on GPU
        if use_gpu:
            self.cuda()

    def generate_g_s(self, q, col, query):
        # data format
        # <BEG> WHERE cond1_col cond1_op cond1
        #         AND cond2_col cond2_op cond2
        #         AND ... <END>

        ret_seq = []
        for cur_q, cur_col, cur_query in zip(q, col, query):
            connect_col = [
                tok for col_tok in cur_col for tok in col_tok + [',']
            ]
            all_toks = self.SQL_TOK + connect_col + [None] + cur_q + [None]
            cur_seq = [all_toks.index('<BEG>')]
            if 'WHERE' in cur_query:
                cur_where_query = cur_query[cur_query.index('WHERE'):]
                cur_seq = cur_seq + map(
                    lambda tok: all_toks.index(tok)
                    if tok in all_toks else 0, cur_where_query)
            cur_seq.append(all_toks.index('<END>'))
            ret_seq.append(cur_seq)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                classif_flag,
                g_s=None,
                reinforce=False):

        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = None, None, None

        x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)

        if agg_classif:
            agg_score = self.agg_classifier(x_emb_var, x_len)

        if sel_classif:
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)
            sel_score = self.sel_classifier(x_emb_var, x_len, col_inp_var,
                                            col_name_len, col_len, col_num)

        if whr_classif:
            whr_score = self.whr_classifier(x_emb_var,
                                            x_len,
                                            g_s,
                                            reinforce=reinforce)

        return (agg_score, sel_score, whr_score)

    def loss(self, score, ref_score, classif_flag, g_s):
        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = score
        loss = 0
        if agg_classif:
            agg_ref = torch.from_numpy(np.array(map(lambda x: x[0],
                                                    ref_score)))
            agg_ref_var = Variable(agg_ref)
            if self.use_gpu:
                agg_ref_var = agg_ref_var.cuda()
            loss += nn.CrossEntropyLoss()(agg_score, agg_ref_var)

        if sel_classif:
            sel_ref = torch.from_numpy(np.array(map(lambda x: x[1],
                                                    ref_score)))
            sel_ref_var = Variable(sel_ref)
            if self.use_gpu:
                sel_ref_var = sel_ref_var.cuda()
            loss += nn.CrossEntropyLoss()(sel_score, sel_ref_var)

        if whr_classif:
            g_s_len = len(g_s)
            for s, g_s_i in enumerate(g_s):
                whr_ref_var = Variable(torch.from_numpy(np.array(g_s_i[1:])))
                if self.use_gpu:
                    whr_ref_var = whr_ref_var.cuda()
                loss += (nn.CrossEntropyLoss()(whr_score[s, :len(g_s_i) - 1],
                                               whr_ref_var) / g_s_len)

        return loss

    def reinforce_backward(self, score, rewards):
        agg_score, sel_score, whr_score = score

        cur_reward = rewards[:]
        eof = self.SQL_TOK.index('<END>')
        for whr_score_t in whr_score[1]:
            reward_inp = torch.FloatTensor(cur_reward).unsqueeze(1)
            if self.use_gpu:
                reward_inp = reward_inp.cuda()
            whr_score_t.reinforce(reward_inp)

            for b, _ in enumerate(rewards):
                if whr_score_t[b].data.cpu().numpy()[0] == eof:
                    cur_reward[b] = 0
        torch.autograd.backward(whr_score[1], [None for _ in whr_score[1]])
        return

    def check_acc(self, classif_queries, g_s_queries, classif_flag):

        agg_classif, sel_classif, whr_classif = classif_flag
        tot_err = agg_err = sel_err = whr_err = whr_num_err = whr_col_err = whr_op_err = whr_val_err = 0.0
        for classif_qry, g_s_qry in zip(classif_queries, g_s_queries):

            agg_err_inc = 1 if agg_classif and classif_qry['agg'] != g_s_qry[
                'agg'] else 0
            agg_err += agg_err_inc

            sel_err_inc = 1 if sel_classif and classif_qry['sel'] != g_s_qry[
                'sel'] else 0
            sel_err += sel_err_inc

            if whr_classif:
                flag = True
                whr_classifier = classif_qry['conds']
                whr_g_s = g_s_qry['conds']
                if len(whr_classifier) != len(whr_g_s):
                    flag = False
                    whr_num_err += 1
                elif set(x[0]
                         for x in whr_classifier) != set(x[0]
                                                         for x in whr_g_s):
                    flag = False
                    whr_col_err += 1
                if flag:
                    for whr_class_i in whr_classifier:
                        g_s_idx = tuple(x[0]
                                        for x in whr_g_s).index(whr_class_i[0])
                        if flag and whr_g_s[g_s_idx][1] != whr_class_i[1]:
                            flag = False
                            whr_op_err += 1
                            break
                if flag:
                    for whr_class_i in whr_classifier:
                        g_s_idx = tuple(x[0]
                                        for x in whr_g_s).index(whr_class_i[0])
                        if flag and unicode(whr_g_s[g_s_idx][2]).lower() != \
                        unicode(whr_class_i[2]).lower():
                            flag = False
                            whr_val_err += 1
                            break

                if not flag:
                    whr_err += 1

            if agg_err_inc > 0 or sel_err_inc > 0 or not flag:
                tot_err += 1

        return np.array((agg_err, sel_err, whr_err)), tot_err

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  classif_flag,
                  reinforce=False,
                  verbose=False):
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str.lower()
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_pair_track = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_pair_track = 1 - double_quote_pair_track
                    if double_quote_pair_track:
                        ret = ret + ' '
                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif (tok[0] not in string.ascii_lowercase) and (
                        tok[0] not in string.digits) and (tok[0] not in '$('):
                    pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) and \
                     (ret[-1] != '"' or not double_quote_pair_track):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = score

        ret_queries = []
        batch_len = len(agg_score) if agg_classif else len(
            sel_score) if sel_classif else len(
                whr_score[0]) if reinforce else len(whr_score)
        for b in range(batch_len):
            cur_query = {}
            if agg_classif:
                cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy())
            if sel_classif:
                cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy())
            if whr_classif:
                cur_query['conds'] = []
                all_toks = self.SQL_TOK + [
                    x for toks in col[b] for x in toks + [',']
                ] + [''] + q[b] + ['']
                whr_toks = []
                if reinforce:
                    for choices in whr_score[1]:
                        if choices[b].data.cpu().numpy()[0] < len(all_toks):
                            whr_val = all_toks[choices[b].data.cpu().numpy()
                                               [0]]
                        else:
                            whr_val = '<UNK>'
                        if whr_val == '<END>':
                            break
                        whr_toks.append(whr_val)
                else:
                    for where_score in whr_score[b].data.cpu().numpy():
                        whr_tok = np.argmax(where_score)
                        whr_val = all_toks[whr_tok]
                        if whr_val == '<END>':
                            break
                        whr_toks.append(whr_val)

                if verbose:
                    print whr_toks
                if len(whr_toks) > 0:
                    whr_toks = whr_toks[1:]
                st = 0
                while st < len(whr_toks):
                    cur_cond = [None, None, None]
                    ed = len(whr_toks) if 'AND' not in whr_toks[st:] \
                         else whr_toks[st:].index('AND') + st
                    if 'EQL' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('EQL') + st
                        cur_cond[1] = 0
                    elif 'GT' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('GT') + st
                        cur_cond[1] = 1
                    elif 'LT' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('LT') + st
                        cur_cond[1] = 2
                    else:
                        op = st
                        cur_cond[1] = 0
                    sel_col = whr_toks[st:op]
                    to_idx = [x.lower() for x in raw_col[b]]
                    classif_col = merge_tokens(sel_col, raw_q[b] + ' || ' + \
                                            ' || '.join(raw_col[b]))
                    if classif_col in to_idx:
                        cur_cond[0] = to_idx.index(classif_col)
                    else:
                        cur_cond[0] = 0
                    cur_cond[2] = merge_tokens(whr_toks[op + 1:ed], raw_q[b])
                    cur_query['conds'].append(cur_cond)
                    st = ed + 1
            ret_queries.append(cur_query)

        return ret_queries
Example #2
0
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 use_ca=True,
                 trainable_emb=False):
        super(SQLNet, self).__init__()
        self.use_ca = use_ca
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=',
            '<BEG>'
        ]
        self.COND_OPS = ['>', '<', '==', '!=']

        # 词向量,可选择自己训练或者使用训练好的词向量,这里选用加载好的词向量
        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         our_model=True,
                                         trainable=trainable_emb)

        # 预测列数目
        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        # 预测那个列被选中了
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=use_ca)

        # 预测相应选定列的聚合函数
        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        # 预测条件数、条件列、条件操作和条件值
        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth,
                                             self.max_col_num,
                                             self.max_tok_num, use_ca, gpu)

        # 预测条件关系,如“and”、“or”
        self.where_rela_pred = WhereRelationPredictor(N_word,
                                                      N_h,
                                                      N_depth,
                                                      use_ca=use_ca)

        self.CE = nn.CrossEntropyLoss()  #交叉熵损失函数
        self.softmax = nn.Softmax(dim=-1)
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    # q:问题,gt_cond_seq:三元组  目的:要选择那一列
    def generate_gt_where_seq_test(self, q, gt_cond_seq):
        ret_seq = []
        for cur_q, ans in zip(q, gt_cond_seq):
            temp_q = u"".join(cur_q)
            cur_q = [u'<BEG>'] + cur_q + [u'<END>']  # 在每个问题前加<BEG>和结尾加<END>
            record = []  #如果条件值在问题中,标记(TRUE,条件值)
            record_cond = []
            for cond in ans:
                if cond[2] not in temp_q:
                    record.append((False, cond[2]))
                else:
                    record.append((True, cond[2]))
            for idx, item in enumerate(record):
                temp_ret_seq = []
                if item[0]:
                    temp_ret_seq.append(0)
                    temp_ret_seq.extend(
                        list(
                            range(
                                temp_q.index(item[1]) + 1,
                                temp_q.index(item[1]) + len(item[1]) +
                                1)))  #获取条件值的索引
                    temp_ret_seq.append(len(cur_q) - 1)
                else:
                    temp_ret_seq.append([0, len(cur_q) - 1])
                    record_cond.append(temp_ret_seq)
            ret_seq.append(record_cond)
        return ret_seq
        #q:问题,col:表头名字,col_num:有几个表头列,gt_where:conds中条件值不出现在问题中,gt_conds:conds,gt_sel:选择那列,gt_sel_num:选择几列
    def forward(self,
                q,
                col,
                col_num,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None,
                gt_sel_num=None):
        B = len(q)  #batch_size的大小

        sel_num_score = None
        agg_score = None
        sel_score = None
        cond_score = None
        #预测聚合函数
        if self.trainable_emb:
            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=gt_sel)

            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)

            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)
            where_rela_score = None
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(
                q, col
            )  #x_len:batch中每个问题的长度,[x_emb_var:batch_size,max_seq_len,word_embedding_size]
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)  #列名向量化,长度,几个列
            sel_num_score = self.sel_num(
                x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num)  #[16,4]对问题的编码经过lstm,linear,softmax之后乘以编码
            # x_emb_var: embedding of each question
            # x_len: length of each question
            # col_inp_var: embedding of each header
            # col_name_len: length of each header
            # col_len: number of headers in each table, array type
            # col_num: number of headers in each table, list type
            if gt_sel_num:
                pr_sel_num = gt_sel_num
            else:
                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(),
                                       axis=1)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)  #【16,19】

            if gt_sel:
                pr_sel = gt_sel
            else:
                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
                sel = sel_score.data.cpu().numpy()
                pr_sel = [
                    list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))
                ]
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=pr_sel,
                                      gt_sel_num=pr_sel_num)  #【16,4,6】

            where_rela_score = self.where_rela_pred(x_emb_var, x_len,
                                                    col_inp_var, col_name_len,
                                                    col_len, col_num)  #【16,3】

            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)  #4=>[16,5]

        return (sel_num_score, sel_score, agg_score, cond_score,
                where_rela_score)

    def loss(self, score, truth_num, gt_where):
        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score

        B = len(truth_num)
        loss = 0

        # Evaluate select number
        sel_num_truth = list(map(lambda x: x[0], truth_num))  #聚合函数个数
        sel_num_truth = torch.from_numpy(
            np.array(sel_num_truth)).long()  #.astype(float))
        if self.gpu:
            sel_num_truth = Variable(sel_num_truth.cuda())
        else:
            sel_num_truth = Variable(sel_num_truth)
        #选择几个列的损失
        loss += self.CE(sel_num_score, sel_num_truth)

        # Evaluate select column选择哪个列的损失
        T = len(sel_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            truth_prob[b][list(truth_num[b][1])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            sel_col_truth_var = Variable(data.cuda())
        else:
            sel_col_truth_var = Variable(data)
        sigm = nn.Sigmoid()
        sel_col_prob = sigm(sel_score)
        bce_loss = -torch.mean(
            3 * (sel_col_truth_var * torch.log(sel_col_prob + 1e-10)) +
            (1 - sel_col_truth_var) * torch.log(1 - sel_col_prob + 1e-10)
        )  #这儿采用bceloss:-w*[y*log(x)+(1-y)*log(1-x)]
        loss += bce_loss

        # Evaluate select aggregation选择聚合函数的损失交叉熵
        for b in range(len(truth_num)):
            data = torch.from_numpy(np.array(truth_num[b][2]))  #真实的聚合函数
            if self.gpu:
                sel_agg_truth_var = Variable(data.cuda())
            else:
                sel_agg_truth_var = Variable(data.long())
            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]  #聚合函数共六种
            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)

        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score

        # Evaluate the number of conditions预测多少个conds的损失交叉熵
        cond_num_truth = list(map(lambda x: x[3], truth_num))
        data = torch.from_numpy(np.array(cond_num_truth).astype(float)).long()
        if self.gpu:
            try:
                cond_num_truth_var = Variable(data.cuda())
            except:
                print("cond_num_truth_var error")
                print(data)
                exit(0)
        else:
            cond_num_truth_var = Variable(data)
        loss += self.CE(cond_num_score, cond_num_truth_var)

        # Evaluate the columns of conditions评估条件列
        T = len(cond_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][4]) > 0:
                truth_prob[b][list(truth_num[b][4])] = 1  #条件列
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            cond_col_truth_var = Variable(data.cuda())
        else:
            cond_col_truth_var = Variable(data)

        sigm = nn.Sigmoid()
        cond_col_prob = sigm(cond_col_score)
        bce_loss = -torch.mean(
            3 * (cond_col_truth_var * torch.log(cond_col_prob + 1e-10)) +
            (1 - cond_col_truth_var) * torch.log(1 - cond_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate the operator of conditions评估操作条件
        for b in range(len(truth_num)):
            if len(truth_num[b][5]) == 0:  #条件类型
                continue
            data = torch.from_numpy(np.array(truth_num[b][5])).long()
            if self.gpu:
                cond_op_truth_var = Variable(data.cuda())
            else:
                cond_op_truth_var = Variable(data)
            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
            # try:
            loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
            # except:
            #     print(cond_op_pred)
            #     print(cond_op_truth_var)
            #     exit(0)

        #Evaluate the strings of conditions评估条件串
        for b in range(len(gt_where)):
            for idx in range(len(gt_where[b])):
                cond_str_truth = gt_where[b][idx]
                if len(cond_str_truth) == 1:
                    continue
                data = torch.from_numpy(np.array(cond_str_truth[1:])).long()
                if self.gpu:
                    cond_str_truth_var = Variable(data.cuda())
                else:
                    cond_str_truth_var = Variable(data)
                str_end = len(cond_str_truth) - 1
                cond_str_pred = cond_str_score[b, idx, :str_end]
                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
                        / (len(gt_where) * len(gt_where[b])))

        # Evaluate condition relationship, and / or评估条件关系
        where_rela_truth = list(map(lambda x: x[6], truth_num))
        data = torch.from_numpy(np.array(where_rela_truth)).long()
        if self.gpu:
            try:
                where_rela_truth = Variable(data.cuda())
            except:
                print("where_rela_truth error")
                print(data)
                exit(0)
        else:
            where_rela_truth = Variable(data)
        loss += self.CE(where_rela_score, where_rela_truth)
        return loss

    def check_acc(self, vis_info, pred_queries, gt_queries):
        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        tot_err = sel_num_err = agg_err = sel_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry[
                'agg'], pred_qry['cond_conn_op']
            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry[
                'agg'], gt_qry['cond_conn_op']

            if where_rela_gt != where_rela_pred:
                good = False
                cond_rela_err += 1

            if len(sel_pred) != len(sel_gt):
                good = False
                sel_num_err += 1

            pred_sel_dict = {
                k: v
                for k, v in zip(list(sel_pred), list(agg_pred))
            }
            gt_sel_dict = {k: v for k, v in zip(sel_gt, agg_gt)}
            if set(sel_pred) != set(sel_gt):
                good = False
                sel_err += 1
            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
            if agg_pred != agg_gt:
                good = False
                agg_err += 1

            cond_pred = pred_qry['conds']
            cond_gt = gt_qry['conds']
            if len(cond_pred) != len(cond_gt):
                good = False
                cond_num_err += 1
            else:
                cond_op_pred, cond_op_gt = {}, {}
                cond_val_pred, cond_val_gt = {}, {}
                for p, g in zip(cond_pred, cond_gt):
                    cond_op_pred[p[0]] = p[1]
                    cond_val_pred[p[0]] = p[2]
                    cond_op_gt[g[0]] = g[1]
                    cond_val_gt[g[0]] = g[2]

                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
                    cond_col_err += 1
                    good = False

                where_op_pred = [
                    cond_op_pred[x] for x in sorted(cond_op_pred.keys())
                ]
                where_op_gt = [
                    cond_op_gt[x] for x in sorted(cond_op_gt.keys())
                ]
                if where_op_pred != where_op_gt:
                    cond_op_err += 1
                    good = False

                where_val_pred = [
                    cond_val_pred[x] for x in sorted(cond_val_pred.keys())
                ]
                where_val_gt = [
                    cond_val_gt[x] for x in sorted(cond_val_gt.keys())
                ]
                if where_val_pred != where_val_gt:
                    cond_val_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array(
            (sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err,
             cond_op_err, cond_val_err, cond_rela_err)), tot_err

    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
        """
        :param score:
        :param q: token-questions
        :param col: token-headers
        :param raw_q: original question sequence
        :return:
        """
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str  # .lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear
                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                # elif tok[0] not in alphabet:
                #     pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                        and (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
        # [64,4,6], [64,14], ..., [64,4]
        sel_num_score = sel_num_score.data.cpu().numpy()
        sel_score = sel_score.data.cpu().numpy()
        agg_score = agg_score.data.cpu().numpy()
        where_rela_score = where_rela_score.data.cpu().numpy()
        ret_queries = []
        B = len(agg_score)
        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
            [x.data.cpu().numpy() for x in cond_score]
        for b in range(B):
            cur_query = {}
            cur_query['sel'] = []
            cur_query['agg'] = []
            sel_num = np.argmax(sel_num_score[b])
            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
            # find the most-probable columns' indexes
            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
            cur_query['sel'].extend([int(i) for i in max_col_idxes])
            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            all_toks = ['<BEG>'] + q[b] + ['<END>']
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])  # where-col
                cur_cond.append(np.argmax(cond_op_score[b][idx]))  # where-op
                cur_cond_str_toks = []
                for str_score in cond_str_score[b][idx]:
                    str_tok = np.argmax(str_score[:len(all_toks)])
                    str_val = all_toks[str_tok]
                    if str_val == '<END>':
                        break
                    cur_cond_str_toks.append(str_val)
                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)
        return ret_queries