Exemple #1
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)

        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            att = self.softmax(att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand = K_sel.unsqueeze(1)

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
Exemple #2
0
    def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None,
            col_len=None, col_num=None, gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, 
                    col_len, self.agg_col_name_enc)
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc), 
                    chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_score = self.agg_out(K_agg)
        return agg_score
Exemple #3
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        # Based on number of grpections to predict grpect-column
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.grp_col_name_enc)  # [bs, col_num, hid]
        h_enc, _ = run_lstm(self.grp_lstm, x_emb_var,
                            x_len)  # [bs, seq_len, hid]

        att_val = torch.bmm(e_col,
                            self.grp_att(h_enc).transpose(
                                1, 2))  # [bs, col_num, seq_len]
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, :, num:] = -100
        att = self.softmax(att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_grp_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)

        grp_score = self.grp_out(
            self.grp_out_K(K_grp_expand) + self.grp_out_col(e_col)).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                grp_score[idx, num:] = -100

        return grp_score
Exemple #4
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the condition relationship part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.where_rela_lstm)
        col_att_val = self.where_rela_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h // 2).transpose(
            0, 1).contiguous()
        h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h // 2).transpose(
            0, 1).contiguous()

        h_enc, _ = run_lstm(self.where_rela_lstm,
                            x_emb_var,
                            x_len,
                            hidden=(h1, h2))

        att_val = self.where_rela_att(h_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -1000000
        att_val = self.softmax(att_val)

        where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1)
        where_rela_score = self.where_rela_out(where_rela_num)
        return where_rela_score
Exemple #5
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the number of select part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.sel_num_lstm)
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()

        h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,hidden=(sel_num_h1, sel_num_h2))

        num_att_val = self.sel_num_att(h_num_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -1000000
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        sel_num_score = self.sel_num_out(K_sel_num)
        return sel_num_score
Exemple #6
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the number of select part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number

        # 过LSTM获取列名的embedding
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.sel_num_lstm)
        # 降低一个维度
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
        # 把序列长度小于最长的置为一个很大的负数,softmax的值会很小
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(num_col_att_val)
        # 对应位置置为0
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)

        # 过了一个全连接层然后转变了下维度,把attention后的结果作为新的state添加到问题的bilstm中作为hidden state
        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(
            (B, 4, self.N_h // 2)).transpose(0, 1).contiguous()
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(
            (B, 4, self.N_h // 2)).transpose(0, 1).contiguous()
        # [bs, max_x_len, N_h]
        h_num_enc, _ = run_lstm(self.sel_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(sel_num_h1, sel_num_h2))
        # [bs, max_x_len]
        num_att_val = self.sel_num_att(h_num_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -1000000
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc *
                     num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        sel_num_score = self.sel_num_out(K_sel_num)
        return sel_num_score
Exemple #7
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None,
                gt_sel_num=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.agg_col_name_enc)
        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] +
                                      [e_col[b, 0]] * (4 - len(gt_sel[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)
        #tmp=self.agg_att(h_enc).unsqueeze(1)
        att_val = torch.matmul(
            self.agg_att(h_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()  # .transpose(1,2))

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1)

        K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        agg_score = self.agg_out(
            self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze()
        return agg_score
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)
        if reinforce:
            raise NotImplementedError('Our model doesn\'t have RL')

        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()

        h_num_enc, _ = run_lstm(self.cond_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc *
                      num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)

        #Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)

        if self.use_ca:
            col_att_val = torch.bmm(
                e_cond_col,
                self.cond_col_att(h_col_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, :, num:] = -100
            col_att = self.softmax(col_att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
        else:
            col_att_val = self.cond_col_att(h_col_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, num:] = -100
            col_att = self.softmax(col_att_val)
            K_cond_col = (h_col_enc *
                          col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_col_score = self.cond_col_out(
            self.cond_col_out_K(K_cond_col) +
            self.cond_col_out_col(e_cond_col)).squeeze()
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100

        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            # print gt_cond
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x]
                 for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if self.use_ca:
            op_att_val = torch.matmul(
                self.cond_op_att(h_op_enc).unsqueeze(1),
                col_emb.unsqueeze(3)).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, :, num:] = -100
            op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
        else:
            op_att_val = self.cond_op_att(h_op_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, num:] = -100
            op_att = self.softmax(op_att_val)
            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_op_score = self.cond_op_out(
            self.cond_op_out_K(K_cond_op) +
            self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                gt_tok_seq.view(B * 4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_str_enc.unsqueeze(1).unsqueeze(
                1)  #question的encoding matrix
            g_ext = g_str_s.unsqueeze(3)  #value string encoding
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)  #col encoding

            cond_str_score = self.cond_str_out(
                self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                self.cond_str_out_col(col_ext)).squeeze()
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:, 0, 0] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                cur_cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B * 4,
                                                         max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B * 4, self.max_tok_num).scatter_(
                    1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_score = (cond_num_score, cond_col_score, cond_op_score,
                      cond_str_score)

        return cond_score
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                table_content, q_seq, gt_where, gt_cond):
        max_x_len = max(x_len)

        # batch size, x_len is a list for token nums of all queries
        B = len(x_len)

        # Predict the number of conditions
        # use column embeddings to calculate the initial hidden unit
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
            B, 2 * self.N_depth, self.N_h // 2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
            B, 2 * self.N_depth, self.N_h // 2).transpose(0, 1).contiguous()

        # Then run the LSTM and predict condition number.
        h_num_enc, _ = run_lstm(self.cond_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc *
                      num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)
        cond_num = torch.argmax(cond_num_score, dim=1)

        assert cond_num_score.shape == (B, 5)
        assert cond_num.shape[-1] == B

        # Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)

        col_att_val = torch.bmm(e_cond_col,
                                self.cond_col_att(h_col_enc).transpose(1, 2))

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                col_att_val[idx, :, num:] = -100

        col_att = self.softmax(col_att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)

        cond_col_score = self.cond_col_out(
            self.cond_col_out_K(K_cond_col) +
            self.cond_col_out_col(e_cond_col)).squeeze()

        # padding -100 for different sequence length
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100

        # use gt_cond when training, use predicted gt_cond when testing
        # chosen_col_gt = [[col_1, col_2, ...], ..., col chosen for last query]
        if gt_cond is None:
            # select num of condition from 0 to 4
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()

            # select the first cond_nums columns as prediction
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        # force max column index is under the table length
        assert len(chosen_col_gt) == B
        for b in range(B):
            table_length = len(table_content[b])
            for c_index, x in enumerate(chosen_col_gt[b]):
                if x > table_length - 1:
                    chosen_col_gt[b][c_index] = table_length - 1

        # Predict the operator of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []

        # Pad the columns to 4, stack([col_1_emb, col_2_emb, pad_0_emb, pad_0_emb])
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        op_att_val = torch.matmul(
            self.cond_op_att(h_op_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()

        # add pad for query embedding
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                op_att_val[idx, :, num:] = -100

        op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
        K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)

        cond_op_score = self.cond_op_out(
            self.cond_op_out_K(K_cond_op) +
            self.cond_op_out_col(col_emb)).squeeze()

        # cond_op = []
        # for b in range(B):
        #     cond_op.extend([cond_op_score[b, num] for num in range(cond_num[b])])
        # cond_op = torch.stack(cond_op)
        # print('cond_op.shape', cond_op.shape)
        # print('cond_num', cond_num)
        # cond_op = torch.argmax(cond_op, dim=-1)

        # cond_op_score=[batch size, max condition num for one query, operation num]
        assert cond_op_score.shape == (B, 4, 4)

        # Predict the value of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if self.use_table:
            # column_content = []
            # for b in range(B):
            #     column_content.append([table_content[b][x] for x in chosen_col_gt[b]])
            # content_emb, table_config = self.emb_layer.gen_table_batch(column_content)

            if gt_where is not None:
                assert len(gt_where) == 4
                gt_index, gt_value, condition_num, max_value_length = gt_where
                # condition num for one batch queries
                assert len(gt_value) == gt_index.shape[0]

                # gt_value_embed=[condition num, max value num, hidden state]
                gt_value_embd = self.emb_layer.condition_value_batch(
                    gt_value, max_value_length)
                # print('gt_value_embd.shape', gt_value_embd.shape)

                query_embed = torch.mean(x_emb_var, dim=1)
                queries_embed = torch.zeros([len(gt_index), 768])
                one_cond = 0
                for c, cond in enumerate(condition_num):
                    for _ in range(cond):
                        queries_embed[one_cond, :] = query_embed[c]
                        one_cond += 1
                assert one_cond == gt_index.shape[0]
                # print('queries_embed.shape', queries_embed.shape)

                gt_value_embd = gt_value_embd.transpose(0, 1)
                value_score = torch.matmul(gt_value_embd,
                                           queries_embed.transpose(0, 1))
                value_score = torch.diagonal(value_score,
                                             offset=0,
                                             dim1=-2,
                                             dim2=-1).transpose(0, 1)
                # print('value_score.shape', value_score.shape)

                # value_score=[condition num, max value length]
                for l, length in enumerate([len(x) for x in gt_value]):
                    if length < max_value_length:
                        value_score[l, length:] = -100

                value_utils = self.softmax(value_score)

            else:
                cond_value, max_value_length = self.cond_value_batch(
                    table_content, q_seq, cond_num, chosen_col_gt,
                    cond_op_score)

                cond_value_embd = self.emb_layer.condition_value_batch(
                    cond_value, max_value_length)

                query_embed = torch.mean(x_emb_var, dim=1)
                queries_embed = torch.zeros([len(cond_value), 768])

                one_cond = 0
                for c, cond in enumerate(cond_num):
                    for _ in range(cond):
                        queries_embed[one_cond, :] = query_embed[c]
                        one_cond += 1
                assert one_cond == len(cond_value)

                cond_value_embd = cond_value_embd.transpose(0, 1)
                value_score = torch.matmul(cond_value_embd,
                                           queries_embed.transpose(0, 1))
                value_score = torch.diagonal(value_score,
                                             offset=0,
                                             dim1=-2,
                                             dim2=-1).transpose(0, 1)
                for l, length in enumerate([len(x) for x in cond_value]):
                    if length < max_value_length:
                        value_score[l, length:] = -100

                softmax_score = self.softmax(value_score)
                value_utils = (softmax_score, cond_value, cond_num)

        else:
            # for training
            if gt_where is not None:
                # gt_where = [[[index range for value in query], column 2, ...],table 2, ...]
                gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
                g_str_s_flat, _ = self.cond_str_decoder(
                    gt_tok_seq.view(B * 4, -1, self.max_tok_num))
                g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

                h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
                g_ext = g_str_s.unsqueeze(3)
                col_ext = col_emb.unsqueeze(2).unsqueeze(2)

                cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()

                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        # [batch_size, condition num, T, TOK_NUM]
                        cond_str_score[b, :, :, num:] = -100

            # for validation and inference
            else:
                h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
                col_ext = col_emb.unsqueeze(2).unsqueeze(2)
                scores = []

                t = 0
                init_inp = np.zeros((B * 4, 1, self.max_tok_num),
                                    dtype=np.float32)
                # Set the <BEG> token
                init_inp[:, 0, 0] = 1
                if self.gpu:
                    cur_inp = torch.from_numpy(init_inp).to('cuda')
                else:
                    cur_inp = torch.from_numpy(init_inp)
                cur_h = None
                while t < 50:
                    if cur_h:
                        g_str_s_flat, cur_h = self.cond_str_decoder(
                            cur_inp, cur_h)
                    else:
                        g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                    g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                    g_ext = g_str_s.unsqueeze(3)

                    cur_cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) +
                        self.cond_str_out_g(g_ext) +
                        self.cond_str_out_col(col_ext)).squeeze()
                    for b, num in enumerate(x_len):
                        if num < max_x_len:
                            cur_cond_str_score[b, :, num:] = -100
                    scores.append(cur_cond_str_score)

                    _, ans_tok_var = cur_cond_str_score.view(B * 4,
                                                             max_x_len).max(1)
                    ans_tok = ans_tok_var.data.cpu()
                    cur_inp = torch.zeros(B * 4, self.max_tok_num).scatter_(
                        1, ans_tok.unsqueeze(1), 1)
                    if self.gpu:
                        cur_inp = cur_inp.to('cuda')

                    cur_inp = cur_inp.unsqueeze(1)

                    t += 1

                cond_str_score = torch.stack(scores, 2)
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cond_str_score[b, :, :, num:] = -100

        if self.use_table:
            cond_score = (cond_num_score, cond_col_score, cond_op_score,
                          value_utils)
        else:
            cond_score = (cond_num_score, cond_col_score, cond_op_score,
                          cond_str_score)

        return cond_score