コード例 #1
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num):
        # Based on number of selections to predict select-column
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.sel_col_name_enc)  # [bs, col_num, hid]
        h_enc, _ = run_lstm(self.sel_lstm, x_emb_var,
                            x_len)  # [bs, seq_len, hid]

        att_val = torch.bmm(e_col,
                            self.sel_att(h_enc).transpose(
                                1, 2))  # [bs, col_num, seq_len]
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, :, num:] = -100
        att = self.softmax(att_val.view(
            (-1, max_x_len))).view(B, -1, max_x_len)
        K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)

        sel_score = self.sel_out(
            self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col)).squeeze()
        max_col_num = max(col_num)
        for idx, num in enumerate(col_num):
            if num < max_col_num:
                sel_score[idx, num:] = -100

        return sel_score
コード例 #2
0
ファイル: where_relation.py プロジェクト: bearblog/nl2sql
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the condition relationship part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.where_rela_lstm)
        col_att_val = self.where_rela_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h//2).transpose(0,1).contiguous()
        h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h//2).transpose(0,1).contiguous()

        h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2))

        att_val = self.where_rela_att(h_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -1000000
        att_val = self.softmax(att_val)

        where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1)
        where_rela_score = self.where_rela_out(where_rela_num)
        return where_rela_score
コード例 #3
0
ファイル: select_number.py プロジェクト: bearblog/nl2sql
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
        B = len(x_len)
        max_x_len = max(x_len)

        # Predict the number of select part
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict select number
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.sel_num_lstm)
        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -1000000
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()
        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous()

        h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,hidden=(sel_num_h1, sel_num_h2))

        num_att_val = self.sel_num_att(h_num_enc).squeeze()
        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -1000000
        num_att = self.softmax(num_att_val)

        K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        sel_num_score = self.sel_num_out(K_sel_num)
        return sel_num_score
コード例 #4
0
ファイル: aggregator_predict.py プロジェクト: bearblog/nl2sql
    def forward(self,
                x_emb_var,
                x_len,
                col_inp_var=None,
                col_name_len=None,
                col_len=None,
                col_num=None,
                gt_sel=None,
                gt_sel_num=None):
        B = len(x_emb_var)
        max_x_len = max(x_len)

        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                   self.agg_col_name_enc)
        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)

        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] +
                                      [e_col[b, 0]] * (4 - len(gt_sel[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        att_val = torch.matmul(
            self.agg_att(h_enc).unsqueeze(1),
            col_emb.unsqueeze(3)).squeeze()  # .transpose(1,2))

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                att_val[idx, num:] = -100
        att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1)

        K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
        agg_score = self.agg_out(
            self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze()
        return agg_score
コード例 #5
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
                col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)
        if reinforce:
            raise NotImplementedError('Our model doesn\'t have RL')

        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
                                             col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
            B, 4, self.N_h // 2).transpose(0, 1).contiguous()

        h_num_enc, _ = run_lstm(self.cond_num_lstm,
                                x_emb_var,
                                x_len,
                                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc *
                      num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)

        #Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)

        if self.use_ca:
            col_att_val = torch.bmm(
                e_cond_col,
                self.cond_col_att(h_col_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, :, num:] = -100
            col_att = self.softmax(col_att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
        else:
            col_att_val = self.cond_col_att(h_col_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, num:] = -100
            col_att = self.softmax(col_att_val)
            K_cond_col = (h_col_enc *
                          col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_col_score = self.cond_col_out(
            self.cond_col_out_K(K_cond_col) +
            self.cond_col_out_col(e_cond_col)).squeeze()
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100

        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [
                list(np.argsort(-col_scores[b])[:cond_nums[b]])
                for b in range(len(cond_nums))
            ]
        else:
            # print gt_cond
            chosen_col_gt = [[x[0] for x in one_gt_cond]
                             for one_gt_cond in gt_cond]

        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x]
                 for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if self.use_ca:
            op_att_val = torch.matmul(
                self.cond_op_att(h_op_enc).unsqueeze(1),
                col_emb.unsqueeze(3)).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, :, num:] = -100
            op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1)
            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
        else:
            op_att_val = self.cond_op_att(h_op_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, num:] = -100
            op_att = self.softmax(op_att_val)
            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_op_score = self.cond_op_out(
            self.cond_op_out_K(K_cond_op) +
            self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len,
                                        self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack(
                [e_cond_col[b, x] for x in chosen_col_gt[b]] +
                [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                gt_tok_seq.view(B * 4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            g_ext = g_str_s.unsqueeze(3)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)

            cond_str_score = self.cond_str_out(
                self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                self.cond_str_out_col(col_ext)).squeeze()
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:, 0, 0] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                cur_cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B * 4,
                                                         max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B * 4, self.max_tok_num).scatter_(
                    1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_score = (cond_num_score, cond_col_score, cond_op_score,
                      cond_str_score)

        return cond_score