def event_type_vocabulary(): event_types = [["A", "B", "C"], ["A", "B"], ["A"]] vocabulary = Vocabulary(tokens=event_types, padding="", unk="Negative", special_first=True) ASSERT.assertEqual(4, vocabulary.size) ASSERT.assertEqual(0, vocabulary.index(vocabulary.unk)) ASSERT.assertEqual(1, vocabulary.index("A")) ASSERT.assertEqual(2, vocabulary.index("B")) ASSERT.assertEqual(3, vocabulary.index("C")) return vocabulary
def test_speical_last(): batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=False, other_special_tokens=["<Start>", "<End>"], min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 8) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 3 + 1) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 3 + 2) ASSERT.assertEqual(vocabulary.index("<Start>"), 3 + 3) ASSERT.assertEqual(vocabulary.index("<End>"), 3 + 4)
def test_vocabulary_speical_first(): """ 测试 vocabulary speical first :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True, min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 6) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
def test_vocabulary(): """ :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding="", unk="", special_first=True, min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 4) ASSERT.assertTrue(not vocabulary.padding) ASSERT.assertTrue(not vocabulary.unk) ASSERT.assertEqual(vocabulary.index("我"), 0) ASSERT.assertEqual(vocabulary.index("和"), 1)
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset, gaz_pretrained_embedding_loader): """ 测试 bilstm gat model collate :return: """ # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) tokens = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) bilstm_gat_model_collate = BiLstmGATModelCollate( token_vocabulary=token_vocabulary, gazetter=gazetter, gaz_vocabulary=gaz_vocabulary, label_vocabulary=label_vocabulary) model_inputs = bilstm_gat_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) t_graph_0 = model_inputs.model_inputs["t_graph"][0] c_graph_0 = model_inputs.model_inputs["c_graph"][0] l_graph_0 = model_inputs.model_inputs["l_graph"][0] expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0)) expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0)) expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0)) gaz_words_indices = model_inputs.model_inputs["gaz_words"] ASSERT.assertEqual((2, 11), gaz_words_indices.size()) metadata_0 = model_inputs.model_inputs["metadata"][0] # 陈元呼吁加强国际合作推动世界经济发展 expect_squeeze_gaz_words_0 = [ "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展" ] sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"] ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0) expect_squeeze_gaz_words_indices_0 = torch.tensor( [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0], dtype=torch.long) ASSERT.assertTrue( is_tensor_equal(expect_squeeze_gaz_words_indices_0, gaz_words_indices[0]))
def test_save_and_load(): """ 测试存储和载入 vocabulary :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"], ["newline\nnewline"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True, other_special_tokens=["<Start>", "<End>"], min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1) ASSERT.assertEqual(vocabulary.index("<Start>"), 2) ASSERT.assertEqual(vocabulary.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk)) vocab_dir = os.path.join(ROOT_PATH, "data/easytext/tests") if not os.path.isdir(vocab_dir): os.makedirs(vocab_dir, exist_ok=True) vocabulary.save_to_file(vocab_dir) loaded_vocab = Vocabulary.from_file(directory=vocab_dir) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(loaded_vocab.padding, vocabulary.PADDING) ASSERT.assertEqual(loaded_vocab.unk, vocabulary.UNK) ASSERT.assertEqual(loaded_vocab.index(vocabulary.padding), 0) ASSERT.assertEqual(loaded_vocab.index(vocabulary.unk), 1) ASSERT.assertEqual(loaded_vocab.index("<Start>"), 2) ASSERT.assertEqual(loaded_vocab.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk))
def test_gaz_model_collate(lattice_ner_demo_dataset, gaz_pretrained_embedding_loader): # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) tokens = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) lattice_model_collate = LatticeModelCollate( token_vocabulary=token_vocabulary, gazetter=gazetter, gaz_vocabulary=gaz_vocabulary, label_vocabulary=label_vocabulary) model_inputs = lattice_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) metadata_0 = model_inputs.model_inputs["metadata"][0] # 陈元呼吁加强国际合作推动世界经济发展 expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"], [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [], ["发展"], []] gaz_words_0 = metadata_0["gaz_words"] ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0) gaz_list_0 = model_inputs.model_inputs["gaz_list"][0] expect_gaz_list_0 = list() for expect_gaz_word in expect_gaz_words_0: if len(expect_gaz_word) > 0: indices = [gaz_vocabulary.index(word) for word in expect_gaz_word] lengthes = [len(word) for word in expect_gaz_word] expect_gaz_list_0.append([indices, lengthes]) else: expect_gaz_list_0.append([]) logging.debug( f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}" ) ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0) tokens_0 = model_inputs.model_inputs["tokens"] ASSERT.assertEqual((2, 19), tokens_0.size()) sequence_label_0 = model_inputs.labels ASSERT.assertEqual((2, 19), sequence_label_0.size()) # 新华社华盛顿4月28日电(记者翟景升) expect_gaz_word_1 = [ ["新华社", "新华"], # 新 ["华社"], # 华 ["社华"], # 社 ["华盛顿", "华盛"], # 华 ["盛顿"], # 盛 [], # 顿 [], # 4 [], # 月 [], # 2 [], # 8 [], # 日 [], # 电 [], # ( ["记者"], # 记 [], # 者 ["翟景升", "翟景"], # 翟 ["景升"], # 景 [], # 升 [] ] # ) metadata_1 = model_inputs.model_inputs["metadata"][1] gaz_words_1 = metadata_1["gaz_words"] ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1) expect_gaz_list_1 = list() for expect_gaz_word in expect_gaz_word_1: if len(expect_gaz_word) > 0: indices = [gaz_vocabulary.index(word) for word in expect_gaz_word] lengthes = [len(word) for word in expect_gaz_word] expect_gaz_list_1.append([indices, lengthes]) else: expect_gaz_list_1.append([]) gaz_list_1 = model_inputs.model_inputs["gaz_list"][1] ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)