Esempio n. 1
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)

        att_np_q = np.ones((B, max_q_len))
        att_val_q = torch.from_numpy(att_np_q).float()
        att_val_q = Variable(att_val_q.cuda())
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_q[idx, num:] = -100
        att_prob_q = self.softmax(att_val_q)
        q_weighted = (q_enc * att_prob_q.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_np_h = np.ones((B, max_hs_len))
        att_val_h = torch.from_numpy(att_np_h).float()
        att_val_h = Variable(att_val_h.cuda())
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_h[idx, num:] = -100
        att_prob_h = self.softmax(att_val_h)
        hs_weighted = (hs_enc * att_prob_h.unsqueeze(2)).sum(1)
        # ao_score: (B, 2)
        ao_score = self.ao_out(
            self.ao_out_q(q_weighted) +
            int(self.use_hs) * self.ao_out_hs(hs_weighted))

        # 06/14/2019: add softmax layer
        ao_score = F.softmax(ao_score)

        return ao_score
Esempio n. 2
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)

        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            att = self.softmax(att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand = K_sel.unsqueeze(1)

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)
        att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        att_prob_qc = self.softmax(att_val_qc)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
        # rt_score: (B, 2)
        rt_score = self.rt_out(self.rt_out_q(q_weighted) + int(self.use_hs)* self.rt_out_hs(hs_weighted) + self.rt_out_c(col_emb))

        return rt_score
Esempio n. 4
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict agg number
        att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num)
        q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, num:] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num)
        hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1)
        # agg_num_score: (B, 4)
        agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) +
                                         int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) +
                                         self.agg_num_out_c(col_emb)) / self.T1

        # Predict aggregators
        att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        att_prob_qc = self.softmax(att_val_qc)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
        # agg_score: (B, 5)
        agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) +
                                 self.agg_out_c(col_emb)) / self.T2

        score = (agg_num_score, agg_score)

        return score
Esempio n. 5
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var,
                mkw_len):
        # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size()))
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        # q_enc: (B, max_q_len, hid_dim)
        # hs_enc: (B, max_hs_len, hid_dim)
        # mkw: (B, 4, hid_dim)
        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len)

        # Compute attention values between multi SQL key words and question tokens.
        # qmkw_att(q_enc).transpose(1, 2): (B, hid_dim, max_q_len)
        # att_val_qmkw: (B, 4, max_q_len)
        # print("mkw_enc {} q_enc {}".format(mkw_enc.size(), self.q_att(q_enc).transpose(1, 2).size()))
        att_val_qmkw = torch.bmm(mkw_enc, self.q_att(q_enc).transpose(1, 2))
        # assign appended positions values -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qmkw[idx, :, num:] = -100
        # att_prob_qmkw: (B, 4, max_q_len)
        att_prob_qmkw = self.softmax(att_val_qmkw.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_enc.unsqueeze(1): (B, 1, max_q_len, hid_dim)
        # att_prob_qmkw.unsqueeze(3): (B, 4, max_q_len, 1)
        # q_weighted: (B, 4, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qmkw.unsqueeze(3)).sum(2)

        # Same as the above, compute SQL history embedding weighted by key words attentions
        att_val_hsmkw = torch.bmm(mkw_enc, self.hs_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hsmkw[idx, :, num:] = -100
        att_prob_hsmkw = self.softmax(att_val_hsmkw.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted = (hs_enc.unsqueeze(1) *
                       att_prob_hsmkw.unsqueeze(3)).sum(2)

        # Compute prediction scores
        # self.multi_out.squeeze(): (B, 4, 1) -> (B, 4)
        mulit_score = self.multi_out(
            self.multi_out_q(q_weighted) +
            int(self.use_hs) * self.multi_out_hs(hs_weighted) +
            self.multi_out_c(mkw_enc)).view(B, -1)

        # 06/14/2019: add softmax layer
        mulit_score = F.softmax(mulit_score)

        return mulit_score
Esempio n. 6
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        # compute the E_col
        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)

        if self.use_ca:
            # compute the hidden states output of lstm corresponding to question
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            # compute v
            # using multiplicative attention
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))

            h_ext = h_enc.unsqueeze(1).unsqueeze(1)
            col_ext = e_col.unsqueeze(2).unsqueeze(2)
            # additive attention
            att_val = self.sel_att_out(
                self.sel_W1(h_ext) + self.sel_W2(col_ext)).squeeze()

            # print()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            # L-dimension attention weight
            att = self.softmax(att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            # compute E_Q|col
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            # normal method without column attention
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand = K_sel.unsqueeze(1)

        # compute P_selcol(i|Q)
        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
Esempio n. 7
0
    def forward(self,
                x_emb_var,
                x_len,
                agg_emb_var,
                col_inp_var=None,
                col_len=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)

        agg_enc = self.agg_out_agg(agg_emb_var)
        #agg_enc: (B, 6, hid_dim)
        #self.sel_att(h_enc) -> (B, max_x_len, hid_dim) .transpose(1, 2) -> (B, hid_dim, max_x_len)
        #att_val_agg: (B, 6, max_x_len)
        att_val_agg = torch.bmm(agg_enc, self.sel_att(h_enc).transpose(1, 2))

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val_agg[idx, :, num:] = -100

        #att_agg: (B, 6, max_x_len)
        att_agg = self.softmax(att_val_agg.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        #h_enc.unsqueeze(1) -> (B, 1, max_x_len, hid_dim)
        #att_agg.unsqueeze(3) -> (B, 6, max_x_len, 1)
        #K_agg_expand -> (B, 6, hid_dim)
        K_agg_expand = (h_enc.unsqueeze(1) * att_agg.unsqueeze(3)).sum(2)
        #agg_score = self.agg_out(K_agg)
        agg_score = self.agg_out_f(
            self.agg_out_se(agg_emb_var) +
            self.agg_out_K(K_agg_expand)).squeeze()

        return agg_score
Esempio n. 8
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        # Based on number of selections to predict select-column
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)  # [bs, col_num, hid]
        h_enc, _ = run_lstm(self.sel_lstm, x_emb_var,
                            x_len)  # [bs, seq_len, hid]

        att_val = torch.bmm(e_col,
                            self.sel_att(h_enc).transpose(
                                1, 2))  # [bs, col_num, seq_len]
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, :, num:] = -100
        att = self.softmax(att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)

        sel_score = self.sel_out(
            self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col)).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the condition relationship part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.where_rela_lstm)
        col_att_val = self.where_rela_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h // 2).transpose(
            0, 1).contiguous()
        h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h // 2).transpose(
            0, 1).contiguous()

        h_enc, _ = run_lstm(self.where_rela_lstm,
                            x_emb_var,
                            x_len,
                            hidden=(h1, h2))

        att_val = self.where_rela_att(h_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -1000000
        att_val = self.softmax(att_val)

        where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1)
        where_rela_score = self.where_rela_out(where_rela_num)
        return where_rela_score
Esempio n. 10
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the number of select part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.sel_num_lstm)
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()

        h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,
                                hidden=(sel_num_h1, sel_num_h2))

        num_att_val = self.sel_num_att(h_num_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -1000000
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(
            h_num_enc)).sum(1)
        sel_num_score = self.sel_num_out(K_sel_num)
        return sel_num_score
Esempio n. 11
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.agg_col_name_enc)
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc),
                                chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_score = self.agg_out(K_agg)
        return agg_score
Esempio n. 12
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_len=None,
                col_name_len=None,
                x_type_emb_var=None,
                gt_sel=None,
                sel_cond_score=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2)

        chosen_sel_col_gt = []
        if gt_sel is None:
            if sel_cond_score is None:
                raise Exception(
                    """In the test mode, sel_num_score and sel_col_score
                                        should be passed in order to predict aggregation!"""
                )
            sel_num_score, _, sel_score, _ = sel_cond_score
            sel_nums = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
            sel_col_scores = sel_score.data.cpu().numpy()
            chosen_sel_col_gt = [
                list(np.argsort(-sel_col_scores[b])[:sel_nums[b]])
                for b in range(len(sel_nums))
            ]
        else:
            chosen_sel_col_gt = [[x for x in one_gt_sel]
                                 for one_gt_sel in gt_sel]

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_concat, x_len)
        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.agg_col_name_enc)
        #e_col, _ = run_lstm(self.agg_col_name_enc, col_inp_var, col_len)

        sel_col_emb = []
        for b in range(B):
            cur_sel_col_emb = torch.stack(
                [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] *
                (4 - len(chosen_sel_col_gt[b]))
            )  # Pad the columns to maximum (4)
            sel_col_emb.append(cur_sel_col_emb)
        sel_col_emb = torch.stack(sel_col_emb)

        agg_att_val = torch.matmul(
            self.agg_att(h_enc).unsqueeze(1),
            sel_col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view(B * 4, -1)).view(B, 4, -1)
        K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2)

        agg_score = self.agg_out(
            self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze()

        return agg_score
Esempio n. 13
0
    def forward(self, x_emb_var, x_len, col_inp_var,
            col_name_len, col_len, col_num, dropout_rate=0.):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.sel_col_name_enc, dropout_rate=dropout_rate)

        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            att = self.softmax(att_val.view((-1, max_x_len))).view(
                    B, -1, max_x_len)
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand=K_sel.unsqueeze(1)

        if dropout_rate > 0.:
            # K_sel_expand: [batch_size, col_size, hid_size]
            K_sel_expand_mask = torch.FloatTensor(K_sel_expand.size()[-1]).view(1, 1, -1)\
                .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate)
            if K_sel_expand.is_cuda:
                K_sel_expand_mask = K_sel_expand_mask.cuda()
            K_sel_expand.data = K_sel_expand.data * K_sel_expand_mask

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze() / self.T
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
Esempio n. 14
0
    def forward(self,x_emb,x_len,col_input,col_token_num,col_len,hidden=None):
        
        batch_size=len(x_emb)
        max_x_len=max(x_len)
        
        emb_col , _ = column_encode(self.select_colname_enc,col_input,col_token_num,col_len)
        
        hidden = None
        if not hidden:

            h_enc, _ =run_lstm(self.select_lstm, x_emb,x_len )
        
        else:

            h_enc , _ = run_lstm(self.select_lstm,x_emb,x_len,hidden)

        
        #to compute the attention score
        attn_value=self.select_att(h_enc).squeeze(2)
        for idx,num in enumerate(x_len):
            if num<max_x_len:
                attn_value[idx,num:]=-100
        
        attention=F.softmax(attn_value,1)
        
        
        
        
        K_select=(h_enc*attention.unsqueeze(2).expand_as(h_enc)).sum(1)
        K_select_expand=K_select.unsqueeze(1)
        select_score = self.select_out(self.select_out_K(K_select_expand) + self.select_out_col(emb_col)).squeeze(2)		
        max_col_num=max(col_len)
        
        
        
        for idx,num in enumerate(col_len):
            if num<max_col_num:
                select_score[idx,num:]= -100
        
        return select_score
		
Esempio n. 15
0
    def forward(self, q, q_len, hidden):

        max_q_len = max(
            q_len
        )  # For the purpose of padding upto length of the largest question

        output, hidden = run_lstm(self.rnn, q, q_len, hidden)
        att_val = self.attn(output).squeeze(2)

        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val[
                    idx,
                    num:] = -100  # Give attention value -100 to words that do not belong to question

        att = F.softmax(att_val, dim=1)
        k_agg = (output * att.unsqueeze(2).expand_as(output)).sum(1)
        agg_score = self.agg_out(k_agg)
        return agg_score
Esempio n. 16
0
    def op_forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
            col_len, chosen_col_gt, dropout_rate=0.):
        B = len(x_len)
        max_x_len = max(x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
                                        col_len, self.cond_op_name_enc, dropout_rate=dropout_rate)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
        if self.use_ca:
            op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
                                      col_emb.unsqueeze(3)).squeeze(-1)
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, :, num:] = -100
            op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
        else:
            op_att_val = self.cond_op_att(h_op_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, num:] = -100
            op_att = self.softmax(op_att_val)
            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)

        if dropout_rate > 0.:
            # K_cond_op: [batch_size, op_size, hid_size]
            K_cond_op_mask = torch.FloatTensor(K_cond_op.size()[-1]).view(1, 1, -1)\
                .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate)
            if K_cond_op.is_cuda:
                K_cond_op_mask = K_cond_op_mask.cuda()
            K_cond_op.data = K_cond_op.data * K_cond_op_mask

        cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
                                         self.cond_op_out_col(col_emb)).squeeze(-1) / self.T3

        return cond_op_score
Esempio n. 17
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_len)
        max_x_len = max(x_len)
        '''预测所选零件的数量,首先使用列嵌入来计算初始隐藏单元,然后运行LSTM并预测select number'''
        # Predict the number of select part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(
            col_inp_var, col_name_len, col_len, self.sel_num_lstm
        )  #e_numcol:对列进行编码[batch_size,max_len,embedding_size]
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze(
            -1)  #【16,14】batch=19,每个列有一个得分
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -1000000  #对于那些小与最大值的补齐,给一个特小的值
        num_col_att = self.softmax(num_col_att_val)  #softmax层【16,19】
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(
            1)  #[16,100]#对列编码后乘以它的softmax得分
        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()  #【4,16,50】
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()
        #对问题进行编码
        h_num_enc, _ = run_lstm(self.sel_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(sel_num_h1, sel_num_h2))  #【16,49,100】

        num_att_val = self.sel_num_att(h_num_enc).squeeze(
            -1)  #[batch.size, max_len]=【16,49】
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -1000000
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc *
                     num_att.unsqueeze(2).expand_as(h_num_enc)).sum(
                         1)  #【16,100】
        sel_num_score = self.sel_num_out(K_sel_num)  #【16,4】
        return sel_num_score  #对问题的编码
Esempio n. 18
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        # compute the hidden states output of lstm corresponding to question
        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
        if self.use_ca:
            #  compute the E_col
            e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.agg_col_name_enc)

            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            # chosen_e_col is v
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc),
                                chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        # att is attention weight
        att = self.softmax(att_val)

        # K_agg is E_Q|col
        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_score = self.agg_out(K_agg)
        return agg_score
Esempio n. 19
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        h_enc, _ = run_lstm(
            self.agg_lstm, x_emb_var,
            x_len)  # this is the hidden state for each token in the question

        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.agg_col_name_enc)
            # the rest is the decoder portion correct??
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc),
                                chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[
                    idx,
                    num:] = -100  # make sure the padded numbers have softmax ~= 0
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_item_score = self.agg_out(K_agg)
        return (agg_num_score, agg_item_score)
Esempio n. 20
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None,
                gt_sel_num=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.agg_col_name_enc)  #【16,19,100】
        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)  #[16,49,100]

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] +
                                      [e_col[b, 0]] * (4 - len(gt_sel[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)  #[16,4,100]

        att_val = torch.matmul(
            self.agg_att(h_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()  #[16,4,49]

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1)  #[16,4,49]

        K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        agg_score = self.agg_out(
            self.agg_out_K(K_agg) +
            self.col_out_col(col_emb))  #.squeeze()为了单个样本注释掉后面的.squeeze()
        return agg_score  #[16,4,6]
Esempio n. 21
0
    def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None,
            col_len=None, col_num=None, gt_sel=None, dropout_rate=0.):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, 
                    col_len, self.agg_col_name_enc, dropout_rate=dropout_rate)
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc), 
                    chosen_e_col.unsqueeze(2)).squeeze(-1) #squeeze dim=-1
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        if dropout_rate > 0.:
            # K_agg: [batch_size, hid_size]
            K_agg_mask = torch.FloatTensor(K_agg.size()[-1]).view(1, -1)\
                .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate)
            if K_agg.is_cuda:
                K_agg_mask = K_agg_mask.cuda()
            K_agg.data = K_agg.data * K_agg_mask

        agg_score = self.agg_out(K_agg) / self.T
        return agg_score
    def forward(self, x_emb_var, x_len, col_inp_var,
            col_name_len, col_len, col_num):
        '''
        Based on number of selections to predict select-column
        input: 
            x_emb_var: embedding of each question
            col_inp_var: embedding of each header
            col_name_len: length of each header
            col_len: number of headers in each table, array type
            col_num: number of headers in each table, list type
        
        '''
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid]
        h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid]  h_enc is output, used for attention object 
        # e_col: [batch_size(16), max_num_of_col_in_train_tab, hidden_size(100)]
        # h_enc: [batch_size(16), max_len_of_question, hidden_size(100)]
        # att_val: [bs[16], max_num_of_col_in_train_tab, max_len_of_question]
        att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len]
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                # column hidden status will have new value when attention on the question,while the some part of 
                # question is of no use on attention calculate.
                att_val[idx, :, num:] = -100
        att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len)
        K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
Esempio n. 23
0
    def forward_mult(self, x_emb_var, x_len, col_inp_var, col_name_len,
                     col_len, col_num):
        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        # exit(1)
        # debug_print('col_inp_var', col_inp_var)
        # debug_print('col_name_len', col_name_len)
        # debug_print('col_len', col_len)
        # debug_print('col_num', col_num)
        B = len(x_len)
        max_x_len = max(x_len)
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len,
                                             self.sel_col_num_name_enc)
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        # get a probability distribution of how many columns likely to be selected
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(
            1)  # not really sure what this is doing

        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(
            B, -1, self.N_h / 2
        ).transpose(0, 1).contiguous(
        )  # not really sure what the second dimension should be - previously was 4
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(
            B, -1, self.N_h / 2).transpose(0, 1).contiguous()
        h_num_enc, _ = run_lstm(self.col_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(sel_num_h1, sel_num_h2))
        num_att_val = self.sel_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc *
                     num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        sel_num_score = self.sel_num_out(K_sel_num)

        #Predict the columns of conditions
        e_sel_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.sel_col_name_enc)
        h_col_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            sel_att_val = torch.bmm(e_sel_col,
                                    self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)

            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            # col_att_val = self.cond_col_att(h_col_enc).squeeze()
            # for idx, num in enumerate(x_len):
            #     if num < max_x_len:
            #         col_att_val[idx, num:] = -100
            # col_att = self.softmax(col_att_val)
            # K_cond_col = (h_col_enc *
            #         col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)
            sel_att_val = self.sel_col_att(h_col_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    sel_att_val[idx, num:] = -100
            sel_att = self.softmax(sel_att_val)
            # print 'att_probabilities', sel_att
            K_sel_col = (h_col_enc *
                         sel_att_val.unsqueeze(2)).sum(1).unsqueeze(1)
        sel_col_score = self.sel_col_out(
            self.sel_out_K(K_sel_col) + self.sel_out_col(e_sel_col)).squeeze()
        # print 'sel_col_score', sel_col_score
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                sel_col_score[b, num:] = -100
        sel_score = (sel_num_score, sel_col_score)
        return sel_score
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var,
                col_len,
                x_type_emb_var,
                gt_where,
                gt_cond,
                sel_cond_score=None,
                x_pos_emb_var=None):
        max_x_len = max(x_len)
        max_col_len = max(col_len)
        B = len(x_len)

        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            if sel_cond_score is None:
                raise Exception(
                    """In the test mode, cond_num_score and cond_col_score
                                should be passed in order to predict condition op and str!"""
                )
            cond_num_score, _, cond_col_score = sel_cond_score
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        if self.types:
            # with type embeddings concatentation
            if self.POS:
                x_emb_concat = torch.cat(
                    (x_emb_var, x_type_emb_var, x_pos_emb_var), 2)
            else:
                x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2)
        else:
            if self.POS:
                x_emb_concat = torch.cat((x_emb_var, x_pos_emb_var), 2)
            else:
                x_emb_concat = x_emb_var

        h_enc, _ = run_lstm(self.cond_opstr_lstm, x_emb_concat, x_len)
        e_col, _ = run_lstm(self.cond_name_enc, col_inp_var, col_len)

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_col[b, x] for x in chosen_col_gt[b]] + [e_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)

        col_emb = torch.stack(col_emb)

        op_att_val = torch.matmul(
            self.cond_op_att(h_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                op_att_val[idx, :, num:] = -100
        op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
        K_cond_op = (h_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)

        cond_op_score = self.cond_op_out(
            self.cond_op_out_K(K_cond_op) +
            self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions

        if self.types:
            xt_str_enc = self.cond_str_x_type(x_type_emb_var)

        if self.POS:
            xpos_str_enc = self.cond_str_x_pos(x_pos_emb_var)

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_col[b, x] for x in chosen_col_gt[b]] +
                                      [e_col[b, 0]] *
                                      (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                gt_tok_seq.view(B * 4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_enc.unsqueeze(1).unsqueeze(1)

            ## CHANGES ## - ## FOR BERT IMPLEMENTATION ##
            # for BERT, don't use hidden representation of type embeddings
            #           and comment out line below
            #
            g_ext = g_str_s.unsqueeze(3)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)

            if self.types:
                # with type embeddings concatenation
                ht_ext = xt_str_enc.unsqueeze(1).unsqueeze(1)

                if self.POS:
                    hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1)
                    cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) +
                        self.cond_str_out_g(g_ext) +
                        self.cond_str_out_col(col_ext) +
                        self.cond_str_out_ht(ht_ext) +
                        self.cond_str_out_pos(hpos_ext)).squeeze()
                else:
                    cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) +
                        self.cond_str_out_g(g_ext) +
                        self.cond_str_out_col(col_ext) +
                        self.cond_str_out_ht(ht_ext)).squeeze()
            else:
                # without type embeddings concatenation
                if self.POS:
                    hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1)
                    cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) +
                        self.cond_str_out_g(g_ext) +
                        self.cond_str_out_col(col_ext) +
                        self.cond_str_out_pos(hpos_ext)).squeeze()
                else:
                    cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) +
                        self.cond_str_out_g(g_ext) +
                        self.cond_str_out_col(col_ext)).squeeze()

            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            #TODO: maybe we should store BERT's [CLS] and [SEP] tokens somewhere?
            init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:, 0,
                     0] = 1  #Set the <BEG> token #TODO: for BERT rather [CLS] token?
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                if self.types:
                    # with type embeddings concatenation
                    ht_ext = xt_str_enc.unsqueeze(1).unsqueeze(1)

                    if self.POS:
                        hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1)
                        cur_cond_str_score = self.cond_str_out(
                            self.cond_str_out_h(h_ext) +
                            self.cond_str_out_g(g_ext) +
                            self.cond_str_out_col(col_ext) +
                            self.cond_str_out_ht(ht_ext) +
                            self.cond_str_out_pos(hpos_ext)).squeeze()
                    else:
                        cur_cond_str_score = self.cond_str_out(
                            self.cond_str_out_h(h_ext) +
                            self.cond_str_out_g(g_ext) +
                            self.cond_str_out_col(col_ext) +
                            self.cond_str_out_ht(ht_ext)).squeeze()
                else:
                    # without type embeddings concatenation
                    if self.POS:
                        hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1)
                        cur_cond_str_score = self.cond_str_out(
                            self.cond_str_out_h(h_ext) +
                            self.cond_str_out_g(g_ext) +
                            self.cond_str_out_col(col_ext) +
                            self.cond_str_out_pos(hpos_ext)).squeeze()
                    else:
                        cur_cond_str_score = self.cond_str_out(
                            self.cond_str_out_h(h_ext) +
                            self.cond_str_out_g(g_ext) +
                            self.cond_str_out_col(col_ext)).squeeze()

                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B * 4,
                                                         max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B * 4, self.max_tok_num).scatter_(
                    1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_op_str_score = (cond_op_score, cond_str_score)

        return cond_op_str_score
Esempio n. 25
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict op number
        att_val_qc_num = torch.bmm(col_emb.unsqueeze(1),
                                   self.q_num_att(q_enc).transpose(1, 2)).view(
                                       B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num)
        q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc_num = torch.bmm(col_emb.unsqueeze(1),
                                   self.hs_num_att(hs_enc).transpose(1,
                                                                     2)).view(
                                                                         B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, num:] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num)
        hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1)
        # op_num_score: (B, 2)
        op_num_score = self.op_num_out(
            self.op_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.op_num_out_hs(hs_weighted_num) +
            self.op_num_out_c(col_emb))

        # Compute attention values between selected column and question tokens.
        # q_enc.transpose(1, 2): (B, hid_dim, max_q_len)
        # col_emb.unsqueeze(1): (B, 1, hid_dim)
        # att_val_qc: (B, max_q_len)
        # print("col_emb {} q_enc {}".format(col_emb.unsqueeze(1).size(),self.q_att(q_enc).transpose(1, 2).size()))
        att_val_qc = torch.bmm(col_emb.unsqueeze(1),
                               self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        # assign appended positions values -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        # att_prob_qc: (B, max_q_len)
        att_prob_qc = self.softmax(att_val_qc)
        # q_enc: (B, max_q_len, hid_dim)
        # att_prob_qc.unsqueeze(2): (B, max_q_len, 1)
        # q_weighted: (B, hid_dim)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1),
                               self.hs_att(hs_enc).transpose(1,
                                                             2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)

        # Compute prediction scores
        # op_score: (B, 10)
        op_score = self.op_out(
            self.op_out_q(q_weighted) +
            int(self.use_hs) * self.op_out_hs(hs_weighted) +
            self.op_out_c(col_emb))

        score = (op_num_score, op_score)

        return score
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len):

        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # Predict column number: 1-3
        # att_val_qc_num: (B, max_col_len, max_q_len)
        att_val_qc_num = torch.bmm(col_enc,
                                   self.q_num_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_qc_num[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, :, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_num = (q_enc.unsqueeze(1) *
                          att_prob_qc_num.unsqueeze(3)).sum(2).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        # att_val_hc_num: (B, max_col_len, max_hs_len)
        att_val_hc_num = torch.bmm(col_enc,
                                   self.hs_num_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, :, num:] = -100
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_hc_num[idx, num:, :] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted_num = (hs_enc.unsqueeze(1) *
                           att_prob_hc_num.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 3)
        col_num_score = self.col_num_out(
            self.col_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.col_num_out_hs(hs_weighted_num))

        for idx, num in enumerate(col_len):
            if num < 5:
                col_num_score[idx, num:] = -100

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_enc, self.hs_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, :, num:] = -100
        att_prob_hc = self.softmax(att_val_hc.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) +
            int(self.use_hs) * self.col_out_hs(hs_weighted) +
            self.col_out_c(col_enc)).view(B, -1)

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        score = (col_num_score, col_score)

        return score
Esempio n. 27
0
    def forward(self, q_emb_var, q_len, col_emb_var, col_len, x_type_emb_var):
        max_q_len = max(q_len)
        max_col_len = max(col_len)
        B = len(q_len)

        x_emb_concat = torch.cat((q_emb_var, x_type_emb_var), 2)
        q_enc, _ = run_lstm(self.q_lstm, x_emb_concat, q_len)
        col_enc, _ = run_lstm(self.col_lstm, col_emb_var, col_len)

        # Predict number
        gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                gby_num_att[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                gby_num_att[idx, :, num:] = -100

        gby_num_att_val = self.softmax(gby_num_att.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        gby_num_K = (q_enc.unsqueeze(1) *
                     gby_num_att_val.unsqueeze(3)).sum(2).sum(1)
        ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze()
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        # Predict aggregation
        agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                agg_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_agg = (q_enc.unsqueeze(1) *
                          agg_att.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        agg_score = self.agg_out(self.agg_out_q(q_weighted_agg))

        # Predict desc asc limit
        dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                dat_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                dat_att_val[idx, :, num:] = -100
        dat_att = self.softmax(dat_att_val.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_dat = (q_enc.unsqueeze(1) *
                          dat_att.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        dat_score = self.dat_out(self.dat_out_q(q_weighted_dat))

        score = (ody_num_score, col_score, agg_score, dat_score)

        return score
Esempio n. 28
0
    def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num,
                col_name_len, gt_sel):
        max_q_len = max(q_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)
        #col_enc, _ = run_lstm(self.col_lstm, col_emb_var, col_len)

        # Predict column number: 1-3
        # att_val_qc_num: (B, max_col_len, max_q_len)
        att_val_qc_num = torch.bmm(col_enc,
                                   self.q_num_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_qc_num[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, :, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_num = (q_enc.unsqueeze(1) *
                          att_prob_qc_num.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze()
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        # get select columns for agg prediction
        chosen_sel_gt = []
        if gt_sel is None:
            sel_nums = [
                x + 1 for x in list(
                    np.argmax(col_num_score.data.cpu().numpy(), axis=1))
            ]
            sel_col_scores = col_score.data.cpu().numpy()
            chosen_sel_gt = [
                list(np.argsort(-sel_col_scores[b])[:sel_nums[b]])
                for b in range(len(sel_nums))
            ]
        else:
            for x in gt_sel:
                curr = x[0]
                curr_sel = [curr]
                for col in x:
                    if col != curr:
                        curr_sel.append(col)
                chosen_sel_gt.append(curr_sel)

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [col_enc[b, x] for x in chosen_sel_gt[b]] + [col_enc[b, 0]] *
                (5 - len(chosen_sel_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)  # (B, 4, hd)

        # Predict aggregation
        # q_enc.unsqueeze(1): (B, 1, max_x_len, hd)
        # col_emb.unsqueeze(3): (B, 4, hd, 1)
        # agg_num_att_val.squeeze: (B, 4, max_x_len)
        agg_num_att_val = torch.matmul(
            self.agg_num_att(q_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                agg_num_att_val[idx, :, num:] = -100
        agg_num_att = self.softmax(agg_num_att_val.view(-1, max_q_len)).view(
            B, -1, max_q_len)
        q_weighted_agg_num = (q_enc.unsqueeze(1) *
                              agg_num_att.unsqueeze(3)).sum(2)
        # (B, 4, 4)
        agg_num_score = self.agg_num_out(
            self.agg_num_out_q(q_weighted_agg_num) +
            self.agg_num_out_c(col_emb)).squeeze()

        agg_att_val = torch.matmul(
            self.agg_att(q_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view(-1, max_q_len)).view(
            B, -1, max_q_len)
        q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2)

        agg_score = self.agg_out(
            self.agg_out_q(q_weighted_agg) +
            self.agg_out_c(col_emb)).squeeze()

        score = (col_num_score, col_score, agg_num_score, agg_score)

        return score
Esempio n. 29
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
            col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)

        h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len)
        decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2)
                for hid in hidden)
        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True)
            g_s, _ = run_lstm(self.cond_decoder,
                    gt_tok_seq, gt_tok_len, decoder_hidden)

            h_enc_expand = h_enc.unsqueeze(1)
            g_s_expand = g_s.unsqueeze(2)
            cond_score = self.cond_out( self.cond_out_h(h_enc_expand) +
                    self.cond_out_g(g_s_expand) ).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    cond_score[idx, :, num:] = -100
        else:
            h_enc_expand = h_enc.unsqueeze(1)
            scores = []
            choices = []
            done_set = set()

            t = 0
            init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:,0,7] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = decoder_hidden
            while len(done_set) < B and t < 100:
                g_s, cur_h = self.cond_decoder(cur_inp, cur_h)
                g_s_expand = g_s.unsqueeze(2)

                cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) +
                        self.cond_out_g(g_s_expand)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_score[b, num:] = -100
                scores.append(cur_cond_score)

                if not reinforce:
                    _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1)
                    ans_tok_var = ans_tok_var.unsqueeze(1)
                else:
                    ans_tok_var = self.softmax(cur_cond_score).multinomial()
                    choices.append(ans_tok_var)
                ans_tok = ans_tok_var.data.cpu()
                if self.gpu:  #To one-hot
                    cur_inp = Variable(torch.zeros(
                        B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda())
                else:
                    cur_inp = Variable(torch.zeros(
                        B, self.max_tok_num).scatter_(1, ans_tok, 1))
                cur_inp = cur_inp.unsqueeze(1)

                for idx, tok in enumerate(ans_tok.squeeze()):
                    if tok == 1:  #Find the <END> token
                        done_set.add(idx)
                t += 1

            cond_score = torch.stack(scores, 1)

        if reinforce:
            return cond_score, choices
        else:
            return cond_score
Esempio n. 30
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
            col_len, col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)
        if reinforce:
            raise NotImplementedError('Our model doesn\'t have RL')

        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
                B, 4, self.N_h//2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
                B, 4, self.N_h//2).transpose(0, 1).contiguous()

        h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)

        #Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)

        if self.use_ca:
            col_att_val = torch.bmm(e_cond_col,
                    self.cond_col_att(h_col_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, :, num:] = -100
            col_att = self.softmax(col_att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
        else:
            col_att_val = self.cond_col_att(h_col_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, num:] = -100
            col_att = self.softmax(col_att_val)
            K_cond_col = (h_col_enc *
                    col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) +
                self.cond_col_out_col(e_cond_col)).squeeze()
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100


        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]])
                    for b in range(len(cond_nums))]
        else:
            # print gt_cond
            chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond]

        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_cond_col[b, x] 
                for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if self.use_ca:
            op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
                    col_emb.unsqueeze(3)).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, :, num:] = -100
            op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1)
            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
        else:
            op_att_val = self.cond_op_att(h_op_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, num:] = -100
            op_att = self.softmax(op_att_val)
            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
                self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
                                      [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                    gt_tok_seq.view(B*4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            g_ext = g_str_s.unsqueeze(3)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)

            cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:,0,0] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                cur_cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext)
                        + self.cond_str_out_col(col_ext)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B*4, self.max_tok_num).scatter_(
                        1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_score = (cond_num_score,
                cond_col_score, cond_op_score, cond_str_score)

        return cond_score