def forward( self, input_ids: torch.LongTensor, attention_mask: torch.BoolTensor = None, position_ids: torch.LongTensor = None, ): if attention_mask is None: attention_mask = input_ids != self.pad_token_id attention_mask = torch.cat([ torch.ones(input_ids.size(0), self.persistent_mem_size, device=input_ids.device).bool(), attention_mask ], dim=1) extended_attention_mask = attention_mask.view( input_ids.shape[0], 1, 1, attention_mask.shape[1]).repeat(1, self.num_attention_heads, attention_mask.shape[1], 1) extended_attention_mask = extended_attention_mask & extended_attention_mask.permute( 0, 1, 3, 2) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids) hidden_states = self.encoder(hidden_states=embedding_output, attention_mask=extended_attention_mask) return hidden_states
def reset_states(self, mask: torch.BoolTensor = None) -> None: """ Resets the internal states of a stateful encoder. # Parameters mask : `torch.BoolTensor`, optional. A tensor of shape `(batch_size,)` indicating which states should be reset. If not provided, all states will be reset. """ if mask is None: self._states = None else: # state has shape (num_layers, batch_size, hidden_size). We reshape # mask to have shape (1, batch_size, 1) so that operations # broadcast properly. mask_batch_size = mask.size(0) mask = mask.view(1, mask_batch_size, 1) new_states = [] for old_state in self._states: old_state_batch_size = old_state.size(1) if old_state_batch_size != mask_batch_size: raise ValueError( f"Trying to reset states using mask with incorrect batch size. " f"Expected batch size: {old_state_batch_size}. " f"Provided batch size: {mask_batch_size}.") new_state = ~mask * old_state new_states.append(new_state.detach()) self._states = tuple(new_states)
def forward(self, input_ids, past=None, mask: torch.BoolTensor = None, token_type_ids=None, position_ids=None): """ mask: [batch_size, seq_length] is attention mask """ # past length calculation and dealing with past if past is None: past_length = input_ids.shape[1] past = [None] * 12 else: # count self past_length = past[0].shape[3] + input_ids.shape[1] if mask is None: # print("mask is not provided") mask = torch.ones(input_ids.shape[0], past_length, dtype=torch.bool, device=input_ids.device) # Fast way to compute lower triangle attention mask mask = mask.view(input_ids.shape[0], 1, 1, mask.shape[1]).repeat(1, self.num_attention_heads, mask.shape[1], 1) mask = mask & mask.permute(0, 1, 3, 2) mask = torch.tril(mask) # calculate embedding output embedding_output = self.embeddings(input_ids, position_ids=position_ids) # Transformer layer last_layer_output, presents = self.encoder(embedding_output, mask=mask, past=past) return last_layer_output, presents
def masked_index_fill( target: torch.Tensor, indices: torch.LongTensor, mask: torch.BoolTensor, fill_value: int = 1 ) -> torch.Tensor: """ The given `indices` in `target` will be will be filled with `fill_value` given a `mask`. # Parameters target : `torch.Tensor`, required. A 2 dimensional tensor of shape (batch_size, sequence_length). This is the tensor to be filled. indices : `torch.LongTensor`, required A 2 dimensional tensor of shape (batch_size, num_indices), These are the indices that will be filled in the original tensor. mask : `torch.Tensor`, required. A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`. fill_value : `int`, optional (default = `1`) The value we fill the tensor with. # Returns filled_target : `torch.Tensor` A tensor with shape (batch_size, sequence_length) where 'indices' are filled with `fill_value` """ mask = mask.bool() prev_shape = target.size() # Shape: (batch_size * num_indices) flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1)) # Shape: (batch_size * num_indices, 1) mask = mask.view(-1) # Shape: (batch_size * sequence_length, 1) flattened_target = target.view(-1, 1) # Shape: (nonzero_indices, 1) unmasked_indices = flattened_indices[mask].unsqueeze(-1) flattened_target = flattened_target.scatter(0, unmasked_indices, fill_value) filled_target = flattened_target.reshape(prev_shape) return filled_target