def __call__(self, logits: torch.Tensor,
                 mask: torch.ByteTensor) -> torch.LongTensor:
        """
        对于 sequence logits, shape: (batch_size, seq_len, num_label), 使用 max 进行 在每一个
        timestep 上进行 decode, 得到 label index.
        :param logits: shape: (batch_size, seq_len, num_label)
        :param mask: shape: (bath_size, seq_len), 存储的是 0 或 1
        :return: 解码后的 label index, shape: (batch_size, seq_len), 注意这是有padding_index 的结果,
        需要使用 mask 来提取实际的 label index.
        """
        if logits.dim() != 3:
            raise RuntimeError(
                f"logits shape 错误, 应该是 (B, seq_len, num_label), "
                f"而现在是 {logits.shape}")

        if (mask is not None) and (mask.dim() != 2):
            raise RuntimeError(f"mask shape 错误, 应该是 (B, seq_len), "
                               f"而现在是 {mask.shape}")

        if mask is None:
            mask = torch.ones(size=(logits.size(0), logits.size(1)),
                              dtype=torch.long)

        mask_bool = mask.bool()

        batch_indices = list()
        for sequence_logits, sequence_mask1d in zip(logits, mask_bool):
            # 扩充维度到 2 维
            sequence_mask2d = torch.unsqueeze(sequence_mask1d, dim=-1)

            assert sequence_logits.dim() == sequence_mask2d.dim(), \
                f"sequence_logits dim: {sequence_logits.dim()} 与 sequence_mask.dim: {sequence_mask2d.dim()} 不匹配"

            sequence_logits = torch.masked_select(sequence_logits,
                                                  mask=sequence_mask2d)
            sequence_logits = sequence_logits.contiguous().view(
                -1, logits.size(-1))
            sequence_labels, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(
                sequence_logits=sequence_logits,
                vocabulary=self._label_vocabulary)

            sequence_label_indices = torch.tensor(sequence_label_indices,
                                                  dtype=torch.long,
                                                  device=logits.device)

            padding = torch.full_like(
                sequence_mask1d,
                fill_value=self._label_vocabulary.padding_index,
                dtype=torch.long)

            sequence_label_indices = padding.masked_scatter(
                sequence_mask1d, sequence_label_indices)
            batch_indices.append(sequence_label_indices)

        batch_indices = torch.stack(batch_indices, dim=0)

        return batch_indices
예제 #2
0
    def _compute_normalizer(self, emissions: torch.Tensor,
                            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].bool().all()
        mask = mask.bool()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)
예제 #3
0
파일: bio.py 프로젝트: freedomkite/easytext
def decode_label_index_to_span(
        batch_sequence_label_index: torch.Tensor, mask: torch.ByteTensor,
        vocabulary: LabelVocabulary) -> List[List[Dict]]:
    """
    将 label index 解码 成span

    batch_sequence_label shape:(B, seq_len)  (B-T: 0, I-T: 1, O: 2)
    [[0, 1, 2],
     [2, 0, 1]]

     对应label序列是:
     [[B, I, O],
      [O, B, I]]

     解码成:

     [[{"label": T, "begin": 0, "end": 2}],
      [{"label": T, "begin": 1, "end": 3}]]

    :param batch_sequence_label_index: shape: (B, seq_len), label index 序列
    :param mask: 对 batch_sequence_label 的 mask
    :param vocabulary: label 词汇表
    :return: 解析好的span列表
    """

    spans = list()

    if mask is None:
        mask = torch.ones(size=(batch_sequence_label_index.shape[0],
                                batch_sequence_label_index.shape[1]),
                          dtype=torch.long)
    mask = mask.bool()
    for sequence_label_index, mask1d in zip(batch_sequence_label_index, mask):
        label_indices = torch.masked_select(sequence_label_index,
                                            mask=mask1d).tolist()

        sequence_label = [vocabulary.token(index) for index in label_indices]

        span = decode_one_sequence_label_to_span(sequence_label=sequence_label)
        spans.append(span)

    return spans
예제 #4
0
    def _greedy_search(
            self, query: torch.FloatTensor, key: torch.FloatTensor,
            edge_head_score: torch.FloatTensor,
            edge_head_mask: torch.ByteTensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict edge heads and labels.
        :param query: [batch_size, query_length, query_vector_dim]
        :param key:  [batch_size, key_length, key_vector_dim]
        :param edge_head_score:  [batch_size, query_length, key_length]
        :param edge_head_mask:  None or [batch_size, query_length, key_length]
        :return:
            edge_head: [batch_size, query_length]
            edge_type: [batch_size, query_length]
        """
        edge_head_score = edge_head_score.masked_fill_(~edge_head_mask.bool(),
                                                       self._minus_inf)
        _, edge_head = edge_head_score.max(dim=2)

        edge_type_score = self._get_edge_type_score(query, key, edge_head)
        _, edge_type = edge_type_score.max(dim=2)

        return edge_head, edge_type
예제 #5
0
    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].bool().all()
        mask = mask.bool()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)
        best_tags_list = [
            item + [-1] * (seq_length - len(item)) for item in best_tags_list
        ]
        best_tags_list = torch.from_numpy(np.array(best_tags_list)).long()
        return torch.LongTensor(best_tags_list).cuda()