Пример #1
0
def test_vocabuary_collate(vocabulary):
    """
    测试 vocabualry collate
    :param conll2003_dataset:
    :return: None
    """

    # data_loader = DataLoader(dataset=conll2003_dataset,
    #                          batch_size=2,
    #                          shuffle=False,
    #                          num_workers=0,
    #                          collate_fn=VocabularyCollate())
    #
    # batch_tokens = list()
    # batch_sequence_labels = list()
    #
    # for collate_dict in data_loader:
    #     batch_tokens.extend(collate_dict["tokens"])
    #     batch_sequence_labels.extend(collate_dict["sequence_labels"])
    #
    # word_vocabulary = Vocabulary(tokens=batch_tokens,
    #                         padding=Vocabulary.PADDING,
    #                         unk=Vocabulary.UNK,
    #                         special_first=True)

    word_vocabulary = vocabulary["token_vocabulary"]
    ASSERT.assertEqual(13 + 2, word_vocabulary.size)

    label_vocabulary = vocabulary["label_vocabulary"]
    ASSERT.assertEqual(3, label_vocabulary.label_size)
Пример #2
0
def test_gaz_vocabulary_collate(lattice_ner_demo_dataset,
                                gaz_pretrained_embedding_loader):

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    words_list = gaz_vocabulary_collate(lattice_ner_demo_dataset)

    logging.info(json2str(words_list))

    # 相应的句子是: "陈元呼吁加强国际合作推动世界经济发展", 得到的 gaz words 是
    expect_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]
    gaz_words_0 = words_list[0]
    ASSERT.assertListEqual(expect_0, gaz_words_0)

    # 新华社华盛顿4月28日电(记者翟景升)
    expect_1 = [
        "新华社", "新华", "华社", "社华", "华盛顿", "华盛", "盛顿", "记者", "翟景升", "翟景", "景升"
    ]
    gaz_words_1 = words_list[1]
    ASSERT.assertListEqual(expect_1, gaz_words_1)
def test_vocabuary_collate(vocabulary):
    """
    测试 vocabualry collate
    :param conll2003_dataset:
    :return: None
    """

    word_vocabulary = vocabulary["token_vocabulary"]
    ASSERT.assertEqual(13 + 2, word_vocabulary.size)

    label_vocabulary = vocabulary["label_vocabulary"]
    ASSERT.assertEqual(3, label_vocabulary.label_size)
def test_lattice_ner_demo_dataset(lattice_ner_demo_dataset):
    """
    测试 lattice ner demo dataset
    :param lattice_ner_demo_dataset: lattice ner demo dataset
    :return:
    """

    instance = lattice_ner_demo_dataset[0]

    tokens: List[Token] = instance["tokens"]

    sentence = "".join([t.text for t in tokens])
    expect_sentence = "陈元呼吁加强国际合作推动世界经济发展"
    expect_sequence_label = ["B-PER", "I-PER"] + ["O"] * 16

    ASSERT.assertEqual(expect_sentence, sentence)
    ASSERT.assertListEqual(expect_sequence_label, instance["sequence_label"])
Пример #5
0
def test_ner_model_collate(conll2003_dataset, vocabulary):
    """
    测试 ner model collate
    :param conll2003_dataset: conll2003 数据集
    :param vocabulary: 在 conftest.py 中的 vocabulary 返回结果, 字典
    :return: None
    """

    token_vocab = vocabulary["token_vocabulary"]
    label_vocab = vocabulary["label_vocabulary"]

    vocabulary_builder = VocabularyBuilder(
        is_training=False,
        dataset=None,
        vocabulary_collate=None,
        token_vocabulary_dir=None,
        label_vocabulary_dir=None,
        is_build_token_vocabulary=True,
        pretrained_word_embedding_loader=None)
    vocabulary_builder.token_vocabulary = token_vocab
    vocabulary_builder.label_vocabulary = label_vocab

    sequence_max_len = 5
    model_collate = NerModelCollate(vocabulary_builder=vocabulary_builder,
                                    sequence_max_len=sequence_max_len)

    data_loader = DataLoader(dataset=conll2003_dataset,
                             batch_size=2,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=model_collate)

    for model_inputs in data_loader:

        model_inputs: ModelInputs = model_inputs

        logging.info(f"model inputs: {json2str(model_inputs)}")

        ASSERT.assertEqual(2, model_inputs.batch_size)

        ASSERT.assertEqual((2, sequence_max_len), model_inputs.labels.size())

        tokens = model_inputs.model_inputs["tokens"]

        expect_tokens = [[2, 3, 4, 5, 6], [13, 14, 0, 0, 0]]
        ASSERT.assertListEqual(expect_tokens, tokens.tolist())

        mask = (tokens != token_vocab.padding_index).long()

        sequence_lengths = mask.sum(dim=-1).tolist()

        ASSERT.assertListEqual([5, 2], sequence_lengths)
Пример #6
0
def test_msra_dataset(msra_dataset):
    """
    测试 msra 数据集
    :param conll2003_dataset: 数据集
    :return: None
    """

    ASSERT.assertEqual(5, len(msra_dataset))

    instance2 = msra_dataset[2]

    ASSERT.assertEqual(22, len(instance2["tokens"]))

    expect_labels = ["O"] * 22
    expect_labels[6] = "B-LOC"
    expect_labels[7] = "I-LOC"

    ASSERT.assertListEqual(expect_labels, instance2["sequence_label"])

    instance4 = msra_dataset[4]
    expect_labels = [
        "O", "B-LOC", "I-LOC", "B-ORG", "I-ORG", "B-PER", "I-PER", "B-LOC"
    ]
    ASSERT.assertListEqual(expect_labels, instance4["sequence_label"])
def test_conll2003_dataset(conll2003_dataset):
    """
    测试 conll2003 数据集
    :param conll2003_dataset: 数据集
    :return: None
    """

    ASSERT.assertEqual(2, len(conll2003_dataset))

    instance0 = conll2003_dataset[0]

    ASSERT.assertEqual(11, len(instance0["tokens"]))

    instance1 = conll2003_dataset[1]

    expect_labels = ["B-LOC", "O"]

    ASSERT.assertListEqual(expect_labels, instance1["sequence_label"])
Пример #8
0
def test_flat_model_collate(lattice_ner_demo_dataset,
                            character_pretrained_embedding_loader,
                            gaz_pretrained_embedding_loader):
    """
    测试 flat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    characters = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    character_vocabulary = Vocabulary(tokens=characters,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
    character_vocabulary = PretrainedVocabulary(
        vocabulary=character_vocabulary,
        pretrained_word_embedding_loader=character_pretrained_embedding_loader)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    flat_vocabulary = FlatPretrainedVocabulary(
        character_pretrained_vocabulary=character_vocabulary,
        gaz_word_pretrained_vocabulary=gaz_vocabulary)

    flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary,
                                          gazetter=gazetter,
                                          label_vocabulary=label_vocabulary)

    model_inputs = flat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    sentence = "陈元呼吁加强国际合作推动世界经济发展"

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0)

    expect_tokens = [character
                     for character in sentence] + expect_squeeze_gaz_words_0

    tokens = metadata_0["tokens"]

    ASSERT.assertListEqual(expect_tokens, tokens)

    character_pos_begin = [index for index in range(len(sentence))]
    character_pos_end = [index for index in range(len(sentence))]

    squeeze_gaz_words_begin = list()
    squeeze_gaz_words_end = list()

    for squeeze_gaz_word in squeeze_gaz_words_0:
        index = sentence.find(squeeze_gaz_word)

        squeeze_gaz_words_begin.append(index)
        squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1)

    pos_begin = model_inputs.model_inputs["pos_begin"][0]
    pos_end = model_inputs.model_inputs["pos_end"][0]

    expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin
    expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin))
    expect_pos_begin = torch.tensor(expect_pos_begin)

    expect_pos_end = character_pos_end + squeeze_gaz_words_end
    expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end))
    expect_pos_end = torch.tensor(expect_pos_end)

    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin))
    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end))

    expect_character_length = len(sentence)
    expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0)

    character_length = model_inputs.model_inputs["sequence_length"][0]
    squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][
        0]

    ASSERT.assertEqual(expect_character_length, character_length.item())
    ASSERT.assertEqual(expect_squeeze_gaz_word_length,
                       squeeze_word_length.item())
Пример #9
0
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset,
                                  gaz_pretrained_embedding_loader):
    """
    测试 bilstm gat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    bilstm_gat_model_collate = BiLstmGATModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = bilstm_gat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    t_graph_0 = model_inputs.model_inputs["t_graph"][0]
    c_graph_0 = model_inputs.model_inputs["c_graph"][0]
    l_graph_0 = model_inputs.model_inputs["l_graph"][0]

    expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0))

    expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0))

    expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0))

    gaz_words_indices = model_inputs.model_inputs["gaz_words"]

    ASSERT.assertEqual((2, 11), gaz_words_indices.size())

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0)

    expect_squeeze_gaz_words_indices_0 = torch.tensor(
        [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0],
        dtype=torch.long)

    ASSERT.assertTrue(
        is_tensor_equal(expect_squeeze_gaz_words_indices_0,
                        gaz_words_indices[0]))
Пример #10
0
def test_gaz_model_collate(lattice_ner_demo_dataset,
                           gaz_pretrained_embedding_loader):
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    lattice_model_collate = LatticeModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = lattice_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"],
                          [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [],
                          ["发展"], []]

    gaz_words_0 = metadata_0["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0)

    gaz_list_0 = model_inputs.model_inputs["gaz_list"][0]

    expect_gaz_list_0 = list()

    for expect_gaz_word in expect_gaz_words_0:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_0.append([indices, lengthes])

        else:
            expect_gaz_list_0.append([])

    logging.debug(
        f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}"
    )
    ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0)

    tokens_0 = model_inputs.model_inputs["tokens"]
    ASSERT.assertEqual((2, 19), tokens_0.size())
    sequence_label_0 = model_inputs.labels
    ASSERT.assertEqual((2, 19), sequence_label_0.size())

    # 新华社华盛顿4月28日电(记者翟景升)
    expect_gaz_word_1 = [
        ["新华社", "新华"],  # 新
        ["华社"],  # 华
        ["社华"],  # 社
        ["华盛顿", "华盛"],  # 华
        ["盛顿"],  # 盛
        [],  # 顿
        [],  # 4
        [],  # 月
        [],  # 2
        [],  # 8
        [],  # 日
        [],  # 电
        [],  # (
        ["记者"],  # 记
        [],  # 者
        ["翟景升", "翟景"],  # 翟
        ["景升"],  # 景
        [],  # 升
        []
    ]  # )

    metadata_1 = model_inputs.model_inputs["metadata"][1]
    gaz_words_1 = metadata_1["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1)

    expect_gaz_list_1 = list()

    for expect_gaz_word in expect_gaz_word_1:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_1.append([indices, lengthes])

        else:
            expect_gaz_list_1.append([])

    gaz_list_1 = model_inputs.model_inputs["gaz_list"][1]
    ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)