def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, ): ''' The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention 0: local attention +ve: global attention ''' assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 extra_attention_mask = attention_mask > 0 remove_from_windowed_attention_mask = attention_mask != 0 num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() if max_num_extra_indices_per_batch <= 0: extra_attention_mask = None else: # To support the case of variable number of global attention in the rows of a batch, # we use the following three selection masks to select global attention embeddings # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` # 1) selecting embeddings that correspond to global attention extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device) # mask indicating which values are actually going to be padding selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) # 2) location of the non-padding values in the selected global attention selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) # 3) location of the padding values in the selected global attention selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) else: remove_from_windowed_attention_mask = None extra_attention_mask = None key_padding_mask = None hidden_states = hidden_states.transpose(0, 1) seq_len, bsz, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q /= math.sqrt(self.head_dim) q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights = (bsz, seq_len, num_heads, window*2+1) if self.attention_mode == 'tvm': q = q.float().contiguous() k = k.float().contiguous() attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) elif self.attention_mode == "sliding_chunks": attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) elif self.attention_mode == "sliding_chunks_no_overlap": attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) else: raise False mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) # cast to float/half then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0) repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding if self.attention_mode == 'tvm': d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) elif self.attention_mode == "sliding_chunks": d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) elif self.attention_mode == "sliding_chunks_no_overlap": d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) attn_weights += d_mask assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] # the extra attention if extra_attention_mask is not None: selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) attn = 0 if extra_attention_mask is not None: selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() if self.attention_mode == 'tvm': v = v.float().contiguous() attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) elif self.attention_mode == "sliding_chunks": attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) elif self.attention_mode == "sliding_chunks_no_overlap": attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window) else: raise False attn = attn.type_as(hidden_states) assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None: selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]] q = self.query_global(selected_hidden_states) k = self.key_global(hidden_states) v = self.value_global(hidden_states) q /= math.sqrt(self.head_dim) q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len] attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 if key_padding_mask is not None: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0, ) attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.bmm(attn_probs, v) assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim] selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) context_layer = attn.transpose(0, 1) if output_attentions: if extra_attention_mask is not None: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attn_weights are padded with -10000.0 attention scores attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours attn_weights = attn_weights.permute(0, 2, 1, 3) outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) return outputs
def test_tvm_equal_sliding_chunks(self): np.random.seed(3) random.seed(3) torch.manual_seed(3) torch.cuda.manual_seed(3) torch.cuda.manual_seed_all(3) torch.set_printoptions(sci_mode=False) N = 4096 # * 16 M = 64 # hidden size W = 256 # one sided. Actual window size = 2w+1 B = 3 D = 1 # no dilation H = 12 # number of heads autoregressive = False # not autoregressive device = 'cuda' dtype = torch.float32 failed_tests = 0 time1 = time2 = 0 for i in range(50): if i < 5: time1 = time2 = 0 # don't include the first few iterations because of high variance query = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).view(B, N, H, M) key = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).flip(dims=(0,)).view(B, N, H, M) value = torch.randn(B * N * H * M, requires_grad=True, device=device, dtype=dtype).view(B, N, H, M) # TVM MM torch.cuda.synchronize() start = time.time() attention1 = diagonaled_mm_tvm(query, key, W, D, False, 0, autoregressive) mask_invalid_locations(attention1, W, D, autoregressive) attention_probs1 = torch.nn.functional.softmax(attention1, dim=-1) context1 = diagonaled_mm_tvm(attention_probs1, value, W, D, True, 0, autoregressive) context1.sum().backward() torch.cuda.synchronize() time1 += time.time() - start torch.cuda.empty_cache() # query = query.half() # uncomment to profile the fp16 performance # key = key.half() # value = value.half() assert D == 1 assert not autoregressive torch.cuda.synchronize() start = time.time() attention2 = sliding_chunks_matmul_qk(query, key, W, float('-inf')) attention_probs2 = torch.nn.functional.softmax(attention2, dim=-1) context2 = sliding_chunks_matmul_pv(attention_probs2, value, W) context2.sum().backward() torch.cuda.synchronize() time2 += time.time() - start torch.cuda.empty_cache() try: assert torch.allclose(attention1, attention2.float(), atol=1e-4, rtol=1e-5) assert torch.allclose(context1, context2.float(), atol=1e-4, rtol=1e-5) except AssertionError: failed_tests += 1 print('Time tvm: {0:.5f} s'.format(time1)) print('Time pytorch sliding chunks: {0:.5f} s'.format(time2)) print('Sliding chunks vs. TVM speedup: {0:.5f}x'.format(time1/time2)) print(f'Failed tests: {failed_tests}/{i+1}') assert failed_tests == 0
def forward(self, input_ids, attention_mask=None, head_mask=None): mixed_query_layer = self.query(input_ids) mixed_key_layer = self.key(input_ids) mixed_value_layer = self.value(input_ids) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) if self.attention_band is not None: query_layer = query_layer.permute(0, 2, 1, 3) key_layer = key_layer.permute(0, 2, 1, 3) value_layer = value_layer.permute(0, 2, 1, 3) attn_band = self.attention_band if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) remove_from_windowed_attention_mask = (attention_mask != 0) query_layer /= math.sqrt(self.attention_head_size) query_layer = query_layer.float().contiguous() key_layer = key_layer.float().contiguous() if False: attention_scores = diagonaled_mm_tvm( query_layer, key_layer, attn_band, 1, False, 0, False # dilation, is_t1_diag, padding, autoregressive ) else: attention_scores = sliding_chunks_matmul_qk(query_layer, key_layer, attn_band, padding_value=0) mask_invalid_locations(attention_scores, attn_band, 1, False) if attention_mask is not None: remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( dim=-1).unsqueeze(dim=-1) float_mask = remove_from_windowed_attention_mask.type_as( query_layer).masked_fill( remove_from_windowed_attention_mask, -10000.0) float_mask = float_mask.repeat(1, 1, 1, 1) # don't think I need this ones = float_mask.new_ones(size=float_mask.size()) if False: d_mask = diagonaled_mm_tvm(ones, float_mask, attn_band, 1, False, 0, False) else: d_mask = sliding_chunks_matmul_qk(ones, float_mask, attn_band, padding_value=0) attention_scores += d_mask attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32) attention_probs = self.dropout(attention_probs) value_layer = value_layer.float().contiguous() if False: context_layer = diagonaled_mm_tvm(attention_probs, value_layer, attn_band, 1, True, 0, False) else: context_layer = sliding_chunks_matmul_pv( attention_probs, value_layer, attn_band) else: # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt( self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if VERBOSE: # print(attention_probs[0, :8, :8]) print(torch.max(attention_probs), torch.min(attention_probs)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3) context_layer = context_layer.contiguous() # Should find a better way to do this w = (self.dense.weight.t().view( self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)) b = self.dense.bias.to(context_layer.dtype) projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer_dropout = self.dropout(projected_context_layer) layernormed_context_layer = self.LayerNorm( input_ids + projected_context_layer_dropout) return (layernormed_context_layer, attention_probs) if self.output_attentions else ( layernormed_context_layer, )
def forward(self, hidden_states, attention_mask=None, head_mask=None): ''' The `attention_mask` is changed in BertModel.forward from 0, 1, 2 to -ve: no attention 0: local attention +ve: global attention ''' if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 extra_attention_mask = attention_mask > 0 remove_from_windowed_attention_mask = attention_mask != 0 num_extra_indices_per_batch = extra_attention_mask.long().sum( dim=1) max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() has_same_length_extra_indices = ( num_extra_indices_per_batch == max_num_extra_indices_per_batch ).all() hidden_states = hidden_states.transpose(0, 1) seq_len, bsz, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q /= math.sqrt(self.head_dim) q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() # attn_weights = (bsz, seq_len, num_heads, window*2+1) attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( dim=-1).unsqueeze(dim=-1) # cast to float/half then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as( q).masked_fill(remove_from_windowed_attention_mask, -10000.0) repeat_size = 1 if isinstance( self.attention_dilation, int) else len(self.attention_dilation) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones( size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) attn_weights += d_mask assert list(attn_weights.size()) == [ bsz, seq_len, self.num_heads, self.attention_window * 2 + 1 ] # the extra attention if extra_attention_mask is not None: if has_same_length_extra_indices: # a simplier implementation for efficiency # k = (bsz, seq_len, num_heads, head_dim) selected_k = k.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view( bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) # selected_k = (bsz, extra_attention_count, num_heads, head_dim) # selected_attn_weights = (bsz, seq_len, num_heads, extra_attention_count) selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) else: # since the number of extra attention indices varies across # the batch, we need to process each element of the batch # individually flat_selected_k = k.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)) selected_attn_weights = torch.ones( bsz, seq_len, self.num_heads, max_num_extra_indices_per_batch, device=k.device, dtype=k.dtype) selected_attn_weights.fill_(-10000.0) start = 0 for i in range(bsz): end = start + num_extra_indices_per_batch[ i] * self.num_heads * self.head_dim # the selected entries for this batch element i_selected_k = flat_selected_k[start:end].view( -1, self.num_heads, self.head_dim) # (seq_len, num_heads, num extra indices) i_selected_attn_weights = torch.einsum( 'lhd,shd->lhs', (q[i, :, :, :], i_selected_k)) selected_attn_weights[ i, :, :, :num_extra_indices_per_batch[ i]] = i_selected_attn_weights start = end # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1) if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill( attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() attn = 0 if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: selected_attn_probs = attn_probs.narrow( -1, 0, max_num_extra_indices_per_batch) if has_same_length_extra_indices: selected_v = v.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view( bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) else: flat_selected_v = v.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)) # don't worry about masking since this is multiplied by attn_probs, and masking above # before softmax will remove masked entries selected_v = torch.zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim, device=v.device, dtype=v.dtype) start = 0 for i in range(bsz): end = start + num_extra_indices_per_batch[ i] * self.num_heads * self.head_dim i_selected_v = flat_selected_v[start:end].view( -1, self.num_heads, self.head_dim) selected_v[ i, : num_extra_indices_per_batch[i], :, :] = i_selected_v start = end attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn_probs = attn_probs.narrow( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) attn = attn.type_as(hidden_states) assert list( attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: if has_same_length_extra_indices: # query = (seq_len, bsz, dim) # extra_attention_mask = (bsz, seq_len) # selected_query = (max_num_extra_indices_per_batch, bsz, embed_dim) selected_hidden_states = hidden_states.masked_select( extra_attention_mask.transpose(0, 1).unsqueeze(-1)).view( max_num_extra_indices_per_batch, bsz, embed_dim) # if *_proj_full exists use them, otherwise default to *_proj q = self.query_global(selected_hidden_states) k = self.key_global(hidden_states) v = self.value_global(hidden_states) q /= math.sqrt(self.head_dim) q = q.contiguous().view( max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim ).transpose( 0, 1 ) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) k = k.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) v = v.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len ] if key_padding_mask is not None: attn_weights = attn_weights.view( bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view( bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax(attn_weights, dim=-1) attn_probs = F.dropout( attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.bmm(attn_probs, v) assert list(selected_attn.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim ] selected_attn = selected_attn.transpose( 0, 1).contiguous().view(max_num_extra_indices_per_batch * bsz * embed_dim) # now update attn by filling in the relevant indices with selected_attn # masked_fill_ only allows floats as values so this doesn't work # attn.masked_fill_(extra_attention_mask.transpose(0, 1).unsqueeze(-1), selected_attn) attn[extra_attention_mask.transpose(0, 1).unsqueeze(-1).repeat( (1, 1, embed_dim))] = selected_attn else: raise ValueError # not implemented context_layer = attn.transpose(0, 1) if self.output_attentions: if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: # With global attention, return global attention probabilities only # batch_size x num_heads x num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention attn_weights = attn_weights.view( bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours attn_weights = attn_weights.permute(0, 2, 1, 3) outputs = (context_layer, attn_weights) if self.output_attentions else ( context_layer, ) return outputs