Beispiel #1
0
def event_type_vocabulary():
    event_types = [["A", "B", "C"], ["A", "B"], ["A"]]

    vocabulary = Vocabulary(tokens=event_types,
                            padding="",
                            unk="Negative",
                            special_first=True)

    ASSERT.assertEqual(4, vocabulary.size)
    ASSERT.assertEqual(0, vocabulary.index(vocabulary.unk))
    ASSERT.assertEqual(1, vocabulary.index("A"))
    ASSERT.assertEqual(2, vocabulary.index("B"))
    ASSERT.assertEqual(3, vocabulary.index("C"))

    return vocabulary
Beispiel #2
0
def test_speical_last():
    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=False,
                            other_special_tokens=["<Start>", "<End>"],
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 8)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 3 + 1)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 3 + 2)
    ASSERT.assertEqual(vocabulary.index("<Start>"), 3 + 3)
    ASSERT.assertEqual(vocabulary.index("<End>"), 3 + 4)
Beispiel #3
0
def test_vocabulary_speical_first():
    """
    测试 vocabulary speical first
    :return:
    """
    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=True,
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 6)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
Beispiel #4
0
def test_vocabulary():
    """

    :return:
    """

    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding="",
                            unk="",
                            special_first=True,
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 4)

    ASSERT.assertTrue(not vocabulary.padding)
    ASSERT.assertTrue(not vocabulary.unk)

    ASSERT.assertEqual(vocabulary.index("我"), 0)
    ASSERT.assertEqual(vocabulary.index("和"), 1)
Beispiel #5
0
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset,
                                  gaz_pretrained_embedding_loader):
    """
    测试 bilstm gat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    bilstm_gat_model_collate = BiLstmGATModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = bilstm_gat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    t_graph_0 = model_inputs.model_inputs["t_graph"][0]
    c_graph_0 = model_inputs.model_inputs["c_graph"][0]
    l_graph_0 = model_inputs.model_inputs["l_graph"][0]

    expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0))

    expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0))

    expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0))

    gaz_words_indices = model_inputs.model_inputs["gaz_words"]

    ASSERT.assertEqual((2, 11), gaz_words_indices.size())

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0)

    expect_squeeze_gaz_words_indices_0 = torch.tensor(
        [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0],
        dtype=torch.long)

    ASSERT.assertTrue(
        is_tensor_equal(expect_squeeze_gaz_words_indices_0,
                        gaz_words_indices[0]))
Beispiel #6
0
def test_save_and_load():
    """
    测试存储和载入 vocabulary
    :return:
    """
    batch_tokens = [["我", "和", "你"], ["在", "我"], ["newline\nnewline"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=True,
                            other_special_tokens=["<Start>", "<End>"],
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 9)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
    ASSERT.assertEqual(vocabulary.index("<Start>"), 2)
    ASSERT.assertEqual(vocabulary.index("<End>"), 3)
    ASSERT.assertEqual(vocabulary.index("我"), 4)
    ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8)
    ASSERT.assertEqual(vocabulary.index("哈哈"),
                       vocabulary.index(vocabulary.unk))

    vocab_dir = os.path.join(ROOT_PATH, "data/easytext/tests")

    if not os.path.isdir(vocab_dir):
        os.makedirs(vocab_dir, exist_ok=True)

    vocabulary.save_to_file(vocab_dir)

    loaded_vocab = Vocabulary.from_file(directory=vocab_dir)

    ASSERT.assertEqual(vocabulary.size, 9)

    ASSERT.assertEqual(loaded_vocab.padding, vocabulary.PADDING)
    ASSERT.assertEqual(loaded_vocab.unk, vocabulary.UNK)
    ASSERT.assertEqual(loaded_vocab.index(vocabulary.padding), 0)
    ASSERT.assertEqual(loaded_vocab.index(vocabulary.unk), 1)
    ASSERT.assertEqual(loaded_vocab.index("<Start>"), 2)
    ASSERT.assertEqual(loaded_vocab.index("<End>"), 3)
    ASSERT.assertEqual(vocabulary.index("我"), 4)
    ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8)
    ASSERT.assertEqual(vocabulary.index("哈哈"),
                       vocabulary.index(vocabulary.unk))
def test_gaz_model_collate(lattice_ner_demo_dataset,
                           gaz_pretrained_embedding_loader):
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    lattice_model_collate = LatticeModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = lattice_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"],
                          [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [],
                          ["发展"], []]

    gaz_words_0 = metadata_0["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0)

    gaz_list_0 = model_inputs.model_inputs["gaz_list"][0]

    expect_gaz_list_0 = list()

    for expect_gaz_word in expect_gaz_words_0:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_0.append([indices, lengthes])

        else:
            expect_gaz_list_0.append([])

    logging.debug(
        f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}"
    )
    ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0)

    tokens_0 = model_inputs.model_inputs["tokens"]
    ASSERT.assertEqual((2, 19), tokens_0.size())
    sequence_label_0 = model_inputs.labels
    ASSERT.assertEqual((2, 19), sequence_label_0.size())

    # 新华社华盛顿4月28日电(记者翟景升)
    expect_gaz_word_1 = [
        ["新华社", "新华"],  # 新
        ["华社"],  # 华
        ["社华"],  # 社
        ["华盛顿", "华盛"],  # 华
        ["盛顿"],  # 盛
        [],  # 顿
        [],  # 4
        [],  # 月
        [],  # 2
        [],  # 8
        [],  # 日
        [],  # 电
        [],  # (
        ["记者"],  # 记
        [],  # 者
        ["翟景升", "翟景"],  # 翟
        ["景升"],  # 景
        [],  # 升
        []
    ]  # )

    metadata_1 = model_inputs.model_inputs["metadata"][1]
    gaz_words_1 = metadata_1["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1)

    expect_gaz_list_1 = list()

    for expect_gaz_word in expect_gaz_word_1:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_1.append([indices, lengthes])

        else:
            expect_gaz_list_1.append([])

    gaz_list_1 = model_inputs.model_inputs["gaz_list"][1]
    ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)