Ejemplo n.º 1
0
def test_decode():
    """
    测试 模型输出的 batch logits 解码
    :return:
    """

    # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]]
    batch_sequence_logits = torch.tensor([[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]],
                                          [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]],
                                         dtype=torch.float)

    expect = [[{"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 1}, {"label": "T", "begin": 1, "end": 3}],
              [{"label": "T", "begin": 0, "end": 3}],
              [{"label": "T", "begin": 0, "end": 2}]]

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    spans = BIO.decode(batch_sequence_logits=batch_sequence_logits,
                       mask=None,
                       vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Ejemplo n.º 2
0
def test_decode_decode_label_index_to_span():
    """
    测试解码 golden label index
    :return:
    """

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    golden_labels = torch.tensor([[0, 1, 2, 0],
                                  [2, 0, 1, 1]])

    expect = [[{"label": "T", "begin": 0, "end": 2}, {"label": "T", "begin": 3, "end": 4}],
              [{"label": "T", "begin": 1, "end": 4}]]

    spans = BIO.decode_label_index_to_span(batch_sequence_label_index=golden_labels,
                                           mask=None,
                                           vocabulary=vocabulary)

    ASSERT.assertListEqual(expect, spans)
Ejemplo n.º 3
0
def vocabulary(ace_dataset):
    vocab_collate_fn = EventVocabularyCollate()
    data_loader = DataLoader(ace_dataset, collate_fn=vocab_collate_fn)

    event_types_list: List[List[str]] = list()
    tokens_list: List[List[str]] = list()
    entity_tags_list: List[List[str]] = list()

    for collate_dict in data_loader:
        event_types_list.extend(collate_dict["event_types"])
        tokens_list.extend(collate_dict["tokens"])
        entity_tags_list.extend(collate_dict["entity_tags"])

    negative_event_type = "Negative"
    event_type_vocab = Vocabulary(tokens=event_types_list,
                                  unk=negative_event_type,
                                  padding="",
                                  special_first=True)

    word_vocab = Vocabulary(tokens=tokens_list,
                            unk=Vocabulary.UNK,
                            padding=Vocabulary.PADDING,
                            special_first=True)

    entity_tag_vocab = LabelVocabulary(entity_tags_list,
                                       padding=LabelVocabulary.PADDING)
    return {
        "event_type_vocab": event_type_vocab,
        "word_vocab": word_vocab,
        "entity_tag_vocab": entity_tag_vocab
    }
Ejemplo n.º 4
0
def vocabulary(
        conll2003_dataset
) -> Dict[str, Union[Vocabulary, PretrainedVocabulary]]:
    data_loader = DataLoader(dataset=conll2003_dataset,
                             batch_size=2,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=VocabularyCollate())

    batch_tokens = list()
    batch_sequence_labels = list()

    for collate_dict in data_loader:
        batch_tokens.extend(collate_dict["tokens"])
        batch_sequence_labels.extend(collate_dict["sequence_labels"])

    token_vocabulary = Vocabulary(tokens=batch_tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                       padding=LabelVocabulary.PADDING)
    return {
        "token_vocabulary": token_vocabulary,
        "label_vocabulary": label_vocabulary
    }
Ejemplo n.º 5
0
    def build_vocabulary(self):
        training_dataset_file_path = self.config["training_dataset_file_path"]

        dataset = ACSASemEvalDataset(
            dataset_file_path=training_dataset_file_path)

        collate_fn = VocabularyCollate()
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=10,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=collate_fn)

        tokens = list()
        categories = list()
        labels = list()
        for collate_dict in data_loader:
            tokens.append(collate_dict["tokens"])
            categories.append(collate_dict["categories"])
            labels.append(collate_dict["labels"])

        token_vocabulary = Vocabulary(tokens=tokens,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)

        if not self.config["debug"]:
            pretrained_file_path = self.config["pretrained_file_path"]
            pretrained_loader = GloveLoader(
                embedding_dim=300, pretrained_file_path=pretrained_file_path)
            pretrained_token_vocabulary = PretrainedVocabulary(
                vocabulary=token_vocabulary,
                pretrained_word_embedding_loader=pretrained_loader)

            token_vocabulary = pretrained_token_vocabulary

        category_vocabulary = LabelVocabulary(labels=categories, padding=None)
        label_vocabulary = LabelVocabulary(labels=labels, padding=None)

        return {
            "token_vocabulary": token_vocabulary,
            "category_vocabulary": category_vocabulary,
            "label_vocabulary": label_vocabulary
        }
Ejemplo n.º 6
0
def test_sequence_max_label_index_decoder():
    label_vocabulary = LabelVocabulary(
        [["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
        padding=LabelVocabulary.PADDING)

    b_index = label_vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = label_vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = label_vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    # [[O, B, I], [B, B, I], [B, I, I], [B, I, O]]
    batch_sequence_logits = torch.tensor(
        [[[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.1]],
         [[0.8, 0.3, 0.4], [0.1, 0.7, 0.3], [0.2, 0.3, 0.5]]],
        dtype=torch.float)

    expect_sequence_labels = [["O", "B-T", "I-T"], ["B-T", "B-T", "I-T"],
                              ["B-T", "I-T", "I-T"], ["B-T", "I-T", "O"]]

    expect = list()

    for expect_sequence_label in expect_sequence_labels:
        expect.append(
            [label_vocabulary.index(label) for label in expect_sequence_label])

    decoder = SequenceMaxLabelIndexDecoder(label_vocabulary=label_vocabulary)

    label_indices = decoder(logits=batch_sequence_logits, mask=None)

    ASSERT.assertEqual(expect, label_indices.tolist())
    def __init__(self):
        bio_labels = [["O", "I-X", "B-X", "I-Y", "B-Y"]]

        self.label_vocabulary = LabelVocabulary(
            labels=bio_labels, padding=LabelVocabulary.PADDING)

        self.logits = torch.tensor([
            [[0, 0, .5, .5, .2], [0, 0, .3, .3, .1], [0, 0, .9, 10, 1]],
            [[0, 0, .2, .5, .2], [0, 0, 3, .3, .1], [0, 0, .9, 1, 1]],
        ],
                                   dtype=torch.float)

        self.tags = torch.tensor([[2, 3, 4], [3, 2, 2]], dtype=torch.long)

        self.transitions = torch.tensor(
            [[0.1, 0.2, 0.3, 0.4, 0.5], [0.8, 0.3, 0.1, 0.7, 0.9],
             [-0.3, 2.1, -5.6, 3.4, 4.0], [0.2, 0.4, 0.6, -0.3, -0.4],
             [1.0, 1.0, 1.0, 1.0, 1.0]],
            dtype=torch.float)

        self.transitions_from_start = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.6],
                                                   dtype=torch.float)
        self.transitions_to_end = torch.tensor([-0.1, -0.2, 0.3, -0.4, -0.4],
                                               dtype=torch.float)

        # Use the CRF Module with fixed transitions to compute the log_likelihood
        self.crf = ConditionalRandomField(5)
        self.crf.transitions = torch.nn.Parameter(self.transitions)
        self.crf.start_transitions = torch.nn.Parameter(
            self.transitions_from_start)
        self.crf.end_transitions = torch.nn.Parameter(self.transitions_to_end)

        # constraint crf
        constraints = {(0, 0), (0, 1), (1, 1), (1, 2), (2, 2), (2, 3), (3, 3),
                       (3, 4), (4, 4), (4, 0)}

        # Add the transitions to the end tag
        # and from the start tag.
        for i in range(5):
            constraints.add((5, i))
            constraints.add((i, 6))

        constraint_crf = ConditionalRandomField(num_tags=5,
                                                constraints=constraints)
        constraint_crf.transitions = torch.nn.Parameter(self.transitions)
        constraint_crf.start_transitions = torch.nn.Parameter(
            self.transitions_from_start)
        constraint_crf.end_transitions = torch.nn.Parameter(
            self.transitions_to_end)
        self.constraint_crf = constraint_crf
    def test_allowed_transitions(self):
        bio_labels = ['O', 'B-X', 'I-X', 'B-Y', 'I-Y']  # start tag, end tag

        label_vocabulary = LabelVocabulary(labels=[bio_labels],
                                           padding=LabelVocabulary.PADDING)
        #              0     1      2      3      4         5          6
        allowed = BIO.allowed_transitions(label_vocabulary=label_vocabulary)

        # The empty spaces in this matrix indicate disallowed transitions.
        assert set(allowed) == {  # Extra column for end tag.
            (0, 0), (0, 1), (0, 3), (0, 6),
            (1, 0), (1, 1), (1, 2), (1, 3), (1, 6),
            (2, 0), (2, 1), (2, 2), (2, 3), (2, 6),
            (3, 0), (3, 1), (3, 3), (3, 4), (3, 6),
            (4, 0), (4, 1), (4, 3), (4, 4), (4, 6),
            (5, 0), (5, 1), (5, 3)  # Extra row for start tag
        }
Ejemplo n.º 9
0
def decode_label_index_to_span(
        batch_sequence_label_index: torch.Tensor, mask: torch.ByteTensor,
        vocabulary: LabelVocabulary) -> List[List[Dict]]:
    """
    将 label index 解码 成span

    batch_sequence_label shape:(B, seq_len)  (B-T: 0, I-T: 1, O: 2)
    [[0, 1, 2],
     [2, 0, 1]]

     对应label序列是:
     [[B, I, O],
      [O, B, I]]

     解码成:

     [[{"label": T, "begin": 0, "end": 2}],
      [{"label": T, "begin": 1, "end": 3}]]

    :param batch_sequence_label_index: shape: (B, seq_len), label index 序列
    :param mask: 对 batch_sequence_label 的 mask
    :param vocabulary: label 词汇表
    :return: 解析好的span列表
    """

    spans = list()
    batch_size = batch_sequence_label_index.size(0)

    if mask is None:
        mask = torch.ones(size=(batch_sequence_label_index.shape[0],
                                batch_sequence_label_index.shape[1]),
                          dtype=torch.long)

    sequence_lengths = mask.sum(dim=-1)

    for i in range(batch_size):
        label_indices = batch_sequence_label_index[
            i, :sequence_lengths[i]].tolist()

        sequence_label = [vocabulary.token(index) for index in label_indices]

        span = decode_one_sequence_label_to_span(sequence_label=sequence_label)
        spans.append(span)

    return spans
Ejemplo n.º 10
0
def allowed_transitions(
        label_vocabulary: LabelVocabulary) -> List[Tuple[int, int]]:
    """
    给定 label 字典,计算 IBO schema 下的 转移矩阵 mask. 因为在 BIO 下面,比如 ["O", "I-Per"]
    或者 ["B-Per", "I-Loc"], 这些转移是不被允许的。该函数,计算mask. 对于允许的转移返回
    pair, (from_label_id, to_label_id)

    这里要特别注意的是: 返回的 allowed pair, 是带有 START 和 STOP的。他们的 index 分别是:
    label_vocabulary.label_size, label_vocabulary.label_size+1

    注意 label_vocabulary.padding_index 也是 label_vocabulary.label_size。虽然是一样的,
    但是因为 START 是用在转移矩阵这个特定场景下,所以不会产生冲突。

    :param label_vocabulary: label 词汇表
    :return: 所有被允许转移的 (from_label_id, to_label_id) pair 对列表。
    """

    num_labels = label_vocabulary.label_size
    start_index = num_labels
    end_index = num_labels + 1

    labels = [(i, label_vocabulary.token(i))
              for i in range(label_vocabulary.label_size)]
    labels_with_boundaries = labels + [(start_index, "START"),
                                       (end_index, "END")]

    allowed = []
    for from_label_index, from_label in labels_with_boundaries:
        if from_label in ("START", "END"):
            from_tag = from_label
            from_entity = ""
        else:
            from_tag = from_label[0]
            from_entity = from_label[1:]
        for to_label_index, to_label in labels_with_boundaries:
            if to_label in ("START", "END"):
                to_tag = to_label
                to_entity = ""
            else:
                to_tag = to_label[0]
                to_entity = to_label[1:]
            if _is_transition_allowed(from_tag, from_entity, to_tag,
                                      to_entity):
                allowed.append((from_label_index, to_label_index))
    return allowed
Ejemplo n.º 11
0
def test_allowed_transitions():
    """
    测试允许转移mask pair
    :return:
    """

    label_vocabulary = LabelVocabulary(labels=[["B-L1", "I-L1", "B-L2", "I-L2", "O"]],
                                       padding=LabelVocabulary.PADDING)

    allowed_pairs = BIO.allowed_transitions(label_vocabulary=label_vocabulary)

    for from_idx, to_idx in allowed_pairs:

        if from_idx == label_vocabulary.label_size:
            from_label = "START"
        else:
            from_label = label_vocabulary.token(from_idx)

        if to_idx == label_vocabulary.label_size + 1:
            to_label = "STOP"
        else:
            to_label = label_vocabulary.token(to_idx)
        print(f"(\"{from_label}\", \"{to_label}\"),")

    expect_trainsition_labels = [
        ("B-L1", "B-L1"), ("B-L1", "I-L1"), ("B-L1", "B-L2"), ("B-L1", "O"), ("B-L1", "STOP"),
        ("I-L1", "B-L1"), ("I-L1", "I-L1"), ("I-L1", "B-L2"), ("I-L1", "O"), ("I-L1", "STOP"),
        ("B-L2", "B-L1"), ("B-L2", "B-L2"), ("B-L2", "I-L2"), ("B-L2", "O"), ("B-L2", "STOP"),
        ("I-L2", "B-L1"), ("I-L2", "B-L2"), ("I-L2", "I-L2"), ("I-L2", "O"), ("I-L2", "STOP"),
        ("O", "B-L1"), ("O", "B-L2"), ("O", "O"), ("O", "STOP"),
        ("START", "B-L1"), ("START", "B-L2"), ("START", "O")]


    expect = list()

    for from_label, to_label in expect_trainsition_labels:
        if from_label == "START":
            from_idx = label_vocabulary.label_size
        else:
            from_idx = label_vocabulary.index(from_label)

        if to_label == "STOP":
            to_idx = label_vocabulary.label_size + 1
        else:
            to_idx = label_vocabulary.index(to_label)

        expect.append((from_idx, to_idx))

    ASSERT.assertSetEqual(set(expect), set(allowed_pairs))
Ejemplo n.º 12
0
    def __init__(self, label_vocabulary: LabelVocabulary) -> None:
        """
        初始化
        :param label_vocabulary: label 的 vocabulary
        """
        labels = set()

        # 从 B-Label, I-Label, 中获取 Label
        for index in range(label_vocabulary.label_size):
            bio_label: str = label_vocabulary.token(index)

            if bio_label == "O":
                continue

            label = bio_label.split("-")[1]
            labels.add(label)

        labels = [_ for _ in labels]

        super().__init__(labels=labels)
        self.label_vocabulary = label_vocabulary
Ejemplo n.º 13
0
def test_label_vocabulary():
    """
    测试 label vocabulary
    :return:
    """
    vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]], padding="")
    ASSERT.assertEqual(vocabulary.size, 5)

    vocabulary = LabelVocabulary([["A", "B", "C"], ["D", "E"]],
                                 padding=LabelVocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.size, 6)
    ASSERT.assertEqual(vocabulary.label_size, 5)

    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 5)

    for index, w in enumerate(["A", "B", "C", "D", "E"]):
        ASSERT.assertEqual(vocabulary.index(w), index)
Ejemplo n.º 14
0
def test_decode_one_sequence_logits_to_label():
    """
    测试 decode sequence label
    :return:
    """

    sequence_logits_list = list()
    expect_list = list()

    sequence_logits = torch.tensor([[0.2, 0.3, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)  # O B I 正常
    expect = ["O", "B-T", "I-T"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "I-T"]

    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    sequence_logits = torch.tensor([[0.9, 0.3, 0.4], [0.2, 0.8, 0.3], [0.2, 0.3, 0.9]],
                                   dtype=torch.float)
    expect = ["B-T", "I-T", "O"]
    sequence_logits_list.append(sequence_logits)
    expect_list.append(expect)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    for sequence_logits, expect in zip(sequence_logits_list, expect_list):
        sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(
            sequence_logits=sequence_logits,
            vocabulary=vocabulary)

        ASSERT.assertListEqual(sequence_label, expect)
        expect_indices = [vocabulary.index(label) for label in expect]
        ASSERT.assertListEqual(sequence_label_indices, expect_indices)
Ejemplo n.º 15
0
    def build_vocabulary(self, dataset: Dataset):

        data_loader = DataLoader(dataset=dataset,
                                 batch_size=100,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=VocabularyCollate())
        batch_tokens = list()
        batch_sequence_labels = list()

        for collate_dict in data_loader:
            batch_tokens.extend(collate_dict["tokens"])
            batch_sequence_labels.extend(collate_dict["sequence_labels"])

        token_vocabulary = Vocabulary(tokens=batch_tokens,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
        model_name = self.config["model_name"]

        if model_name in {ModelName.NER_V2, ModelName.NER_V3}:
            pretrained_word_embedding_file_path = self.config[
                "pretrained_word_embedding_file_path"]
            glove_loader = GloveLoader(
                embedding_dim=100,
                pretrained_file_path=pretrained_word_embedding_file_path)

            token_vocabulary = PretrainedVocabulary(
                vocabulary=token_vocabulary,
                pretrained_word_embedding_loader=glove_loader)

        label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                           padding=LabelVocabulary.PADDING)

        return {
            "token_vocabulary": token_vocabulary,
            "label_vocabulary": label_vocabulary
        }
Ejemplo n.º 16
0
def test_decode_one_sequence_logits_to_label_abnormal():
    """
    测试异常case
    :return:
    """

    # [0.2, 0.5, 0.4] argmax 解码是 I 这是异常的 case, 整个序列是: I B O
    # 而 decode_sequence_lable_bio 会将 概率值 是 0.4 的 也就是 O 作为标签输出 来修订这个个错误
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.7, 0.2, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    vocabulary = LabelVocabulary([["B-T", "B-T", "B-T", "I-T", "I-T", "O"]],
                                 padding=LabelVocabulary.PADDING)

    b_index = vocabulary.index("B-T")
    ASSERT.assertEqual(0, b_index)
    i_index = vocabulary.index("I-T")
    ASSERT.assertEqual(1, i_index)
    o_index = vocabulary.index("O")
    ASSERT.assertEqual(2, o_index)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)

    expect = ["O", "B-T", "I-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)

    # argmax 解码是 I I I 经过修订后是: O O B
    sequence_logits = torch.tensor([[0.2, 0.5, 0.4], [0.2, 0.9, 0.3], [0.2, 0.3, 0.1]],
                                   dtype=torch.float)

    sequence_label, sequence_label_indices = BIO.decode_one_sequence_logits_to_label(sequence_logits=sequence_logits,
                                                             vocabulary=vocabulary)
    expect = ["O", "O", "B-T"]
    expect_indices = [vocabulary.index(label) for label in expect]
    ASSERT.assertListEqual(expect, sequence_label)
    ASSERT.assertListEqual(expect_indices, sequence_label_indices)
Ejemplo n.º 17
0
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset,
                                  gaz_pretrained_embedding_loader):
    """
    测试 bilstm gat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    bilstm_gat_model_collate = BiLstmGATModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = bilstm_gat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    t_graph_0 = model_inputs.model_inputs["t_graph"][0]
    c_graph_0 = model_inputs.model_inputs["c_graph"][0]
    l_graph_0 = model_inputs.model_inputs["l_graph"][0]

    expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0))

    expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0))

    expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0))

    gaz_words_indices = model_inputs.model_inputs["gaz_words"]

    ASSERT.assertEqual((2, 11), gaz_words_indices.size())

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0)

    expect_squeeze_gaz_words_indices_0 = torch.tensor(
        [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0],
        dtype=torch.long)

    ASSERT.assertTrue(
        is_tensor_equal(expect_squeeze_gaz_words_indices_0,
                        gaz_words_indices[0]))
Ejemplo n.º 18
0
    def __init__(self,
                 is_training: bool,
                 dataset: Dataset,
                 vocabulary_collate,
                 token_vocabulary_dir: str,
                 label_vocabulary_dir: str,
                 is_build_token_vocabulary: bool,
                 pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader):
        """
        词汇表构建器
        :param is_training: 因为在 train 和 非 train, 词汇表的构建行为有所不同;
                            如果是 train, 则一般需要重新构建; 而对于 非train, 使用先前构建好的即可。
        :param dataset: 数据集
        :param vocabulary_collate: 词汇表 collate
        :param token_vocabulary_dir: token vocabulary 存放目录
        :param label_vocabulary_dir: label vocabulary 存放目录
        :param is_build_token_vocabulary: 是否构建 token vocabulary, 因为在使用 Bert 或者 其他模型作为预训练的 embedding,
                                          则没有必要构建 token vocabulary.
        :param pretrained_word_embedding_loader: 预训练词汇表
        """

        super().__init__(is_training=is_training)

        token_vocabulary = None
        label_vocabulary = None

        if is_training:
            data_loader = DataLoader(dataset=dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     num_workers=0,
                                     collate_fn=vocabulary_collate)
            batch_tokens = list()
            batch_sequence_labels = list()

            for collate_dict in data_loader:
                batch_tokens.extend(collate_dict["tokens"])
                batch_sequence_labels.extend(collate_dict["sequence_labels"])

            if is_build_token_vocabulary:
                token_vocabulary = Vocabulary(tokens=batch_tokens,
                                              padding=Vocabulary.PADDING,
                                              unk=Vocabulary.UNK,
                                              special_first=True)

                if pretrained_word_embedding_loader is not None:
                    token_vocabulary = \
                        PretrainedVocabulary(vocabulary=token_vocabulary,
                                             pretrained_word_embedding_loader=pretrained_word_embedding_loader)

                if token_vocabulary_dir:
                    token_vocabulary.save_to_file(token_vocabulary_dir)

            label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                               padding=LabelVocabulary.PADDING)

            if label_vocabulary_dir:
                label_vocabulary.save_to_file(label_vocabulary_dir)
        else:
            if is_build_token_vocabulary and token_vocabulary_dir:
                token_vocabulary = Vocabulary.from_file(token_vocabulary_dir)

            if label_vocabulary_dir:
                label_vocabulary = LabelVocabulary.from_file(label_vocabulary_dir)

        self.token_vocabulary = token_vocabulary
        self.label_vocabulary = label_vocabulary
Ejemplo n.º 19
0
def decode_one_sequence_logits_to_label(
        sequence_logits: torch.Tensor,
        vocabulary: LabelVocabulary) -> Tuple[List[str], List[int]]:
    """
    对 输出 sequence logits 进行解码, 是仅仅一个 sequence 进行解码,而不是 batch sequence 进行解码。
    batch sequence 解码需要进行循环
    :param sequence_logits: shape: (seq_len, label_num),
    是 mask 之后的有效 sequence,而不是包含 mask 的 sequecne logits.
    :return: sequence label, B, I, O 的list 以及 label 对应的 index list
    """

    if len(sequence_logits.shape) != 2:
        raise RuntimeError(
            f"sequence_logits shape 是 (seq_len, label_num), 现在是 {sequence_logits.shape}"
        )

    idel_state, span_state = 0, 1

    sequence_length = sequence_logits.size(0)

    state = idel_state

    # 按权重进行排序 indices shape: (seq_len, label_num)
    sorted_sequence_indices = torch.argsort(sequence_logits,
                                            dim=-1,
                                            descending=True)

    sequence_label = list()
    sequence_label_indices = list()

    for i in range(sequence_length):

        indices = sorted_sequence_indices[i, :].tolist()

        if state == idel_state:

            # 循环寻找,直到找到一个合理的标签
            for index in indices:

                label = vocabulary.token(index)

                if label[0] == "O":

                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = idel_state
                    break
                elif label[0] == "B":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                else:
                    # 其他情况 "I" 这是不合理的,所以这个逻辑是找到一个合理的标签
                    pass

        elif state == span_state:
            for index in indices:

                label = vocabulary.token(index)

                if label[0] == "B":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                elif label[0] == "O":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = idel_state
                    break
                elif label[0] == "I":
                    sequence_label.append(label)
                    sequence_label_indices.append(index)
                    state = span_state
                    break
                else:
                    raise RuntimeError(f"{label} 不符合 BIO 格式")
        else:
            raise RuntimeError(f"state is error: {state}")

    return sequence_label, sequence_label_indices
Ejemplo n.º 20
0
    def __call__(self, config: Dict, train_type: int):
        serialize_dir = config["serialize_dir"]
        vocabulary_dir = config["vocabulary_dir"]
        pretrained_embedding_file_path = config["pretrained_embedding_file_path"]
        word_embedding_dim = config["word_embedding_dim"]
        pretrained_embedding_max_size = config["pretrained_embedding_max_size"]
        is_fine_tuning = config["fine_tuning"]

        word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary")
        event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary")
        entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary")

        if train_type == Train.NEW_TRAIN:

            if os.path.isdir(serialize_dir):
                shutil.rmtree(serialize_dir)

            os.makedirs(serialize_dir)

            if os.path.isdir(vocabulary_dir):
                shutil.rmtree(vocabulary_dir)
            os.makedirs(vocabulary_dir)
            os.makedirs(word_vocab_dir)
            os.makedirs(event_type_vocab_dir)
            os.makedirs(entity_tag_vocab_dir)

        elif train_type == Train.RECOVERY_TRAIN:
            pass
        else:
            assert False, f"train_type: {train_type} error!"

        train_dataset_file_path = config["train_dataset_file_path"]
        validation_dataset_file_path = config["validation_dataset_file_path"]

        num_epoch = config["epoch"]
        batch_size = config["batch_size"]

        if train_type == Train.NEW_TRAIN:
            # 构建词汇表
            ace_dataset = ACEDataset(train_dataset_file_path)
            vocab_data_loader = DataLoader(dataset=ace_dataset,
                                           batch_size=10,
                                           shuffle=False, num_workers=0,
                                           collate_fn=EventVocabularyCollate())

            tokens: List[List[str]] = list()
            event_types: List[List[str]] = list()
            entity_tags: List[List[str]] = list()

            for colleta_dict in vocab_data_loader:
                tokens.extend(colleta_dict["tokens"])
                event_types.extend(colleta_dict["event_types"])
                entity_tags.extend(colleta_dict["entity_tags"])

            word_vocabulary = Vocabulary(tokens=tokens,
                                         padding=Vocabulary.PADDING,
                                         unk=Vocabulary.UNK,
                                         special_first=True)

            glove_loader = GloveLoader(embedding_dim=word_embedding_dim,
                                       pretrained_file_path=pretrained_embedding_file_path,
                                       max_size=pretrained_embedding_max_size)

            pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary,
                                                              pretrained_word_embedding_loader=glove_loader)

            pretrained_word_vocabulary.save_to_file(word_vocab_dir)

            event_type_vocabulary = Vocabulary(tokens=event_types,
                                               padding="",
                                               unk="Negative",
                                               special_first=True)
            event_type_vocabulary.save_to_file(event_type_vocab_dir)

            entity_tag_vocabulary = LabelVocabulary(labels=entity_tags,
                                                    padding=LabelVocabulary.PADDING)
            entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir)
        else:
            pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir)
            event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir)
            entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir)

        model = EventModel(alpha=0.5,
                           activate_score=True,
                           sentence_vocab=pretrained_word_vocabulary,
                           sentence_embedding_dim=word_embedding_dim,
                           entity_tag_vocab=entity_tag_vocabulary,
                           entity_tag_embedding_dim=50,
                           event_type_vocab=event_type_vocabulary,
                           event_type_embedding_dim=300,
                           lstm_hidden_size=300,
                           lstm_encoder_num_layer=1,
                           lstm_encoder_droupout=0.4)

        trainer = Trainer(
            serialize_dir=serialize_dir,
            num_epoch=num_epoch,
            model=model,
            loss=EventLoss(),
            optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning),
            metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary),
            patient=10,
            num_check_point_keep=5,
            devices=None
        )

        train_dataset = EventDataset(dataset_file_path=train_dataset_file_path,
                                     event_type_vocabulary=event_type_vocabulary)
        validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path,
                                          event_type_vocabulary=event_type_vocabulary)

        event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary,
                                     event_type_vocabulary=event_type_vocabulary,
                                     entity_tag_vocabulary=entity_tag_vocabulary,
                                     sentence_max_len=512)
        train_data_loader = DataLoader(dataset=train_dataset,
                                       batch_size=batch_size,
                                       num_workers=0,
                                       collate_fn=event_collate)

        validation_data_loader = DataLoader(dataset=validation_dataset,
                                            batch_size=batch_size,
                                            num_workers=0,
                                            collate_fn=event_collate)

        if train_type == Train.NEW_TRAIN:
            trainer.train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)
        else:
            trainer.recovery_train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)
Ejemplo n.º 21
0
def test_flat_model_collate(lattice_ner_demo_dataset,
                            character_pretrained_embedding_loader,
                            gaz_pretrained_embedding_loader):
    """
    测试 flat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    characters = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    character_vocabulary = Vocabulary(tokens=characters,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
    character_vocabulary = PretrainedVocabulary(
        vocabulary=character_vocabulary,
        pretrained_word_embedding_loader=character_pretrained_embedding_loader)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    flat_vocabulary = FlatPretrainedVocabulary(
        character_pretrained_vocabulary=character_vocabulary,
        gaz_word_pretrained_vocabulary=gaz_vocabulary)

    flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary,
                                          gazetter=gazetter,
                                          label_vocabulary=label_vocabulary)

    model_inputs = flat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    sentence = "陈元呼吁加强国际合作推动世界经济发展"

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0)

    expect_tokens = [character
                     for character in sentence] + expect_squeeze_gaz_words_0

    tokens = metadata_0["tokens"]

    ASSERT.assertListEqual(expect_tokens, tokens)

    character_pos_begin = [index for index in range(len(sentence))]
    character_pos_end = [index for index in range(len(sentence))]

    squeeze_gaz_words_begin = list()
    squeeze_gaz_words_end = list()

    for squeeze_gaz_word in squeeze_gaz_words_0:
        index = sentence.find(squeeze_gaz_word)

        squeeze_gaz_words_begin.append(index)
        squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1)

    pos_begin = model_inputs.model_inputs["pos_begin"][0]
    pos_end = model_inputs.model_inputs["pos_end"][0]

    expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin
    expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin))
    expect_pos_begin = torch.tensor(expect_pos_begin)

    expect_pos_end = character_pos_end + squeeze_gaz_words_end
    expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end))
    expect_pos_end = torch.tensor(expect_pos_end)

    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin))
    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end))

    expect_character_length = len(sentence)
    expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0)

    character_length = model_inputs.model_inputs["sequence_length"][0]
    squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][
        0]

    ASSERT.assertEqual(expect_character_length, character_length.item())
    ASSERT.assertEqual(expect_squeeze_gaz_word_length,
                       squeeze_word_length.item())
def test_sequence_label_decoder():
    """
    测试 sequence label decoder
    :return:
    """
    sequence_label_list = list()
    expect_spans = list()

    sequence_label = ["B-T", "I-T", "O-T"]
    expect = [{"label": "T", "begin": 0, "end": 2}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T"]
    expect = [{"label": "T", "begin": 0, "end": 3}]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    sequence_label = ["B-T", "I-T", "I-T", "B-T"]
    expect = [{
        "label": "T",
        "begin": 0,
        "end": 3
    }, {
        "label": "T",
        "begin": 3,
        "end": 4
    }]

    sequence_label_list.append(sequence_label)
    expect_spans.append(expect)

    label_vocabulary = LabelVocabulary(sequence_label_list,
                                       padding=LabelVocabulary.PADDING)

    sequence_label_indices = list()
    mask_list = list()

    max_sequence_len = 4
    for sequence_labels in sequence_label_list:
        sequence_label_index = [
            label_vocabulary.index(label) for label in sequence_labels
        ]

        mask = [1] * len(sequence_label_index) + [0] * (
            max_sequence_len - len(sequence_label_index))

        sequence_label_index.extend(
            [label_vocabulary.padding_index] *
            (max_sequence_len - len(sequence_label_index)))

        sequence_label_indices.append(sequence_label_index)
        mask_list.append(mask)

    sequence_label_indices = torch.tensor(sequence_label_indices,
                                          dtype=torch.long)
    mask = torch.tensor(mask_list, dtype=torch.uint8)

    decoder = SequenceLabelDecoder(label_vocabulary=label_vocabulary)

    spans = decoder(label_indices=sequence_label_indices, mask=mask)

    ASSERT.assertListEqual(expect_spans, spans)
Ejemplo n.º 23
0
def test_gaz_model_collate(lattice_ner_demo_dataset,
                           gaz_pretrained_embedding_loader):
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    lattice_model_collate = LatticeModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = lattice_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"],
                          [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [],
                          ["发展"], []]

    gaz_words_0 = metadata_0["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0)

    gaz_list_0 = model_inputs.model_inputs["gaz_list"][0]

    expect_gaz_list_0 = list()

    for expect_gaz_word in expect_gaz_words_0:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_0.append([indices, lengthes])

        else:
            expect_gaz_list_0.append([])

    logging.debug(
        f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}"
    )
    ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0)

    tokens_0 = model_inputs.model_inputs["tokens"]
    ASSERT.assertEqual((2, 19), tokens_0.size())
    sequence_label_0 = model_inputs.labels
    ASSERT.assertEqual((2, 19), sequence_label_0.size())

    # 新华社华盛顿4月28日电(记者翟景升)
    expect_gaz_word_1 = [
        ["新华社", "新华"],  # 新
        ["华社"],  # 华
        ["社华"],  # 社
        ["华盛顿", "华盛"],  # 华
        ["盛顿"],  # 盛
        [],  # 顿
        [],  # 4
        [],  # 月
        [],  # 2
        [],  # 8
        [],  # 日
        [],  # 电
        [],  # (
        ["记者"],  # 记
        [],  # 者
        ["翟景升", "翟景"],  # 翟
        ["景升"],  # 景
        [],  # 升
        []
    ]  # )

    metadata_1 = model_inputs.model_inputs["metadata"][1]
    gaz_words_1 = metadata_1["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1)

    expect_gaz_list_1 = list()

    for expect_gaz_word in expect_gaz_word_1:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_1.append([indices, lengthes])

        else:
            expect_gaz_list_1.append([])

    gaz_list_1 = model_inputs.model_inputs["gaz_list"][1]
    ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)