def replace_inf_with_zero(x): return th.masked_fill(x, th.isinf(x), 0)
def forward(self, hidden_states, attention_mask=None, head_mask=None): ''' The `attention_mask` is changed in BertModel.forward from 0, 1, 2 to -ve: no attention 0: local attention +ve: global attention ''' if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 extra_attention_mask = attention_mask > 0 remove_from_windowed_attention_mask = attention_mask != 0 num_extra_indices_per_batch = extra_attention_mask.long().sum( dim=1) max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() has_same_length_extra_indices = ( num_extra_indices_per_batch == max_num_extra_indices_per_batch ).all() hidden_states = hidden_states.transpose(0, 1) seq_len, bsz, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q /= math.sqrt(self.head_dim) q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() # attn_weights = (bsz, seq_len, num_heads, window*2+1) attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( dim=-1).unsqueeze(dim=-1) # cast to float/half then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as( q).masked_fill(remove_from_windowed_attention_mask, -10000.0) repeat_size = 1 if isinstance( self.attention_dilation, int) else len(self.attention_dilation) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones( size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) attn_weights += d_mask assert list(attn_weights.size()) == [ bsz, seq_len, self.num_heads, self.attention_window * 2 + 1 ] # the extra attention if extra_attention_mask is not None: if has_same_length_extra_indices: # a simplier implementation for efficiency # k = (bsz, seq_len, num_heads, head_dim) selected_k = k.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view( bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) # selected_k = (bsz, extra_attention_count, num_heads, head_dim) # selected_attn_weights = (bsz, seq_len, num_heads, extra_attention_count) selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) else: # since the number of extra attention indices varies across # the batch, we need to process each element of the batch # individually flat_selected_k = k.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)) selected_attn_weights = torch.ones( bsz, seq_len, self.num_heads, max_num_extra_indices_per_batch, device=k.device, dtype=k.dtype) selected_attn_weights.fill_(-10000.0) start = 0 for i in range(bsz): end = start + num_extra_indices_per_batch[ i] * self.num_heads * self.head_dim # the selected entries for this batch element i_selected_k = flat_selected_k[start:end].view( -1, self.num_heads, self.head_dim) # (seq_len, num_heads, num extra indices) i_selected_attn_weights = torch.einsum( 'lhd,shd->lhs', (q[i, :, :, :], i_selected_k)) selected_attn_weights[ i, :, :, :num_extra_indices_per_batch[ i]] = i_selected_attn_weights start = end # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1) if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill( attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1).contiguous().float() attn = 0 if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: selected_attn_probs = attn_probs.narrow( -1, 0, max_num_extra_indices_per_batch) if has_same_length_extra_indices: selected_v = v.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)).view( bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) else: flat_selected_v = v.masked_select( extra_attention_mask.unsqueeze(-1).unsqueeze(-1)) # don't worry about masking since this is multiplied by attn_probs, and masking above # before softmax will remove masked entries selected_v = torch.zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim, device=v.device, dtype=v.dtype) start = 0 for i in range(bsz): end = start + num_extra_indices_per_batch[ i] * self.num_heads * self.head_dim i_selected_v = flat_selected_v[start:end].view( -1, self.num_heads, self.head_dim) selected_v[ i, : num_extra_indices_per_batch[i], :, :] = i_selected_v start = end attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn_probs = attn_probs.narrow( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) attn = attn.type_as(hidden_states) assert list( attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: if has_same_length_extra_indices: # query = (seq_len, bsz, dim) # extra_attention_mask = (bsz, seq_len) # selected_query = (max_num_extra_indices_per_batch, bsz, embed_dim) selected_hidden_states = hidden_states.masked_select( extra_attention_mask.transpose(0, 1).unsqueeze(-1)).view( max_num_extra_indices_per_batch, bsz, embed_dim) # if *_proj_full exists use them, otherwise default to *_proj q = self.query_global(selected_hidden_states) k = self.key_global(hidden_states) v = self.value_global(hidden_states) q /= math.sqrt(self.head_dim) q = q.contiguous().view( max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim ).transpose( 0, 1 ) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) k = k.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) v = v.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len ] if key_padding_mask is not None: attn_weights = attn_weights.view( bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view( bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax(attn_weights, dim=-1) attn_probs = F.dropout( attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.bmm(attn_probs, v) assert list(selected_attn.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim ] selected_attn = selected_attn.transpose( 0, 1).contiguous().view(max_num_extra_indices_per_batch * bsz * embed_dim) # now update attn by filling in the relevant indices with selected_attn # masked_fill_ only allows floats as values so this doesn't work # attn.masked_fill_(extra_attention_mask.transpose(0, 1).unsqueeze(-1), selected_attn) attn[extra_attention_mask.transpose(0, 1).unsqueeze(-1).repeat( (1, 1, embed_dim))] = selected_attn else: raise ValueError # not implemented context_layer = attn.transpose(0, 1) if self.output_attentions: if extra_attention_mask is not None and max_num_extra_indices_per_batch > 0: # With global attention, return global attention probabilities only # batch_size x num_heads x num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention attn_weights = attn_weights.view( bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours attn_weights = attn_weights.permute(0, 2, 1, 3) outputs = (context_layer, attn_weights) if self.output_attentions else ( context_layer, ) return outputs
def forward(self, hidden_states, attention_mask=None, head_mask=None): ''' The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention 0: local attention +ve: global attention ''' if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 extra_attention_mask = attention_mask > 0 remove_from_windowed_attention_mask = attention_mask != 0 num_extra_indices_per_batch = extra_attention_mask.long().sum( dim=1) max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() if max_num_extra_indices_per_batch <= 0: extra_attention_mask = None else: # To support the case of variable number of global attention in the rows of a batch, # we use the following three selection masks to select global attention embeddings # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` # 1) selecting embeddings that correspond to global attention extra_attention_mask_nonzeros = extra_attention_mask.nonzero( as_tuple=True) zero_to_max_range = torch.arange( 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device) # mask indicating which values are actually going to be padding selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze( dim=-1) # 2) location of the non-padding values in the selected global attention selection_padding_mask_nonzeros = selection_padding_mask.nonzero( as_tuple=True) # 3) location of the padding values in the selected global attention selection_padding_mask_zeros = ( selection_padding_mask == 0).nonzero(as_tuple=True) else: remove_from_windowed_attention_mask = None extra_attention_mask = None key_padding_mask = None hidden_states = hidden_states.transpose(0, 1) seq_len, bsz, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q /= math.sqrt(self.head_dim) q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights = (bsz, seq_len, num_heads, window*2+1) if self.attention_mode == 'tvm': q = q.float().contiguous() k = k.float().contiguous() attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) else: # "sliding_chunks" attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( dim=-1).unsqueeze(dim=-1) # cast to float/half then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as( q).masked_fill(remove_from_windowed_attention_mask, -10000.0) repeat_size = 1 if isinstance( self.attention_dilation, int) else len(self.attention_dilation) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones( size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding if self.attention_mode == 'tvm': d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) else: d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) attn_weights += d_mask assert list(attn_weights.size()) == [ bsz, seq_len, self.num_heads, self.attention_window * 2 + 1 ] # the extra attention if extra_attention_mask is not None: selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_k[selection_padding_mask_nonzeros] = k[ extra_attention_mask_nonzeros] # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax( attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill( attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) attn = 0 if extra_attention_mask is not None: selected_attn_probs = attn_probs.narrow( -1, 0, max_num_extra_indices_per_batch) selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_v[selection_padding_mask_nonzeros] = v[ extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn = torch.matmul( selected_attn_probs.transpose(1, 2), selected_v.transpose( 1, 2).type_as(selected_attn_probs)).transpose(1, 2) attn_probs = attn_probs.narrow( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() if self.attention_mode == 'tvm': v = v.float().contiguous() attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) else: # "sliding_chunks" attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) attn = attn.type_as(hidden_states) assert list( attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None: selected_hidden_states = hidden_states.new_zeros( max_num_extra_indices_per_batch, bsz, embed_dim) selected_hidden_states[ selection_padding_mask_nonzeros[::-1]] = hidden_states[ extra_attention_mask_nonzeros[::-1]] q = self.query_global(selected_hidden_states) k = self.key_global(hidden_states) v = self.value_global(hidden_states) q /= math.sqrt(self.head_dim) q = q.contiguous().view( max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim ).transpose( 0, 1 ) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) k = k.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) v = v.contiguous().view( -1, bsz * self.num_heads, self.head_dim).transpose( 0, 1) # bsz * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len ] attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 if key_padding_mask is not None: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0, ) attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax( attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.bmm(attn_probs, v) assert list(selected_attn.size()) == [ bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim ] selected_attn_4d = selected_attn.view( bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) nonzero_selected_attn = selected_attn_4d[ selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] attn[ extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) context_layer = attn.transpose(0, 1) if self.output_attentions: if extra_attention_mask is not None: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length # which is the attention weights from tokens with global attention to all tokens # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attn_weights are padded with -10000.0 attention scores attn_weights = attn_weights.view( bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours attn_weights = attn_weights.permute(0, 2, 1, 3) outputs = (context_layer, attn_weights) if self.output_attentions else ( context_layer, ) return outputs
def forward( self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, output_attentions=False, ): """ :class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_band`. Padding to `attention_band` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer. The `attention_mask` is changed in :meth:`LongformerModel.forward` from 0, 1, 2 to: * -10000: no attention * 0: local attention * +10000: global attention """ hidden_states = hidden_states.transpose(0, 1) # project hidden states query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) seq_len, batch_size, embed_dim = hidden_states.size() assert ( embed_dim == self.embed_dim ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" # normalize query query_vectors /= math.sqrt(self.head_dim) query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) attn_scores = self._sliding_chunks_query_key_matmul( query_vectors, key_vectors, self.one_sided_attn_window_size) # values to pad for attention probs remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as( query_vectors).masked_fill(remove_from_windowed_attention_mask, -10000.0) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size) # pad local attention probs attn_scores += diagonal_mask assert list(attn_scores.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1, ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" # compute local attention probs from global attention keys and contact over window dim if is_global_attn: # compute global attn indices required through out forward fn ( max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, ) = self._get_global_attn_indices(is_index_global_attn) # calculate global attn probs from global key global_key_attn_scores = self._concat_with_global_key_attn_probs( query_vectors=query_vectors, key_vectors=key_vectors, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero= is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero= is_local_index_no_global_attn_nonzero, ) # concat to local_attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) # free memory del global_key_attn_scores attn_probs = F.softmax( attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = attn_probs.type_as(attn_scores) # free memory del attn_scores # apply dropout attn_probs = F.dropout(attn_probs, p=self.config.attention_probs_dropout_prob, training=self.training) value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # compute local attention output with global attention value and add if is_global_attn: # compute sum of global and local attn attn_output = self._compute_attn_output_with_global_indices( value_vectors=value_vectors, attn_probs=attn_probs, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero= is_local_index_global_attn_nonzero, ) else: # compute local attn only attn_output = self._sliding_chunks_matmul_attn_probs_value( attn_probs, value_vectors, self.one_sided_attn_window_size) assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation if is_global_attn: global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, is_local_index_global_attn_nonzero= is_local_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero= is_local_index_no_global_attn_nonzero, is_index_masked=is_index_masked, ) # get only non zero global attn output nonzero_global_attn_output = global_attn_output[ is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]] # overwrite values with global attention attn_output[ is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( len(is_local_index_global_attn_nonzero[0]), -1) # The attention weights for tokens with global attention are # just filler values, they were never used to compute the output. # Fill with 0 now, the correct values are in 'global_attn_probs'. attn_probs[is_index_global_attn_nonzero] = 0 outputs = (attn_output.transpose(0, 1), ) if output_attentions: outputs += (attn_probs, ) return outputs + (global_attn_probs, ) if ( is_global_attn and output_attentions) else outputs
def forward( self, query: torch.Tensor, keys: List[torch.Tensor], ref_point: torch.Tensor, query_mask: torch.Tensor = None, key_masks: Optional[torch.Tensor] = None, ): """ :param key_masks: :param query_mask: :param query: B, H, W, C :param keys: List[B, H, W, C] :param ref_point: B, H, W, 2 :return: """ if key_masks is None: key_masks = [None] * len(keys) assert len(keys) == self.scales attns = {'attns': None, 'offsets': None} nbatches, query_height, query_width, _ = query.shape # B, H, W, C query = self.q_proj(query) # B, H, W, 2MLK offset = self.offset_proj(query) # B, H, W, M, 2LK offset = offset.view(nbatches, query_height, query_width, self.h, -1) # B, H, W, MLK A = self.A_proj(query) # B, H, W, 1, mask before softmax if query_mask is not None: query_mask_ = query_mask.unsqueeze(dim=-1) _, _, _, mlk = A.shape query_mask_ = query_mask_.expand(nbatches, query_height, query_width, mlk) A = torch.masked_fill(A, mask=query_mask_, value=float('-inf')) # B, H, W, M, LK A = A.view(nbatches, query_height, query_width, self.h, -1) A = F.softmax(A, dim=-1) # mask nan position if query_mask is not None: # B, H, W, 1, 1 query_mask_ = query_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) A = torch.masked_fill(A, query_mask_.expand_as(A), 0.0) if self.need_attn: attns['attns'] = A attns['offsets'] = offset offset = offset.view(nbatches, query_height, query_width, self.h, self.scales, self.k, 2) offset = offset.permute(0, 3, 4, 5, 1, 2, 6).contiguous() # B*M, L, K, H, W, 2 offset = offset.view(nbatches * self.h, self.scales, self.k, query_height, query_width, 2) A = A.permute(0, 3, 1, 2, 4).contiguous() # B*M, H*W, LK A = A.view(nbatches * self.h, query_height * query_width, -1) scale_features = [] for l in range(self.scales): feat_map = keys[l] _, h, w, _ = feat_map.shape key_mask = key_masks[l] # B, H, W, 2 reversed_ref_point = restore_scale(height=h, width=w, ref_point=ref_point) # B, H, W, 2 -> B*M, H, W, 2 reversed_ref_point = reversed_ref_point.repeat(self.h, 1, 1, 1) # B, h, w, M, C_v scale_feature = self.k_proj(feat_map).view(nbatches, h, w, self.h, self.d_k) if key_mask is not None: # B, h, w, 1, 1 key_mask = key_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) key_mask = key_mask.expand(nbatches, h, w, self.h, self.d_k) scale_feature = torch.masked_fill(scale_feature, mask=key_mask, value=0) # B, M, C_v, h, w scale_feature = scale_feature.permute(0, 3, 4, 1, 2).contiguous() # B*M, C_v, h, w scale_feature = scale_feature.view(-1, self.d_k, h, w) k_features = [] for k in range(self.k): points = reversed_ref_point + offset[:, l, k, :, :, :] vgrid_x = 2.0 * points[:, :, :, 0] / max(w - 1, 1) - 1.0 vgrid_y = 2.0 * points[:, :, :, 1] / max(h - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) # B*M, C_v, H, W feat = F.grid_sample(scale_feature, vgrid_scaled, mode='bilinear', padding_mode='zeros') k_features.append(feat) # B*M, k, C_v, H, W k_features = torch.stack(k_features, dim=1) scale_features.append(k_features) # B*M, L, K, C_v, H, W scale_features = torch.stack(scale_features, dim=1) # B*M, H*W, C_v, LK scale_features = scale_features.permute(0, 4, 5, 3, 1, 2).contiguous() scale_features = scale_features.view(nbatches * self.h, query_height * query_width, self.d_k, -1) # B*M, H*W, C_v feat = torch.einsum('nlds, nls -> nld', scale_features, A) # B*M, H*W, C_v -> B, M, H, W, C_v feat = feat.view(nbatches, self.h, query_height, query_width, self.d_k) # B, M, H, W, C_v -> B, H, W, M, C_v feat = feat.permute(0, 2, 3, 1, 4).contiguous() # B, H, W, M, C_v -> B, H, W, M * C_v feat = feat.view(nbatches, query_height, query_width, self.d_k * self.h) feat = self.wm_proj(feat) if self.dropout: feat = self.dropout(feat) return feat, attns
def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, q_length, _, _ = query_layer.shape query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: # - key: [batch_size * self.num_heads, head_dim, kv_length] # - value: [batch_size * self.num_heads, kv_length, head_dim] key_layer = torch.cat((past_key, key_layer), dim=2) value_layer = torch.cat((past_value, value_layer), dim=1) _, _, kv_length = key_layer.shape if use_cache is True: present = (key_layer, value_layer) else: present = None # [batch_size * num_heads, q_length, kv_length] # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 matmul_result = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor, ) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` if input_dtype == torch.float16: attention_scores = attention_scores.to(torch.float) attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, num_heads, q_length, head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, present) if output_attentions: outputs += (attention_probs,) return outputs
def _zero_padding_tokens(response_tokens): mask = response_tokens == PADDING_TOKEN assert (not (mask[:, :, 1:] < mask[:, :, :-1]).any().item() ), f"Padding tokens not a suffix {to_numpy(response_tokens)}" return mask, torch.masked_fill(response_tokens, mask, 0)
def forward(self, x, target_domain, mix_output: bool = False, used_domain_list: list = None, mix_weight: torch.Tensor = None, domain_mask: torch.Tensor = None): """ confirm the order in domain_weight and used_domain_list is same :param x: [B, L, H] :param target_domain: a string, like 'news', 'book', 'bible' :param mix_output: :param used_domain_list: used domain list, like ['book', 'iwslt'] :param mix_weight: [B, L, D] :param domain_mask: [D] :return: """ # if we only use one adapter if mix_output is False: if target_domain == 'news': return x else: return self.sublayer_connection_for_adapter[target_domain](x, self.adapter_layers[target_domain]) # else we should mix the current adapters output else: assert self.check_domain_list_order(used_domain_list) is True # first, produce the weight based on the input, or provide the weight if mix_weight is None: if domain_mask is None: domain_mask = [0] * self.max_domain_num for domain in used_domain_list: domain_mask[self.domain_dict[domain]] = 1 domain_mask = torch.Tensor(domain_mask).to(x.device) mix_weight = self.inner_gate[target_domain](x) # [B, S, D_Max] mix_weight = torch.masked_fill(mix_weight, domain_mask == 0, -1e9) mix_weight = torch.softmax(mix_weight, dim=-1) # print(mix_weight) used_domain_idx = domain_mask.nonzero().flatten() select_mix_weight = mix_weight.index_select(dim=-1, index=used_domain_idx) # [B, S, D_Max] -> [B, S, D_Used] # print(mix_weight) else: select_mix_weight = mix_weight # calculate the adapter outputs separately adapter_outputs = [] for domain in used_domain_list: if domain == 'news': # this maybe not used, because we use residual connect adapter_outputs.append(x) else: # here not plus x domain_adapter_output = self.sublayer_connection_for_adapter[domain]. \ wo_residual_forward(x, self.adapter_layers[domain]) adapter_outputs.append(domain_adapter_output) # mix the outputs adapter_outputs = torch.stack(adapter_outputs, dim=-1) # [B, L, H, D] select_mix_weight = select_mix_weight.unsqueeze(-1) # [B, L, D, 1] mix_adapter_outputs = torch.matmul(adapter_outputs, select_mix_weight).squeeze(-1) # residual return x + mix_adapter_outputs, mix_weight