def forward(self, decoder_state, source_hids, src_lengths): assert self.decoder_hidden_state_dim == self.context_dim max_src_len = source_hids.size()[0] assert max_src_len == src_lengths.data.max() batch_size = source_hids.size()[1] src_mask = (attention_utils.create_src_lengths_mask( batch_size, src_lengths).type_as(source_hids).t().unsqueeze(2)) if self.pool_type == "mean": # need to make src_lengths a 3-D tensor to normalize masked_hiddens denom = src_lengths.view(1, batch_size, 1).type_as(source_hids) masked_hiddens = source_hids * src_mask context = (masked_hiddens / denom).sum(dim=0) elif self.pool_type == "max": masked_hiddens = source_hids - 10e6 * (1 - src_mask) context = masked_hiddens.max(dim=0)[0] else: raise ValueError( f"Pooling type {self.pool_type} is not supported.") attn_scores = Variable( torch.ones(src_mask.shape[1], src_mask.shape[0]).type_as(source_hids.data), requires_grad=False, ).t() return context, attn_scores
def forward(self, decoder_state, source_hids, src_lengths): # decoder_state: bsz x context_dim if self.input_proj is not None: decoder_state = self.input_proj(decoder_state) # compute attention attn_scores = (source_hids * decoder_state.unsqueeze(0)).sum(dim=2).t() if self.src_length_masking: max_src_len = source_hids.size()[0] assert max_src_len == src_lengths.data.max() batch_size = source_hids.size()[1] src_mask = attention_utils.create_src_lengths_mask( batch_size, src_lengths, ) masked_attn_scores = attn_scores.masked_fill( src_mask == 0, -np.inf) # Since input of varying lengths, need to make sure the attn_scores # for each sentence sum up to one attn_scores = F.softmax(masked_attn_scores, dim=-1) # bsz x srclen score_denom = torch.sum(attn_scores, dim=1).unsqueeze(dim=1).expand( batch_size, max_src_len) normalized_masked_attn_scores = torch.div(attn_scores, score_denom).t() else: normalized_masked_attn_scores = F.softmax(attn_scores, dim=-1).t() # sum weighted sources attn_weighted_context = ( source_hids * normalized_masked_attn_scores.unsqueeze(2)).sum(dim=0) return attn_weighted_context, normalized_masked_attn_scores
def forward(self, decoder_state, source_hids, src_lengths, squeeze=True): """ Computes MultiheadAttention with respect to either a vector or a tensor Inputs: decoder_state: (bsz x decoder_hidden_state_dim) or (bsz x T x decoder_hidden_state_dim) source_hids: srclen x bsz x context_dim src_lengths: bsz x 1, actual sequence lengths squeeze: Whether or not to squeeze on the time dimension. Even if decoder_state.dim() is 2 dimensional an explicit time step dimension will be unsqueezed. Outputs: [batch_size, max_src_len] if decoder_state.dim() == 2 & squeeze or [batch_size, 1, max_src_len] if decoder_state.dim() == 2 & !squeeze or [batch_size, T, max_src_len] if decoder_state.dim() == 3 & !squeeze or [batch_size, T, max_src_len] if decoder_state.dim() == 3 & squeeze & T != 1 or [batch_size, max_src_len] if decoder_state.dim() == 3 & squeeze & T == 1 """ batch_size = decoder_state.shape[0] if decoder_state.dim() == 3: query = decoder_state elif decoder_state.dim() == 2: query = decoder_state.unsqueeze(1) else: raise ValueError("decoder state must be either 2 or 3 dimensional") query = query.transpose(0, 1) value = key = source_hids src_len_mask = None if src_lengths is not None and self.use_src_length_mask: # [batch_size, 1, seq_len] src_len_mask_int = attention_utils.create_src_lengths_mask( batch_size=batch_size, src_lengths=src_lengths) src_len_mask = src_len_mask_int != 1 attn, attn_weights = self._fair_attn.forward( query, key, value, key_padding_mask=src_len_mask, need_weights=True) # attn.shape = T X bsz X embed_dim # attn_weights.shape = bsz X T X src_len attn_weights = attn_weights.transpose(0, 2) # attn_weights.shape = src_len X T X bsz if squeeze: attn = attn.squeeze(0) # attn.shape = squeeze(T) X bsz X embed_dim attn_weights = attn_weights.squeeze(1) # attn_weights.shape = src_len X squeeze(T) X bsz return attn, attn_weights return attn, attn_weights
def apply_masks(scores, batch_size, unseen_mask, src_lengths): seq_len = scores.shape[-1] # [1, batch_size, seq_len] sequence_mask = torch.ones(batch_size, seq_len).unsqueeze(0).int() if src_lengths is not None: # [batch_size, 1, seq_len] sequence_mask = attention_utils.create_src_lengths_mask( batch_size=batch_size, src_lengths=src_lengths ).unsqueeze(-2) # [batch_size, 1, seq_len, seq_len] sequence_mask = sequence_mask.unsqueeze(1) scores = scores.masked_fill(sequence_mask == 0, -np.inf) return scores
def forward(self, decoder_state, source_hids, src_lengths): batch_size = decoder_state.shape[0] query = decoder_state.unsqueeze(1).transpose(0, 1) value = key = source_hids src_len_mask = None if src_lengths is not None and self.use_src_length_mask: # [batch_size, 1, seq_len] src_len_mask_int = attention_utils.create_src_lengths_mask( batch_size=batch_size, src_lengths=src_lengths) src_len_mask = src_len_mask_int != 1 attn, attn_weights = self._fair_attn.forward( query, key, value, key_padding_mask=src_len_mask, need_weights=True) # attn.shape = tgt_len X bsz X embed_dim # attn_weights.shape = src_len X tgt_len X bsz return attn.squeeze(0), attn_weights.transpose(0, 2).squeeze(1)