Exemplo n.º 1
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        '''
        The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
            -ve: no attention
              0: local attention
            +ve: global attention
        '''
        assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
        assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"

        if attention_mask is not None:
            attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
            key_padding_mask = attention_mask < 0
            extra_attention_mask = attention_mask > 0
            remove_from_windowed_attention_mask = attention_mask != 0

            num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
            max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
            if max_num_extra_indices_per_batch <= 0:
                extra_attention_mask = None
            else:
                # To support the case of variable number of global attention in the rows of a batch,
                # we use the following three selection masks to select global attention embeddings
                # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
                # 1) selecting embeddings that correspond to global attention
                extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
                zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
                                                 device=num_extra_indices_per_batch.device)
                # mask indicating which values are actually going to be padding
                selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
                # 2) location of the non-padding values in the selected global attention
                selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
                # 3) location of the padding values in the selected global attention
                selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
        else:
            remove_from_windowed_attention_mask = None
            extra_attention_mask = None
            key_padding_mask = None

        hidden_states = hidden_states.transpose(0, 1)
        seq_len, bsz, embed_dim = hidden_states.size()
        assert embed_dim == self.embed_dim
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)
        q /= math.sqrt(self.head_dim)

        q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        # attn_weights = (bsz, seq_len, num_heads, window*2+1)
        if self.attention_mode == 'tvm':
            q = q.float().contiguous()
            k = k.float().contiguous()
            attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
        elif self.attention_mode == "sliding_chunks":
            attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
        elif self.attention_mode == "sliding_chunks_no_overlap":
            attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
        else:
            raise False
        mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
        if remove_from_windowed_attention_mask is not None:
            # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
            # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
            remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # cast to float/half then replace 1's with -inf
            float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
            repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
            float_mask = float_mask.repeat(1, 1, repeat_size, 1)
            ones = float_mask.new_ones(size=float_mask.size())  # tensor of ones
            # diagonal mask with zeros everywhere and -inf inplace of padding
            if self.attention_mode == 'tvm':
                d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
            elif self.attention_mode == "sliding_chunks":
                d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
            elif self.attention_mode == "sliding_chunks_no_overlap":
                d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)

            attn_weights += d_mask
        assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
        assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]

        # the extra attention
        if extra_attention_mask is not None:
            selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
            selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
            # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
            selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
            selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
            # concat to attn_weights
            # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
            attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

        attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)  # use fp32 for numerical stability
        if key_padding_mask is not None:
            # softmax sometimes inserts NaN if all positions are masked, replace them with 0
            attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)

        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
        v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
        attn = 0
        if extra_attention_mask is not None:
            selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
            selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
            selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
            # use `matmul` because `einsum` crashes sometimes with fp16
            # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
            attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
            attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()

        if self.attention_mode == 'tvm':
            v = v.float().contiguous()
            attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
        elif self.attention_mode == "sliding_chunks":
            attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
        elif self.attention_mode == "sliding_chunks_no_overlap":
            attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
        else:
            raise False

        attn = attn.type_as(hidden_states)
        assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
        attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()

        # For this case, we'll just recompute the attention for these indices
        # and overwrite the attn tensor. TODO: remove the redundant computation
        if extra_attention_mask is not None:
            selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
            selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]

            q = self.query_global(selected_hidden_states)
            k = self.key_global(hidden_states)
            v = self.value_global(hidden_states)
            q /= math.sqrt(self.head_dim)

            q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1)  # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)  # bsz * self.num_heads, seq_len, head_dim)
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)  # bsz * self.num_heads, seq_len, head_dim)
            attn_weights = torch.bmm(q, k.transpose(1, 2))
            assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]

            attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
            attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
            if key_padding_mask is not None:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    -10000.0,
                )
            attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
            attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)  # use fp32 for numerical stability
            attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
            selected_attn = torch.bmm(attn_probs, v)
            assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]

            selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
            nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
            attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)

        context_layer = attn.transpose(0, 1)
        if output_attentions:
            if extra_attention_mask is not None:
                # With global attention, return global attention probabilities only
                # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
                # which is the attention weights from tokens with global attention to all tokens
                # It doesn't not return local attention
                # In case of variable number of global attantion in the rows of a batch,
                # attn_weights are padded with -10000.0 attention scores
                attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
            else:
                # without global attention, return local attention probabilities
                # batch_size x num_heads x sequence_length x window_size
                # which is the attention weights of every token attending to its neighbours
                attn_weights = attn_weights.permute(0, 2, 1, 3)
        outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
        return outputs
Exemplo n.º 2
0
    def test_tvm_equal_sliding_chunks(self):
        np.random.seed(3)
        random.seed(3)
        torch.manual_seed(3)
        torch.cuda.manual_seed(3)
        torch.cuda.manual_seed_all(3)

        torch.set_printoptions(sci_mode=False)
        N = 4096  # * 16
        M = 64  # hidden size
        W = 256  # one sided. Actual window size = 2w+1
        B = 3
        D = 1  # no dilation
        H = 12  # number of heads
        autoregressive = False  # not autoregressive
        device = 'cuda'
        dtype = torch.float32

        failed_tests = 0
        time1 = time2 = 0
        for i in range(50):
            if i < 5:
                time1 = time2 = 0  # don't include the first few iterations because of high variance

            query = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).view(B, N, H, M)
            key = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).flip(dims=(0,)).view(B, N, H, M)
            value = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).view(B, N, H, M)

            # TVM MM
            torch.cuda.synchronize()
            start = time.time()
            attention1 = diagonaled_mm_tvm(query, key, W, D, False, 0, autoregressive)
            mask_invalid_locations(attention1, W, D, autoregressive)
            attention_probs1 = torch.nn.functional.softmax(attention1, dim=-1)
            context1 = diagonaled_mm_tvm(attention_probs1, value, W, D, True, 0, autoregressive)
            context1.sum().backward()
            torch.cuda.synchronize()
            time1 += time.time() - start
            torch.cuda.empty_cache()

            # query = query.half()  # uncomment to profile the fp16 performance
            # key = key.half()
            # value = value.half()
            assert D == 1
            assert not autoregressive
            torch.cuda.synchronize()
            start = time.time()
            attention2 = sliding_chunks_matmul_qk(query, key, W, float('-inf'))
            attention_probs2 = torch.nn.functional.softmax(attention2, dim=-1)
            context2 = sliding_chunks_matmul_pv(attention_probs2, value, W)
            context2.sum().backward()
            torch.cuda.synchronize()
            time2 += time.time() - start
            torch.cuda.empty_cache()

            try:
                assert torch.allclose(attention1, attention2.float(), atol=1e-4, rtol=1e-5)
                assert torch.allclose(context1, context2.float(), atol=1e-4, rtol=1e-5)
            except AssertionError:
                failed_tests += 1

        print('Time tvm: {0:.5f} s'.format(time1))
        print('Time pytorch sliding chunks: {0:.5f} s'.format(time2))
        print('Sliding chunks vs. TVM speedup: {0:.5f}x'.format(time1/time2))
        print(f'Failed tests: {failed_tests}/{i+1}')
        assert failed_tests == 0
Exemplo n.º 3
0
    def forward(self, input_ids, attention_mask=None, head_mask=None):
        mixed_query_layer = self.query(input_ids)
        mixed_key_layer = self.key(input_ids)
        mixed_value_layer = self.value(input_ids)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        if self.attention_band is not None:
            query_layer = query_layer.permute(0, 2, 1, 3)
            key_layer = key_layer.permute(0, 2, 1, 3)
            value_layer = value_layer.permute(0, 2, 1, 3)

            attn_band = self.attention_band
            if attention_mask is not None:
                attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
                remove_from_windowed_attention_mask = (attention_mask != 0)
            query_layer /= math.sqrt(self.attention_head_size)
            query_layer = query_layer.float().contiguous()
            key_layer = key_layer.float().contiguous()
            if False:
                attention_scores = diagonaled_mm_tvm(
                    query_layer,
                    key_layer,
                    attn_band,
                    1,
                    False,
                    0,
                    False  # dilation, is_t1_diag, padding, autoregressive
                )
            else:
                attention_scores = sliding_chunks_matmul_qk(query_layer,
                                                            key_layer,
                                                            attn_band,
                                                            padding_value=0)
            mask_invalid_locations(attention_scores, attn_band, 1, False)
            if attention_mask is not None:
                remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(
                    dim=-1).unsqueeze(dim=-1)
                float_mask = remove_from_windowed_attention_mask.type_as(
                    query_layer).masked_fill(
                        remove_from_windowed_attention_mask, -10000.0)
                float_mask = float_mask.repeat(1, 1, 1,
                                               1)  # don't think I need this
                ones = float_mask.new_ones(size=float_mask.size())
                if False:
                    d_mask = diagonaled_mm_tvm(ones, float_mask, attn_band, 1,
                                               False, 0, False)
                else:
                    d_mask = sliding_chunks_matmul_qk(ones,
                                                      float_mask,
                                                      attn_band,
                                                      padding_value=0)
                attention_scores += d_mask

            attention_probs = F.softmax(attention_scores,
                                        dim=-1,
                                        dtype=torch.float32)
            attention_probs = self.dropout(attention_probs)

            value_layer = value_layer.float().contiguous()
            if False:
                context_layer = diagonaled_mm_tvm(attention_probs, value_layer,
                                                  attn_band, 1, True, 0, False)
            else:
                context_layer = sliding_chunks_matmul_pv(
                    attention_probs, value_layer, attn_band)

        else:
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer,
                                            key_layer.transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(
                self.attention_head_size)
            if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask

            # Normalize the attention scores to probabilities.
            attention_probs = nn.Softmax(dim=-1)(attention_scores)
            if VERBOSE:
                # print(attention_probs[0, :8, :8])
                print(torch.max(attention_probs), torch.min(attention_probs))

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.dropout(attention_probs)

            # Mask heads if we want to
            if head_mask is not None:
                attention_probs = attention_probs * head_mask

            context_layer = torch.matmul(attention_probs, value_layer)

            context_layer = context_layer.permute(0, 2, 1, 3)

        context_layer = context_layer.contiguous()

        # Should find a better way to do this
        w = (self.dense.weight.t().view(
            self.num_attention_heads, self.attention_head_size,
            self.hidden_size).to(context_layer.dtype))
        b = self.dense.bias.to(context_layer.dtype)

        projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer,
                                               w) + b
        projected_context_layer_dropout = self.dropout(projected_context_layer)
        layernormed_context_layer = self.LayerNorm(
            input_ids + projected_context_layer_dropout)
        return (layernormed_context_layer,
                attention_probs) if self.output_attentions else (
                    layernormed_context_layer, )
Exemplo n.º 4
0
    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        '''
        The `attention_mask` is changed in BertModel.forward from 0, 1, 2 to
            -ve: no attention
              0: local attention
            +ve: global attention
        '''
        if attention_mask is not None:
            attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
            key_padding_mask = attention_mask < 0
            extra_attention_mask = attention_mask > 0
            remove_from_windowed_attention_mask = attention_mask != 0

            num_extra_indices_per_batch = extra_attention_mask.long().sum(
                dim=1)
            max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
            has_same_length_extra_indices = (
                num_extra_indices_per_batch == max_num_extra_indices_per_batch
            ).all()
        hidden_states = hidden_states.transpose(0, 1)
        seq_len, bsz, embed_dim = hidden_states.size()
        assert embed_dim == self.embed_dim
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)
        q /= math.sqrt(self.head_dim)

        q = q.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        k = k.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        # attn_weights = (bsz, seq_len, num_heads, window*2+1)
        attn_weights = diagonaled_mm_tvm(q, k, self.attention_window,
                                         self.attention_dilation, False, 0,
                                         False)
        mask_invalid_locations(attn_weights, self.attention_window,
                               self.attention_dilation, False)
        if remove_from_windowed_attention_mask is not None:
            # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
            # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
            remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(
                dim=-1).unsqueeze(dim=-1)
            # cast to float/half then replace 1's with -inf
            float_mask = remove_from_windowed_attention_mask.type_as(
                q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
            repeat_size = 1 if isinstance(
                self.attention_dilation, int) else len(self.attention_dilation)
            float_mask = float_mask.repeat(1, 1, repeat_size, 1)
            ones = float_mask.new_ones(
                size=float_mask.size())  # tensor of ones
            # diagonal mask with zeros everywhere and -inf inplace of padding
            d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window,
                                       self.attention_dilation, False, 0,
                                       False)
            attn_weights += d_mask
        assert list(attn_weights.size()) == [
            bsz, seq_len, self.num_heads, self.attention_window * 2 + 1
        ]

        # the extra attention
        if extra_attention_mask is not None:
            if has_same_length_extra_indices:
                # a simplier implementation for efficiency
                # k = (bsz, seq_len, num_heads, head_dim)
                selected_k = k.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view(
                        bsz, max_num_extra_indices_per_batch, self.num_heads,
                        self.head_dim)
                # selected_k = (bsz, extra_attention_count, num_heads, head_dim)
                # selected_attn_weights = (bsz, seq_len, num_heads, extra_attention_count)
                selected_attn_weights = torch.einsum('blhd,bshd->blhs',
                                                     (q, selected_k))
            else:
                # since the number of extra attention indices varies across
                # the batch, we need to process each element of the batch
                # individually
                flat_selected_k = k.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1))
                selected_attn_weights = torch.ones(
                    bsz,
                    seq_len,
                    self.num_heads,
                    max_num_extra_indices_per_batch,
                    device=k.device,
                    dtype=k.dtype)
                selected_attn_weights.fill_(-10000.0)
                start = 0
                for i in range(bsz):
                    end = start + num_extra_indices_per_batch[
                        i] * self.num_heads * self.head_dim
                    # the selected entries for this batch element
                    i_selected_k = flat_selected_k[start:end].view(
                        -1, self.num_heads, self.head_dim)
                    # (seq_len, num_heads, num extra indices)
                    i_selected_attn_weights = torch.einsum(
                        'lhd,shd->lhs', (q[i, :, :, :], i_selected_k))
                    selected_attn_weights[
                        i, :, :, :num_extra_indices_per_batch[
                            i]] = i_selected_attn_weights
                    start = end

            # concat to attn_weights
            # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
            attn_weights = torch.cat((selected_attn_weights, attn_weights),
                                     dim=-1)

        attn_weights_float = F.softmax(attn_weights, dim=-1)
        if key_padding_mask is not None:
            # softmax sometimes inserts NaN if all positions are masked, replace them with 0
            attn_weights_float = torch.masked_fill(
                attn_weights_float,
                key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)

        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights),
                               p=self.dropout,
                               training=self.training)
        v = v.view(seq_len, bsz, self.num_heads,
                   self.head_dim).transpose(0, 1).contiguous().float()
        attn = 0
        if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
            selected_attn_probs = attn_probs.narrow(
                -1, 0, max_num_extra_indices_per_batch)
            if has_same_length_extra_indices:
                selected_v = v.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view(
                        bsz, max_num_extra_indices_per_batch, self.num_heads,
                        self.head_dim)
            else:
                flat_selected_v = v.masked_select(
                    extra_attention_mask.unsqueeze(-1).unsqueeze(-1))
                # don't worry about masking since this is multiplied by attn_probs, and masking above
                # before softmax will remove masked entries
                selected_v = torch.zeros(bsz,
                                         max_num_extra_indices_per_batch,
                                         self.num_heads,
                                         self.head_dim,
                                         device=v.device,
                                         dtype=v.dtype)
                start = 0
                for i in range(bsz):
                    end = start + num_extra_indices_per_batch[
                        i] * self.num_heads * self.head_dim
                    i_selected_v = flat_selected_v[start:end].view(
                        -1, self.num_heads, self.head_dim)
                    selected_v[
                        i, :
                        num_extra_indices_per_batch[i], :, :] = i_selected_v
                    start = end
            attn = torch.einsum('blhs,bshd->blhd',
                                (selected_attn_probs, selected_v))
            attn_probs = attn_probs.narrow(
                -1, max_num_extra_indices_per_batch,
                attn_probs.size(-1) -
                max_num_extra_indices_per_batch).contiguous()

        attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window,
                                  self.attention_dilation, True, 0, False)
        attn = attn.type_as(hidden_states)
        assert list(
            attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
        attn = attn.transpose(0, 1).reshape(seq_len, bsz,
                                            embed_dim).contiguous()

        # For this case, we'll just recompute the attention for these indices
        # and overwrite the attn tensor. TODO: remove the redundant computation
        if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
            if has_same_length_extra_indices:
                # query = (seq_len, bsz, dim)
                # extra_attention_mask = (bsz, seq_len)
                # selected_query = (max_num_extra_indices_per_batch, bsz, embed_dim)
                selected_hidden_states = hidden_states.masked_select(
                    extra_attention_mask.transpose(0, 1).unsqueeze(-1)).view(
                        max_num_extra_indices_per_batch, bsz, embed_dim)
                # if *_proj_full exists use them, otherwise default to *_proj
                q = self.query_global(selected_hidden_states)
                k = self.key_global(hidden_states)
                v = self.value_global(hidden_states)
                q /= math.sqrt(self.head_dim)

                q = q.contiguous().view(
                    max_num_extra_indices_per_batch, bsz * self.num_heads,
                    self.head_dim
                ).transpose(
                    0, 1
                )  # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
                k = k.contiguous().view(
                    -1, bsz * self.num_heads, self.head_dim).transpose(
                        0, 1)  # bsz * self.num_heads, seq_len, head_dim)
                v = v.contiguous().view(
                    -1, bsz * self.num_heads, self.head_dim).transpose(
                        0, 1)  # bsz * self.num_heads, seq_len, head_dim)
                attn_weights = torch.bmm(q, k.transpose(1, 2))
                assert list(attn_weights.size()) == [
                    bsz * self.num_heads, max_num_extra_indices_per_batch,
                    seq_len
                ]
                if key_padding_mask is not None:
                    attn_weights = attn_weights.view(
                        bsz, self.num_heads, max_num_extra_indices_per_batch,
                        seq_len)
                    attn_weights = attn_weights.masked_fill(
                        key_padding_mask.unsqueeze(1).unsqueeze(2),
                        float('-inf'),
                    )
                    attn_weights = attn_weights.view(
                        bsz * self.num_heads, max_num_extra_indices_per_batch,
                        seq_len)
                attn_weights_float = F.softmax(attn_weights, dim=-1)
                attn_probs = F.dropout(
                    attn_weights_float.type_as(attn_weights),
                    p=self.dropout,
                    training=self.training)
                selected_attn = torch.bmm(attn_probs, v)
                assert list(selected_attn.size()) == [
                    bsz * self.num_heads, max_num_extra_indices_per_batch,
                    self.head_dim
                ]
                selected_attn = selected_attn.transpose(
                    0, 1).contiguous().view(max_num_extra_indices_per_batch *
                                            bsz * embed_dim)

                # now update attn by filling in the relevant indices with selected_attn
                # masked_fill_ only allows floats as values so this doesn't work
                # attn.masked_fill_(extra_attention_mask.transpose(0, 1).unsqueeze(-1), selected_attn)
                attn[extra_attention_mask.transpose(0, 1).unsqueeze(-1).repeat(
                    (1, 1, embed_dim))] = selected_attn
            else:
                raise ValueError  # not implemented

        context_layer = attn.transpose(0, 1)
        if self.output_attentions:
            if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0:
                # With global attention, return global attention probabilities only
                # batch_size x num_heads x num_global_attention_tokens x sequence_length
                # which is the attention weights from tokens with global attention to all tokens
                # It doesn't not return local attention
                attn_weights = attn_weights.view(
                    bsz, self.num_heads, max_num_extra_indices_per_batch,
                    seq_len)
            else:
                # without global attention, return local attention probabilities
                # batch_size x num_heads x sequence_length x window_size
                # which is the attention weights of every token attending to its neighbours
                attn_weights = attn_weights.permute(0, 2, 1, 3)
        outputs = (context_layer,
                   attn_weights) if self.output_attentions else (
                       context_layer, )
        return outputs