Ejemplo n.º 1
0
def test_allowed_transitions():
    """
    测试允许转移mask pair
    :return:
    """

    label_vocabulary = LabelVocabulary(labels=[["B-L1", "I-L1", "B-L2", "I-L2", "O"]],
                                       padding=LabelVocabulary.PADDING)

    allowed_pairs = BIO.allowed_transitions(label_vocabulary=label_vocabulary)

    for from_idx, to_idx in allowed_pairs:

        if from_idx == label_vocabulary.label_size:
            from_label = "START"
        else:
            from_label = label_vocabulary.token(from_idx)

        if to_idx == label_vocabulary.label_size + 1:
            to_label = "STOP"
        else:
            to_label = label_vocabulary.token(to_idx)
        print(f"(\"{from_label}\", \"{to_label}\"),")

    expect_trainsition_labels = [
        ("B-L1", "B-L1"), ("B-L1", "I-L1"), ("B-L1", "B-L2"), ("B-L1", "O"), ("B-L1", "STOP"),
        ("I-L1", "B-L1"), ("I-L1", "I-L1"), ("I-L1", "B-L2"), ("I-L1", "O"), ("I-L1", "STOP"),
        ("B-L2", "B-L1"), ("B-L2", "B-L2"), ("B-L2", "I-L2"), ("B-L2", "O"), ("B-L2", "STOP"),
        ("I-L2", "B-L1"), ("I-L2", "B-L2"), ("I-L2", "I-L2"), ("I-L2", "O"), ("I-L2", "STOP"),
        ("O", "B-L1"), ("O", "B-L2"), ("O", "O"), ("O", "STOP"),
        ("START", "B-L1"), ("START", "B-L2"), ("START", "O")]


    expect = list()

    for from_label, to_label in expect_trainsition_labels:
        if from_label == "START":
            from_idx = label_vocabulary.label_size
        else:
            from_idx = label_vocabulary.index(from_label)

        if to_label == "STOP":
            to_idx = label_vocabulary.label_size + 1
        else:
            to_idx = label_vocabulary.index(to_label)

        expect.append((from_idx, to_idx))

    ASSERT.assertSetEqual(set(expect), set(allowed_pairs))
Ejemplo n.º 2
0
def allowed_transitions(
        label_vocabulary: LabelVocabulary) -> List[Tuple[int, int]]:
    """
    给定 label 字典,计算 IBO schema 下的 转移矩阵 mask. 因为在 BIO 下面,比如 ["O", "I-Per"]
    或者 ["B-Per", "I-Loc"], 这些转移是不被允许的。该函数,计算mask. 对于允许的转移返回
    pair, (from_label_id, to_label_id)

    这里要特别注意的是: 返回的 allowed pair, 是带有 START 和 STOP的。他们的 index 分别是:
    label_vocabulary.label_size, label_vocabulary.label_size+1

    注意 label_vocabulary.padding_index 也是 label_vocabulary.label_size。虽然是一样的,
    但是因为 START 是用在转移矩阵这个特定场景下,所以不会产生冲突。

    :param label_vocabulary: label 词汇表
    :return: 所有被允许转移的 (from_label_id, to_label_id) pair 对列表。
    """

    num_labels = label_vocabulary.label_size
    start_index = num_labels
    end_index = num_labels + 1

    labels = [(i, label_vocabulary.token(i))
              for i in range(label_vocabulary.label_size)]
    labels_with_boundaries = labels + [(start_index, "START"),
                                       (end_index, "END")]

    allowed = []
    for from_label_index, from_label in labels_with_boundaries:
        if from_label in ("START", "END"):
            from_tag = from_label
            from_entity = ""
        else:
            from_tag = from_label[0]
            from_entity = from_label[1:]
        for to_label_index, to_label in labels_with_boundaries:
            if to_label in ("START", "END"):
                to_tag = to_label
                to_entity = ""
            else:
                to_tag = to_label[0]
                to_entity = to_label[1:]
            if _is_transition_allowed(from_tag, from_entity, to_tag,
                                      to_entity):
                allowed.append((from_label_index, to_label_index))
    return allowed
Ejemplo n.º 3
0
def decode_label_index_to_span(
        batch_sequence_label_index: torch.Tensor, mask: torch.ByteTensor,
        vocabulary: LabelVocabulary) -> List[List[Dict]]:
    """
    将 label index 解码 成span

    batch_sequence_label shape:(B, seq_len)  (B-T: 0, I-T: 1, O: 2)
    [[0, 1, 2],
     [2, 0, 1]]

     对应label序列是:
     [[B, I, O],
      [O, B, I]]

     解码成:

     [[{"label": T, "begin": 0, "end": 2}],
      [{"label": T, "begin": 1, "end": 3}]]

    :param batch_sequence_label_index: shape: (B, seq_len), label index 序列
    :param mask: 对 batch_sequence_label 的 mask
    :param vocabulary: label 词汇表
    :return: 解析好的span列表
    """

    spans = list()
    batch_size = batch_sequence_label_index.size(0)

    if mask is None:
        mask = torch.ones(size=(batch_sequence_label_index.shape[0],
                                batch_sequence_label_index.shape[1]),
                          dtype=torch.long)

    sequence_lengths = mask.sum(dim=-1)

    for i in range(batch_size):
        label_indices = batch_sequence_label_index[
            i, :sequence_lengths[i]].tolist()

        sequence_label = [vocabulary.token(index) for index in label_indices]

        span = decode_one_sequence_label_to_span(sequence_label=sequence_label)
        spans.append(span)

    return spans
Ejemplo n.º 4
0
    def __init__(self, label_vocabulary: LabelVocabulary) -> None:
        """
        初始化
        :param label_vocabulary: label 的 vocabulary
        """
        labels = set()

        # 从 B-Label, I-Label, 中获取 Label
        for index in range(label_vocabulary.label_size):
            bio_label: str = label_vocabulary.token(index)

            if bio_label == "O":
                continue

            label = bio_label.split("-")[1]
            labels.add(label)

        labels = [_ for _ in labels]

        super().__init__(labels=labels)
        self.label_vocabulary = label_vocabulary
Ejemplo n.º 5
0
def decode_one_sequence_logits_to_label(
        sequence_logits: torch.Tensor,
        vocabulary: LabelVocabulary) -> Tuple[List[str], List[int]]:
    """
    对 输出 sequence logits 进行解码, 是仅仅一个 sequence 进行解码,而不是 batch sequence 进行解码。
    batch sequence 解码需要进行循环
    :param sequence_logits: shape: (seq_len, label_num),
    是 mask 之后的有效 sequence,而不是包含 mask 的 sequecne logits.
    :return: sequence label, B, I, O 的list 以及 label 对应的 index list
    """

    if len(sequence_logits.shape) != 2:
        raise RuntimeError(
            f"sequence_logits shape 是 (seq_len, label_num), 现在是 {sequence_logits.shape}"
        )

    idel_state, span_state = 0, 1

    sequence_length = sequence_logits.size(0)

    state = idel_state

    # 按权重进行排序 indices shape: (seq_len, label_num)
    sorted_sequence_indices = torch.argsort(sequence_logits,
                                            dim=-1,
                                            descending=True)

    sequence_label = list()
    sequence_label_indices = list()

    for i in range(sequence_length):

        indices = sorted_sequence_indices[i, :].tolist()

        if state == idel_state:

            # 循环寻找,直到找到一个合理的标签
            for index in indices:

                label = vocabulary.token(index)

                if label[0] == "O":

                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = idel_state
                    break
                elif label[0] == "B":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                else:
                    # 其他情况 "I" 这是不合理的,所以这个逻辑是找到一个合理的标签
                    pass

        elif state == span_state:
            for index in indices:

                label = vocabulary.token(index)

                if label[0] == "B":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                elif label[0] == "O":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = idel_state
                    break
                elif label[0] == "I":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                else:
                    raise RuntimeError(f"{label} 不符合 BIO 格式")
        else:
            raise RuntimeError(f"state is error: {state}")

    return sequence_label, sequence_label_indices