示例#1
0
    def forward(self, decoder_state, source_hids, src_lengths):
        """The expected input dimensions are:

        decoder_state: bsz x decoder_hidden_state_dim
        source_hids: src_len x bsz x context_dim
        src_lengths: bsz
        """
        src_len, bsz, _ = source_hids.size()
        # (src_len*bsz) x context_dim (to feed through linear)
        flat_source_hids = source_hids.view(-1, self.context_dim)
        # (src_len*bsz) x attention_dim
        encoder_component = self.encoder_proj(flat_source_hids)
        # src_len x bsz x attention_dim
        encoder_component = encoder_component.view(src_len, bsz, self.attention_dim)
        # 1 x bsz x attention_dim
        decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
        # Sum with broadcasting and apply the non linearity
        # src_len x bsz x attention_dim
        hidden_att = F.tanh(
            (decoder_component + encoder_component).view(-1, self.attention_dim)
        )
        # Project onto the reals to get attentions scores (bsz x src_len)
        attn_scores = self.to_scores(hidden_att).view(src_len, bsz).t()

        # Mask + softmax (src_len x bsz)
        normalized_masked_attn_scores = attention_utils.masked_softmax(
            attn_scores, src_lengths, self.src_length_masking
        ).t()

        # Sum weighted sources (bsz x context_dim)
        attn_weighted_context = (
            source_hids * normalized_masked_attn_scores.unsqueeze(2)
        ).sum(0)

        return attn_weighted_context, normalized_masked_attn_scores
示例#2
0
    def test_masked_softmax(self):
        scores = torch.rand(20, 20)
        lengths = torch.arange(start=1, end=21)

        masked_normalized_scores = attention_utils.masked_softmax(
            scores, lengths, src_length_masking=True)

        for i in range(20):
            scores_sum = masked_normalized_scores[i].numpy().sum()
            self.assertAlmostEqual(scores_sum, 1, places=6)
示例#3
0
    def forward(self, decoder_state, source_hids, src_lengths):
        # Reshape to bsz x src_len x context_dim
        source_hids = source_hids.transpose(0, 1)
        # decoder_state: bsz x context_dim
        if self.input_proj is not None:
            decoder_state = self.input_proj(decoder_state)
        # compute attention (bsz x src_len x context_dim) * (bsz x context_dim x 1)
        attn_scores = torch.bmm(source_hids,
                                decoder_state.unsqueeze(2)).squeeze(2)

        # Mask + softmax (bsz x src_len)
        normalized_masked_attn_scores = attention_utils.masked_softmax(
            attn_scores, src_lengths, self.src_length_masking)

        # Sum weighted sources
        attn_weighted_context = (
            source_hids * normalized_masked_attn_scores.unsqueeze(2)).sum(1)

        return attn_weighted_context, normalized_masked_attn_scores.t()