예제 #1
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_len=None,
                col_name_len=None,
                x_type_emb_var=None,
                gt_sel=None,
                sel_cond_score=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2)

        chosen_sel_col_gt = []
        if gt_sel is None:
            if sel_cond_score is None:
                raise Exception(
                    """In the test mode, sel_num_score and sel_col_score
                                        should be passed in order to predict aggregation!"""
                )
            sel_num_score, _, sel_score, _ = sel_cond_score
            sel_nums = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
            sel_col_scores = sel_score.data.cpu().numpy()
            chosen_sel_col_gt = [
                list(np.argsort(-sel_col_scores[b])[:sel_nums[b]])
                for b in range(len(sel_nums))
            ]
        else:
            chosen_sel_col_gt = [[x for x in one_gt_sel]
                                 for one_gt_sel in gt_sel]

        h_enc, _ = run_lstm(self.agg_lstm, x_emb_concat, x_len)
        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.agg_col_name_enc)
        #e_col, _ = run_lstm(self.agg_col_name_enc, col_inp_var, col_len)

        sel_col_emb = []
        for b in range(B):
            cur_sel_col_emb = torch.stack(
                [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] *
                (4 - len(chosen_sel_col_gt[b]))
            )  # Pad the columns to maximum (4)
            sel_col_emb.append(cur_sel_col_emb)
        sel_col_emb = torch.stack(sel_col_emb)

        agg_att_val = torch.matmul(
            self.agg_att(h_enc).unsqueeze(1),
            sel_col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view(B * 4, -1)).view(B, 4, -1)
        K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2)

        agg_score = self.agg_out(
            self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze()

        return agg_score
예제 #2
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict agg number
        att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num)
        q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, num:] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num)
        hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1)
        # agg_num_score: (B, 4)
        agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) +
                                         int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) +
                                         self.agg_num_out_c(col_emb)) / self.T1

        # Predict aggregators
        att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        att_prob_qc = self.softmax(att_val_qc)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
        # agg_score: (B, 5)
        agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) +
                                 self.agg_out_c(col_emb)) / self.T2

        score = (agg_num_score, agg_score)

        return score
예제 #3
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        # compute the E_col
        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)

        if self.use_ca:
            # compute the hidden states output of lstm corresponding to question
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            # compute v
            # using multiplicative attention
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))

            h_ext = h_enc.unsqueeze(1).unsqueeze(1)
            col_ext = e_col.unsqueeze(2).unsqueeze(2)
            # additive attention
            att_val = self.sel_att_out(
                self.sel_W1(h_ext) + self.sel_W2(col_ext)).squeeze()

            # print()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            # L-dimension attention weight
            att = self.softmax(att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            # compute E_Q|col
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            # normal method without column attention
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand = K_sel.unsqueeze(1)

        # compute P_selcol(i|Q)
        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
예제 #4
0
    def forward(self, x_emb_var, x_len, col_inp_var,
            col_name_len, col_len, col_num, dropout_rate=0.):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.sel_col_name_enc, dropout_rate=dropout_rate)

        if self.use_ca:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
            att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, :, num:] = -100
            att = self.softmax(att_val.view((-1, max_x_len))).view(
                    B, -1, max_x_len)
            K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        else:
            h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate)
            att_val = self.sel_att(h_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    att_val[idx, num:] = -100
            att = self.softmax(att_val)
            K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
            K_sel_expand=K_sel.unsqueeze(1)

        if dropout_rate > 0.:
            # K_sel_expand: [batch_size, col_size, hid_size]
            K_sel_expand_mask = torch.FloatTensor(K_sel_expand.size()[-1]).view(1, 1, -1)\
                .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate)
            if K_sel_expand.is_cuda:
                K_sel_expand_mask = K_sel_expand_mask.cuda()
            K_sel_expand.data = K_sel_expand.data * K_sel_expand_mask

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col) ).squeeze() / self.T
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
예제 #5
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        h_enc, _ = run_lstm(
            self.agg_lstm, x_emb_var,
            x_len)  # this is the hidden state for each token in the question

        if self.use_ca:
            e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.agg_col_name_enc)
            # the rest is the decoder portion correct??
            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc),
                                chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[
                    idx,
                    num:] = -100  # make sure the padded numbers have softmax ~= 0
        att = self.softmax(att_val)

        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_item_score = self.agg_out(K_agg)
        return (agg_num_score, agg_item_score)
예제 #6
0
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)
        # compute the hidden states output of lstm corresponding to question
        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
        if self.use_ca:
            #  compute the E_col
            e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                       self.agg_col_name_enc)

            chosen_sel_idx = torch.LongTensor(gt_sel)
            aux_range = torch.LongTensor(range(len(gt_sel)))
            if x_emb_var.is_cuda:
                chosen_sel_idx = chosen_sel_idx.cuda()
                aux_range = aux_range.cuda()
            # chosen_e_col is v
            chosen_e_col = e_col[aux_range, chosen_sel_idx]
            att_val = torch.bmm(self.agg_att(h_enc),
                                chosen_e_col.unsqueeze(2)).squeeze()
        else:
            att_val = self.agg_att(h_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        # att is attention weight
        att = self.softmax(att_val)

        # K_agg is E_Q|col
        K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1)
        agg_score = self.agg_out(K_agg)
        return agg_score
    def forward(self, x_emb_var, x_len, col_inp_var,
            col_name_len, col_len, col_num):
        '''
        Based on number of selections to predict select-column
        input: 
            x_emb_var: embedding of each question
            col_inp_var: embedding of each header
            col_name_len: length of each header
            col_len: number of headers in each table, array type
            col_num: number of headers in each table, list type
        
        '''
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid]
        h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid]  h_enc is output, used for attention object 
        # e_col: [batch_size(16), max_num_of_col_in_train_tab, hidden_size(100)]
        # h_enc: [batch_size(16), max_len_of_question, hidden_size(100)]
        # att_val: [bs[16], max_num_of_col_in_train_tab, max_len_of_question]
        att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len]
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                # column hidden status will have new value when attention on the question,while the some part of 
                # question is of no use on attention calculate.
                att_val[idx, :, num:] = -100
        att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len)
        K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)

        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
예제 #8
0
    def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num,
                col_name_len, gt_cond):
        max_q_len = max(q_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # Predict column number: 0-4
        # att_val_qc_num: (B, max_col_len, max_q_len)
        att_val_qc_num = torch.bmm(col_enc,
                                   self.q_num_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_qc_num[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, :, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_num = (q_enc.unsqueeze(1) *
                          att_prob_qc_num.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze()
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100
        # get select columns for op prediction
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(col_num_score.data.cpu().numpy(), axis=1)
            col_scores = col_score.data.cpu().numpy()
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [col_enc[b, x] for x in chosen_col_gt[b]] + [col_enc[b, 0]] *
                (5 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        # Predict op
        op_att_val = torch.matmul(
            self.op_att(q_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze()
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                op_att_val[idx, :, num:] = -100
        op_att = self.softmax(op_att_val.view(-1, max_q_len)).view(
            B, -1, max_q_len)
        q_weighted_op = (q_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)

        op_score = self.op_out(
            self.op_out_q(q_weighted_op) + self.op_out_c(col_emb)).squeeze()

        score = (col_num_score, col_score, op_score)

        return score
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                x_type_emb_var, gt_sel):
        max_x_len = max(x_len)
        max_col_len = max(col_len)
        B = len(x_len)
        print("x_size:{},x_type_size:{}".format(x_emb_var.size(),
                                                x_type_emb_var.size()))
        x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2)
        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.selcond_name_enc)
        # e_col, _ = run_lstm(self.selcond_name_enc, col_inp_var, col_len)
        h_enc, _ = run_lstm(self.selcond_lstm, x_emb_concat, x_len)

        # Predict the number of selected columns
        # att_sel_num_type_val:(B, max_col_len, max_x_len)
        att_sel_num_type_val = torch.bmm(
            e_col,
            self.sel_num_type_att(h_enc).transpose(1, 2))

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_sel_num_type_val[idx, num:, :] = -100
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_sel_num_type_val[idx, :, num:] = -100

        # att_sel_num_type: (B, max_col_len, max_x_len)
        att_sel_num_type = self.softmax(
            att_sel_num_type_val.view((-1, max_x_len))).view(B, -1, max_x_len)
        # h_enc.unsqueeze(1): (B, 1, max_x_len, hid_dim)
        # att_sel_num_type.unsqueeze(3): (B, max_col_len, max_x_len, 1)
        # K_num_type (B, max_col_len, hid_dim)
        K_sel_num_type = (h_enc.unsqueeze(1) *
                          att_sel_num_type.unsqueeze(3)).sum(2).sum(1)
        # K_sel_num: (B, hid_dim)
        # K_sel_num_type (B, hid_dim)
        sel_num_score = self.sel_num_out(self.ty_sel_num_out(K_sel_num_type))

        #Predict the selection condition
        #att_val: (B, max_col_len, max_x_len)
        sel_att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2))
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                sel_att_val[idx, :, num:] = -100
        sel_att = self.softmax(sel_att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        #print(sel_att.size())
        #K_sel_expand -> (B, max_number of col names in batch tables, hid_dim)
        K_sel_expand = (h_enc.unsqueeze(1) * sel_att.unsqueeze(3)).sum(2)
        sel_score = self.sel_out(self.sel_out_K(K_sel_expand) + \
                self.sel_out_col(e_col)).squeeze()

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                sel_score[idx, num:] = -100

        # Predict the number of conditions
        #att_cond_num_type_val:(B, max_col_len, max_x_len)
        att_cond_num_type_val = torch.bmm(
            e_col,
            self.cond_num_type_att(h_enc).transpose(1, 2))

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_cond_num_type_val[idx, num:, :] = -100
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_cond_num_type_val[idx, :, num:] = -100

        #att_cond_num_type: (B, max_col_len, max_x_len)
        att_cond_num_type = self.softmax(
            att_cond_num_type_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
        #h_enc.unsqueeze(1): (B, 1, max_x_len, hid_dim)
        #att_cond_num_type.unsqueeze(3): (B, max_col_len, max_x_len, 1)
        #K_num_type (B, max_col_len, hid_dim)
        K_cond_num_type = (h_enc.unsqueeze(1) *
                           att_cond_num_type.unsqueeze(3)).sum(2).sum(1)
        #K_cond_num: (B, hid_dim)
        #K_cond_num_type (B, hid_dim)
        cond_num_score = self.cond_num_out(
            self.ty_cond_num_out(K_cond_num_type))

        #Predict the columns of conditions
        if gt_sel is None:
            num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) + 1
            sel = sel_score.data.cpu().numpy()
            # gt_sel = np.argmax(sel_score.data.cpu().numpy(), axis=1)
            chosen_sel_col_gt = [
                list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))
            ]
        else:
            chosen_sel_col_gt = [[x for x in one_gt_sel]
                                 for one_gt_sel in gt_sel]

        sel_col_emb = []
        for b in range(B):
            cur_sel_col_emb = torch.stack(
                [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] *
                (max_col_len - len(chosen_sel_col_gt[b]))
            )  # Pad the columns to maximum (4)
            sel_col_emb.append(cur_sel_col_emb)
        sel_col_emb = torch.stack(sel_col_emb)

        # chosen_sel_idx = torch.LongTensor(gt_sel)
        #aux_range (B) (0,1,...)
        # aux_range = torch.LongTensor(range(len(gt_sel)))
        # if x_emb_var.is_cuda:
        #     chosen_sel_idx = chosen_sel_idx.cuda()
        #     aux_range = aux_range.cuda()
        #chosen_e_col: (B, hid_dim)
        # chosen_e_col = e_col[aux_range, chosen_sel_idx]
        #chosen_e_col.unsqueeze(2): (B, hid_dim, 1)
        #self.col_att(h_enc): (B, max_x_len, hid_dim)
        #att_sel_val: (B, max_x_len)

        # K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2)
        #
        # agg_score = self.agg_out(self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze()
        #
        #
        # att_sel_val = torch.bmm(self.col_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze()
        att_sel_val = torch.matmul(
            self.col_att(h_enc).unsqueeze(1),
            sel_col_emb.unsqueeze(3)).squeeze()
        col_att_val = torch.bmm(e_col,
                                self.cond_col_att(h_enc).transpose(1, 2))
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                col_att_val[idx, :, num:] = -100
                att_sel_val[idx, num:] = -100

        sel_att = self.softmax(att_sel_val.view(B * max_col_len,
                                                -1)).view(B, max_col_len, -1)
        #K_sel_agg = (h_enc * sel_att.unsqueeze(2).expand_as(h_enc)).sum(1)
        K_sel_agg = (h_enc.unsqueeze(1) * sel_att.unsqueeze(3)).sum(2)
        col_att = self.softmax(col_att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_cond_col = (h_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
        cond_col_score = self.cond_col_out(
            self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_col) +
            self.cond_col_out_sel(K_sel_agg.expand_as(K_cond_col))).squeeze()

        for b, num in enumerate(col_len):
            if num < max_col_len:
                cond_col_score[b, num:] = -100

        sel_cond_score = (sel_num_score, cond_num_score, sel_score,
                          cond_col_score)

        return sel_cond_score
예제 #10
0
    def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num, col_name_len):
        max_q_len = max(q_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)


        # Predict number
        gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                gby_num_att[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                gby_num_att[idx, :, num:] = -100

        gby_num_att_val = self.softmax(gby_num_att.view((-1, max_q_len))).view(B, -1, max_q_len)
        gby_num_K = (q_enc.unsqueeze(1) * gby_num_att_val.unsqueeze(3)).sum(2).sum(1)
        ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view((-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze()
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        # Predict aggregation
        agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                agg_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view((-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        agg_score = self.agg_out(self.agg_out_q(q_weighted_agg))

        # Predict desc asc limit
        dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                dat_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                dat_att_val[idx, :, num:] = -100
        dat_att = self.softmax(dat_att_val.view((-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_dat = (q_enc.unsqueeze(1) * dat_att.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        dat_score = self.dat_out(self.dat_out_q(q_weighted_dat))

        score = (ody_num_score, col_score, agg_score, dat_score)

        return score
    def forward(self,
                perm,
                st,
                ed,
                q_emb_var,
                q_len,
                col_emb_var,
                col_len,
                col_num,
                col_name_len,
                q_seq,
                col_seq,
                emb_layer,
                train=True):
        max_q_len = max(q_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # Predict number
        gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                gby_num_att[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                gby_num_att[idx, :, num:] = -100

        gby_num_att_val = self.softmax(gby_num_att.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        gby_num_K = (q_enc.unsqueeze(1) *
                     gby_num_att_val.unsqueeze(3)).sum(2).sum(1)
        ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze()
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        # Predict aggregation
        agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                agg_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                agg_att_val[idx, :, num:] = -100
        agg_att = self.softmax(agg_att_val.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_agg = (q_enc.unsqueeze(1) *
                          agg_att.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 4)
        agg_score = self.agg_out(self.agg_out_q(q_weighted_agg))

        # Predict desc asc limit
        dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                dat_att_val[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                dat_att_val[idx, :, num:] = -100
        dat_att = self.softmax(dat_att_val.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_dat = (q_enc.unsqueeze(1) *
                          dat_att.unsqueeze(3)).sum(2).sum(1)

        # self.col_num_out: (B, 4)

        col_scores = col_score.data.cpu().numpy()
        chosen_col_gt = [np.argmax(col_scores[b]) for b in range(B)]

        assert B == ed - st
        dirc_vecs = torch.zeros([B, 50])
        zero_feats = torch.zeros([50])
        if self.gpu:
            dirc_vecs = dirc_vecs.cuda()
            zero_feats = zero_feats.cuda()
        dirc_vecs = Variable(dirc_vecs, requires_grad=False)
        for b in range(st, ed):
            idx = perm[b]
            gt_col = chosen_col_gt[b - st]
            dirc_feat = emb_layer.get_direction_feature(
                max_q_len, idx, gt_col, train)
            if self.feats_format == 'direct':
                # [max_len] (-1/0/1)
                mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
                mask_i = mask.cpu().data.numpy()[0]
                if mask_i > 0:
                    dirc_vec = dirc_feat[1]
                elif mask_i < 0:
                    dirc_vec = dirc_feat[2]
                else:
                    dirc_vec = zero_feats
                dirc_vec = Variable(dirc_vec, requires_grad=False)
            else:
                # [max_len, len(feats)]
                dirc_vec = torch.matmul(
                    att_prob_qc[b - st, gt_col].unsqueeze(0),
                    dirc_feat).squeeze()
            dirc_vecs[b - st] = dirc_vec


        dat_score = self.dat_out(self.dat_out_q(q_weighted_dat)) + \
                   self.dat_out_dirc_out(self.dat_out_dirc(dirc_vecs))

        score = (ody_num_score, col_score, agg_score, dat_score)

        return score