Beispiel #1
0
class Seq2SQL(nn.Module):
    def __init__(self,
                 word_emb,
                 num_words,
                 num_hidden=100,
                 num_layers=2,
                 use_gpu=True):
        super(Seq2SQL, self).__init__()

        self.word_emb = word_emb
        self.num_words = num_words
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.use_gpu = use_gpu

        self.max_col_num = 45
        self.max_tok_num = 200
        self.COND_OPS = ['EQL', 'GT', 'LT']
        self.SQL_TOK = ['<UNK>', '<BEG>', '<END>', 'WHERE', 'AND'
                        ] + self.COND_OPS

        # GloVe Word Embedding
        self.embed_layer = WordEmbedding(word_emb, num_words, self.SQL_TOK,
                                         use_gpu)

        # Aggregation Classifier
        self.agg_classifier = AggregationClassifier(num_words, num_hidden,
                                                    num_layers)

        # SELECT Column(s)
        self.sel_classifier = SelectClassifier(num_words, num_hidden,
                                               num_layers, self.max_tok_num)

        # WHERE Clause
        self.whr_classifier = WhereClassifier(num_words, num_hidden,
                                              num_layers, self.max_col_num,
                                              self.max_tok_num, use_gpu)

        # run on GPU
        if use_gpu:
            self.cuda()

    def generate_g_s(self, q, col, query):
        # data format
        # <BEG> WHERE cond1_col cond1_op cond1
        #         AND cond2_col cond2_op cond2
        #         AND ... <END>

        ret_seq = []
        for cur_q, cur_col, cur_query in zip(q, col, query):
            connect_col = [
                tok for col_tok in cur_col for tok in col_tok + [',']
            ]
            all_toks = self.SQL_TOK + connect_col + [None] + cur_q + [None]
            cur_seq = [all_toks.index('<BEG>')]
            if 'WHERE' in cur_query:
                cur_where_query = cur_query[cur_query.index('WHERE'):]
                cur_seq = cur_seq + map(
                    lambda tok: all_toks.index(tok)
                    if tok in all_toks else 0, cur_where_query)
            cur_seq.append(all_toks.index('<END>'))
            ret_seq.append(cur_seq)
        return ret_seq

    def forward(self,
                q,
                col,
                col_num,
                classif_flag,
                g_s=None,
                reinforce=False):

        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = None, None, None

        x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)

        if agg_classif:
            agg_score = self.agg_classifier(x_emb_var, x_len)

        if sel_classif:
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)
            sel_score = self.sel_classifier(x_emb_var, x_len, col_inp_var,
                                            col_name_len, col_len, col_num)

        if whr_classif:
            whr_score = self.whr_classifier(x_emb_var,
                                            x_len,
                                            g_s,
                                            reinforce=reinforce)

        return (agg_score, sel_score, whr_score)

    def loss(self, score, ref_score, classif_flag, g_s):
        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = score
        loss = 0
        if agg_classif:
            agg_ref = torch.from_numpy(np.array(map(lambda x: x[0],
                                                    ref_score)))
            agg_ref_var = Variable(agg_ref)
            if self.use_gpu:
                agg_ref_var = agg_ref_var.cuda()
            loss += nn.CrossEntropyLoss()(agg_score, agg_ref_var)

        if sel_classif:
            sel_ref = torch.from_numpy(np.array(map(lambda x: x[1],
                                                    ref_score)))
            sel_ref_var = Variable(sel_ref)
            if self.use_gpu:
                sel_ref_var = sel_ref_var.cuda()
            loss += nn.CrossEntropyLoss()(sel_score, sel_ref_var)

        if whr_classif:
            g_s_len = len(g_s)
            for s, g_s_i in enumerate(g_s):
                whr_ref_var = Variable(torch.from_numpy(np.array(g_s_i[1:])))
                if self.use_gpu:
                    whr_ref_var = whr_ref_var.cuda()
                loss += (nn.CrossEntropyLoss()(whr_score[s, :len(g_s_i) - 1],
                                               whr_ref_var) / g_s_len)

        return loss

    def reinforce_backward(self, score, rewards):
        agg_score, sel_score, whr_score = score

        cur_reward = rewards[:]
        eof = self.SQL_TOK.index('<END>')
        for whr_score_t in whr_score[1]:
            reward_inp = torch.FloatTensor(cur_reward).unsqueeze(1)
            if self.use_gpu:
                reward_inp = reward_inp.cuda()
            whr_score_t.reinforce(reward_inp)

            for b, _ in enumerate(rewards):
                if whr_score_t[b].data.cpu().numpy()[0] == eof:
                    cur_reward[b] = 0
        torch.autograd.backward(whr_score[1], [None for _ in whr_score[1]])
        return

    def check_acc(self, classif_queries, g_s_queries, classif_flag):

        agg_classif, sel_classif, whr_classif = classif_flag
        tot_err = agg_err = sel_err = whr_err = whr_num_err = whr_col_err = whr_op_err = whr_val_err = 0.0
        for classif_qry, g_s_qry in zip(classif_queries, g_s_queries):

            agg_err_inc = 1 if agg_classif and classif_qry['agg'] != g_s_qry[
                'agg'] else 0
            agg_err += agg_err_inc

            sel_err_inc = 1 if sel_classif and classif_qry['sel'] != g_s_qry[
                'sel'] else 0
            sel_err += sel_err_inc

            if whr_classif:
                flag = True
                whr_classifier = classif_qry['conds']
                whr_g_s = g_s_qry['conds']
                if len(whr_classifier) != len(whr_g_s):
                    flag = False
                    whr_num_err += 1
                elif set(x[0]
                         for x in whr_classifier) != set(x[0]
                                                         for x in whr_g_s):
                    flag = False
                    whr_col_err += 1
                if flag:
                    for whr_class_i in whr_classifier:
                        g_s_idx = tuple(x[0]
                                        for x in whr_g_s).index(whr_class_i[0])
                        if flag and whr_g_s[g_s_idx][1] != whr_class_i[1]:
                            flag = False
                            whr_op_err += 1
                            break
                if flag:
                    for whr_class_i in whr_classifier:
                        g_s_idx = tuple(x[0]
                                        for x in whr_g_s).index(whr_class_i[0])
                        if flag and unicode(whr_g_s[g_s_idx][2]).lower() != \
                        unicode(whr_class_i[2]).lower():
                            flag = False
                            whr_val_err += 1
                            break

                if not flag:
                    whr_err += 1

            if agg_err_inc > 0 or sel_err_inc > 0 or not flag:
                tot_err += 1

        return np.array((agg_err, sel_err, whr_err)), tot_err

    def gen_query(self,
                  score,
                  q,
                  col,
                  raw_q,
                  raw_col,
                  classif_flag,
                  reinforce=False,
                  verbose=False):
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str.lower()
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_pair_track = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_pair_track = 1 - double_quote_pair_track
                    if double_quote_pair_track:
                        ret = ret + ' '
                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif (tok[0] not in string.ascii_lowercase) and (
                        tok[0] not in string.digits) and (tok[0] not in '$('):
                    pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) and \
                     (ret[-1] != '"' or not double_quote_pair_track):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        agg_classif, sel_classif, whr_classif = classif_flag
        agg_score, sel_score, whr_score = score

        ret_queries = []
        batch_len = len(agg_score) if agg_classif else len(
            sel_score) if sel_classif else len(
                whr_score[0]) if reinforce else len(whr_score)
        for b in range(batch_len):
            cur_query = {}
            if agg_classif:
                cur_query['agg'] = np.argmax(agg_score[b].data.cpu().numpy())
            if sel_classif:
                cur_query['sel'] = np.argmax(sel_score[b].data.cpu().numpy())
            if whr_classif:
                cur_query['conds'] = []
                all_toks = self.SQL_TOK + [
                    x for toks in col[b] for x in toks + [',']
                ] + [''] + q[b] + ['']
                whr_toks = []
                if reinforce:
                    for choices in whr_score[1]:
                        if choices[b].data.cpu().numpy()[0] < len(all_toks):
                            whr_val = all_toks[choices[b].data.cpu().numpy()
                                               [0]]
                        else:
                            whr_val = '<UNK>'
                        if whr_val == '<END>':
                            break
                        whr_toks.append(whr_val)
                else:
                    for where_score in whr_score[b].data.cpu().numpy():
                        whr_tok = np.argmax(where_score)
                        whr_val = all_toks[whr_tok]
                        if whr_val == '<END>':
                            break
                        whr_toks.append(whr_val)

                if verbose:
                    print whr_toks
                if len(whr_toks) > 0:
                    whr_toks = whr_toks[1:]
                st = 0
                while st < len(whr_toks):
                    cur_cond = [None, None, None]
                    ed = len(whr_toks) if 'AND' not in whr_toks[st:] \
                         else whr_toks[st:].index('AND') + st
                    if 'EQL' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('EQL') + st
                        cur_cond[1] = 0
                    elif 'GT' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('GT') + st
                        cur_cond[1] = 1
                    elif 'LT' in whr_toks[st:ed]:
                        op = whr_toks[st:ed].index('LT') + st
                        cur_cond[1] = 2
                    else:
                        op = st
                        cur_cond[1] = 0
                    sel_col = whr_toks[st:op]
                    to_idx = [x.lower() for x in raw_col[b]]
                    classif_col = merge_tokens(sel_col, raw_q[b] + ' || ' + \
                                            ' || '.join(raw_col[b]))
                    if classif_col in to_idx:
                        cur_cond[0] = to_idx.index(classif_col)
                    else:
                        cur_cond[0] = 0
                    cur_cond[2] = merge_tokens(whr_toks[op + 1:ed], raw_q[b])
                    cur_query['conds'].append(cur_cond)
                    st = ed + 1
            ret_queries.append(cur_query)

        return ret_queries
Beispiel #2
0
class SuperModel(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=300,
                 N_depth=2,
                 gpu=True,
                 trainable_emb=False,
                 table_type="std",
                 use_hs=True):
        super(SuperModel, self).__init__()
        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth
        self.trainable_emb = trainable_emb
        self.table_type = table_type
        self.use_hs = use_hs
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]

        # word embedding layer
        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         trainable=trainable_emb)

        # initial all modules
        self.multi_sql = MultiSqlPredictor(N_word=N_word,
                                           N_h=N_h,
                                           N_depth=N_depth,
                                           gpu=gpu,
                                           use_hs=use_hs)
        self.multi_sql.eval()

        self.key_word = KeyWordPredictor(N_word=N_word,
                                         N_h=N_h,
                                         N_depth=N_depth,
                                         gpu=gpu,
                                         use_hs=use_hs)
        self.key_word.eval()

        self.col = ColPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=gpu,
                                use_hs=use_hs)
        self.col.eval()

        self.op = OpPredictor(N_word=N_word,
                              N_h=N_h,
                              N_depth=N_depth,
                              gpu=gpu,
                              use_hs=use_hs)
        self.op.eval()

        self.agg = AggPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=gpu,
                                use_hs=use_hs)
        self.agg.eval()

        self.root_teminal = RootTeminalPredictor(N_word=N_word,
                                                 N_h=N_h,
                                                 N_depth=N_depth,
                                                 gpu=gpu,
                                                 use_hs=use_hs)
        self.root_teminal.eval()

        self.des_asc = DesAscLimitPredictor(N_word=N_word,
                                            N_h=N_h,
                                            N_depth=N_depth,
                                            gpu=gpu,
                                            use_hs=use_hs)
        self.des_asc.eval()

        self.having = HavingPredictor(N_word=N_word,
                                      N_h=N_h,
                                      N_depth=N_depth,
                                      gpu=gpu,
                                      use_hs=use_hs)
        self.having.eval()

        self.andor = AndOrPredictor(N_word=N_word,
                                    N_h=N_h,
                                    N_depth=N_depth,
                                    gpu=gpu,
                                    use_hs=use_hs)
        self.andor.eval()

        self.softmax = nn.Softmax()  #dim=1
        self.CE = nn.CrossEntropyLoss()
        self.log_softmax = nn.LogSoftmax()
        self.mlsml = nn.MultiLabelSoftMarginLoss()
        self.bce_logit = nn.BCEWithLogitsLoss()
        self.sigm = nn.Sigmoid()
        if gpu:
            self.cuda()
        self.path_not_found = 0

    def forward(self, q_seq, history, tables):
        # if self.part:
        #     return self.part_forward(q_seq,history,tables)
        # else:
        return self.full_forward(q_seq, history, tables)

    def full_forward(self, q_seq, history, tables):
        B = len(q_seq)
        # print("q_seq:{}".format(q_seq))
        # print("Batch size:{}".format(B))
        q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq)
        col_seq = to_batch_tables(tables, B, self.table_type)
        col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
            col_seq)

        mkw_emb_var = self.embed_layer.gen_word_list_embedding(
            ["none", "except", "intersect", "union"], (B))
        mkw_len = np.full(q_len.shape, 4, dtype=np.int64)
        kw_emb_var = self.embed_layer.gen_word_list_embedding(
            ["where", "group by", "order by"], (B))
        kw_len = np.full(q_len.shape, 3, dtype=np.int64)

        stack = Stack()
        stack.push(("root", None))
        history = [["root"]] * B
        andor_cond = ""
        has_limit = False
        # sql = {}
        current_sql = {}
        sql_stack = []
        idx_stack = []
        kw_stack = []
        kw = ""
        nested_label = ""
        has_having = False

        timeout = time.time(
        ) + 2  # set timer to prevent infinite recursion in SQL generation
        failed = False
        while not stack.isEmpty():
            if time.time() > timeout:
                failed = True
                break
            vet = stack.pop()
            # print(vet)
            hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history)
            if len(idx_stack) > 0 and stack.size() < idx_stack[-1]:
                # print("pop!!!!!!!!!!!!!!!!!!!!!!")
                idx_stack.pop()
                current_sql = sql_stack.pop()
                kw = kw_stack.pop()
                # current_sql = current_sql["sql"]
            # history.append(vet)
            # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size()))
            if isinstance(vet, tuple) and vet[0] == "root":
                if history[0][-1] != "root":
                    history[0].append("root")
                    hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                        history)
                if vet[1] != "original":
                    idx_stack.append(stack.size())
                    sql_stack.append(current_sql)
                    kw_stack.append(kw)
                else:
                    idx_stack.append(stack.size())
                    sql_stack.append(sql_stack[-1])
                    kw_stack.append(kw)
                if "sql" in current_sql:
                    current_sql["nested_sql"] = {}
                    current_sql["nested_label"] = nested_label
                    current_sql = current_sql["nested_sql"]
                elif isinstance(vet[1], dict):
                    vet[1]["sql"] = {}
                    current_sql = vet[1]["sql"]
                elif vet[1] != "original":
                    current_sql["sql"] = {}
                    current_sql = current_sql["sql"]
                # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size()))
                if vet[1] == "nested" or vet[1] == "original":
                    stack.push("none")
                    history[0].append("none")
                else:
                    score = self.multi_sql.forward(q_emb_var, q_len,
                                                   hs_emb_var, hs_len,
                                                   mkw_emb_var, mkw_len)
                    label = np.argmax(score[0].data.cpu().numpy())
                    label = SQL_OPS[label]
                    history[0].append(label)
                    stack.push(label)
                if label != "none":
                    nested_label = label

            elif vet in ('intersect', 'except', 'union'):
                stack.push(("root", "nested"))
                stack.push(("root", "original"))
                # history[0].append("root")
            elif vet == "none":
                score = self.key_word.forward(q_emb_var, q_len, hs_emb_var,
                                              hs_len, kw_emb_var, kw_len)
                kw_num_score, kw_score = [x.data.cpu().numpy() for x in score]
                # print("kw_num_score:{}".format(kw_num_score))
                # print("kw_score:{}".format(kw_score))
                num_kw = np.argmax(kw_num_score[0])
                kw_score = list(np.argsort(-kw_score[0])[:num_kw])
                kw_score.sort(reverse=True)
                # print("num_kw:{}".format(num_kw))
                for kw in kw_score:
                    stack.push(KW_OPS[kw])
                stack.push("select")
            elif vet in ("select", "orderBy", "where", "groupBy", "having"):
                kw = vet
                current_sql[kw] = []
                history[0].append(vet)
                stack.push(("col", vet))
                # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len)
                # label = score[0].data.cpu().numpy()
                # andor_cond = COND_OPS[label]
                # history.append("")
            # elif vet == "groupBy":
            #     score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,)
            elif isinstance(vet, tuple) and vet[0] == "col":
                # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size()))
                score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                         col_emb_var, col_len, col_name_len)
                col_num_score, col_score = [
                    x.data.cpu().numpy() for x in score
                ]
                col_num = np.argmax(col_num_score[0]) + 1  # double check
                cols = np.argsort(-col_score[0])[:col_num]
                # print(col_num)
                # print("col_num_score:{}".format(col_num_score))
                # print("col_score:{}".format(col_score))
                for col in cols:
                    if vet[1] == "where":
                        stack.push(("op", "where", col))
                    elif vet[1] != "groupBy":
                        stack.push(("agg", vet[1], col))
                    elif vet[1] == "groupBy":
                        history[0].append(index_to_column_name(col, tables))
                        current_sql[kw].append(
                            index_to_column_name(col, tables))
                #predict and or or when there is multi col in where condition
                if col_num > 1 and vet[1] == "where":
                    score = self.andor.forward(q_emb_var, q_len, hs_emb_var,
                                               hs_len)
                    label = np.argmax(score[0].data.cpu().numpy())
                    andor_cond = COND_OPS[label]
                    current_sql[kw].append(andor_cond)
                if vet[1] == "groupBy" and col_num > 0:
                    score = self.having.forward(
                        q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                        col_len, col_name_len,
                        np.full(B, cols[0], dtype=np.int64))
                    label = np.argmax(score[0].data.cpu().numpy())
                    if label == 1:
                        has_having = (label == 1)
                        # stack.insert(-col_num,"having")
                        stack.push("having")
                # history.append(index_to_column_name(cols[-1], tables[0]))
            elif isinstance(vet, tuple) and vet[0] == "agg":
                history[0].append(index_to_column_name(vet[2], tables))
                if vet[1] not in ("having", "orderBy"):  #DEBUG-ed 20180817
                    try:
                        current_sql[kw].append(
                            index_to_column_name(vet[2], tables))
                    except Exception as e:
                        # print(e)
                        traceback.print_exc()
                        print("history:{},current_sql:{} stack:{}".format(
                            history[0], current_sql, stack.items))
                        print("idx_stack:{}".format(idx_stack))
                        print("sql_stack:{}".format(sql_stack))
                        exit(1)
                hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                    history)

                score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                         col_emb_var, col_len, col_name_len,
                                         np.full(B, vet[2], dtype=np.int64))
                agg_num_score, agg_score = [
                    x.data.cpu().numpy() for x in score
                ]
                agg_num = np.argmax(agg_num_score[0])  # double check
                agg_idxs = np.argsort(-agg_score[0])[:agg_num]
                # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs]))
                if len(agg_idxs) > 0:
                    history[0].append(AGG_OPS[agg_idxs[0]])
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append(AGG_OPS[agg_idxs[0]])
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2],
                                    AGG_OPS[agg_idxs[0]]))  #DEBUG-ed 20180817
                    else:
                        stack.push(
                            ("op", "having", vet[2], AGG_OPS[agg_idxs[0]]))
                for agg in agg_idxs[1:]:
                    history[0].append(index_to_column_name(vet[2], tables))
                    history[0].append(AGG_OPS[agg])
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append(
                            index_to_column_name(vet[2], tables))
                        current_sql[kw].append(AGG_OPS[agg])
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2], AGG_OPS[agg]))
                    else:
                        stack.push(("op", "having", vet[2], agg_idxs))
                if len(agg_idxs) == 0:
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append("none_agg")
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2], "none_agg"))
                    else:
                        stack.push(("op", "having", vet[2], "none_agg"))
                # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs])
                # if vet[1] == "having":
                #     stack.push(("op","having",vet[2],agg_idxs))
                # if vet[1] == "orderBy":
                #     stack.push(("des_asc",vet[2],agg_idxs))
                # if vet[1] == "groupBy" and has_having:
                #     stack.push("having")
            elif isinstance(vet, tuple) and vet[0] == "op":
                if vet[1] == "where":
                    # current_sql[kw].append(index_to_column_name(vet[2], tables))
                    history[0].append(index_to_column_name(vet[2], tables))
                    hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                        history)

                score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                        col_emb_var, col_len, col_name_len,
                                        np.full(B, vet[2], dtype=np.int64))

                op_num_score, op_score = [x.data.cpu().numpy() for x in score]
                op_num = np.argmax(
                    op_num_score[0]
                ) + 1  # num_score 0 maps to 1 in truth, must have at least one op
                ops = np.argsort(-op_score[0])[:op_num]
                # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops])
                if op_num > 0:
                    history[0].append(NEW_WHERE_OPS[ops[0]])
                    if vet[1] == "having":
                        stack.push(("root_teminal", vet[2], vet[3], ops[0]))
                    else:
                        stack.push(("root_teminal", vet[2], ops[0]))
                    # current_sql[kw].append(NEW_WHERE_OPS[ops[0]])
                for op in ops[1:]:
                    history[0].append(index_to_column_name(vet[2], tables))
                    history[0].append(NEW_WHERE_OPS[op])
                    # current_sql[kw].append(index_to_column_name(vet[2], tables))
                    # current_sql[kw].append(NEW_WHERE_OPS[op])
                    if vet[1] == "having":
                        stack.push(("root_teminal", vet[2], vet[3], op))
                    else:
                        stack.push(("root_teminal", vet[2], op))
                # stack.push(("root_teminal",vet[2]))
            elif isinstance(vet, tuple) and vet[0] == "root_teminal":
                score = self.root_teminal.forward(
                    q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len,
                    col_name_len, np.full(B, vet[1], dtype=np.int64))

                label = np.argmax(score[0].data.cpu().numpy())
                label = ROOT_TERM_OPS[label]
                if len(vet) == 4:
                    current_sql[kw].append(index_to_column_name(
                        vet[1], tables))
                    current_sql[kw].append(vet[2])
                    current_sql[kw].append(NEW_WHERE_OPS[vet[3]])
                else:
                    # print("kw:{}".format(kw))
                    try:
                        current_sql[kw].append(
                            index_to_column_name(vet[1], tables))
                    except Exception as e:
                        # print(e)
                        traceback.print_exc()
                        print("history:{},current_sql:{} stack:{}".format(
                            history[0], current_sql, stack.items))
                        print("idx_stack:{}".format(idx_stack))
                        print("sql_stack:{}".format(sql_stack))
                        exit(1)
                    current_sql[kw].append(NEW_WHERE_OPS[vet[2]])
                if label == "root":
                    history[0].append("root")
                    current_sql[kw].append({})
                    # current_sql = current_sql[kw][-1]
                    stack.push(("root", current_sql[kw][-1]))
                else:
                    current_sql[kw].append("terminal")
            elif isinstance(vet, tuple) and vet[0] == "des_asc":
                current_sql[kw].append(index_to_column_name(vet[1], tables))
                current_sql[kw].append(vet[2])
                score = self.des_asc.forward(
                    q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len,
                    col_name_len, np.full(B, vet[1], dtype=np.int64))
                label = np.argmax(score[0].data.cpu().numpy())
                dec_asc, has_limit = DEC_ASC_OPS[label]
                history[0].append(dec_asc)
                current_sql[kw].append(dec_asc)
                current_sql[kw].append(has_limit)
        # print("{}".format(current_sql))

        if failed: return None
        print("history:{}".format(history[0]))
        if len(sql_stack) > 0:
            current_sql = sql_stack[0]
        # print("{}".format(current_sql))
        return current_sql

    def gen_col(self, col, table, table_alias_dict):
        colname = table["column_names_original"][col[2]][1]
        table_idx = table["column_names_original"][col[2]][0]
        if table_idx not in table_alias_dict:
            return colname
        return "T{}.{}".format(table_alias_dict[table_idx], colname)

    def gen_group_by(self, sql, kw, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql)):
            # if len(sql[i+1]) == 0:
            # if sql[i+1] == "none_agg":
            ret.append(self.gen_col(sql[i], table, table_alias_dict))
            # else:
            #     ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict)))
            # for agg in sql[i+1]:
            #     ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict)))
        return "{} {}".format(kw, ",".join(ret))

    def gen_select(self, sql, kw, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql), 2):
            # if len(sql[i+1]) == 0:
            if sql[i + 1] == "none_agg" or not isinstance(
                    sql[i + 1], basestring):  #DEBUG-ed 20180817
                ret.append(self.gen_col(sql[i], table, table_alias_dict))
            else:
                ret.append("{}({})".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)))
            # for agg in sql[i+1]:
            #     ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict)))
        return "{} {}".format(kw, ",".join(ret))

    def gen_where(self, sql, table, table_alias_dict):
        if len(sql) == 0:
            return ""
        start_idx = 0
        andor = "and"
        if isinstance(sql[0], basestring):
            start_idx += 1
            andor = sql[0]
        ret = []
        for i in range(start_idx, len(sql), 3):
            col = self.gen_col(sql[i], table, table_alias_dict)
            op = sql[i + 1]
            val = sql[i + 2]
            where_item = ""
            if val == "terminal":
                where_item = "{} {} '{}'".format(col, op, val)
            else:
                val = self.gen_sql(val, table)
                where_item = "{} {} ({})".format(col, op, val)
            if op == "between":
                #TODO temprarily fixed
                where_item += " and 'terminal'"
            ret.append(where_item)
        return "where {}".format(" {} ".format(andor).join(ret))

    def gen_orderby(self, sql, table, table_alias_dict):
        ret = []
        limit = ""
        if sql[-1] == True:
            limit = "limit 1"
        for i in range(0, len(sql), 4):
            if sql[i + 1] == "none_agg" or not isinstance(
                    sql[i + 1], basestring):  #DEBUG-ed 20180817
                ret.append("{} {}".format(
                    self.gen_col(sql[i], table, table_alias_dict), sql[i + 2]))
            else:
                ret.append("{}({}) {}".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict),
                    sql[i + 2]))
        return "order by {} {}".format(",".join(ret), limit)

    def gen_having(self, sql, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql), 4):
            if sql[i + 1] == "none_agg":
                col = self.gen_col(sql[i], table, table_alias_dict)
            else:
                col = "{}({})".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))
            op = sql[i + 2]
            val = sql[i + 3]
            if val == "terminal":
                ret.append("{} {} '{}'".format(col, op, val))
            else:
                val = self.gen_sql(val, table)
                ret.append("{} {} ({})".format(col, op, val))
        return "having {}".format(",".join(ret))

    def find_shortest_path(self, start, end, graph):
        stack = [[start, []]]
        visited = set()
        while len(stack) > 0:
            ele, history = stack.pop()
            if ele == end:
                return history
            for node in graph[ele]:
                if node[0] not in visited:
                    stack.append((node[0], history + [(node[0], node[1])]))
                    visited.add(node[0])
        print("table {} table {}".format(start, end))
        # print("could not find path!!!!!{}".format(self.path_not_found))
        self.path_not_found += 1
        # return []
    def gen_from(self, candidate_tables, table):
        def find(d, col):
            if d[col] == -1:
                return col
            return find(d, d[col])

        def union(d, c1, c2):
            r1 = find(d, c1)
            r2 = find(d, c2)
            if r1 == r2:
                return
            d[r1] = r2

        ret = ""
        if len(candidate_tables) <= 1:
            if len(candidate_tables) == 1:
                ret = "from {}".format(
                    table["table_names_original"][list(candidate_tables)[0]])
            else:
                ret = "from {}".format(table["table_names_original"][0])
            #TODO: temporarily settings
            return {}, ret
        # print("candidate:{}".format(candidate_tables))
        table_alias_dict = {}
        uf_dict = {}
        for t in candidate_tables:
            uf_dict[t] = -1
        idx = 1
        graph = defaultdict(list)
        for acol, bcol in table["foreign_keys"]:
            t1 = table["column_names"][acol][0]
            t2 = table["column_names"][bcol][0]
            graph[t1].append((t2, (acol, bcol)))
            graph[t2].append((t1, (bcol, acol)))
            # if t1 in candidate_tables and t2 in candidate_tables:
            #     r1 = find(uf_dict,t1)
            #     r2 = find(uf_dict,t2)
            #     if r1 == r2:
            #         continue
            #     union(uf_dict,t1,t2)
            #     if len(ret) == 0:
            #         ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2],
            #                                                                       idx+1,idx,table["column_names_original"][acol][1],idx+1,
            #                                                                       table["column_names_original"][bcol][1])
            #         table_alias_dict[t1] = idx
            #         table_alias_dict[t2] = idx+1
            #         idx += 2
            #     else:
            #         if t1 in table_alias_dict:
            #             old_t = t1
            #             new_t = t2
            #             acol,bcol = bcol,acol
            #         elif t2 in table_alias_dict:
            #             old_t = t2
            #             new_t = t1
            #         else:
            #             ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx,
            #                                                                           table["table_names"][t2],
            #                                                                           idx + 1, idx,
            #                                                                           table["column_names_original"][acol][1],
            #                                                                           idx + 1,
            #                                                                           table["column_names_original"][bcol][1])
            #             table_alias_dict[t1] = idx
            #             table_alias_dict[t2] = idx + 1
            #             idx += 2
            #             continue
            #         ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1],
            #                                                        table_alias_dict[old_t],table["column_names_original"][bcol][1])
            #         table_alias_dict[new_t] = idx
            #         idx += 1
        # visited = set()
        candidate_tables = list(candidate_tables)
        start = candidate_tables[0]
        table_alias_dict[start] = idx
        idx += 1
        ret = "from {} as T1".format(table["table_names_original"][start])
        try:
            for end in candidate_tables[1:]:
                if end in table_alias_dict:
                    continue
                path = self.find_shortest_path(start, end, graph)
                prev_table = start
                if not path:
                    table_alias_dict[end] = idx
                    idx += 1
                    ret = "{} join {} as T{}".format(
                        ret,
                        table["table_names_original"][end],
                        table_alias_dict[end],
                    )
                    continue
                for node, (acol, bcol) in path:
                    if node in table_alias_dict:
                        prev_table = node
                        continue
                    table_alias_dict[node] = idx
                    idx += 1
                    ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(
                        ret, table["table_names_original"][node],
                        table_alias_dict[node], table_alias_dict[prev_table],
                        table["column_names_original"][acol][1],
                        table_alias_dict[node],
                        table["column_names_original"][bcol][1])
                    prev_table = node
        except:
            traceback.print_exc()
            print("db:{}".format(table["db_id"]))
            # print(table["db_id"])
            return table_alias_dict, ret
        # if len(candidate_tables) != len(table_alias_dict):
        #     print("error in generate from clause!!!!!")
        return table_alias_dict, ret

    def gen_sql(self, sql, table):
        select_clause = ""
        from_clause = ""
        groupby_clause = ""
        orderby_clause = ""
        having_clause = ""
        where_clause = ""
        nested_clause = ""
        cols = {}
        candidate_tables = set()
        nested_sql = {}
        nested_label = ""
        parent_sql = sql
        # if "sql" in sql:
        #     sql = sql["sql"]
        if "nested_label" in sql:
            nested_label = sql["nested_label"]
            nested_sql = sql["nested_sql"]
            sql = sql["sql"]
        elif "sql" in sql:
            sql = sql["sql"]
        for key in sql:
            if key not in KW_WITH_COL:
                continue
            for item in sql[key]:
                if isinstance(item, tuple) and len(item) == 3:
                    if table["column_names"][item[2]][0] != -1:
                        candidate_tables.add(table["column_names"][item[2]][0])
        table_alias_dict, from_clause = self.gen_from(candidate_tables, table)
        ret = []
        if "select" in sql:
            select_clause = self.gen_select(sql["select"], "select", table,
                                            table_alias_dict)
            if len(select_clause) > 0:
                ret.append(select_clause)
            else:
                print("select not found:{}".format(parent_sql))
        else:
            print("select not found:{}".format(parent_sql))
        if len(from_clause) > 0:
            ret.append(from_clause)
        if "where" in sql:
            where_clause = self.gen_where(sql["where"], table,
                                          table_alias_dict)
            if len(where_clause) > 0:
                ret.append(where_clause)
        if "groupBy" in sql:  ## DEBUG-ed order
            groupby_clause = self.gen_group_by(sql["groupBy"], "group by",
                                               table, table_alias_dict)
            if len(groupby_clause) > 0:
                ret.append(groupby_clause)
        if "orderBy" in sql:
            orderby_clause = self.gen_orderby(sql["orderBy"], table,
                                              table_alias_dict)
            if len(orderby_clause) > 0:
                ret.append(orderby_clause)
        if "having" in sql:
            having_clause = self.gen_having(sql["having"], table,
                                            table_alias_dict)
            if len(having_clause) > 0:
                ret.append(having_clause)
        if len(nested_label) > 0:
            nested_clause = "{} {}".format(nested_label,
                                           self.gen_sql(nested_sql, table))
            if len(nested_clause) > 0:
                ret.append(nested_clause)
        return " ".join(ret)

    def check_acc(self, pred_sql, gt_sql):
        pass
Beispiel #3
0
class SQLNet(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=100,
                 N_depth=2,
                 gpu=False,
                 use_ca=True,
                 trainable_emb=False):
        super(SQLNet, self).__init__()
        self.use_ca = use_ca
        self.trainable_emb = trainable_emb

        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth

        self.max_col_num = 45
        self.max_tok_num = 200
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=',
            '<BEG>'
        ]
        self.COND_OPS = ['>', '<', '==', '!=']

        # 词向量,可选择自己训练或者使用训练好的词向量,这里选用加载好的词向量
        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         our_model=True,
                                         trainable=trainable_emb)

        # 预测列数目
        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        # 预测那个列被选中了
        self.sel_pred = SelPredictor(N_word,
                                     N_h,
                                     N_depth,
                                     self.max_tok_num,
                                     use_ca=use_ca)

        # 预测相应选定列的聚合函数
        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)

        # 预测条件数、条件列、条件操作和条件值
        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth,
                                             self.max_col_num,
                                             self.max_tok_num, use_ca, gpu)

        # 预测条件关系,如“and”、“or”
        self.where_rela_pred = WhereRelationPredictor(N_word,
                                                      N_h,
                                                      N_depth,
                                                      use_ca=use_ca)

        self.CE = nn.CrossEntropyLoss()  #交叉熵损失函数
        self.softmax = nn.Softmax(dim=-1)
        self.log_softmax = nn.LogSoftmax()
        self.bce_logit = nn.BCEWithLogitsLoss()
        if gpu:
            self.cuda()

    # q:问题,gt_cond_seq:三元组  目的:要选择那一列
    def generate_gt_where_seq_test(self, q, gt_cond_seq):
        ret_seq = []
        for cur_q, ans in zip(q, gt_cond_seq):
            temp_q = u"".join(cur_q)
            cur_q = [u'<BEG>'] + cur_q + [u'<END>']  # 在每个问题前加<BEG>和结尾加<END>
            record = []  #如果条件值在问题中,标记(TRUE,条件值)
            record_cond = []
            for cond in ans:
                if cond[2] not in temp_q:
                    record.append((False, cond[2]))
                else:
                    record.append((True, cond[2]))
            for idx, item in enumerate(record):
                temp_ret_seq = []
                if item[0]:
                    temp_ret_seq.append(0)
                    temp_ret_seq.extend(
                        list(
                            range(
                                temp_q.index(item[1]) + 1,
                                temp_q.index(item[1]) + len(item[1]) +
                                1)))  #获取条件值的索引
                    temp_ret_seq.append(len(cur_q) - 1)
                else:
                    temp_ret_seq.append([0, len(cur_q) - 1])
                    record_cond.append(temp_ret_seq)
            ret_seq.append(record_cond)
        return ret_seq
        #q:问题,col:表头名字,col_num:有几个表头列,gt_where:conds中条件值不出现在问题中,gt_conds:conds,gt_sel:选择那列,gt_sel_num:选择几列
    def forward(self,
                q,
                col,
                col_num,
                gt_where=None,
                gt_cond=None,
                reinforce=False,
                gt_sel=None,
                gt_sel_num=None):
        B = len(q)  #batch_size的大小

        sel_num_score = None
        agg_score = None
        sel_score = None
        cond_score = None
        #预测聚合函数
        if self.trainable_emb:
            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=gt_sel)

            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)

            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(
                col)
            max_x_len = max(x_len)
            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)
            where_rela_score = None
        else:
            x_emb_var, x_len = self.embed_layer.gen_x_batch(
                q, col
            )  #x_len:batch中每个问题的长度,[x_emb_var:batch_size,max_seq_len,word_embedding_size]
            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
                col)  #列名向量化,长度,几个列
            sel_num_score = self.sel_num(
                x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num)  #[16,4]对问题的编码经过lstm,linear,softmax之后乘以编码
            # x_emb_var: embedding of each question
            # x_len: length of each question
            # col_inp_var: embedding of each header
            # col_name_len: length of each header
            # col_len: number of headers in each table, array type
            # col_num: number of headers in each table, list type
            if gt_sel_num:
                pr_sel_num = gt_sel_num
            else:
                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(),
                                       axis=1)
            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
                                      col_name_len, col_len, col_num)  #【16,19】

            if gt_sel:
                pr_sel = gt_sel
            else:
                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
                sel = sel_score.data.cpu().numpy()
                pr_sel = [
                    list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))
                ]
            agg_score = self.agg_pred(x_emb_var,
                                      x_len,
                                      col_inp_var,
                                      col_name_len,
                                      col_len,
                                      col_num,
                                      gt_sel=pr_sel,
                                      gt_sel_num=pr_sel_num)  #【16,4,6】

            where_rela_score = self.where_rela_pred(x_emb_var, x_len,
                                                    col_inp_var, col_name_len,
                                                    col_len, col_num)  #【16,3】

            cond_score = self.cond_pred(x_emb_var,
                                        x_len,
                                        col_inp_var,
                                        col_name_len,
                                        col_len,
                                        col_num,
                                        gt_where,
                                        gt_cond,
                                        reinforce=reinforce)  #4=>[16,5]

        return (sel_num_score, sel_score, agg_score, cond_score,
                where_rela_score)

    def loss(self, score, truth_num, gt_where):
        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score

        B = len(truth_num)
        loss = 0

        # Evaluate select number
        sel_num_truth = list(map(lambda x: x[0], truth_num))  #聚合函数个数
        sel_num_truth = torch.from_numpy(
            np.array(sel_num_truth)).long()  #.astype(float))
        if self.gpu:
            sel_num_truth = Variable(sel_num_truth.cuda())
        else:
            sel_num_truth = Variable(sel_num_truth)
        #选择几个列的损失
        loss += self.CE(sel_num_score, sel_num_truth)

        # Evaluate select column选择哪个列的损失
        T = len(sel_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            truth_prob[b][list(truth_num[b][1])] = 1
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            sel_col_truth_var = Variable(data.cuda())
        else:
            sel_col_truth_var = Variable(data)
        sigm = nn.Sigmoid()
        sel_col_prob = sigm(sel_score)
        bce_loss = -torch.mean(
            3 * (sel_col_truth_var * torch.log(sel_col_prob + 1e-10)) +
            (1 - sel_col_truth_var) * torch.log(1 - sel_col_prob + 1e-10)
        )  #这儿采用bceloss:-w*[y*log(x)+(1-y)*log(1-x)]
        loss += bce_loss

        # Evaluate select aggregation选择聚合函数的损失交叉熵
        for b in range(len(truth_num)):
            data = torch.from_numpy(np.array(truth_num[b][2]))  #真实的聚合函数
            if self.gpu:
                sel_agg_truth_var = Variable(data.cuda())
            else:
                sel_agg_truth_var = Variable(data.long())
            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]  #聚合函数共六种
            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)

        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score

        # Evaluate the number of conditions预测多少个conds的损失交叉熵
        cond_num_truth = list(map(lambda x: x[3], truth_num))
        data = torch.from_numpy(np.array(cond_num_truth).astype(float)).long()
        if self.gpu:
            try:
                cond_num_truth_var = Variable(data.cuda())
            except:
                print("cond_num_truth_var error")
                print(data)
                exit(0)
        else:
            cond_num_truth_var = Variable(data)
        loss += self.CE(cond_num_score, cond_num_truth_var)

        # Evaluate the columns of conditions评估条件列
        T = len(cond_col_score[0])
        truth_prob = np.zeros((B, T), dtype=np.float32)
        for b in range(B):
            if len(truth_num[b][4]) > 0:
                truth_prob[b][list(truth_num[b][4])] = 1  #条件列
        data = torch.from_numpy(truth_prob)
        if self.gpu:
            cond_col_truth_var = Variable(data.cuda())
        else:
            cond_col_truth_var = Variable(data)

        sigm = nn.Sigmoid()
        cond_col_prob = sigm(cond_col_score)
        bce_loss = -torch.mean(
            3 * (cond_col_truth_var * torch.log(cond_col_prob + 1e-10)) +
            (1 - cond_col_truth_var) * torch.log(1 - cond_col_prob + 1e-10))
        loss += bce_loss

        # Evaluate the operator of conditions评估操作条件
        for b in range(len(truth_num)):
            if len(truth_num[b][5]) == 0:  #条件类型
                continue
            data = torch.from_numpy(np.array(truth_num[b][5])).long()
            if self.gpu:
                cond_op_truth_var = Variable(data.cuda())
            else:
                cond_op_truth_var = Variable(data)
            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
            # try:
            loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
            # except:
            #     print(cond_op_pred)
            #     print(cond_op_truth_var)
            #     exit(0)

        #Evaluate the strings of conditions评估条件串
        for b in range(len(gt_where)):
            for idx in range(len(gt_where[b])):
                cond_str_truth = gt_where[b][idx]
                if len(cond_str_truth) == 1:
                    continue
                data = torch.from_numpy(np.array(cond_str_truth[1:])).long()
                if self.gpu:
                    cond_str_truth_var = Variable(data.cuda())
                else:
                    cond_str_truth_var = Variable(data)
                str_end = len(cond_str_truth) - 1
                cond_str_pred = cond_str_score[b, idx, :str_end]
                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
                        / (len(gt_where) * len(gt_where[b])))

        # Evaluate condition relationship, and / or评估条件关系
        where_rela_truth = list(map(lambda x: x[6], truth_num))
        data = torch.from_numpy(np.array(where_rela_truth)).long()
        if self.gpu:
            try:
                where_rela_truth = Variable(data.cuda())
            except:
                print("where_rela_truth error")
                print(data)
                exit(0)
        else:
            where_rela_truth = Variable(data)
        loss += self.CE(where_rela_score, where_rela_truth)
        return loss

    def check_acc(self, vis_info, pred_queries, gt_queries):
        def gen_cond_str(conds, header):
            if len(conds) == 0:
                return 'None'
            cond_str = []
            for cond in conds:
                cond_str.append(header[cond[0]] + ' ' +
                                self.COND_OPS[cond[1]] + ' ' +
                                unicode(cond[2]).lower())
            return 'WHERE ' + ' AND '.join(cond_str)

        tot_err = sel_num_err = agg_err = sel_err = 0.0
        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
            good = True
            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry[
                'agg'], pred_qry['cond_conn_op']
            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry[
                'agg'], gt_qry['cond_conn_op']

            if where_rela_gt != where_rela_pred:
                good = False
                cond_rela_err += 1

            if len(sel_pred) != len(sel_gt):
                good = False
                sel_num_err += 1

            pred_sel_dict = {
                k: v
                for k, v in zip(list(sel_pred), list(agg_pred))
            }
            gt_sel_dict = {k: v for k, v in zip(sel_gt, agg_gt)}
            if set(sel_pred) != set(sel_gt):
                good = False
                sel_err += 1
            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
            if agg_pred != agg_gt:
                good = False
                agg_err += 1

            cond_pred = pred_qry['conds']
            cond_gt = gt_qry['conds']
            if len(cond_pred) != len(cond_gt):
                good = False
                cond_num_err += 1
            else:
                cond_op_pred, cond_op_gt = {}, {}
                cond_val_pred, cond_val_gt = {}, {}
                for p, g in zip(cond_pred, cond_gt):
                    cond_op_pred[p[0]] = p[1]
                    cond_val_pred[p[0]] = p[2]
                    cond_op_gt[g[0]] = g[1]
                    cond_val_gt[g[0]] = g[2]

                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
                    cond_col_err += 1
                    good = False

                where_op_pred = [
                    cond_op_pred[x] for x in sorted(cond_op_pred.keys())
                ]
                where_op_gt = [
                    cond_op_gt[x] for x in sorted(cond_op_gt.keys())
                ]
                if where_op_pred != where_op_gt:
                    cond_op_err += 1
                    good = False

                where_val_pred = [
                    cond_val_pred[x] for x in sorted(cond_val_pred.keys())
                ]
                where_val_gt = [
                    cond_val_gt[x] for x in sorted(cond_val_gt.keys())
                ]
                if where_val_pred != where_val_gt:
                    cond_val_err += 1
                    good = False

            if not good:
                tot_err += 1

        return np.array(
            (sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err,
             cond_op_err, cond_val_err, cond_rela_err)), tot_err

    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
        """
        :param score:
        :param q: token-questions
        :param col: token-headers
        :param raw_q: original question sequence
        :return:
        """
        def merge_tokens(tok_list, raw_tok_str):
            tok_str = raw_tok_str  # .lower()
            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
            special = {
                '-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '``': '"',
                '\'\'': '"',
                '--': u'\u2013'
            }
            ret = ''
            double_quote_appear = 0
            for raw_tok in tok_list:
                if not raw_tok:
                    continue
                tok = special.get(raw_tok, raw_tok)
                if tok == '"':
                    double_quote_appear = 1 - double_quote_appear
                if len(ret) == 0:
                    pass
                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
                    ret = ret + ' '
                elif len(ret) > 0 and ret + tok in tok_str:
                    pass
                elif tok == '"':
                    if double_quote_appear:
                        ret = ret + ' '
                # elif tok[0] not in alphabet:
                #     pass
                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
                        and (ret[-1] != '"' or not double_quote_appear):
                    ret = ret + ' '
                ret = ret + tok
            return ret.strip()

        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
        # [64,4,6], [64,14], ..., [64,4]
        sel_num_score = sel_num_score.data.cpu().numpy()
        sel_score = sel_score.data.cpu().numpy()
        agg_score = agg_score.data.cpu().numpy()
        where_rela_score = where_rela_score.data.cpu().numpy()
        ret_queries = []
        B = len(agg_score)
        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
            [x.data.cpu().numpy() for x in cond_score]
        for b in range(B):
            cur_query = {}
            cur_query['sel'] = []
            cur_query['agg'] = []
            sel_num = np.argmax(sel_num_score[b])
            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
            # find the most-probable columns' indexes
            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
            cur_query['sel'].extend([int(i) for i in max_col_idxes])
            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
            cur_query['conds'] = []
            cond_num = np.argmax(cond_num_score[b])
            all_toks = ['<BEG>'] + q[b] + ['<END>']
            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
            for idx in range(cond_num):
                cur_cond = []
                cur_cond.append(max_idxes[idx])  # where-col
                cur_cond.append(np.argmax(cond_op_score[b][idx]))  # where-op
                cur_cond_str_toks = []
                for str_score in cond_str_score[b][idx]:
                    str_tok = np.argmax(str_score[:len(all_toks)])
                    str_val = all_toks[str_tok]
                    if str_val == '<END>':
                        break
                    cur_cond_str_toks.append(str_val)
                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
                cur_query['conds'].append(cur_cond)
            ret_queries.append(cur_query)
        return ret_queries