Exemplo n.º 1
0
    def forward(self,
                key,
                value,
                query,
                mask,
                aw_prev=None,
                aw_lower=None,
                cache=False,
                mode='',
                trigger_points=None,
                eps_wait=-1,
                streaming=False):
        """Forward pass.

        Args:
            key (FloatTensor): `[B, klen, kdim]`
            value (FloatTensor): `[B, klen, vdim]`
            query (FloatTensor): `[B, qlen, qdim]`
            mask (ByteTensor): `[B, qlen, klen]`
            aw_prev: dummy interface
            cache (bool): cache key, value, and mask
            mode: dummy interface for MoChA/MMA
            trigger_points: dummy interface for MoChA/MMA
            eps_wait: dummy interface for MMA
            streaming: dummy interface for streaming attention
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            aw (FloatTensor): `[B, H, qlen, klen]`
            beta: dummy interface for MoChA/MMA
            p_choose: dummy interface for MoChA/MMA

        """
        bs, klen = key.size()[:2]
        qlen = query.size(1)

        # Pre-computation of encoder-side features for computing scores
        if self.key is None or not cache:
            self.key = self.w_key(key).view(bs, -1, self.n_heads,
                                            self.d_k)  # `[B, klen, H, d_k]`
            self.value = self.w_value(value).view(
                bs, -1, self.n_heads, self.d_k)  # `[B, klen, H, d_k]`
            if mask is not None:
                self.mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads])
                mask_size = (bs, qlen, klen, self.n_heads)
                assert self.mask.size() == mask_size, (self.mask.size(),
                                                       mask_size)
            else:
                self.mask = None

        key = self.key
        query = self.w_query(query).view(bs, -1, self.n_heads,
                                         self.d_k)  # `[B, qlen, H, d_k]`

        if self.atype == 'scaled_dot':
            e = torch.einsum("bihd,bjhd->bijh", (query, key)) / self.scale
        elif self.atype == 'add':
            e = self.v(
                torch.tanh(key[:, None] + query[:, :, None]).view(
                    bs, qlen, klen, -1))
        # e: `[B, qlen, klen, H]`

        # Compute attention weights
        if self.mask is not None:
            NEG_INF = float(
                np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
            e = e.masked_fill_(self.mask == 0, NEG_INF)  # `[B, qlen, klen, H]`
        aw = torch.softmax(e, dim=2)
        aw = self.dropout_attn(aw)
        aw_masked = aw.clone()

        # mask out each head independently (HeadDrop)
        if self.dropout_head > 0 and self.training:
            aw_masked = aw_masked.permute(0, 3, 1, 2)
            aw_masked = headdrop(aw_masked, self.n_heads,
                                 self.dropout_head)  # `[B, H, qlen, klen]`
            aw_masked = aw_masked.permute(0, 2, 3, 1)

        cv = torch.einsum("bijh,bjhd->bihd",
                          (aw_masked, self.value))  # `[B, qlen, H, d_k]`
        cv = cv.contiguous().view(bs, -1, self.n_heads *
                                  self.d_k)  # `[B, qlen, H * d_k]`
        cv = self.w_out(cv)
        aw = aw.permute(0, 3, 1, 2)  # `[B, H, qlen, klen]`

        return cv, aw, None, None
    def forward(self, key, query, pos_embs, mask, u_bias=None, v_bias=None):
        """Forward pass.

        Args:
            cat (FloatTensor): `[B, mlen+qlen, kdim]`
            mask (ByteTensor): `[B, qlen, mlen+qlen]`
            pos_embs (LongTensor): `[mlen+qlen, 1, d_model]`
            u_bias (nn.Parameter): `[H, d_k]`
            v_bias (nn.Parameter): `[H, d_k]`
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            aw (FloatTensor): `[B, H, qlen, mlen+qlen]`

        """
        bs, qlen = query.size()[:2]
        mlen = key.size(1) - qlen
        # NOTE: cat already includes memory, i.e., klen=mlen+qlen

        if mask is not None:
            mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads])
            assert mask.size() == (bs, qlen, mlen + qlen, self.n_heads), \
                (mask.size(), (bs, qlen, mlen + qlen, self.n_heads))

        k = self.w_key(key).view(bs, -1, self.n_heads,
                                 self.d_k)  # `[B, mlen+qlen, H, d_k]`
        v = self.w_value(key).view(bs, -1, self.n_heads,
                                   self.d_k)  # `[B, mlen+qlen, H, d_k]`
        q = self.w_query(key[:, -qlen:]).view(bs, -1, self.n_heads,
                                              self.d_k)  # `[B, qlen, H, d_k]`

        if self.xl_like:
            _pos_embs = self.w_pos(pos_embs)
        else:
            _pos_embs = self.w_value(pos_embs)  # NOTE: this is not w_value
        _pos_embs = _pos_embs.view(-1, self.n_heads,
                                   self.d_k)  # `[mlen+qlen, H, d_k]`

        # content-based attention term: (a) + (c)
        if u_bias is not None:
            assert self.xl_like
            AC = torch.einsum(
                "bihd,bjhd->bijh",
                (q + u_bias[None, None], k))  # `[B, qlen, mlen+qlen, H]`
        else:
            # A only accutually
            AC = torch.einsum("bihd,bjhd->bijh",
                              (q, k))  # `[B, qlen, mlen+qlen, H]`

        # position-based attention term: (b) + (d)
        if v_bias is not None:
            assert self.xl_like
            BD = torch.einsum("bihd,jhd->bijh",
                              (q + v_bias[None, None],
                               _pos_embs))  # `[B, qlen, mlen+qlen, H]`
        else:
            # B only accutually
            BD = torch.einsum("bihd,jhd->bijh",
                              (q, _pos_embs))  # `[B, qlen, mlen+qlen, H]`

        # Compute positional attention efficiently
        # BD = self._rel_shift_v1(BD)
        BD = self._rel_shift_v2(BD)

        # the attention is the sum of content-based and position-based attention
        e = (AC + BD) / self.scale  # `[B, qlen, mlen+qlen, H]`

        # Compute attention weights
        if mask is not None:
            NEG_INF = float(
                np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
            e = e.masked_fill_(mask == 0, NEG_INF)  # `[B, qlen, mlen+qlen, H]`
        aw = torch.softmax(e, dim=2)
        aw = self.dropout_attn(aw)  # `[B, qlen, mlen+qlen, H]`
        aw_masked = aw.clone()

        # mask out each head independently (HeadDrop)
        if self.dropout_head > 0 and self.training:
            aw_masked = aw_masked.permute(0, 3, 1, 2)
            aw_masked = headdrop(aw_masked, self.n_heads,
                                 self.dropout_head)  # `[B, H, qlen, klen]`
            aw_masked = aw_masked.permute(0, 2, 3, 1)

        cv = torch.einsum("bijh,bjhd->bihd", (aw, v))  # `[B, qlen, H, d_k]`
        cv = cv.contiguous().view(bs, -1, self.n_heads *
                                  self.d_k)  # `[B, qlen, H * d_k]`
        cv = self.w_out(cv)
        aw = aw.permute(0, 3, 1, 2)  # `[B, H, qlen, mlen+qlen]`

        return cv, aw