def forward(self, decoder_state, source_hids, src_lengths): """The expected input dimensions are: decoder_state: bsz x decoder_hidden_state_dim source_hids: src_len x bsz x context_dim src_lengths: bsz """ src_len, bsz, _ = source_hids.size() # (src_len*bsz) x context_dim (to feed through linear) flat_source_hids = source_hids.view(-1, self.context_dim) # (src_len*bsz) x attention_dim encoder_component = self.encoder_proj(flat_source_hids) # src_len x bsz x attention_dim encoder_component = encoder_component.view(src_len, bsz, self.attention_dim) # 1 x bsz x attention_dim decoder_component = self.decoder_proj(decoder_state).unsqueeze(0) # Sum with broadcasting and apply the non linearity # src_len x bsz x attention_dim hidden_att = F.tanh( (decoder_component + encoder_component).view(-1, self.attention_dim) ) # Project onto the reals to get attentions scores (bsz x src_len) attn_scores = self.to_scores(hidden_att).view(src_len, bsz).t() # Mask + softmax (src_len x bsz) normalized_masked_attn_scores = attention_utils.masked_softmax( attn_scores, src_lengths, self.src_length_masking ).t() # Sum weighted sources (bsz x context_dim) attn_weighted_context = ( source_hids * normalized_masked_attn_scores.unsqueeze(2) ).sum(0) return attn_weighted_context, normalized_masked_attn_scores
def test_masked_softmax(self): scores = torch.rand(20, 20) lengths = torch.arange(start=1, end=21) masked_normalized_scores = attention_utils.masked_softmax( scores, lengths, src_length_masking=True) for i in range(20): scores_sum = masked_normalized_scores[i].numpy().sum() self.assertAlmostEqual(scores_sum, 1, places=6)
def forward(self, decoder_state, source_hids, src_lengths): # Reshape to bsz x src_len x context_dim source_hids = source_hids.transpose(0, 1) # decoder_state: bsz x context_dim if self.input_proj is not None: decoder_state = self.input_proj(decoder_state) # compute attention (bsz x src_len x context_dim) * (bsz x context_dim x 1) attn_scores = torch.bmm(source_hids, decoder_state.unsqueeze(2)).squeeze(2) # Mask + softmax (bsz x src_len) normalized_masked_attn_scores = attention_utils.masked_softmax( attn_scores, src_lengths, self.src_length_masking) # Sum weighted sources attn_weighted_context = ( source_hids * normalized_masked_attn_scores.unsqueeze(2)).sum(1) return attn_weighted_context, normalized_masked_attn_scores.t()