def masked_topk( input_: torch.FloatTensor, mask: torch.BoolTensor, k: Union[int, torch.LongTensor], dim: int = -1, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: if input_.size() != mask.size(): raise ValueError("`input_` and `mask` must have the same shape.") if not -input_.dim() <= dim < input_.dim(): raise ValueError("`dim` must be in `[-input_.dim(), input_.dim())`") dim = (dim + input_.dim()) % input_.dim() max_k = k if isinstance(k, int) else k.max() permutation = list(range(input_.dim())) permutation.pop(dim) permutation += [dim] reverse_permutation = list(range(input_.dim() - 1)) reverse_permutation.insert(dim, -1) other_dims_size = list(input_.size()) other_dims_size.pop(dim) permuted_size = other_dims_size + [max_k] # for restoration if isinstance(k, int): k = k * torch.ones(*other_dims_size, dtype=torch.long, device=mask.device) else: if list(k.size()) != other_dims_size: raise ValueError( "`k` must have the same shape as `input_` with dimension `dim` removed." ) num_items = input_.size(dim) input_ = input_.permute(*permutation).reshape(-1, num_items) mask = mask.permute(*permutation).reshape(-1, num_items) k = k.reshape(-1) input_ = replace_masked_values(input_, mask, min_value_of_dtype(input_.dtype)) _, top_indices = input_.topk(max_k, 1) top_indices_mask = get_mask_from_sequence_lengths(k, max_k).bool() fill_value, _ = top_indices.max(dim=1, keepdim=True) top_indices = torch.where(top_indices_mask, top_indices, fill_value) top_indices, _ = top_indices.sort(1) sequence_mask = mask.gather(1, top_indices) top_mask = top_indices_mask & sequence_mask top_input = input_.gather(1, top_indices) return ( top_input.reshape(*permuted_size).permute(*reverse_permutation), top_mask.reshape(*permuted_size).permute(*reverse_permutation), top_indices.reshape(*permuted_size).permute(*reverse_permutation), )
def forward(self, input_ids, past=None, mask: torch.BoolTensor = None, token_type_ids=None, position_ids=None): """ mask: [batch_size, seq_length] is attention mask """ # past length calculation and dealing with past if past is None: past_length = input_ids.shape[1] past = [None] * 12 else: # count self past_length = past[0].shape[3] + input_ids.shape[1] if mask is None: # print("mask is not provided") mask = torch.ones(input_ids.shape[0], past_length, dtype=torch.bool, device=input_ids.device) # Fast way to compute lower triangle attention mask mask = mask.view(input_ids.shape[0], 1, 1, mask.shape[1]).repeat(1, self.num_attention_heads, mask.shape[1], 1) mask = mask & mask.permute(0, 1, 3, 2) mask = torch.tril(mask) # calculate embedding output embedding_output = self.embeddings(input_ids, position_ids=position_ids) # Transformer layer last_layer_output, presents = self.encoder(embedding_output, mask=mask, past=past) return last_layer_output, presents