コード例 #1
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb) # [B, dim]
        # self.q_att(q_enc).transpose(1, 2): [B, dim, max_q_len]
        att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        att_prob_qc = self.softmax(att_val_qc)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
        # dat_score: (B, 4)
        dat_score = self.dat_out(self.dat_out_q(q_weighted) + int(self.use_hs)* self.dat_out_hs(hs_weighted) + self.dat_out_c(col_emb))

        return dat_score
コード例 #2
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var,
                mkw_len):
        # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size()))
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        # q_enc: (B, max_q_len, hid_dim)
        # hs_enc: (B, max_hs_len, hid_dim)
        # mkw: (B, 4, hid_dim)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len)

        # Compute attention values between multi SQL key words and question tokens.
        q_weighted = seq_conditional_weighted_num(self.q_att, q_enc, q_len,
                                                  mkw_enc)
        SIZE_CHECK(q_weighted, [B, 4, self.N_h])

        # Same as the above, compute SQL history embedding weighted by key words attentions
        hs_weighted = seq_conditional_weighted_num(self.hs_att, hs_enc, hs_len,
                                                   mkw_enc)

        # Compute prediction scores=
        mulit_score = self.multi_out(
            self.multi_out_q(q_weighted) +
            int(self.use_hs) * self.multi_out_hs(hs_weighted) +
            self.multi_out_c(mkw_enc)).view(B, 4)

        return mulit_score
コード例 #3
0
ファイル: andor_predictor.py プロジェクト: ishay2b/syntaxSQL
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)

        att_np_q = np.ones((B, max_q_len))
        att_val_q = torch.from_numpy(att_np_q).float()
        att_val_q = Variable(att_val_q.cuda())
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_q[idx, num:] = -100
        att_prob_q = self.softmax(att_val_q)
        q_weighted = (q_enc * att_prob_q.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_np_h = np.ones((B, max_hs_len))
        att_val_h = torch.from_numpy(att_np_h).float()
        att_val_h = Variable(att_val_h.cuda())
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_h[idx, num:] = -100
        att_prob_h = self.softmax(att_val_h)
        hs_weighted = (hs_enc * att_prob_h.unsqueeze(2)).sum(1)
        # ao_score: (B, 2)
        ao_score = self.ao_out(
            self.ao_out_q(q_weighted) +
            int(self.use_hs) * self.ao_out_hs(hs_weighted))

        return ao_score
コード例 #4
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_tab_name_encode(col_emb_var, col_name_len, col_len,
                                         self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)  # [B, dim]
        q_weighted = plain_conditional_weighted_num(self.q_att, q_enc, q_len,
                                                    col_emb)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted = plain_conditional_weighted_num(self.hs_att, hs_enc,
                                                     hs_len, col_emb)
        # dat_score: (B, 4)
        dat_score = self.dat_out(
            self.dat_out_q(q_weighted) +
            int(self.use_hs) * self.dat_out_hs(hs_weighted) +
            self.dat_out_c(col_emb))

        return dat_score
コード例 #5
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        kw_enc, _ = run_lstm(self.kw_lstm, kw_emb_var, kw_len)

        # Predict key words number: 0-3
        q_weighted_num = seq_conditional_weighted_num(self.q_num_att, q_enc, q_len, kw_enc).sum(1)
        # Same as the above, compute SQL history embedding weighted by key words attentions
        hs_weighted_num = seq_conditional_weighted_num(self.hs_num_att, hs_enc, hs_len, kw_enc).sum(1)
        # Compute prediction scores
        kw_num_score = self.kw_num_out(self.kw_num_out_q(q_weighted_num) + int(self.use_hs)* self.kw_num_out_hs(hs_weighted_num))
        SIZE_CHECK(kw_num_score, [B, 4])

        # Predict key words: WHERE, GROUP BY, ORDER BY.
        q_weighted = seq_conditional_weighted_num(self.q_att, q_enc, q_len, kw_enc)
        SIZE_CHECK(q_weighted, [B, 3, self.N_h])

        # Same as the above, compute SQL history embedding weighted by key words attentions
        hs_weighted = seq_conditional_weighted_num(self.hs_att, hs_enc, hs_len, kw_enc)
        # Compute prediction scores
        kw_score = self.kw_out(self.kw_out_q(q_weighted) + int(self.use_hs)* self.kw_out_hs(hs_weighted) + self.kw_out_kw(kw_enc)).view(B,3)

        score = (kw_num_score, kw_score)

        return score
コード例 #6
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict agg number
        att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num)
        q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, num:] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num)
        hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1)
        # agg_num_score: (B, 4)
        agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) + int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) + self.agg_num_out_c(col_emb))

        # Predict aggregators
        att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        att_prob_qc = self.softmax(att_val_qc)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
        # agg_score: (B, 5)
        agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) + self.agg_out_c(col_emb))

        score = (agg_num_score, agg_score)

        return score
コード例 #7
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_tab_name_encode(col_emb_var, col_name_len, col_len,
                                         self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict op number
        q_weighted_num = plain_conditional_weighted_num(
            self.q_num_att, q_enc, q_len, col_emb)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted_num = plain_conditional_weighted_num(
            self.hs_num_att, hs_enc, hs_len, col_emb)
        # op_num_score: (B, 2)
        op_num_score = self.op_num_out(
            self.op_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.op_num_out_hs(hs_weighted_num) +
            self.op_num_out_c(col_emb))
        SIZE_CHECK(op_num_score, [B, 2])

        # Compute attention values between selected column and question tokens.
        q_weighted = plain_conditional_weighted_num(self.q_att, q_enc, q_len,
                                                    col_emb)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted = plain_conditional_weighted_num(self.hs_att, hs_enc,
                                                     hs_len, col_emb)

        # Compute prediction scores
        # op_score: (B, 10)
        op_score = self.op_out(
            self.op_out_q(q_weighted) +
            int(self.use_hs) * self.op_out_hs(hs_weighted) +
            self.op_out_c(col_emb))
        SIZE_CHECK(op_score, [B, 11])

        score = (op_num_score, op_score)

        return score
コード例 #8
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var,
                mkw_len):
        # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size()))
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        # q_enc: (B, max_q_len, hid_dim)
        # hs_enc: (B, max_hs_len, hid_dim)
        # mkw: (B, 4, hid_dim)
        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len)

        # Compute attention values between multi SQL key words and question tokens.
        # qmkw_att(q_enc).transpose(1, 2): (B, hid_dim, max_q_len)
        # att_val_qmkw: (B, 4, max_q_len)
        # print("mkw_enc {} q_enc {}".format(mkw_enc.size(), self.q_att(q_enc).transpose(1, 2).size()))
        att_val_qmkw = torch.bmm(mkw_enc, self.q_att(q_enc).transpose(1, 2))
        # assign appended positions values -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qmkw[idx, :, num:] = -100
        # att_prob_qmkw: (B, 4, max_q_len)
        att_prob_qmkw = self.softmax(att_val_qmkw.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_enc.unsqueeze(1): (B, 1, max_q_len, hid_dim)
        # att_prob_qmkw.unsqueeze(3): (B, 4, max_q_len, 1)
        # q_weighted: (B, 4, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qmkw.unsqueeze(3)).sum(2)

        # Same as the above, compute SQL history embedding weighted by key words attentions
        att_val_hsmkw = torch.bmm(mkw_enc, self.hs_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hsmkw[idx, :, num:] = -100
        att_prob_hsmkw = self.softmax(att_val_hsmkw.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted = (hs_enc.unsqueeze(1) *
                       att_prob_hsmkw.unsqueeze(3)).sum(2)

        # Compute prediction scores
        # self.multi_out.squeeze(): (B, 4, 1) -> (B, 4)
        mulit_score = self.multi_out(
            self.multi_out_q(q_weighted) +
            int(self.use_hs) * self.multi_out_hs(hs_weighted) +
            self.multi_out_c(mkw_enc)).view(B, -1)

        return mulit_score
コード例 #9
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_tab_name_encode(col_emb_var, col_name_len, col_len,
                                         self.col_lstm)

        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict agg number
        q_weighted_num = plain_conditional_weighted_num(
            self.q_num_att, q_enc, q_len, col_emb)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted_num = plain_conditional_weighted_num(
            self.hs_num_att, hs_enc, hs_len, col_emb)
        agg_num_score = self.agg_num_out(
            self.agg_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.agg_num_out_hs(hs_weighted_num) +
            self.agg_num_out_c(col_emb))
        SIZE_CHECK(agg_num_score, [B, 4])

        # Predict aggregators
        q_weighted = plain_conditional_weighted_num(self.q_att, q_enc, q_len,
                                                    col_emb)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted = plain_conditional_weighted_num(self.hs_att, hs_enc,
                                                     hs_len, col_emb)
        # agg_score: (B, 5)
        agg_score = self.agg_out(
            self.agg_out_q(q_weighted) +
            int(self.use_hs) * self.agg_out_hs(hs_weighted) +
            self.agg_out_c(col_emb))

        score = (agg_num_score, agg_score)

        return score
コード例 #10
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, col_candidates=None):

        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_tab_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        # Predict column number: 1-3
        q_weighted_num = seq_conditional_weighted_num(self.q_num_att, q_enc, q_len, col_enc, col_len).sum(1)
        SIZE_CHECK(q_weighted_num, [B, self.N_h])

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted_num = seq_conditional_weighted_num(self.hs_num_att, hs_enc, hs_len, col_enc, col_len).sum(1)
        SIZE_CHECK(hs_weighted_num, [B, self.N_h])
        # self.col_num_out: (B, 3)
        col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num) + int(self.use_hs) * self.col_num_out_hs(hs_weighted_num))

        # Predict columns.
        q_weighted = seq_conditional_weighted_num(self.q_att, q_enc, q_len, col_enc)

        # Same as the above, compute SQL history embedding weighted by column attentions
        hs_weighted = seq_conditional_weighted_num(self.hs_att, hs_enc, hs_len, col_enc)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(self.col_out_q(q_weighted) + int(self.use_hs)* self.col_out_hs(hs_weighted) + self.col_out_c(col_enc)).view(B,-1)

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100
            for col_num in range(num):
                if col_candidates is not None:
                    if col_num not in col_candidates[idx]:
                        col_score[idx, col_num] = -100

        score = (col_num_score, col_score)

        return score
コード例 #11
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col):
        B = len(q_len)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_tab_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)
        q_weighted = plain_conditional_weighted_num(self.q_att, q_enc, q_len, col_emb)
        hs_weighted = plain_conditional_weighted_num(self.hs_att, hs_enc, hs_len, col_emb)
        hv_score = self.hv_out(self.hv_out_q(q_weighted) + int(self.use_hs)* self.hv_out_hs(hs_weighted) + self.hv_out_c(col_emb))
        SIZE_CHECK(hv_score, [B, 2])

        return hv_score
コード例 #12
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len):

        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # Predict column number: 1-3
        # att_val_qc_num: (B, max_col_len, max_q_len)
        att_val_qc_num = torch.bmm(col_enc,
                                   self.q_num_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_qc_num[idx, num:, :] = -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, :, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted_num: (B, hid_dim)
        q_weighted_num = (q_enc.unsqueeze(1) *
                          att_prob_qc_num.unsqueeze(3)).sum(2).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        # att_val_hc_num: (B, max_col_len, max_hs_len)
        att_val_hc_num = torch.bmm(col_enc,
                                   self.hs_num_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, :, num:] = -100
        for idx, num in enumerate(col_len):
            if num < max_col_len:
                att_val_hc_num[idx, num:, :] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted_num = (hs_enc.unsqueeze(1) *
                           att_prob_hc_num.unsqueeze(3)).sum(2).sum(1)
        # self.col_num_out: (B, 3)
        col_num_score = self.col_num_out(
            self.col_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.col_num_out_hs(hs_weighted_num))

        # Predict columns.
        att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, :, num:] = -100
        att_prob_qc = self.softmax(att_val_qc.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, max_col_len, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_enc, self.hs_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, :, num:] = -100
        att_prob_hc = self.softmax(att_val_hc.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hc.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.col_out.squeeze(): (B, max_col_len)
        col_score = self.col_out(
            self.col_out_q(q_weighted) +
            int(self.use_hs) * self.col_out_hs(hs_weighted) +
            self.col_out_c(col_enc)).view(B, -1)

        for idx, num in enumerate(col_len):
            if num < max_col_len:
                col_score[idx, num:] = -100

        score = (col_num_score, col_score)

        return score
コード例 #13
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var,
                kw_len):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        kw_enc, _ = run_lstm(self.kw_lstm, kw_emb_var, kw_len)

        # Predict key words number: 0-3
        att_val_qkw_num = torch.bmm(kw_enc,
                                    self.q_num_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qkw_num[idx, :, num:] = -100
        att_prob_qkw_num = self.softmax(att_val_qkw_num.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, hid_dim)
        q_weighted_num = (q_enc.unsqueeze(1) *
                          att_prob_qkw_num.unsqueeze(3)).sum(2).sum(1)

        # Same as the above, compute SQL history embedding weighted by key words attentions
        att_val_hskw_num = torch.bmm(kw_enc,
                                     self.hs_num_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hskw_num[idx, :, num:] = -100
        att_prob_hskw_num = self.softmax(
            att_val_hskw_num.view((-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted_num = (hs_enc.unsqueeze(1) *
                           att_prob_hskw_num.unsqueeze(3)).sum(2).sum(1)
        # Compute prediction scores
        # self.kw_num_out: (B, 4)
        kw_num_score = self.kw_num_out(
            self.kw_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.kw_num_out_hs(hs_weighted_num))

        # Predict key words: WHERE, GROUP BY, ORDER BY.
        att_val_qkw = torch.bmm(kw_enc, self.q_att(q_enc).transpose(1, 2))
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qkw[idx, :, num:] = -100
        att_prob_qkw = self.softmax(att_val_qkw.view(
            (-1, max_q_len))).view(B, -1, max_q_len)
        # q_weighted: (B, 3, hid_dim)
        q_weighted = (q_enc.unsqueeze(1) * att_prob_qkw.unsqueeze(3)).sum(2)

        # Same as the above, compute SQL history embedding weighted by key words attentions
        att_val_hskw = torch.bmm(kw_enc, self.hs_att(hs_enc).transpose(1, 2))
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hskw[idx, :, num:] = -100
        att_prob_hskw = self.softmax(att_val_hskw.view(
            (-1, max_hs_len))).view(B, -1, max_hs_len)
        hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hskw.unsqueeze(3)).sum(2)
        # Compute prediction scores
        # self.kw_out.squeeze(): (B, 3)
        kw_score = self.kw_out(
            self.kw_out_q(q_weighted) +
            int(self.use_hs) * self.kw_out_hs(hs_weighted) +
            self.kw_out_kw(kw_enc)).view(B, -1)

        score = (kw_num_score, kw_score)

        return score
コード例 #14
0
ファイル: op_predictor.py プロジェクト: ygan/syntaxSQL
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                col_len, col_name_len, gt_col):
        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        B = len(q_len)

        q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len,
                                     self.col_lstm)

        # get target/predicted column's embedding
        # col_emb: (B, hid_dim)
        col_emb = []
        for b in range(B):
            col_emb.append(col_enc[b, gt_col[b]])
        col_emb = torch.stack(col_emb)

        # Predict op number
        att_val_qc_num = torch.bmm(col_emb.unsqueeze(1),
                                   self.q_num_att(q_enc).transpose(1, 2)).view(
                                       B, -1)
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc_num[idx, num:] = -100
        att_prob_qc_num = self.softmax(att_val_qc_num)
        q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc_num = torch.bmm(col_emb.unsqueeze(1),
                                   self.hs_num_att(hs_enc).transpose(1,
                                                                     2)).view(
                                                                         B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc_num[idx, num:] = -100
        att_prob_hc_num = self.softmax(att_val_hc_num)
        hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1)
        # op_num_score: (B, 2)
        op_num_score = self.op_num_out(
            self.op_num_out_q(q_weighted_num) +
            int(self.use_hs) * self.op_num_out_hs(hs_weighted_num) +
            self.op_num_out_c(col_emb))

        # Compute attention values between selected column and question tokens.
        # q_enc.transpose(1, 2): (B, hid_dim, max_q_len)
        # col_emb.unsqueeze(1): (B, 1, hid_dim)
        # att_val_qc: (B, max_q_len)
        # print("col_emb {} q_enc {}".format(col_emb.unsqueeze(1).size(),self.q_att(q_enc).transpose(1, 2).size()))
        att_val_qc = torch.bmm(col_emb.unsqueeze(1),
                               self.q_att(q_enc).transpose(1, 2)).view(B, -1)
        # assign appended positions values -100
        for idx, num in enumerate(q_len):
            if num < max_q_len:
                att_val_qc[idx, num:] = -100
        # att_prob_qc: (B, max_q_len)
        att_prob_qc = self.softmax(att_val_qc)
        # q_enc: (B, max_q_len, hid_dim)
        # att_prob_qc.unsqueeze(2): (B, max_q_len, 1)
        # q_weighted: (B, hid_dim)
        q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1)

        # Same as the above, compute SQL history embedding weighted by column attentions
        att_val_hc = torch.bmm(col_emb.unsqueeze(1),
                               self.hs_att(hs_enc).transpose(1,
                                                             2)).view(B, -1)
        for idx, num in enumerate(hs_len):
            if num < max_hs_len:
                att_val_hc[idx, num:] = -100
        att_prob_hc = self.softmax(att_val_hc)
        hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)

        # Compute prediction scores
        # op_score: (B, 10)
        op_score = self.op_out(
            self.op_out_q(q_weighted) +
            int(self.use_hs) * self.op_out_hs(hs_weighted) +
            self.op_out_c(col_emb))

        score = (op_num_score, op_score)

        return score
コード例 #15
0
    def forward(self, parent_tables, foreign_keys, q_emb_var, q_len,
                hs_emb_var, hs_len, col_emb_var, col_len, col_name_len,
                table_emb_var, table_len, table_name_len):

        max_q_len = max(q_len)
        max_hs_len = max(hs_len)
        max_col_len = max(col_len)
        max_table_len = max(table_len)
        B = len(q_len)
        if self.use_bert:
            q_enc = self.q_bert(q_emb_var, q_len)
        else:
            q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len)
        assert list(q_enc.size()) == [B, max_q_len, self.encoded_num]
        hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len)
        table_tensors, col_tensors, batch_graph = self.schema_encoder(
            parent_tables, foreign_keys, col_emb_var, col_name_len, col_len,
            table_emb_var, table_name_len, table_len)
        aggregated_schema = self.schema_aggregator(batch_graph)
        SIZE_CHECK(table_tensors, [B, max_table_len, self.N_h])
        SIZE_CHECK(col_tensors, [B, max_col_len, self.N_h])

        q_table_weighted_num_num = seq_conditional_weighted_num(
            self.q_table_num_att, q_enc, q_len, table_tensors,
            table_len).sum(1)
        hs_table_weighted_num_num = seq_conditional_weighted_num(
            self.hs_table_num_att, hs_enc, hs_len, table_tensors,
            table_len).sum(1)
        q_col_weighted_num_num = seq_conditional_weighted_num(
            self.q_col_num_att, q_enc, q_len, col_tensors, col_len).sum(1)
        hs_col_weighted_num_num = seq_conditional_weighted_num(
            self.hs_col_num_att, hs_enc, hs_len, col_tensors, col_len).sum(1)

        x = self.schema_out(F.relu(aggregated_schema))
        x = x + self.q_table_out(q_table_weighted_num_num)
        x = x + int(self.use_hs) * self.hs_table_out(hs_table_weighted_num_num)
        x = x + self.q_col_out(q_col_weighted_num_num)
        x = x + int(self.use_hs) * self.hs_col_out(hs_col_weighted_num_num)
        table_num_score = self.table_num_out(x)

        q_table_weighted_num = seq_conditional_weighted_num(
            self.q_table_att, q_enc, q_len, table_tensors, table_len).sum(1)
        hs_table_weighted_num = seq_conditional_weighted_num(
            self.hs_table_att, hs_enc, hs_len, table_tensors, table_len).sum(1)
        q_col_weighted_num = seq_conditional_weighted_num(
            self.q_col_att, q_enc, q_len, col_tensors, col_len).sum(1)
        hs_col_weighted_num = seq_conditional_weighted_num(
            self.hs_col_att, hs_enc, hs_len, col_tensors, col_len).sum(1)

        x = self.schema_out(F.relu(aggregated_schema))
        x = x + self.q_table_out(q_table_weighted_num)
        x = x + int(self.use_hs) * self.hs_table_out(hs_table_weighted_num)
        x = x + self.q_col_out(q_col_weighted_num)
        x = x + int(self.use_hs) * self.hs_col_out(hs_col_weighted_num)

        SIZE_CHECK(x, [B, self.N_h])
        table_score = (self.table_att(table_tensors) * x.unsqueeze(1)).sum(2)
        SIZE_CHECK(table_score, [B, max_table_len])
        for idx, num in enumerate(table_len.tolist()):
            if num < max_table_len:
                table_score[idx, num:] = -100

        return table_num_score, table_score