def expected_alignment_from_p_choose(p_choose: Tensor, padding_mask: Optional[Tensor] = None, eps: float = 1e-6): """ Calculating expected alignment for from stepwise probability Reference: Online and Linear-Time Attention by Enforcing Monotonic Alignments https://arxiv.org/pdf/1704.00784.pdf q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ Expected input size p_choose: bsz, tgt_len, src_len """ prob_check(p_choose) # p_choose: bsz, tgt_len, src_len bsz, tgt_len, src_len = p_choose.size() dtype = p_choose.dtype p_choose = p_choose.float() if padding_mask is not None: p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0) # cumprod_1mp : bsz, tgt_len, src_len cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) alpha_0[:, :, 0] = 1.0 previous_alpha = [alpha_0] for i in range(tgt_len): # p_choose: bsz , tgt_len, src_len # cumprod_1mp_clamp : bsz, tgt_len, src_len # previous_alpha[i]: bsz, 1, src_len # alpha_i: bsz, src_len alpha_i = (p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum( previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)).clamp( 0, 1.0) previous_alpha.append(alpha_i.unsqueeze(1)) # alpha: bsz * num_heads, tgt_len, src_len alpha = torch.cat(previous_alpha[1:], dim=1) # Mix precision to prevent overflow for fp16 alpha = alpha.type(dtype) prob_check(alpha) return alpha
def _test_custom_alignment_train_ref(self, p_choose, eps): cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) bsz = p_choose.size(0) tgt_len = p_choose.size(1) src_len = p_choose.size(2) alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) alpha_0[:, :, 0] = 1.0 previous_alpha = [alpha_0] for i in range(tgt_len): # p_choose: bsz , tgt_len, src_len # cumprod_1mp_clamp : bsz, tgt_len, src_len # previous_alpha[i]: bsz, 1, src_len # alpha_i: bsz, src_len alpha_i = ( p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum(previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)).clamp(0, 1.0) previous_alpha.append(alpha_i.unsqueeze(1)) # alpha: bsz * num_heads, tgt_len, src_len alpha = torch.cat(previous_alpha[1:], dim=1) return alpha
def expected_alignment_train(self, p_choose, key_padding_mask): """ Calculating expected alignment for MMA Mask is not need because p_choose will be 0 if masked q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij parellel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ Expected input size p_choose: bsz * num_heads, tgt_len, src_len """ # p_choose: bsz * num_heads, tgt_len, src_len bsz_num_heads, tgt_len, src_len = p_choose.size() # cumprod_1mp : bsz * num_heads, tgt_len, src_len cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) init_attention[:, :, 0] = 1.0 previous_attn = [init_attention] for i in range(tgt_len): # p_choose: bsz * num_heads, tgt_len, src_len # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len # previous_attn[i]: bsz * num_heads, 1, src_len # alpha_i: bsz * num_heads, src_len alpha_i = ( p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum( previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1 ) ).clamp(0, 1.0) previous_attn.append(alpha_i.unsqueeze(1)) # alpha: bsz * num_heads, tgt_len, src_len alpha = torch.cat(previous_attn[1:], dim=1) if self.mass_preservation: # Last token has the residual probabilities alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) assert not torch.isnan(alpha).any(), "NaN detected in alpha." return alpha
def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]): """ Calculating expected alignment for MMA Mask is not need because p_choose will be 0 if masked q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ Expected input size p_choose: bsz * num_heads, tgt_len, src_len """ # p_choose: bsz * num_heads, tgt_len, src_len bsz_num_heads, tgt_len, src_len = p_choose.size() # cumprod_1mp : bsz * num_heads, tgt_len, src_len cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) init_attention[:, :, 0] = 1.0 previous_attn = [init_attention] for i in range(tgt_len): # p_choose: bsz * num_heads, tgt_len, src_len # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len # previous_attn[i]: bsz * num_heads, 1, src_len # alpha_i: bsz * num_heads, src_len alpha_i = ( p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)).clamp(0, 1.0) previous_attn.append(alpha_i.unsqueeze(1)) # alpha: bsz * num_heads, tgt_len, src_len alpha = torch.cat(previous_attn[1:], dim=1) if self.mass_preservation: # Last token has the residual probabilities if key_padding_mask is not None and key_padding_mask[:, -1].any(): # right padding batch_size = key_padding_mask.size(0) residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) src_lens = src_lens.expand(batch_size, self.num_heads).contiguous().view( -1, 1) src_lens = src_lens.expand(-1, tgt_len).contiguous() # add back the last value residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) else: residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) alpha[:, :, -1] = residuals if torch.isnan(alpha).any(): # Something is wrong raise RuntimeError("NaN in alpha.") return alpha