def __call__(self, start_logits: torch.Tensor, end_logits: torch.Tensor, match_logits: torch.Tensor, mask: torch.BoolTensor) -> torch.LongTensor: mask = mask.bool() batch_size, seq_len = start_logits.size() # match label pred, [batch_size, seq_len, seq_len] match_preds = match_logits > 0 # mask 保留 match_preds 或者 start, end 其中之一即可 match_preds = match_preds \ & mask.unsqueeze(-1).expand(-1, -1, seq_len) \ & mask.unsqueeze(1).expand(-1, seq_len, -1) # [batch_size, seq_len] start_preds = start_logits > 0 start_preds = start_preds & mask # [batch_size, seq_len] end_preds = end_logits > 0 end_preds = end_preds & mask # match label 最终结果 match_preds = (match_preds & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1)) return match_preds
def _joint_likelihood( self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor ) -> torch.Tensor: mask = mask.bool() """ Computes the numerator term for the log-likelihood, which is just score(inputs, tags) """ batch_size, sequence_length, _ = logits.data.shape # Transpose batch size and sequence dimensions: logits = logits.transpose(0, 1).contiguous() mask = mask.transpose(0, 1).contiguous() tags = tags.transpose(0, 1).contiguous() # Start with the transition scores from start_tag to the first tag in each input if self.include_start_end_transitions: score = self.start_transitions.index_select(0, tags[0]) else: score = 0.0 # Add up the scores for the observed transitions and all the inputs but the last for i in range(sequence_length - 1): # Each is shape (batch_size,) current_tag, next_tag = tags[i], tags[i + 1] # The scores for transitioning from current_tag to next_tag transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)] # The score for using current_tag emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) # Include transition score if next element is unmasked, # input_score if this element is unmasked. score = score + transition_score * mask[i + 1] + emit_score * mask[i] # Transition from last state to "stop" state. To start with, we need to find the last tag # for each instance. last_tag_index = mask.sum(0).long() - 1 last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) # Compute score of transitioning to `stop_tag` from each "last tag". if self.include_start_end_transitions: last_transition_score = self.end_transitions.index_select(0, last_tags) else: last_transition_score = 0.0 # Add the last input if it's not masked. last_inputs = logits[-1] # (batch_size, num_tags) last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) score = score + last_transition_score + last_input_score * mask[-1] return score
def masked_index_fill( target: torch.Tensor, indices: torch.LongTensor, mask: torch.BoolTensor, fill_value: int = 1 ) -> torch.Tensor: """ The given `indices` in `target` will be will be filled with `fill_value` given a `mask`. # Parameters target : `torch.Tensor`, required. A 2 dimensional tensor of shape (batch_size, sequence_length). This is the tensor to be filled. indices : `torch.LongTensor`, required A 2 dimensional tensor of shape (batch_size, num_indices), These are the indices that will be filled in the original tensor. mask : `torch.Tensor`, required. A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`. fill_value : `int`, optional (default = `1`) The value we fill the tensor with. # Returns filled_target : `torch.Tensor` A tensor with shape (batch_size, sequence_length) where 'indices' are filled with `fill_value` """ mask = mask.bool() prev_shape = target.size() # Shape: (batch_size * num_indices) flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1)) # Shape: (batch_size * num_indices, 1) mask = mask.view(-1) # Shape: (batch_size * sequence_length, 1) flattened_target = target.view(-1, 1) # Shape: (nonzero_indices, 1) unmasked_indices = flattened_indices[mask].unsqueeze(-1) flattened_target = flattened_target.scatter(0, unmasked_indices, fill_value) filled_target = flattened_target.reshape(prev_shape) return filled_target
def masked_mean(x: torch.FloatTensor, m: torch.BoolTensor): """Compute mean for where the values of `m` = 1.""" if m.bool().sum() == len(m): return torch.full((1, ), fill_value=float('inf'), device=x.device) return x[m.bool()].mean()