def get_extended_attention_mask(self, attention_mask: flow.Tensor, input_shape: Tuple[int], device: flow.device) -> flow.Tensor: # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.is_decoder: batch_size, seq_length = input_shape seq_ids = flow.arange(seq_length, device=device) causal_mask = (seq_ids[None, None, :].repeat( batch_size, seq_length, 1) <= seq_ids[None, :, None]) # in case past_key_values are used we need to add a prefix ones mask to the causal mask causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[ 1] - causal_mask.shape[1] causal_mask = flow.cat( [ flow.ones( (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :]) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) extended_attention_mask = extended_attention_mask.to(dtype=flow.float) extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 return extended_attention_mask
def get_extended_attention_mask( self, attention_mask: flow.Tensor, input_ids: flow.Tensor, ): if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError("Wrong shape for input_ids (shape {}) " "or attention_mask (shape {})".format( input_ids.shape, attention_mask.shape)) extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask