def test_is_tensor_equal(): """ 测试两个 tensor 是否相等 :return: """ x = torch.tensor([1, 2, 3]) y = torch.tensor([1, 2, 3]) equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=0) ASSERT.assertTrue(equal) x = torch.tensor([1, 2, 3]) y = torch.tensor([2, 2, 3]) equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=0) ASSERT.assertFalse(equal) x = torch.tensor([1.0001, 2.0001, 3.0001]) y = torch.tensor([1., 2., 3.]) equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=1e-3) ASSERT.assertTrue(equal) equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=1e-4) ASSERT.assertFalse(equal)
def test_bert_model_collate(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset): max_length = 128 bert_dir = "data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch" bert_dir = os.path.join(ROOT_PATH, bert_dir) bert_model_collate = BertModelCollate(tokenizer=bert_tokenizer(bert_dir), max_length=max_length) instances = [instance for instance in mrc_msra_ner_dataset] model_inputs = bert_model_collate(instances=instances) inputs = model_inputs.model_inputs paper_instances = [instance for instance in paper_mrc_msra_ner_dataset] paper_model_inputs = collate_to_max_length(paper_instances) paper_token_ids = paper_model_inputs[0] token_ids = inputs["input_ids"] ASSERT.assertTrue(is_tensor_equal(paper_token_ids, token_ids, epsilon=0)) paper_type_ids = paper_model_inputs[1] type_ids = inputs["token_type_ids"] ASSERT.assertTrue(is_tensor_equal(paper_type_ids, type_ids, epsilon=0)) paper_start_label_indices = paper_model_inputs[2] start_label_indices = model_inputs.labels["start_position_labels"] ASSERT.assertTrue( is_tensor_equal(paper_start_label_indices, start_label_indices, epsilon=0)) paper_end_label_indices = paper_model_inputs[3] end_label_indices = model_inputs.labels["end_position_labels"] ASSERT.assertTrue( is_tensor_equal(paper_end_label_indices, end_label_indices, epsilon=0)) paper_start_label_mask = paper_model_inputs[4] sequence_mask = inputs["sequence_mask"].long() ASSERT.assertTrue( is_tensor_equal(paper_start_label_mask, sequence_mask, epsilon=0)) paper_end_label_mask = paper_model_inputs[5] sequence_mask = inputs["sequence_mask"].long() ASSERT.assertTrue( is_tensor_equal(paper_end_label_mask, sequence_mask, epsilon=0)) paper_match_labels = paper_model_inputs[6] match_labels = model_inputs.labels["match_position_labels"] ASSERT.assertTrue( is_tensor_equal(paper_match_labels, match_labels, epsilon=0))
def test_gat_without_hidden(): """ 测试 gat :return: """ torch.manual_seed(7) torch.cuda.manual_seed_all(7) in_features = 2 out_features = 4 gat = GAT(in_features=in_features, out_features=out_features, dropout=0., alpha=0.1, num_heads=3, hidden_size=None) nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]], dtype=torch.float) adj = torch.tensor( [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]], dtype=torch.long) output_nodes: torch.Tensor = gat(nodes=nodes, adj=adj) expect_size = (nodes.size(0), nodes.size(1), out_features) ASSERT.assertEqual(expect_size, output_nodes.size()) expect = torch.tensor([[[-1.6478, -0.3935, -2.6613, -2.7653], [-1.3204, -0.8394, -1.8519, -1.9375], [-1.6478, -0.3935, -2.6613, -2.7653]], [[-1.9897, -0.4203, -2.4447, -2.1232], [-2.1944, -0.1897, -3.4053, -3.5697], [-2.9364, -0.0878, -4.1695, -4.1617]]], dtype=torch.float) ASSERT.assertTrue( tensor_util.is_tensor_equal(expect, output_nodes, epsilon=1e-4))
def test_gat_with_hidden(): """ 测试 gat :return: """ torch.manual_seed(7) torch.cuda.manual_seed_all(7) in_features = 2 out_features = 4 gat = GAT(in_features=in_features, out_features=out_features, dropout=0., alpha=0.1, num_heads=3, hidden_size=3) nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]], dtype=torch.float) adj = torch.tensor( [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]], dtype=torch.long) output_nodes: torch.Tensor = gat(nodes=nodes, adj=adj) expect_size = (nodes.size(0), nodes.size(1), out_features) ASSERT.assertEqual(expect_size, output_nodes.size()) expect = torch.tensor([[[-1.3835, -1.4764, -1.2033, -1.5113], [-1.3316, -1.5785, -1.1564, -1.5368], [-1.3475, -1.5467, -1.1706, -1.5279]], [[-1.3388, -1.6693, -1.4427, -1.1610], [-1.4288, -1.6525, -1.6607, -0.9707], [-1.4320, -1.4422, -1.6465, -1.1025]]]) ASSERT.assertTrue( tensor_util.is_tensor_equal(expect, output_nodes, epsilon=1e-4))
def test_graph_attention_layer(): torch.manual_seed(7) torch.cuda.manual_seed_all(7) in_features = 2 out_features = 4 gat_layer = GraphAttentionLayer(in_features=in_features, out_features=out_features, dropout=0.0, alpha=0.1) nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]], dtype=torch.float) adj = torch.tensor( [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]], dtype=torch.long) outputs: torch.Tensor = gat_layer(input=nodes, adj=adj) expect_size = (nodes.size(0), nodes.size(1), out_features) ASSERT.assertEqual(expect_size, outputs.size()) # 下面的 expect 是从原论文中测试得到的结果,直接拿来用 expect = torch.tensor([[[0.2831, 0.3588, -0.5131, -0.2058], [0.1606, 0.1292, -0.2264, -0.0951], [0.2831, 0.3588, -0.5131, -0.2058]], [[-0.0748, 0.5025, -0.3840, -0.1192], [0.2959, 0.4624, -0.6123, -0.2405], [0.1505, 0.8668, -0.8609, -0.3059]]], dtype=torch.float) ASSERT.assertTrue( tensor_util.is_tensor_equal(expect, outputs, epsilon=1e-4))
def test_flat_model_collate(lattice_ner_demo_dataset, character_pretrained_embedding_loader, gaz_pretrained_embedding_loader): """ 测试 flat model collate :return: """ # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) characters = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] character_vocabulary = Vocabulary(tokens=characters, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) character_vocabulary = PretrainedVocabulary( vocabulary=character_vocabulary, pretrained_word_embedding_loader=character_pretrained_embedding_loader) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) flat_vocabulary = FlatPretrainedVocabulary( character_pretrained_vocabulary=character_vocabulary, gaz_word_pretrained_vocabulary=gaz_vocabulary) flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary, gazetter=gazetter, label_vocabulary=label_vocabulary) model_inputs = flat_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) metadata_0 = model_inputs.model_inputs["metadata"][0] sentence = "陈元呼吁加强国际合作推动世界经济发展" # 陈元呼吁加强国际合作推动世界经济发展 expect_squeeze_gaz_words_0 = [ "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展" ] squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"] ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0) expect_tokens = [character for character in sentence] + expect_squeeze_gaz_words_0 tokens = metadata_0["tokens"] ASSERT.assertListEqual(expect_tokens, tokens) character_pos_begin = [index for index in range(len(sentence))] character_pos_end = [index for index in range(len(sentence))] squeeze_gaz_words_begin = list() squeeze_gaz_words_end = list() for squeeze_gaz_word in squeeze_gaz_words_0: index = sentence.find(squeeze_gaz_word) squeeze_gaz_words_begin.append(index) squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1) pos_begin = model_inputs.model_inputs["pos_begin"][0] pos_end = model_inputs.model_inputs["pos_end"][0] expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin)) expect_pos_begin = torch.tensor(expect_pos_begin) expect_pos_end = character_pos_end + squeeze_gaz_words_end expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end)) expect_pos_end = torch.tensor(expect_pos_end) ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin)) ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end)) expect_character_length = len(sentence) expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0) character_length = model_inputs.model_inputs["sequence_length"][0] squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][ 0] ASSERT.assertEqual(expect_character_length, character_length.item()) ASSERT.assertEqual(expect_squeeze_gaz_word_length, squeeze_word_length.item())
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset, gaz_pretrained_embedding_loader): """ 测试 bilstm gat model collate :return: """ # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) tokens = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) bilstm_gat_model_collate = BiLstmGATModelCollate( token_vocabulary=token_vocabulary, gazetter=gazetter, gaz_vocabulary=gaz_vocabulary, label_vocabulary=label_vocabulary) model_inputs = bilstm_gat_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) t_graph_0 = model_inputs.model_inputs["t_graph"][0] c_graph_0 = model_inputs.model_inputs["c_graph"][0] l_graph_0 = model_inputs.model_inputs["l_graph"][0] expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0)) expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0)) expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0)) gaz_words_indices = model_inputs.model_inputs["gaz_words"] ASSERT.assertEqual((2, 11), gaz_words_indices.size()) metadata_0 = model_inputs.model_inputs["metadata"][0] # 陈元呼吁加强国际合作推动世界经济发展 expect_squeeze_gaz_words_0 = [ "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展" ] sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"] ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0) expect_squeeze_gaz_words_indices_0 = torch.tensor( [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0], dtype=torch.long) ASSERT.assertTrue( is_tensor_equal(expect_squeeze_gaz_words_indices_0, gaz_words_indices[0]))