Exemplo n.º 1
0
    def forward(self, x, label=None, x_mask=None):
        if label is None:
            x_alpha = F.linear(x, self.weight, self.bias).transpose(
                1, 2)  # (batch, label_size, len)
            x_mask = x_mask.unsqueeze(1).expand(-1, x_alpha.size(1), -1)
            x_prop = masked_softmax(x_alpha, x_mask,
                                    dim=-1)  # (batch, label_size, len)
            x_att_rep = torch.bmm(x_prop,
                                  x)  # (batch, label_size, hidden_size)
        else:
            cur_weight = torch.mm(label.float(), self.weight).unsqueeze(
                -1)  # (batch, in_features, 1)
            cur_bias = None
            if self.bias is not None:
                cur_bias = torch.mm(label.float(),
                                    self.bias.unsqueeze(-1))  # (batch, 1)

            x_alpha = torch.bmm(x, cur_weight).squeeze(-1)  # (batch, len)
            if cur_bias is not None:
                x_alpha += cur_bias

            x_prop = masked_softmax(x_alpha, x_mask, dim=-1)
            x_att_rep = torch.bmm(x_prop.unsqueeze(1), x) \
                .squeeze(1)  # (batch, hidden_size)

        return x_att_rep, x_prop
Exemplo n.º 2
0
    def forward(self, x, x_mask):
        h = F.relu(self.linear_h(x))
        h = self.dropout(h)
        o = self.linear_o(h)
        o = o.squeeze(2).transpose(0, 1)  # (batch, seq_len)

        beta = masked_softmax(o, x_mask, dim=1)
        return beta
Exemplo n.º 3
0
    def forward(self, x, y, x_mask):
        x = self.dropout(x)
        y = self.dropout(y)

        Wy = self.linear(y)
        xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)  # (batch, x_len)
        alpha = masked_softmax(xWy, x_mask, dim=1)

        return alpha
Exemplo n.º 4
0
    def forward(self, x, x_mask=None):
        x_alpha = self.alpha_linear(x) \
            .squeeze(-1)  # (batch, len)

        x_prop = masked_softmax(x_alpha, x_mask, dim=-1)
        x_att_rep = torch.bmm(x_prop.unsqueeze(1), x) \
            .squeeze(1)  # (batch, hidden_size)

        return x_att_rep, x_prop
Exemplo n.º 5
0
 def forward(self, Hpi, Hq, Hr_last, Hq_mask):
     wq_hq = self.linear_wq(Hq)  # (question_len, batch, hidden_size)
     wp_hp = self.linear_wp(Hpi).unsqueeze(0)  # (1, batch, hidden_size)
     wr_hr = self.linear_wr(Hr_last).unsqueeze(0)  # (1, batch, hidden_size)
     G = F.tanh(wq_hq + wp_hp + wr_hr)  # (question_len, batch, hidden_size), auto broadcast
     wg_g = self.linear_wg(G) \
         .squeeze(2) \
         .transpose(0, 1)  # (batch, question_len)
     alpha = masked_softmax(wg_g, m=Hq_mask, dim=1)  # (batch, question_len)
     return alpha
Exemplo n.º 6
0
    def forward(self, x, x_mask):
        g_tanh = F.tanh(self.linear_g(x))
        gt = self.linear_t.forward(g_tanh) \
            .squeeze(2) \
            .transpose(0, 1)  # (batch, seq_len)

        gt_prop = masked_softmax(gt, x_mask, dim=1)
        gt_prop = gt_prop.transpose(0, 1).unsqueeze(2)  # (seq_len, batch, 1)
        x_gt = x * gt_prop

        return x_gt
Exemplo n.º 7
0
    def forward(self, Hr, Hr_mask, Hk_pre):
        wr_hr = self.linear_wr(Hr)  # (context_len, batch, hidden_size)
        wa_ha = self.linear_wa(Hk_pre).unsqueeze(0)  # (1, batch, hidden_size)
        f = F.tanh(wr_hr + wa_ha)  # (context_len, batch, hidden_size)

        beta_tmp = self.linear_wf(f) \
            .squeeze(2) \
            .transpose(0, 1)  # (batch, context_len)

        beta = masked_softmax(beta_tmp, m=Hr_mask, dim=1)
        return beta
Exemplo n.º 8
0
    def forward(self, uq, mask):
        q_tanh = F.tanh(self.linear_u(uq))
        q_s = self.linear_t(q_tanh) \
            .squeeze(2) \
            .transpose(0, 1)  # (batch, seq_len)

        alpha = masked_softmax(q_s, mask, dim=1)  # (batch, seq_len)
        rq = torch.bmm(alpha.unsqueeze(1), uq.transpose(0, 1)) \
            .squeeze(1)  # (batch, input_size)

        return rq
Exemplo n.º 9
0
    def forward(self, h1, h2, h2_mask):
        h1 = h1.transpose(0, 1)
        h2 = h2.transpose(0, 1)

        alpha = h1.bmm(h2.transpose(1, 2))  # (batch, seq1_len, seq2_len)
        alpha = masked_softmax(alpha, h2_mask.unsqueeze(1),
                               dim=2)  # (batch, seq1_len, seq2_len)

        alpha_seq2 = alpha.bmm(h2)  # (batch, seq1_len, hidden_size)
        alpha_seq2 = alpha_seq2.transpose(0, 1)

        return alpha_seq2, alpha
Exemplo n.º 10
0
    def forward(self, h, h_mask):
        h = h.transpose(0, 1)
        batch, seq_len, _ = h.shape

        alpha = h.bmm(h.transpose(1, 2))  # (batch, seq_len, seq_len)

        # make element i==j to zero
        mask = torch.eye(seq_len, dtype=torch.uint8, device=h.device)
        mask = mask.unsqueeze(0)
        alpha.masked_fill_(mask, 0.)

        alpha = masked_softmax(alpha, h_mask.unsqueeze(1), dim=2)
        alpha_seq = alpha.bmm(h)

        alpha_seq = alpha_seq.transpose(0, 1)
        return alpha_seq, alpha
Exemplo n.º 11
0
    def forward(self, x, x_mask=None, q=None):
        if q is not None:
            # (batch, 1, hidden_size)
            q_rep = self.query_linear(q).unsqueeze(1)
            x_tanh = torch.tanh(self.self_linear(x) + q_rep)
        else:
            x_tanh = torch.tanh(self.self_linear(x))

        x_tanh = self.dropout_layer(x_tanh)

        x_alpha = self.alpha_linear(x_tanh) \
            .squeeze(-1)  # (batch, len)

        x_prop = masked_softmax(x_alpha, x_mask, dim=-1)
        x_att_rep = torch.bmm(x_prop.unsqueeze(1), x) \
            .squeeze(1)  # (batch, hidden_size)

        return x_att_rep, x_prop