Beispiel #1
0
def replace_inf_with_zero(x):
    return th.masked_fill(x, th.isinf(x), 0)
Beispiel #2
0
    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        '''
        The `attention_mask` is changed in BertModel.forward from 0, 1, 2 to
            -ve: no attention
              0: local attention
            +ve: global attention
        '''
        if attention_mask is not None:
            attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
            key_padding_mask = attention_mask < 0
            extra_attention_mask = attention_mask > 0
            remove_from_windowed_attention_mask = attention_mask != 0

            num_extra_indices_per_batch = extra_attention_mask.long().sum(
                dim=1)
            max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
            has_same_length_extra_indices = (
                num_extra_indices_per_batch == max_num_extra_indices_per_batch
            ).all()
        hidden_states = hidden_states.transpose(0, 1)
        seq_len, bsz, embed_dim = hidden_states.size()
        assert embed_dim == self.embed_dim
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)
        q /= math.sqrt(self.head_dim)

        q = q.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        k = k.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        # attn_weights = (bsz, seq_len, num_heads, window*2+1)
        attn_weights = diagonaled_mm_tvm(q, k, self.attention_window,
                                         self.attention_dilation, False, 0,
                                         False)
        mask_invalid_locations(attn_weights, self.attention_window,
                               self.attention_dilation, False)
        if remove_from_windowed_attention_mask is not None:
            # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
            # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
            remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(
                dim=-1).unsqueeze(dim=-1)
            # cast to float/half then replace 1's with -inf
            float_mask = remove_from_windowed_attention_mask.type_as(
                q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
            repeat_size = 1 if isinstance(
                self.attention_dilation, int) else len(self.attention_dilation)
            float_mask = float_mask.repeat(1, 1, repeat_size, 1)
            ones = float_mask.new_ones(
                size=float_mask.size())  # tensor of ones
            # diagonal mask with zeros everywhere and -inf inplace of padding
            d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window,
                                       self.attention_dilation, False, 0,
                                       False)
            attn_weights += d_mask
        assert list(attn_weights.size()) == [
            bsz, seq_len, self.num_heads, self.attention_window * 2 + 1
        ]

        # the extra attention
        if extra_attention_mask is not None:
            if has_same_length_extra_indices:
                # a simplier implementation for efficiency
                # k = (bsz, seq_len, num_heads, head_dim)
                selected_k = k.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view(
                        bsz, max_num_extra_indices_per_batch, self.num_heads,
                        self.head_dim)
                # selected_k = (bsz, extra_attention_count, num_heads, head_dim)
                # selected_attn_weights = (bsz, seq_len, num_heads, extra_attention_count)
                selected_attn_weights = torch.einsum('blhd,bshd->blhs',
                                                     (q, selected_k))
            else:
                # since the number of extra attention indices varies across
                # the batch, we need to process each element of the batch
                # individually
                flat_selected_k = k.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1))
                selected_attn_weights = torch.ones(
                    bsz,
                    seq_len,
                    self.num_heads,
                    max_num_extra_indices_per_batch,
                    device=k.device,
                    dtype=k.dtype)
                selected_attn_weights.fill_(-10000.0)
                start = 0
                for i in range(bsz):
                    end = start + num_extra_indices_per_batch[
                        i] * self.num_heads * self.head_dim
                    # the selected entries for this batch element
                    i_selected_k = flat_selected_k[start:end].view(
                        -1, self.num_heads, self.head_dim)
                    # (seq_len, num_heads, num extra indices)
                    i_selected_attn_weights = torch.einsum(
                        'lhd,shd->lhs', (q[i, :, :, :], i_selected_k))
                    selected_attn_weights[
                        i, :, :, :num_extra_indices_per_batch[
                            i]] = i_selected_attn_weights
                    start = end

            # concat to attn_weights
            # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
            attn_weights = torch.cat((selected_attn_weights, attn_weights),
                                     dim=-1)

        attn_weights_float = F.softmax(attn_weights, dim=-1)
        if key_padding_mask is not None:
            # softmax sometimes inserts NaN if all positions are masked, replace them with 0
            attn_weights_float = torch.masked_fill(
                attn_weights_float,
                key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)

        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                               p=self.dropout,
                               training=self.training)
        v = v.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        attn = 0
        if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
            selected_attn_probs = attn_probs.narrow(
                -1, 0, max_num_extra_indices_per_batch)
            if has_same_length_extra_indices:
                selected_v = v.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view(
                        bsz, max_num_extra_indices_per_batch, self.num_heads,
                        self.head_dim)
            else:
                flat_selected_v = v.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1))
                # don't worry about masking since this is multiplied by attn_probs, and masking above
                # before softmax will remove masked entries
                selected_v = torch.zeros(bsz,
                                         max_num_extra_indices_per_batch,
                                         self.num_heads,
                                         self.head_dim,
                                         device=v.device,
                                         dtype=v.dtype)
                start = 0
                for i in range(bsz):
                    end = start + num_extra_indices_per_batch[
                        i] * self.num_heads * self.head_dim
                    i_selected_v = flat_selected_v[start:end].view(
                        -1, self.num_heads, self.head_dim)
                    selected_v[
                        i, :
                        num_extra_indices_per_batch[i], :, :] = i_selected_v
                    start = end
            attn = torch.einsum('blhs,bshd->blhd',
                                (selected_attn_probs, selected_v))
            attn_probs = attn_probs.narrow(
                -1, max_num_extra_indices_per_batch,
                attn_probs.size(-1) -
                max_num_extra_indices_per_batch).contiguous()

        attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window,
                                  self.attention_dilation, True, 0, False)
        attn = attn.type_as(hidden_states)
        assert list(
            attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
        attn = attn.transpose(0, 1).reshape(seq_len, bsz,
                                            embed_dim).contiguous()

        # For this case, we'll just recompute the attention for these indices
        # and overwrite the attn tensor. TODO: remove the redundant computation
        if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
            if has_same_length_extra_indices:
                # query = (seq_len, bsz, dim)
                # extra_attention_mask = (bsz, seq_len)
                # selected_query = (max_num_extra_indices_per_batch, bsz, embed_dim)
                selected_hidden_states = hidden_states.masked_select(
                    extra_attention_mask.transpose(0, 1).unsqueeze(-1)).view(
                        max_num_extra_indices_per_batch, bsz, embed_dim)
                # if *_proj_full exists use them, otherwise default to *_proj
                q = self.query_global(selected_hidden_states)
                k = self.key_global(hidden_states)
                v = self.value_global(hidden_states)
                q /= math.sqrt(self.head_dim)

                q = q.contiguous().view(
                    max_num_extra_indices_per_batch, bsz * self.num_heads,
                    self.head_dim
                ).transpose(
                    0, 1
                )  # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
                k = k.contiguous().view(
                    -1, bsz * self.num_heads, self.head_dim).transpose(
                        0, 1)  # bsz * self.num_heads, seq_len, head_dim)
                v = v.contiguous().view(
                    -1, bsz * self.num_heads, self.head_dim).transpose(
                        0, 1)  # bsz * self.num_heads, seq_len, head_dim)
                attn_weights = torch.bmm(q, k.transpose(1, 2))
                assert list(attn_weights.size()) == [
                    bsz * self.num_heads, max_num_extra_indices_per_batch,
                    seq_len
                ]
                if key_padding_mask is not None:
                    attn_weights = attn_weights.view(
                        bsz, self.num_heads, max_num_extra_indices_per_batch,
                        seq_len)
                    attn_weights = attn_weights.masked_fill(
                        key_padding_mask.unsqueeze(1).unsqueeze(2),
                        float('-inf'),
                    )
                    attn_weights = attn_weights.view(
                        bsz * self.num_heads, max_num_extra_indices_per_batch,
                        seq_len)
                attn_weights_float = F.softmax(attn_weights, dim=-1)
                attn_probs = F.dropout(
                    attn_weights_float.type_as(attn_weights),
                    p=self.dropout,
                    training=self.training)
                selected_attn = torch.bmm(attn_probs, v)
                assert list(selected_attn.size()) == [
                    bsz * self.num_heads, max_num_extra_indices_per_batch,
                    self.head_dim
                ]
                selected_attn = selected_attn.transpose(
                    0, 1).contiguous().view(max_num_extra_indices_per_batch *
                                            bsz * embed_dim)

                # now update attn by filling in the relevant indices with selected_attn
                # masked_fill_ only allows floats as values so this doesn't work
                # attn.masked_fill_(extra_attention_mask.transpose(0, 1).unsqueeze(-1), selected_attn)
                attn[extra_attention_mask.transpose(0, 1).unsqueeze(-1).repeat(
                    (1, 1, embed_dim))] = selected_attn
            else:
                raise ValueError  # not implemented

        context_layer = attn.transpose(0, 1)
        if self.output_attentions:
            if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
                # With global attention, return global attention probabilities only
                # batch_size x num_heads x num_global_attention_tokens x sequence_length
                # which is the attention weights from tokens with global attention to all tokens
                # It doesn't not return local attention
                attn_weights = attn_weights.view(
                    bsz, self.num_heads, max_num_extra_indices_per_batch,
                    seq_len)
            else:
                # without global attention, return local attention probabilities
                # batch_size x num_heads x sequence_length x window_size
                # which is the attention weights of every token attending to its neighbours
                attn_weights = attn_weights.permute(0, 2, 1, 3)
        outputs = (context_layer,
                   attn_weights) if self.output_attentions else (
                       context_layer, )
        return outputs
Beispiel #3
0
    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        '''
        The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
            -ve: no attention
              0: local attention
            +ve: global attention
        '''
        if attention_mask is not None:
            attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
            key_padding_mask = attention_mask < 0
            extra_attention_mask = attention_mask > 0
            remove_from_windowed_attention_mask = attention_mask != 0

            num_extra_indices_per_batch = extra_attention_mask.long().sum(
                dim=1)
            max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
            if max_num_extra_indices_per_batch <= 0:
                extra_attention_mask = None
            else:
                # To support the case of variable number of global attention in the rows of a batch,
                # we use the following three selection masks to select global attention embeddings
                # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
                # 1) selecting embeddings that correspond to global attention
                extra_attention_mask_nonzeros = extra_attention_mask.nonzero(
                    as_tuple=True)
                zero_to_max_range = torch.arange(
                    0,
                    max_num_extra_indices_per_batch,
                    device=num_extra_indices_per_batch.device)
                # mask indicating which values are actually going to be padding
                selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(
                    dim=-1)
                # 2) location of the non-padding values in the selected global attention
                selection_padding_mask_nonzeros = selection_padding_mask.nonzero(
                    as_tuple=True)
                # 3) location of the padding values in the selected global attention
                selection_padding_mask_zeros = (
                    selection_padding_mask == 0).nonzero(as_tuple=True)
        else:
            remove_from_windowed_attention_mask = None
            extra_attention_mask = None
            key_padding_mask = None

        hidden_states = hidden_states.transpose(0, 1)
        seq_len, bsz, embed_dim = hidden_states.size()
        assert embed_dim == self.embed_dim
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)
        q /= math.sqrt(self.head_dim)

        q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        # attn_weights = (bsz, seq_len, num_heads, window*2+1)
        if self.attention_mode == 'tvm':
            q = q.float().contiguous()
            k = k.float().contiguous()
            attn_weights = diagonaled_mm_tvm(q, k, self.attention_window,
                                             self.attention_dilation, False, 0,
                                             False)
        else:  # "sliding_chunks"
            attn_weights = sliding_chunks_matmul_qk(q,
                                                    k,
                                                    self.attention_window,
                                                    padding_value=0)
        mask_invalid_locations(attn_weights, self.attention_window,
                               self.attention_dilation, False)
        if remove_from_windowed_attention_mask is not None:
            # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
            # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
            remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(
                dim=-1).unsqueeze(dim=-1)
            # cast to float/half then replace 1's with -inf
            float_mask = remove_from_windowed_attention_mask.type_as(
                q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
            repeat_size = 1 if isinstance(
                self.attention_dilation, int) else len(self.attention_dilation)
            float_mask = float_mask.repeat(1, 1, repeat_size, 1)
            ones = float_mask.new_ones(
                size=float_mask.size())  # tensor of ones
            # diagonal mask with zeros everywhere and -inf inplace of padding
            if self.attention_mode == 'tvm':
                d_mask = diagonaled_mm_tvm(ones, float_mask,
                                           self.attention_window,
                                           self.attention_dilation, False, 0,
                                           False)
            else:
                d_mask = sliding_chunks_matmul_qk(ones,
                                                  float_mask,
                                                  self.attention_window,
                                                  padding_value=0)
            attn_weights += d_mask
        assert list(attn_weights.size()) == [
            bsz, seq_len, self.num_heads, self.attention_window * 2 + 1
        ]

        # the extra attention
        if extra_attention_mask is not None:
            selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch,
                                     self.num_heads, self.head_dim)
            selected_k[selection_padding_mask_nonzeros] = k[
                extra_attention_mask_nonzeros]
            # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
            selected_attn_weights = torch.einsum('blhd,bshd->blhs',
                                                 (q, selected_k))
            selected_attn_weights[selection_padding_mask_zeros[0], :, :,
                                  selection_padding_mask_zeros[1]] = -10000
            # concat to attn_weights
            # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
            attn_weights = torch.cat((selected_attn_weights, attn_weights),
                                     dim=-1)

        attn_weights_float = F.softmax(
            attn_weights, dim=-1,
            dtype=torch.float32)  # use fp32 for numerical stability
        if key_padding_mask is not None:
            # softmax sometimes inserts NaN if all positions are masked, replace them with 0
            attn_weights_float = torch.masked_fill(
                attn_weights_float,
                key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)

        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                               p=self.dropout,
                               training=self.training)
        v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        attn = 0
        if extra_attention_mask is not None:
            selected_attn_probs = attn_probs.narrow(
                -1, 0, max_num_extra_indices_per_batch)
            selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch,
                                     self.num_heads, self.head_dim)
            selected_v[selection_padding_mask_nonzeros] = v[
                extra_attention_mask_nonzeros]
            # use `matmul` because `einsum` crashes sometimes with fp16
            # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
            attn = torch.matmul(
                selected_attn_probs.transpose(1, 2),
                selected_v.transpose(
                    1, 2).type_as(selected_attn_probs)).transpose(1, 2)
            attn_probs = attn_probs.narrow(
                -1, max_num_extra_indices_per_batch,
                attn_probs.size(-1) -
                max_num_extra_indices_per_batch).contiguous()

        if self.attention_mode == 'tvm':
            v = v.float().contiguous()
            attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window,
                                      self.attention_dilation, True, 0, False)
        else:  # "sliding_chunks"
            attn += sliding_chunks_matmul_pv(attn_probs, v,
                                             self.attention_window)

        attn = attn.type_as(hidden_states)
        assert list(
            attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
        attn = attn.transpose(0, 1).reshape(seq_len, bsz,
                                            embed_dim).contiguous()

        # For this case, we'll just recompute the attention for these indices
        # and overwrite the attn tensor. TODO: remove the redundant computation
        if extra_attention_mask is not None:
            selected_hidden_states = hidden_states.new_zeros(
                max_num_extra_indices_per_batch, bsz, embed_dim)
            selected_hidden_states[
                selection_padding_mask_nonzeros[::-1]] = hidden_states[
                    extra_attention_mask_nonzeros[::-1]]

            q = self.query_global(selected_hidden_states)
            k = self.key_global(hidden_states)
            v = self.value_global(hidden_states)
            q /= math.sqrt(self.head_dim)

            q = q.contiguous().view(
                max_num_extra_indices_per_batch, bsz * self.num_heads,
                self.head_dim
            ).transpose(
                0, 1
            )  # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
            k = k.contiguous().view(
                -1, bsz * self.num_heads, self.head_dim).transpose(
                    0, 1)  # bsz * self.num_heads, seq_len, head_dim)
            v = v.contiguous().view(
                -1, bsz * self.num_heads, self.head_dim).transpose(
                    0, 1)  # bsz * self.num_heads, seq_len, head_dim)
            attn_weights = torch.bmm(q, k.transpose(1, 2))
            assert list(attn_weights.size()) == [
                bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len
            ]

            attn_weights = attn_weights.view(bsz, self.num_heads,
                                             max_num_extra_indices_per_batch,
                                             seq_len)
            attn_weights[selection_padding_mask_zeros[0], :,
                         selection_padding_mask_zeros[1], :] = -10000.0
            if key_padding_mask is not None:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    -10000.0,
                )
            attn_weights = attn_weights.view(bsz * self.num_heads,
                                             max_num_extra_indices_per_batch,
                                             seq_len)
            attn_weights_float = F.softmax(
                attn_weights, dim=-1,
                dtype=torch.float32)  # use fp32 for numerical stability
            attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                                   p=self.dropout,
                                   training=self.training)
            selected_attn = torch.bmm(attn_probs, v)
            assert list(selected_attn.size()) == [
                bsz * self.num_heads, max_num_extra_indices_per_batch,
                self.head_dim
            ]

            selected_attn_4d = selected_attn.view(
                bsz, self.num_heads, max_num_extra_indices_per_batch,
                self.head_dim)
            nonzero_selected_attn = selected_attn_4d[
                selection_padding_mask_nonzeros[0], :,
                selection_padding_mask_nonzeros[1]]
            attn[
                extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
                    len(selection_padding_mask_nonzeros[0]),
                    -1).type_as(hidden_states)

        context_layer = attn.transpose(0, 1)
        if self.output_attentions:
            if extra_attention_mask is not None:
                # With global attention, return global attention probabilities only
                # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
                # which is the attention weights from tokens with global attention to all tokens
                # It doesn't not return local attention
                # In case of variable number of global attantion in the rows of a batch,
                # attn_weights are padded with -10000.0 attention scores
                attn_weights = attn_weights.view(
                    bsz, self.num_heads, max_num_extra_indices_per_batch,
                    seq_len)
            else:
                # without global attention, return local attention probabilities
                # batch_size x num_heads x sequence_length x window_size
                # which is the attention weights of every token attending to its neighbours
                attn_weights = attn_weights.permute(0, 2, 1, 3)
        outputs = (context_layer,
                   attn_weights) if self.output_attentions else (
                       context_layer, )
        return outputs
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        is_index_masked=None,
        is_index_global_attn=None,
        is_global_attn=None,
        output_attentions=False,
    ):
        """
        :class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_band`. Padding to
        `attention_band` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer.

        The `attention_mask` is changed in :meth:`LongformerModel.forward` from 0, 1, 2 to:

            * -10000: no attention
            * 0: local attention
            * +10000: global attention
        """
        hidden_states = hidden_states.transpose(0, 1)

        # project hidden states
        query_vectors = self.query(hidden_states)
        key_vectors = self.key(hidden_states)
        value_vectors = self.value(hidden_states)

        seq_len, batch_size, embed_dim = hidden_states.size()
        assert (
            embed_dim == self.embed_dim
        ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"

        # normalize query
        query_vectors /= math.sqrt(self.head_dim)

        query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads,
                                           self.head_dim).transpose(0, 1)
        key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads,
                                       self.head_dim).transpose(0, 1)

        attn_scores = self._sliding_chunks_query_key_matmul(
            query_vectors, key_vectors, self.one_sided_attn_window_size)

        # values to pad for attention probs
        remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None,
                                                                    None]

        # cast to fp32/fp16 then replace 1's with -inf
        float_mask = remove_from_windowed_attention_mask.type_as(
            query_vectors).masked_fill(remove_from_windowed_attention_mask,
                                       -10000.0)
        # diagonal mask with zeros everywhere and -inf inplace of padding
        diagonal_mask = self._sliding_chunks_query_key_matmul(
            float_mask.new_ones(size=float_mask.size()), float_mask,
            self.one_sided_attn_window_size)

        # pad local attention probs
        attn_scores += diagonal_mask

        assert list(attn_scores.size()) == [
            batch_size,
            seq_len,
            self.num_heads,
            self.one_sided_attn_window_size * 2 + 1,
        ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"

        # compute local attention probs from global attention keys and contact over window dim
        if is_global_attn:
            # compute global attn indices required through out forward fn
            (
                max_num_global_attn_indices,
                is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero,
            ) = self._get_global_attn_indices(is_index_global_attn)
            # calculate global attn probs from global key

            global_key_attn_scores = self._concat_with_global_key_attn_probs(
                query_vectors=query_vectors,
                key_vectors=key_vectors,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=
                is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=
                is_local_index_no_global_attn_nonzero,
            )
            # concat to local_attn_probs
            # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
            attn_scores = torch.cat((global_key_attn_scores, attn_scores),
                                    dim=-1)

            # free memory
            del global_key_attn_scores

        attn_probs = F.softmax(
            attn_scores, dim=-1,
            dtype=torch.float32)  # use fp32 for numerical stability

        # softmax sometimes inserts NaN if all positions are masked, replace them with 0
        attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None,
                                                                   None], 0.0)
        attn_probs = attn_probs.type_as(attn_scores)

        # free memory
        del attn_scores

        # apply dropout
        attn_probs = F.dropout(attn_probs,
                               p=self.config.attention_probs_dropout_prob,
                               training=self.training)

        value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads,
                                           self.head_dim).transpose(0, 1)

        # compute local attention output with global attention value and add
        if is_global_attn:
            # compute sum of global and local attn
            attn_output = self._compute_attn_output_with_global_indices(
                value_vectors=value_vectors,
                attn_probs=attn_probs,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=
                is_local_index_global_attn_nonzero,
            )
        else:
            # compute local attn only
            attn_output = self._sliding_chunks_matmul_attn_probs_value(
                attn_probs, value_vectors, self.one_sided_attn_window_size)

        assert attn_output.size() == (batch_size, seq_len, self.num_heads,
                                      self.head_dim), "Unexpected size"
        attn_output = attn_output.transpose(0,
                                            1).reshape(seq_len, batch_size,
                                                       embed_dim).contiguous()

        # compute value for global attention and overwrite to attention output
        # TODO: remove the redundant computation
        if is_global_attn:
            global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
                hidden_states=hidden_states,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_local_index_global_attn_nonzero=
                is_local_index_global_attn_nonzero,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=
                is_local_index_no_global_attn_nonzero,
                is_index_masked=is_index_masked,
            )

            # get only non zero global attn output
            nonzero_global_attn_output = global_attn_output[
                is_local_index_global_attn_nonzero[0], :,
                is_local_index_global_attn_nonzero[1]]

            # overwrite values with global attention
            attn_output[
                is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
                    len(is_local_index_global_attn_nonzero[0]), -1)
            # The attention weights for tokens with global attention are
            # just filler values, they were never used to compute the output.
            # Fill with 0 now, the correct values are in 'global_attn_probs'.
            attn_probs[is_index_global_attn_nonzero] = 0

        outputs = (attn_output.transpose(0, 1), )

        if output_attentions:
            outputs += (attn_probs, )

        return outputs + (global_attn_probs, ) if (
            is_global_attn and output_attentions) else outputs
Beispiel #5
0
    def forward(
        self,
        query: torch.Tensor,
        keys: List[torch.Tensor],
        ref_point: torch.Tensor,
        query_mask: torch.Tensor = None,
        key_masks: Optional[torch.Tensor] = None,
    ):
        """
        :param key_masks:
        :param query_mask:
        :param query: B, H, W, C
        :param keys: List[B, H, W, C]
        :param ref_point: B, H, W, 2
        :return:
        """
        if key_masks is None:
            key_masks = [None] * len(keys)

        assert len(keys) == self.scales

        attns = {'attns': None, 'offsets': None}

        nbatches, query_height, query_width, _ = query.shape

        # B, H, W, C
        query = self.q_proj(query)

        # B, H, W, 2MLK
        offset = self.offset_proj(query)
        # B, H, W, M, 2LK
        offset = offset.view(nbatches, query_height, query_width, self.h, -1)

        # B, H, W, MLK
        A = self.A_proj(query)

        # B, H, W, 1, mask before softmax
        if query_mask is not None:
            query_mask_ = query_mask.unsqueeze(dim=-1)
            _, _, _, mlk = A.shape
            query_mask_ = query_mask_.expand(nbatches, query_height,
                                             query_width, mlk)
            A = torch.masked_fill(A, mask=query_mask_, value=float('-inf'))

        # B, H, W, M, LK
        A = A.view(nbatches, query_height, query_width, self.h, -1)
        A = F.softmax(A, dim=-1)

        # mask nan position
        if query_mask is not None:
            # B, H, W, 1, 1
            query_mask_ = query_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
            A = torch.masked_fill(A, query_mask_.expand_as(A), 0.0)

        if self.need_attn:
            attns['attns'] = A
            attns['offsets'] = offset

        offset = offset.view(nbatches, query_height, query_width, self.h,
                             self.scales, self.k, 2)
        offset = offset.permute(0, 3, 4, 5, 1, 2, 6).contiguous()
        # B*M, L, K, H, W, 2
        offset = offset.view(nbatches * self.h, self.scales, self.k,
                             query_height, query_width, 2)

        A = A.permute(0, 3, 1, 2, 4).contiguous()
        # B*M, H*W, LK
        A = A.view(nbatches * self.h, query_height * query_width, -1)

        scale_features = []
        for l in range(self.scales):
            feat_map = keys[l]
            _, h, w, _ = feat_map.shape

            key_mask = key_masks[l]

            # B, H, W, 2
            reversed_ref_point = restore_scale(height=h,
                                               width=w,
                                               ref_point=ref_point)

            # B, H, W, 2 -> B*M, H, W, 2
            reversed_ref_point = reversed_ref_point.repeat(self.h, 1, 1, 1)

            # B, h, w, M, C_v
            scale_feature = self.k_proj(feat_map).view(nbatches, h, w, self.h,
                                                       self.d_k)

            if key_mask is not None:
                # B, h, w, 1, 1
                key_mask = key_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
                key_mask = key_mask.expand(nbatches, h, w, self.h, self.d_k)
                scale_feature = torch.masked_fill(scale_feature,
                                                  mask=key_mask,
                                                  value=0)

            # B, M, C_v, h, w
            scale_feature = scale_feature.permute(0, 3, 4, 1, 2).contiguous()
            # B*M, C_v, h, w
            scale_feature = scale_feature.view(-1, self.d_k, h, w)

            k_features = []

            for k in range(self.k):
                points = reversed_ref_point + offset[:, l, k, :, :, :]
                vgrid_x = 2.0 * points[:, :, :, 0] / max(w - 1, 1) - 1.0
                vgrid_y = 2.0 * points[:, :, :, 1] / max(h - 1, 1) - 1.0
                vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)

                # B*M, C_v, H, W
                feat = F.grid_sample(scale_feature,
                                     vgrid_scaled,
                                     mode='bilinear',
                                     padding_mode='zeros')
                k_features.append(feat)
            # B*M, k, C_v, H, W
            k_features = torch.stack(k_features, dim=1)
            scale_features.append(k_features)

        # B*M, L, K, C_v, H, W
        scale_features = torch.stack(scale_features, dim=1)

        # B*M, H*W, C_v, LK
        scale_features = scale_features.permute(0, 4, 5, 3, 1, 2).contiguous()
        scale_features = scale_features.view(nbatches * self.h,
                                             query_height * query_width,
                                             self.d_k, -1)

        # B*M, H*W, C_v
        feat = torch.einsum('nlds, nls -> nld', scale_features, A)

        # B*M, H*W, C_v -> B, M, H, W, C_v
        feat = feat.view(nbatches, self.h, query_height, query_width, self.d_k)
        # B, M, H, W, C_v -> B, H, W, M, C_v
        feat = feat.permute(0, 2, 3, 1, 4).contiguous()
        # B, H, W, M, C_v -> B, H, W, M * C_v
        feat = feat.view(nbatches, query_height, query_width,
                         self.d_k * self.h)

        feat = self.wm_proj(feat)
        if self.dropout:
            feat = self.dropout(feat)

        return feat, attns
    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        alibi: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: bool = False,
    ):
        fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]

        # 3 x [batch_size, seq_length, num_heads, head_dim]
        (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

        batch_size, q_length, _, _ = query_layer.shape

        query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
        key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
        value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
        if layer_past is not None:
            past_key, past_value = layer_past
            # concatenate along seq_length dimension:
            #  - key: [batch_size * self.num_heads, head_dim, kv_length]
            #  - value: [batch_size * self.num_heads, kv_length, head_dim]
            key_layer = torch.cat((past_key, key_layer), dim=2)
            value_layer = torch.cat((past_value, value_layer), dim=1)

        _, _, kv_length = key_layer.shape

        if use_cache is True:
            present = (key_layer, value_layer)
        else:
            present = None

        # [batch_size * num_heads, q_length, kv_length]
        # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
        matmul_result = alibi.baddbmm(
            batch1=query_layer,
            batch2=key_layer,
            beta=self.beta,
            alpha=self.inv_norm_factor,
        )

        # change view to [batch_size, num_heads, q_length, kv_length]
        attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)

        # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
        input_dtype = attention_scores.dtype
        # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
        if input_dtype == torch.float16:
            attention_scores = attention_scores.to(torch.float)
        attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
        attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

        # [batch_size, num_heads, q_length, kv_length]
        attention_probs = self.attention_dropout(attention_probs)

        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # change view [batch_size x num_heads, q_length, kv_length]
        attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)

        # matmul: [batch_size * num_heads, q_length, head_dim]
        context_layer = torch.bmm(attention_probs_reshaped, value_layer)

        # change view [batch_size, num_heads, q_length, head_dim]
        context_layer = self._merge_heads(context_layer)

        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
        if self.pretraining_tp > 1 and self.slow_but_exact:
            slices = self.hidden_size / self.pretraining_tp
            output_tensor = torch.zeros_like(context_layer)
            for i in range(self.pretraining_tp):
                output_tensor = output_tensor + F.linear(
                    context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
                    self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
                )
        else:
            output_tensor = self.dense(context_layer)

        output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

        outputs = (output_tensor, present)
        if output_attentions:
            outputs += (attention_probs,)

        return outputs
Beispiel #7
0
def _zero_padding_tokens(response_tokens):
    mask = response_tokens == PADDING_TOKEN
    assert (not (mask[:, :, 1:] < mask[:, :, :-1]).any().item()
            ), f"Padding tokens not a suffix {to_numpy(response_tokens)}"
    return mask, torch.masked_fill(response_tokens, mask, 0)
Beispiel #8
0
    def forward(self,
                x,
                target_domain,
                mix_output: bool = False,
                used_domain_list: list = None,
                mix_weight: torch.Tensor = None,
                domain_mask: torch.Tensor = None):

        """

        confirm the order in domain_weight and used_domain_list is same

        :param x: [B, L, H]
        :param target_domain: a string, like 'news', 'book', 'bible'
        :param mix_output:
        :param used_domain_list: used domain list, like ['book', 'iwslt']
        :param mix_weight: [B, L, D]
        :param domain_mask: [D]
        :return:
        """

        # if we only use one adapter
        if mix_output is False:
            if target_domain == 'news':
                return x
            else:
                return self.sublayer_connection_for_adapter[target_domain](x, self.adapter_layers[target_domain])

        # else we should mix the current adapters output
        else:

            assert self.check_domain_list_order(used_domain_list) is True

            # first, produce the weight based on the input, or provide the weight
            if mix_weight is None:

                if domain_mask is None:
                    domain_mask = [0] * self.max_domain_num
                    for domain in used_domain_list:
                        domain_mask[self.domain_dict[domain]] = 1
                    domain_mask = torch.Tensor(domain_mask).to(x.device)

                mix_weight = self.inner_gate[target_domain](x)  # [B, S, D_Max]
                mix_weight = torch.masked_fill(mix_weight, domain_mask == 0, -1e9)
                mix_weight = torch.softmax(mix_weight, dim=-1)

                # print(mix_weight)
                used_domain_idx = domain_mask.nonzero().flatten()
                select_mix_weight = mix_weight.index_select(dim=-1,
                                                            index=used_domain_idx)  # [B, S, D_Max] -> [B, S, D_Used]
                # print(mix_weight)

            else:
                select_mix_weight = mix_weight

            # calculate the adapter outputs separately
            adapter_outputs = []
            for domain in used_domain_list:
                if domain == 'news':  # this maybe not used, because we use residual connect
                    adapter_outputs.append(x)
                else:
                    # here not plus x
                    domain_adapter_output = self.sublayer_connection_for_adapter[domain]. \
                        wo_residual_forward(x, self.adapter_layers[domain])
                    adapter_outputs.append(domain_adapter_output)

            # mix the outputs
            adapter_outputs = torch.stack(adapter_outputs, dim=-1)  # [B, L, H, D]
            select_mix_weight = select_mix_weight.unsqueeze(-1)  # [B, L, D, 1]
            mix_adapter_outputs = torch.matmul(adapter_outputs, select_mix_weight).squeeze(-1)

            # residual
            return x + mix_adapter_outputs, mix_weight