Пример #1
0
    def __call__(self, start_logits: torch.Tensor, end_logits: torch.Tensor,
                 match_logits: torch.Tensor,
                 mask: torch.BoolTensor) -> torch.LongTensor:

        mask = mask.bool()
        batch_size, seq_len = start_logits.size()

        # match label pred, [batch_size, seq_len, seq_len]
        match_preds = match_logits > 0

        # mask 保留 match_preds 或者 start, end 其中之一即可
        match_preds = match_preds \
                      & mask.unsqueeze(-1).expand(-1, -1, seq_len) \
                      & mask.unsqueeze(1).expand(-1, seq_len, -1)

        # [batch_size, seq_len]
        start_preds = start_logits > 0

        start_preds = start_preds & mask

        # [batch_size, seq_len]
        end_preds = end_logits > 0
        end_preds = end_preds & mask

        # match label 最终结果
        match_preds = (match_preds
                       & start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                       & end_preds.unsqueeze(1).expand(-1, seq_len, -1))

        return match_preds
    def _joint_likelihood(
        self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor
    ) -> torch.Tensor:
        mask = mask.bool()
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, _ = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i + 1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)]

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i + 1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0)

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]  # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(-1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()  # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score
Пример #3
0
def masked_index_fill(
    target: torch.Tensor, indices: torch.LongTensor, mask: torch.BoolTensor, fill_value: int = 1
) -> torch.Tensor:
    """
    The given `indices` in `target` will be will be filled with `fill_value` given a `mask`.


    # Parameters

    target : `torch.Tensor`, required.
        A 2 dimensional tensor of shape (batch_size, sequence_length).
        This is the tensor to be filled.
    indices : `torch.LongTensor`, required
        A 2 dimensional tensor of shape (batch_size, num_indices),
        These are the indices that will be filled in the original tensor.
    mask : `torch.Tensor`, required.
        A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`.
    fill_value : `int`, optional (default = `1`)
        The value we fill the tensor with.

    # Returns

    filled_target : `torch.Tensor`
        A tensor with shape (batch_size, sequence_length) where 'indices' are filled with `fill_value`
    """
    mask = mask.bool()
    prev_shape = target.size()
    # Shape: (batch_size * num_indices)
    flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1))
    # Shape: (batch_size * num_indices, 1)
    mask = mask.view(-1)
    # Shape: (batch_size * sequence_length, 1)
    flattened_target = target.view(-1, 1)
    # Shape: (nonzero_indices, 1)
    unmasked_indices = flattened_indices[mask].unsqueeze(-1)

    flattened_target = flattened_target.scatter(0, unmasked_indices, fill_value)

    filled_target = flattened_target.reshape(prev_shape)

    return filled_target
Пример #4
0
 def masked_mean(x: torch.FloatTensor, m: torch.BoolTensor):
     """Compute mean for where the values of `m` = 1."""
     if m.bool().sum() == len(m):
         return torch.full((1, ), fill_value=float('inf'), device=x.device)
     return x[m.bool()].mean()