Esempio n. 1
0
def expected_alignment_from_p_choose(p_choose: Tensor,
                                     padding_mask: Optional[Tensor] = None,
                                     eps: float = 1e-6):
    """
    Calculating expected alignment for from stepwise probability

    Reference:
    Online and Linear-Time Attention by Enforcing Monotonic Alignments
    https://arxiv.org/pdf/1704.00784.pdf

    q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
    a_ij = p_ij q_ij

    Parallel solution:
    ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

    ============================================================
    Expected input size
    p_choose: bsz, tgt_len, src_len
    """
    prob_check(p_choose)

    # p_choose: bsz, tgt_len, src_len
    bsz, tgt_len, src_len = p_choose.size()
    dtype = p_choose.dtype

    p_choose = p_choose.float()

    if padding_mask is not None:
        p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)

    # cumprod_1mp : bsz, tgt_len, src_len
    cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
    cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)

    alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
    alpha_0[:, :, 0] = 1.0

    previous_alpha = [alpha_0]

    for i in range(tgt_len):
        # p_choose: bsz , tgt_len, src_len
        # cumprod_1mp_clamp : bsz, tgt_len, src_len
        # previous_alpha[i]: bsz, 1, src_len
        # alpha_i: bsz, src_len
        alpha_i = (p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum(
            previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)).clamp(
                0, 1.0)

        previous_alpha.append(alpha_i.unsqueeze(1))

    # alpha: bsz * num_heads, tgt_len, src_len
    alpha = torch.cat(previous_alpha[1:], dim=1)

    # Mix precision to prevent overflow for fp16
    alpha = alpha.type(dtype)

    prob_check(alpha)

    return alpha
Esempio n. 2
0
    def _test_custom_alignment_train_ref(self, p_choose, eps):
        cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
        cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)

        bsz = p_choose.size(0)
        tgt_len = p_choose.size(1)
        src_len = p_choose.size(2)

        alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
        alpha_0[:, :, 0] = 1.0

        previous_alpha = [alpha_0]

        for i in range(tgt_len):
            # p_choose: bsz , tgt_len, src_len
            # cumprod_1mp_clamp : bsz, tgt_len, src_len
            # previous_alpha[i]: bsz, 1, src_len
            # alpha_i: bsz, src_len
            alpha_i = (
                p_choose[:, i] * cumprod_1mp[:, i] *
                torch.cumsum(previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i],
                             dim=1)).clamp(0, 1.0)

            previous_alpha.append(alpha_i.unsqueeze(1))

        # alpha: bsz * num_heads, tgt_len, src_len
        alpha = torch.cat(previous_alpha[1:], dim=1)
        return alpha
    def expected_alignment_train(self, p_choose, key_padding_mask):
        """
        Calculating expected alignment for MMA
        Mask is not need because p_choose will be 0 if masked

        q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
        a_ij = p_ij q_ij

        parellel solution:
        ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

        ============================================================
        Expected input size
        p_choose: bsz * num_heads, tgt_len, src_len
        """

        # p_choose: bsz * num_heads, tgt_len, src_len
        bsz_num_heads, tgt_len, src_len = p_choose.size()

        # cumprod_1mp : bsz * num_heads, tgt_len, src_len
        cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
        cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)

        init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
        init_attention[:, :, 0] = 1.0

        previous_attn = [init_attention]

        for i in range(tgt_len):
            # p_choose: bsz * num_heads, tgt_len, src_len
            # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
            # previous_attn[i]: bsz * num_heads, 1, src_len
            # alpha_i: bsz * num_heads, src_len
            alpha_i = (
                p_choose[:, i]
                * cumprod_1mp[:, i]
                * torch.cumsum(
                    previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
                    dim=1
                )
            ).clamp(0, 1.0)
            previous_attn.append(alpha_i.unsqueeze(1))

        # alpha: bsz * num_heads, tgt_len, src_len
        alpha = torch.cat(previous_attn[1:], dim=1)

        if self.mass_preservation:
            # Last token has the residual probabilities
            alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)

        assert not torch.isnan(alpha).any(), "NaN detected in alpha."

        return alpha
    def expected_alignment_train(self, p_choose,
                                 key_padding_mask: Optional[Tensor]):
        """
        Calculating expected alignment for MMA
        Mask is not need because p_choose will be 0 if masked

        q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
        a_ij = p_ij q_ij

        Parallel solution:
        ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

        ============================================================
        Expected input size
        p_choose: bsz * num_heads, tgt_len, src_len
        """

        # p_choose: bsz * num_heads, tgt_len, src_len
        bsz_num_heads, tgt_len, src_len = p_choose.size()

        # cumprod_1mp : bsz * num_heads, tgt_len, src_len
        cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
        cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)

        init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
        init_attention[:, :, 0] = 1.0

        previous_attn = [init_attention]

        for i in range(tgt_len):
            # p_choose: bsz * num_heads, tgt_len, src_len
            # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
            # previous_attn[i]: bsz * num_heads, 1, src_len
            # alpha_i: bsz * num_heads, src_len
            alpha_i = (
                p_choose[:, i] * cumprod_1mp[:, i] *
                torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
                             dim=1)).clamp(0, 1.0)
            previous_attn.append(alpha_i.unsqueeze(1))

        # alpha: bsz * num_heads, tgt_len, src_len
        alpha = torch.cat(previous_attn[1:], dim=1)

        if self.mass_preservation:
            # Last token has the residual probabilities
            if key_padding_mask is not None and key_padding_mask[:, -1].any():
                # right padding
                batch_size = key_padding_mask.size(0)
                residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0)
                src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True)
                src_lens = src_lens.expand(batch_size,
                                           self.num_heads).contiguous().view(
                                               -1, 1)
                src_lens = src_lens.expand(-1, tgt_len).contiguous()
                # add back the last value
                residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1)
                alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals)
            else:
                residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
                alpha[:, :, -1] = residuals

        if torch.isnan(alpha).any():
            # Something is wrong
            raise RuntimeError("NaN in alpha.")

        return alpha