Exemple #1
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, mask=True):

        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)
        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder,
                                  h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        # [bs, max_h_num, max_q_len]
        # torch.bmm: bs * ([max_header_len, d_h], [d_h, max_q_len])
        att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))
        att_mask = build_mask(att_weights, h_nums, dim=-1)
        att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))
        att_weights = self.softmax(att_weights)

        # attention_weights: -> [bs, max_h_num, max_q_len, 1]
        # q_enc: -> [bs, 1, max_q_len, d_h]
        # [bs, max_h_num, d_h]
        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)
        comb_context = torch.cat([self.W_question(q_context), self.W_header(h_pooling)], dim=-1)

        score_ord_col = self.W_out(comb_context).squeeze(2)

        if mask:
            for b, h_num in enumerate(h_nums):
                score_ord_col[b, h_num:] = -float('inf')
        return score_ord_col
Exemple #2
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):
        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        # [bs, max_h_num, max_q_len]
        att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))
        att_mask = build_mask(att_weights, q_lens, dim=-1)
        att_weights = self.softmax(att_weights.masked_fill(att_mask==0, -float("inf")))

        # att_weights: -> [bs, max_h_num, max_q_len, 1]
        # q_enc: -> [bs, 1, max_q_len, d_h]
        # [bs, max_h_num, d_h]
        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)

        # [bs, max_h_num, d_h * 2]
        comb_context = torch.cat([self.W_q(q_context), self.W_h(h_pooling)], dim=-1)

        # [bs, max_h_num]
        score_sel_col = self.W_out(comb_context).squeeze(2)

        # mask
        for b, h_num in enumerate(h_nums):
            score_sel_col[b, h_num:] = -float("inf")
        return score_sel_col
Exemple #3
0
    def forward(self, q_emb, q_lens):
        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)

        # [bs, max_q_len, 1]
        att_weights = self.W_att(q_enc)
        att_mask = build_mask(att_weights, q_lens, dim=-2)
        att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))
        # [bs, max_q_len, 1]
        att_weights = self.softmax(att_weights)

        # [bs, d_h]
        q_context = torch.bmm(q_enc.transpose(1, 2), att_weights).squeeze(2)
        score_limit = self.W_out(q_context)
        return score_limit
Exemple #4
0
    def forward(self, q_emb, q_lens):

        q_enc = encode_question(self.q_encoder, q_emb, q_lens)

        #  self-atttention for question
        #  [bs, max_q_len]
        att_weights_q = self.W_att(q_enc).squeeze(2)
        att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)
        att_weights_q = att_weights_q.masked_fill(att_mask_q == 0, -float('inf'))
        att_weights_q = self.softmax(att_weights_q)

        #  [bs, d_h]
        q_context = torch.mul(
            q_enc,
            att_weights_q.unsqueeze(2).expand_as(q_enc)
        ).sum(dim=1)
        where_op_logits = self.W_out(q_context)
        return where_op_logits
Exemple #5
0
    def get_context(self, q_emb, q_lens, h_emb, h_lens, h_nums,
                    q_encoder, h_encoder, W_att, W_q, W_h):
        # [bs, max_q_len, d_h]
        q_enc = encode_question(q_encoder, q_emb, q_lens)

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        # [bs, max_h_num, max_q_len]
        att_weights = torch.bmm(h_pooling, W_att(q_enc).transpose(1, 2))
        att_mask = build_mask(att_weights, q_lens, dim=-1)
        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))

        # att_weights: -> [bs, max_h_num, max_q_len, 1]
        # q_enc: -> [bs, 1, max_q_len, d_h]
        # [bs, max_h_num, d_h]
        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)

        return W_q(q_context), W_h(h_pooling)
Exemple #6
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, sel_col):
        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        bs = len(q_emb)
        h_pooling_sel = h_pooling[list(range(bs)), sel_col]

        att_weights = torch.bmm(self.W_att(q_enc), h_pooling_sel.unsqueeze(2)).squeeze(2)
        att_mask = build_mask(att_weights, q_lens, dim=-2)
        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))

        # att_weights: [bs, max_sel_num, max_q_len] -> [bs, max_sel_num, max_q_len, 1]
        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]
        q_context = torch.mul(q_enc, att_weights.unsqueeze(2).expand_as(q_enc)).sum(dim=1)

        # [bs, max_sel_num, n_agg]
        score_sel_agg = self.W_out(q_context)
        return score_sel_agg
Exemple #7
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        bs = len(q_lens)

        # self-attention for header
        # [bs, max_h_num]
        att_weights_h = self.W_att_h(h_pooling).squeeze(2)
        att_mask_h = build_mask(att_weights_h, q_lens, dim=-2)
        att_weights_h = self.softmax(att_weights_h.masked_fill(att_mask_h == 0, -float("inf")))

        # [bs, d_h]
        h_context = torch.mul(h_pooling, att_weights_h.unsqueeze(2)).sum(1)

        # [bs, d_h] -> [bs, 2 * d_h]
        # enlarge because there are two layers.
        hidden = self.W_hidden(h_context)
        hidden = hidden.view(bs, self.n_layers * 2, int(self.d_h / 2))
        hidden = hidden.transpose(0, 1).contiguous()

        cell = self.W_cell(h_context)
        cell = cell.view(bs, self.n_layers * 2, int(self.d_h / 2))
        cell = cell.transpose(0, 1).contiguous()

        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens, init_states=(hidden, cell))

        # self-attention for question
        # [bs, max_q_len]
        att_weights_q = self.W_att_q(q_enc).squeeze(2)
        att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)
        att_weights_q = self.softmax(att_weights_q.masked_fill(att_mask_q == 0, -float("inf")))

        q_context = torch.mul(q_enc, att_weights_q.unsqueeze(2).expand_as(q_enc)).sum(dim=1)

        # [bs, max_select_num + 1]
        score_sel_num = self.W_out(q_context)
        return score_sel_num
Exemple #8
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, where_cols):
        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)

        h_pooling_where = []
        for b, cols in enumerate(where_cols):
            if len(cols) > 0:
                h_tmp = [h_pooling[b][cols, :]]
            else:
                h_tmp = []
            h_tmp += [padding_t] * (self.max_where_num - len(cols))
            h_tmp = torch.cat(h_tmp, dim=0)
            h_pooling_where.append(h_tmp)
        # [bs, max_where_num, d_h]
        h_pooling_where = torch.stack(h_pooling_where)

        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]
        # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]
        # [bs, max_where_num, max_q_len]
        att_weights = torch.matmul(
            self.W_att(q_enc).unsqueeze(1),
            h_pooling_where.unsqueeze(3)
        ).squeeze(3)
        att_mask = build_mask(att_weights, q_lens, dim=-1)
        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))

        # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]
        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]
        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)

        # [bs, max_where_num, n_agg]
        score_where_agg = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_where)], dim=2))
        return score_where_agg
Exemple #9
0
    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,
                q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, where_cols, where_ops,
                q_feature):
        bs = len(q_emb)
        max_q_len = max(q_lens)

        # [bs, max_q_len, d_h]
        q_enc = encode_question(self.q_encoder, q_emb, q_lens)
        for b, f in enumerate(q_feature):
            while len(f) < max_q_len:
                q_feature[b].append(0)
        q_feature = torch.tensor(q_feature)
        if q_enc.is_cuda:
            q_feature = q_feature.to(q_enc.device)

        q_feature_enc = self.q_feature_embed(q_feature)

        q_enc = torch.cat([q_enc, q_feature_enc], -1)

        # [bs, max_h_num, d_h]
        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)

        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)
        h_pooling_where = []
        for b, cols in enumerate(where_cols):
            if len(cols) > 0:
                h_tmp = [h_pooling[b][cols, :]]
            else:
                h_tmp = []
            h_tmp += [padding_t] * (self.max_where_num - len(cols))
            h_tmp = torch.cat(h_tmp, dim=0)
            h_pooling_where.append(h_tmp)
        # [bs, max_where_num, d_h]
        h_pooling_where = torch.stack(h_pooling_where)

        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]
        # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]
        # [bs, max_where_num, max_q_len]
        att_weights = torch.matmul(
            self.W_att(q_enc).unsqueeze(1),
            h_pooling_where.unsqueeze(3)
        ).squeeze(3)
        att_mask = build_mask(att_weights, q_lens, dim=-1)
        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))

        # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]
        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]
        # [bs, max_where_num, d_h]
        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)

        q_enc_ch = encode_question(self.q_encoder_ch, q_emb_ch, q_lens_ch)

        # [bs, max_h_num, d_h]
        h_pooling_ch = encode_header(self.h_encoder_ch, h_emb_ch, h_lens_ch,
                                     h_nums, pooling_type=self.pooling_type)

        padding_t_ch = torch.zeros_like(h_pooling_ch[0][0]).unsqueeze(0)

        h_pooling_where_ch = []
        for b, cols in enumerate(where_cols):
            if len(cols) > 0:
                h_tmp = [h_pooling_ch[b][cols, :]]
            else:
                h_tmp = []
            h_tmp += [padding_t_ch] * (self.max_where_num - len(cols))
            h_tmp = torch.cat(h_tmp, dim=0)
            h_pooling_where_ch.append(h_tmp)
        h_pooling_where_ch = torch.stack(h_pooling_where_ch)

        att_weights_ch = torch.matmul(
            self.W_att_ch(q_enc_ch).unsqueeze(1),
            h_pooling_where_ch.unsqueeze(3)
        ).squeeze(3)
        att_mask_ch = build_mask(att_weights_ch, q_lens_ch, dim=-1)
        att_weights_ch = self.softmax(att_weights_ch.masked_fill(att_mask_ch == 0, -float("inf")))
        q_context_ch = torch.mul(att_weights_ch.unsqueeze(3), q_enc_ch.unsqueeze(1)).sum(dim=2)

        op_enc = []
        for b in range(bs):
            op_enc_tmp = torch.zeros(self.max_where_num, self.n_op)
            op = where_ops[b]
            idx_scatter = []
            op_len = len(op)
            for i in range(self.max_where_num):
                if i < op_len:
                    idx_scatter.append([op[i]])
                else:
                    idx_scatter.append([0])
            op_enc_tmp = op_enc_tmp.scatter(1, torch.tensor(idx_scatter), 1)
            op_enc.append(op_enc_tmp)
        op_enc = torch.stack(op_enc)
        if q_context.is_cuda:
            op_enc = op_enc.to(q_context.device)

        comb_context = torch.cat(
            [self.W_q(q_context),
             self.W_h(h_pooling_where),
             self.W_q_ch(q_context_ch),
             self.W_h_ch(h_pooling_where_ch),
             self.W_op(op_enc)],
            dim=2
        )
        comb_context = comb_context.unsqueeze(2).expand(-1, -1, q_enc.size(1), -1)
        q_enc = q_enc.unsqueeze(1).expand(-1, comb_context.size(1), -1, -1)

        # [bs, max_where_num, max_q_num, 2]
        score_where_val = self.W_out(torch.cat([comb_context, q_enc], dim=3))

        for b, l in enumerate(q_lens):
            if l < max_q_len:
                score_where_val[b, :, l:, :] = -float("inf")
        return score_where_val