Beispiel #1
0
    def forward(self,
                key,
                value,
                query,
                mask=None,
                aw_prev=None,
                cache=False,
                mode='hard',
                trigger_points=None,
                eps_wait=-1,
                efficient_decoding=False):
        """Forward pass.

        Args:
            key (FloatTensor): `[B, klen, kdim]`
            value (FloatTensor): `[B, klen, vdim]`
            query (FloatTensor): `[B, qlen, qdim]`
            mask (ByteTensor): `[B, qlen, klen]`
            aw_prev (FloatTensor): `[B, H_ma, 1, klen]`
            cache (bool): cache key and mask
            mode (str): recursive/parallel/hard
            trigger_points (IntTensor): `[B, qlen]`
            eps_wait (int): wait time delay for head-synchronous decoding in MMA
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            alpha (FloatTensor): `[B, H_ma, qlen, klen]`
            beta (FloatTensor): `[B, H_ma * H_ca, qlen, klen]`
            p_choose (FloatTensor): `[B, H_ma, qlen, klen]`

        """
        bs, klen = key.size()[:2]
        qlen = query.size(1)
        tail_len = self.key_prev_tail.size(
            1) if self.key_prev_tail is not None else 0

        if aw_prev is None:
            # aw_prev = [1, 0, 0 ... 0]
            aw_prev = key.new_zeros(bs, self.n_heads_ma, 1, klen)
            aw_prev[:, :, :, 0:1] = key.new_ones(bs, self.n_heads_ma, 1, 1)

        # Compute monotonic energy
        e_ma = self.monotonic_energy(
            key, query, mask, cache=cache,
            boundary_leftmost=self.bd_offset)  # `[B, H_ma, qlen, klen]`
        assert e_ma.size(3) + self.bd_offset == key.size(1)

        if mode == 'recursive':  # training
            alpha, p_choose = self.recursive(e_ma, aw_prev)
            alpha_masked = alpha.clone()

        elif mode == 'parallel':  # training (efficient)
            alpha, p_choose = self.parallel(e_ma, aw_prev, trigger_points)

            # mask out each head independently (HeadDrop)
            if self.dropout_head > 0 and self.training:
                alpha_masked = headdrop(alpha.clone(), self.n_heads_ma,
                                        self.dropout_head)
            else:
                alpha_masked = alpha.clone()

        elif mode == 'hard':  # inference
            alpha, p_choose = self.hard(e_ma, aw_prev, eps_wait)
            alpha_masked = alpha.clone()

        else:
            raise ValueError(
                "mode must be 'recursive', 'parallel', or 'hard'.")

        # Compute chunk energy
        beta = None
        if self.chunk_energy is not None:
            bd_leftmost = 0
            bd_rightmost = klen - 1 - self.bd_offset
            if efficient_decoding and mode == 'hard' and alpha.sum() > 0:
                bd_leftmost = alpha[:, :, 0].nonzero()[:, -1].min().item()
                bd_rightmost = alpha[:, :, 0].nonzero()[:, -1].max().item()
                if bd_leftmost == bd_rightmost:
                    alpha_masked = alpha_masked[:, :, :,
                                                bd_leftmost:bd_leftmost + 1]
                else:
                    alpha_masked = alpha_masked[:, :, :,
                                                bd_leftmost:bd_rightmost]

            if mode == 'hard':
                if self.key_prev_tail is not None:
                    key_ = torch.cat(
                        [self.key_prev_tail[0:1].repeat([bs, 1, 1]), key],
                        dim=1)
                else:
                    key_ = key
                e_ca = self.chunk_energy(
                    key_,
                    query,
                    mask,
                    cache=cache,
                    boundary_leftmost=0 if self.milk else max(
                        0, self.bd_offset + bd_leftmost - self.w + 1),
                    boundary_rightmost=self.bd_offset + bd_rightmost + 1 +
                    tail_len)  # `[B, (H_ma*)H_ca, qlen, ken]`
            else:
                e_ca = self.chunk_energy(
                    key, query, mask,
                    cache=cache)  # `[B, (H_ma*)H_ca, qlen, ken]`

            # padding for chunkwise attention over adjacent input segments
            additional = e_ca.size(3) - alpha_masked.size(3)
            if efficient_decoding and mode == 'hard':
                alpha = torch.cat([
                    alpha.new_zeros(bs, alpha.size(1), 1,
                                    klen - alpha.size(3)), alpha
                ],
                                  dim=3)
                if additional > 0:
                    alpha_masked = torch.cat([
                        alpha_masked.new_zeros(bs, alpha_masked.size(1), 1,
                                               additional), alpha_masked
                    ],
                                             dim=3)

            if mode == 'hard':
                if self.key_prev_tail is not None:
                    alpha_masked = torch.cat([
                        alpha_masked.new_zeros(bs, self.n_heads_ma, qlen,
                                               tail_len), alpha_masked
                    ],
                                             dim=3)
                beta = hard_chunkwise_attention(alpha_masked, e_ca, mask,
                                                self.w, self.n_heads_ca,
                                                self.sharpening_factor,
                                                self.share_ca)
            else:
                beta = efficient_chunkwise_attention(alpha_masked, e_ca, mask,
                                                     self.w, self.n_heads_ca,
                                                     self.sharpening_factor,
                                                     self.share_ca)
            beta = self.dropout_attn(beta)  # `[B, H_ma * H_ca, qlen, klen]`

            if efficient_decoding and mode == 'hard':
                value = value[:,
                              max(0, self.bd_offset + bd_leftmost - self.w +
                                  1):self.bd_offset + bd_rightmost + 1]

        # Update after calculating beta
        bd_offset_prev = self.bd_offset
        if efficient_decoding and mode == 'hard' and alpha.sum() > 0:
            self.bd_offset += alpha[:, :, 0, self.bd_offset:].nonzero(
            )[:, -1].min().item()

        # Compute context vector
        if self.n_heads_ma * self.n_heads_ca > 1:
            value = self.w_value(value).view(bs, -1,
                                             self.n_heads_ma * self.n_heads_ca,
                                             self.d_k)
            value = value.transpose(
                2, 1).contiguous()  # `[B, H_ma * H_ca, klen, d_k]`
            cv = torch.matmul(alpha if self.w == 1 else beta,
                              value)  # `[B, H_ma * H_ca, qlen, d_k]`
            cv = cv.transpose(2, 1).contiguous().view(
                bs, -1, self.n_heads_ma * self.n_heads_ca * self.d_k)
            cv = self.w_out(cv)  # `[B, qlen, adim]`
        else:
            if self.w == 1:
                cv = torch.bmm(alpha.squeeze(1), value)  # `[B, 1, adim]`
            else:
                if self.key_prev_tail is not None:
                    value_ = torch.cat(
                        [self.key_prev_tail[0:1].repeat([bs, 1, 1]), value],
                        dim=1)
                    cv = torch.bmm(beta.squeeze(1), value_)  # `[B, 1, adim]`
                else:
                    cv = torch.bmm(beta.squeeze(1), value)  # `[B, 1, adim]`

        assert alpha.size() == (bs, self.n_heads_ma, qlen, klen), \
            (alpha.size(), (bs, self.n_heads_ma, qlen, klen))
        if self.w > 1:
            _w = max(1, (bd_offset_prev + bd_rightmost + 1) -
                     max(0, bd_offset_prev + bd_leftmost - self.w + 1))
            assert beta.size() == (bs, self.n_heads_ma * self.n_heads_ca, qlen, _w + tail_len), \
                (beta.size(), (bs, self.n_heads_ma * self.n_heads_ca, qlen, _w + tail_len))
        elif self.milk:
            assert beta.size() == (bs, self.n_heads_ma * self.n_heads_ca, qlen, klen), \
                (beta.size(), (bs, self.n_heads_ma * self.n_heads_ca, qlen, klen))

        return cv, alpha, beta, p_choose
Beispiel #2
0
    def forward(self, key, query, pos_embs, mask, u_bias=None, v_bias=None):
        """Forward pass.

        Args:
            cat (FloatTensor): `[B, mlen+qlen, kdim]`
            mask (ByteTensor): `[B, qlen, mlen+qlen]`
            pos_embs (LongTensor): `[mlen+qlen, 1, d_model]`
            u_bias (nn.Parameter): `[H, d_k]`
            v_bias (nn.Parameter): `[H, d_k]`
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            aw (FloatTensor): `[B, H, qlen, mlen+qlen]`

        """
        bs, qlen = query.size()[:2]
        mlen = key.size(1) - qlen
        # NOTE: cat already includes memory, i.e., klen=mlen+qlen

        if mask is not None:
            mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads])
            assert mask.size() == (bs, qlen, mlen + qlen, self.n_heads), \
                (mask.size(), (bs, qlen, mlen + qlen, self.n_heads))

        k = self.w_key(key).view(bs, -1, self.n_heads, self.d_k)  # `[B, mlen+qlen, H, d_k]`
        v = self.w_value(key).view(bs, -1, self.n_heads, self.d_k)  # `[B, mlen+qlen, H, d_k]`
        q = self.w_query(key[:, -qlen:]).view(bs, -1, self.n_heads, self.d_k)  # `[B, qlen, H, d_k]`

        if self.xl_like:
            _pos_embs = self.w_pos(pos_embs)
        else:
            _pos_embs = self.w_value(pos_embs)  # NOTE: this is not w_value
        _pos_embs = _pos_embs.view(-1, self.n_heads, self.d_k)  # `[mlen+qlen, H, d_k]`

        # content-based attention term: (a) + (c)
        if u_bias is not None:
            assert self.xl_like
            AC = torch.einsum("bihd,bjhd->bijh", (q + u_bias[None, None], k))  # `[B, qlen, mlen+qlen, H]`
        else:
            # A only accutually
            AC = torch.einsum("bihd,bjhd->bijh", (q, k))  # `[B, qlen, mlen+qlen, H]`

        # position-based attention term: (b) + (d)
        if v_bias is not None:
            assert self.xl_like
            BD = torch.einsum("bihd,jhd->bijh", (q + v_bias[None, None], _pos_embs))  # `[B, qlen, mlen+qlen, H]`
        else:
            # B only accutually
            BD = torch.einsum("bihd,jhd->bijh", (q, _pos_embs))  # `[B, qlen, mlen+qlen, H]`

        # Compute positional attention efficiently
        BD = self._rel_shift(BD)

        # the attention is the sum of content-based and position-based attention
        e = (AC + BD) / self.scale  # `[B, qlen, mlen+qlen, H]`

        # Compute attention weights
        if mask is not None:
            NEG_INF = float(np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
            e = e.masked_fill_(mask == 0, NEG_INF)  # `[B, qlen, mlen+qlen, H]`
        aw = torch.softmax(e, dim=2)
        aw = self.dropout_attn(aw)  # `[B, qlen, mlen+qlen, H]`
        aw_masked = aw.clone()

        # mask out each head independently (HeadDrop)
        if self.dropout_head > 0 and self.training:
            aw_masked = aw_masked.permute(0, 3, 1, 2)
            aw_masked = headdrop(aw_masked, self.n_heads, self.dropout_head)  # `[B, H, qlen, klen]`
            aw_masked = aw_masked.permute(0, 2, 3, 1)

        cv = torch.einsum("bijh,bjhd->bihd", (aw, v))  # `[B, qlen, H, d_k]`
        cv = cv.contiguous().view(bs, -1, self.n_heads * self.d_k)  # `[B, qlen, H * d_k]`
        cv = self.w_out(cv)
        aw = aw.permute(0, 3, 1, 2)  # `[B, H, qlen, mlen+qlen]`

        return cv, aw
Beispiel #3
0
    def forward(self,
                key,
                value,
                query,
                mask,
                aw_prev=None,
                cache=False,
                mode='hard',
                trigger_points=None,
                eps_wait=-1,
                linear_decoding=False,
                streaming=False):
        """Forward pass.

        Args:
            key (FloatTensor): `[B, klen, kdim]`
            value (FloatTensor): `[B, klen, vdim]`
            query (FloatTensor): `[B, qlen, qdim]`
            mask (ByteTensor): `[B, qlen, klen]`
            aw_prev (FloatTensor): `[B, H_ma, 1, klen]`
            cache (bool): cache key and mask
            mode (str): recursive/parallel/hard
            trigger_points (IntTensor): `[B, qlen]`
            eps_wait (int): wait time delay for head-synchronous decoding in MMA
            linear_decoding (bool): linear-time decoding mode
            streaming (bool): streaming mode (use self.key_prev_tail)
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            alpha (FloatTensor): `[B, H_ma, qlen, klen]`
            attn_state (dict):
                beta (FloatTensor): `[B, H_ma * H_ca, qlen, klen]`
                p_choose (FloatTensor): `[B, H_ma, qlen, klen]`

        """
        klen = key.size(1)
        bs, qlen = query.size()[:2]
        tail_len = self.key_prev_tail.size(
            1) if self.key_prev_tail is not None else 0
        bd_L = self.bd_L_prev
        bd_R = klen - 1
        attn_state = {}

        if aw_prev is None:
            aw_prev = key.new_zeros(bs, self.H_ma, 1, klen)
            aw_prev[:, :, :,
                    0:1] = key.new_ones(bs, self.H_ma, 1,
                                        1)  # aw_prev = [1, 0, 0 ... 0]

        # Compute monotonic energy
        e_ma = self.monotonic_energy(key, query, mask, cache, bd_L,
                                     bd_R)  # `[B, H_ma, qlen, klen]`
        assert e_ma.size(3) + bd_L == klen, (e_ma.size(), self.bd_L_prev,
                                             key.size())

        if mode == 'recursive':  # training (incremental)
            alpha, p_choose = self.recursive(e_ma, aw_prev)
            alpha_masked = alpha.clone()
        elif mode == 'parallel':  # training (efficient)
            alpha, p_choose = self.parallel(e_ma, aw_prev, trigger_points)
            # mask out each head independently (HeadDrop)
            if self.dropout_head > 0 and self.training:
                alpha_masked = headdrop(alpha.clone(), self.H_ma,
                                        self.dropout_head)
            else:
                alpha_masked = alpha.clone()
        elif mode == 'hard':  # inference
            aw_prev = aw_prev[:, :, :, -e_ma.size(3):]
            alpha, p_choose = self.hard(e_ma, aw_prev, eps_wait)
            alpha_masked = alpha.clone()
        else:
            raise ValueError(
                "mode must be 'recursive', 'parallel', or 'hard'.")

        is_boundary = (alpha.sum().item() > 0)

        # to the right of the leftmost boundary offset at the current step
        if linear_decoding and mode == 'hard' and is_boundary:
            bd_L = self.bd_L_prev + alpha[:, :, -1].nonzero()[:,
                                                              -1].min().item()
            bd_R = self.bd_L_prev + alpha[:, :, -1].nonzero()[:,
                                                              -1].max().item()
        bd_L_ca = max(0, bd_L + 1 - self.w) if not self.milk else 0
        use_tail = streaming and is_boundary and (bd_L + 1 <
                                                  self.w) and tail_len > 0

        # Compute chunk energy
        beta = None
        if self.chunk_energy is not None:
            if mode == 'hard':
                if not is_boundary:
                    # No boundary detected
                    beta = alpha.new_zeros(bs, self.H_total, qlen,
                                           value.size(1))
                else:
                    if use_tail:
                        key = torch.cat([self.key_prev_tail[0:1], key[0:1]],
                                        dim=1)
                        bd_L += tail_len
                        bd_R += tail_len
                    bd_L_ca = max(0, bd_L + 1 - self.w) if not self.milk else 0

                    e_ca = self.chunk_energy(
                        key, query, mask, cache, bd_L_ca,
                        bd_R)  # `[B, (H_ma*)H_ca, qlen, ken]`
                    assert e_ca.size(3) == bd_R - bd_L_ca + 1, (e_ca.size(),
                                                                bd_L_ca, bd_R,
                                                                key.size())

                    if alpha_masked.size(3) < klen:
                        # back to the original shape
                        alpha_masked = torch.cat([
                            alpha.new_zeros(bs, self.H_ma, qlen, klen -
                                            alpha_masked.size(3)), alpha_masked
                        ],
                                                 dim=3)
                    if use_tail:
                        alpha_masked = torch.cat([
                            alpha.new_zeros(bs, self.H_ma, qlen, tail_len),
                            alpha_masked
                        ],
                                                 dim=3)
                        value = torch.cat(
                            [self.key_prev_tail[0:1], value[0:1]], dim=1)

                    alpha_masked = alpha_masked[:, :, :, bd_L_ca:bd_R + 1]
                    value = value[:, bd_L_ca:bd_R + 1]
                    # NOTE: alpha_masked must have the same shape as beta

                    beta = hard_chunkwise_attention(
                        alpha_masked, e_ca, mask, self.w, self.H_ca,
                        self.sharpening_factor,
                        self.share_ca)  # `[B, H_ma * H_ca, qlen, klen]`
                    beta = self.dropout_attn(beta)

                    assert beta.size() == (bs, self.H_total, qlen, bd_R - bd_L_ca + 1), \
                        (beta.size(), (bs, self.H_total, qlen, bd_L_ca, bd_R))
            else:
                e_ca = self.chunk_energy(key, query, mask, cache, 0,
                                         bd_R)  # `[B, (H_ma*)H_ca, qlen, ken]`

                beta = soft_chunkwise_attention(
                    alpha_masked, e_ca, mask, self.w, self.H_ca,
                    self.sharpening_factor,
                    self.share_ca)  # `[B, H_ma * H_ca, qlen, klen]`
                beta = self.dropout_attn(beta)

                assert beta.size() == (bs, self.H_total, qlen, klen), \
                    (beta.size(), (bs, self.H_total, qlen, klen))

        if value.size(0) != bs:  # for infernece
            value = value[0:1].repeat([bs, 1, 1])

        # Compute context vector
        if self.H_total > 1:
            v = self.w_value(value).view(bs, -1, self.H_total, self.d_k)
            # TODO: cache at test time
            v = v.transpose(2, 1).contiguous()  # `[B, H_ma * H_ca, klen, d_k]`
            cv = torch.matmul(alpha_masked if self.w == 1 else beta,
                              v)  # `[B, H_ma * H_ca, qlen, d_k]`
            cv = cv.transpose(2, 1).contiguous().view(bs, -1,
                                                      self.H_total * self.d_k)
            cv = self.w_out(cv)  # `[B, qlen, adim]`
        else:
            cv = torch.bmm(
                alpha_masked.squeeze(1) if self.w == 1 else beta.squeeze(1),
                value)  # `[B, 1, adim]`

        if mode == 'hard' and use_tail:
            bd_L -= tail_len
            bd_R -= tail_len
            alpha_masked = alpha_masked[:, :, :, -klen:]
        self.bd_L_prev = bd_L

        # padding for the next step
        if mode == 'hard':
            alpha = alpha.new_zeros(bs, alpha.size(1), qlen, klen)
            if is_boundary:
                alpha[:, :, :,
                      bd_L:bd_R + 1] = alpha_masked[:, :, :,
                                                    -(bd_R - bd_L + 1):]

        assert alpha.size() == (bs, self.H_ma, qlen, klen), \
            (alpha.size(), (bs, self.H_ma, qlen, klen, bd_L, bd_R))

        # cache encoder outputs when moving to the next block
        if mode == 'hard' and streaming and self.key_cur_tail is None:
            if not is_boundary:
                self.key_cur_tail = key.detach()[:, -(self.w - 1):]
            elif bd_L + 1 < self.w:
                n_rest = self.w - (bd_L + 1)
                if n_rest < klen:
                    self.key_cur_tail = key.detach()[:, -n_rest:]
                elif self.key_prev_tail is not None:
                    # concatetane multiple blocks (>=3)
                    self.key_cur_tail = torch.cat([
                        self.key_prev_tail[:, -(klen - n_rest):],
                        key.detach()
                    ],
                                                  dim=1)[:, -n_rest:]

        attn_state['beta'] = beta
        attn_state['p_choose'] = p_choose

        return cv, alpha, attn_state
Beispiel #4
0
    def forward(self,
                key,
                value,
                query,
                mask,
                aw_prev=None,
                aw_lower=None,
                cache=False,
                mode='',
                trigger_points=None,
                eps_wait=-1,
                streaming=False):
        """Forward pass.

        Args:
            key (FloatTensor): `[B, klen, kdim]`
            value (FloatTensor): `[B, klen, vdim]`
            query (FloatTensor): `[B, qlen, qdim]`
            mask (ByteTensor): `[B, qlen, klen]`
            aw_prev: dummy interface
            cache (bool): cache key, value, and mask
            mode: dummy interface for MoChA/MMA
            trigger_points: dummy interface for MoChA/MMA
            eps_wait: dummy interface for MMA
            streaming: dummy interface for streaming attention
        Returns:
            cv (FloatTensor): `[B, qlen, vdim]`
            aw (FloatTensor): `[B, H, qlen, klen]`
            attn_state (dict): dummy interface

        """
        bs, klen = key.size()[:2]
        qlen = query.size(1)
        attn_state = {}
        # print('key',key.size())
        # print('query', query.size())

        # Pre-computation of encoder-side features for computing scores
        if self.key is None or not cache:
            self.key = self.w_key(key).view(bs, -1, self.n_heads,
                                            self.d_k)  # `[B, klen, H, d_k]`
            self.value = self.w_value(value).view(
                bs, -1, self.n_heads, self.d_k)  # `[B, klen, H, d_k]`
            if mask is not None:
                # print('mask:',mask.shape)
                # assert False, 'vv'
                self.mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads])
                # print(self.mask.shape)
                mask_size = (bs, qlen, klen, self.n_heads)
                assert self.mask.size() == mask_size, (self.mask.size(),
                                                       mask_size)
            else:
                self.mask = None

        key = self.key
        query = self.w_query(query).view(bs, -1, self.n_heads,
                                         self.d_k)  # `[B, qlen, H, d_k]`

        if self.atype == 'scaled_dot':
            e = torch.einsum("bihd,bjhd->bijh", (query, key)) / self.scale
        elif self.atype == 'add':
            e = self.v(
                torch.tanh(key[:, None] + query[:, :, None]).view(
                    bs, qlen, klen, -1))
        # e: `[B, qlen, klen, H]`

        # Compute attention weights
        if self.mask is not None:
            NEG_INF = float(
                np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
            e = e.masked_fill_(self.mask == 0, NEG_INF)  # `[B, qlen, klen, H]`
        aw = torch.softmax(e, dim=2)
        aw = self.dropout_attn(aw)
        aw_masked = aw.clone()

        # mask out each head independently (HeadDrop)
        if self.dropout_head > 0 and self.training:
            aw_masked = aw_masked.permute(0, 3, 1, 2)
            aw_masked = headdrop(aw_masked, self.n_heads,
                                 self.dropout_head)  # `[B, H, qlen, klen]`
            aw_masked = aw_masked.permute(0, 2, 3, 1)

        cv = torch.einsum("bijh,bjhd->bihd",
                          (aw_masked, self.value))  # `[B, qlen, H, d_k]`
        cv = cv.contiguous().view(bs, -1, self.n_heads *
                                  self.d_k)  # `[B, qlen, H * d_k]`
        cv = self.w_out(cv)
        aw = aw.permute(0, 3, 1, 2)  # `[B, H, qlen, klen]`

        return cv, aw, attn_state