示例#1
0
def test_is_tensor_equal():
    """
    测试两个 tensor 是否相等
    :return:
    """

    x = torch.tensor([1, 2, 3])
    y = torch.tensor([1, 2, 3])

    equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=0)

    ASSERT.assertTrue(equal)

    x = torch.tensor([1, 2, 3])
    y = torch.tensor([2, 2, 3])

    equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=0)

    ASSERT.assertFalse(equal)

    x = torch.tensor([1.0001, 2.0001, 3.0001])
    y = torch.tensor([1., 2., 3.])

    equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=1e-3)

    ASSERT.assertTrue(equal)

    equal = tensor_util.is_tensor_equal(tensor1=x, tensor2=y, epsilon=1e-4)

    ASSERT.assertFalse(equal)
示例#2
0
def test_bert_model_collate(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset):
    max_length = 128

    bert_dir = "data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch"
    bert_dir = os.path.join(ROOT_PATH, bert_dir)

    bert_model_collate = BertModelCollate(tokenizer=bert_tokenizer(bert_dir),
                                          max_length=max_length)

    instances = [instance for instance in mrc_msra_ner_dataset]
    model_inputs = bert_model_collate(instances=instances)

    inputs = model_inputs.model_inputs

    paper_instances = [instance for instance in paper_mrc_msra_ner_dataset]
    paper_model_inputs = collate_to_max_length(paper_instances)

    paper_token_ids = paper_model_inputs[0]
    token_ids = inputs["input_ids"]

    ASSERT.assertTrue(is_tensor_equal(paper_token_ids, token_ids, epsilon=0))

    paper_type_ids = paper_model_inputs[1]
    type_ids = inputs["token_type_ids"]

    ASSERT.assertTrue(is_tensor_equal(paper_type_ids, type_ids, epsilon=0))

    paper_start_label_indices = paper_model_inputs[2]

    start_label_indices = model_inputs.labels["start_position_labels"]

    ASSERT.assertTrue(
        is_tensor_equal(paper_start_label_indices,
                        start_label_indices,
                        epsilon=0))

    paper_end_label_indices = paper_model_inputs[3]

    end_label_indices = model_inputs.labels["end_position_labels"]

    ASSERT.assertTrue(
        is_tensor_equal(paper_end_label_indices, end_label_indices, epsilon=0))

    paper_start_label_mask = paper_model_inputs[4]
    sequence_mask = inputs["sequence_mask"].long()

    ASSERT.assertTrue(
        is_tensor_equal(paper_start_label_mask, sequence_mask, epsilon=0))

    paper_end_label_mask = paper_model_inputs[5]
    sequence_mask = inputs["sequence_mask"].long()

    ASSERT.assertTrue(
        is_tensor_equal(paper_end_label_mask, sequence_mask, epsilon=0))

    paper_match_labels = paper_model_inputs[6]
    match_labels = model_inputs.labels["match_position_labels"]

    ASSERT.assertTrue(
        is_tensor_equal(paper_match_labels, match_labels, epsilon=0))
示例#3
0
def test_gat_without_hidden():
    """
    测试 gat
    :return:
    """

    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    in_features = 2
    out_features = 4

    gat = GAT(in_features=in_features,
              out_features=out_features,
              dropout=0.,
              alpha=0.1,
              num_heads=3,
              hidden_size=None)

    nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
                          [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]],
                         dtype=torch.float)

    adj = torch.tensor(
        [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]],
        dtype=torch.long)

    output_nodes: torch.Tensor = gat(nodes=nodes, adj=adj)

    expect_size = (nodes.size(0), nodes.size(1), out_features)
    ASSERT.assertEqual(expect_size, output_nodes.size())

    expect = torch.tensor([[[-1.6478, -0.3935, -2.6613, -2.7653],
                            [-1.3204, -0.8394, -1.8519, -1.9375],
                            [-1.6478, -0.3935, -2.6613, -2.7653]],
                           [[-1.9897, -0.4203, -2.4447, -2.1232],
                            [-2.1944, -0.1897, -3.4053, -3.5697],
                            [-2.9364, -0.0878, -4.1695, -4.1617]]],
                          dtype=torch.float)

    ASSERT.assertTrue(
        tensor_util.is_tensor_equal(expect, output_nodes, epsilon=1e-4))
示例#4
0
def test_gat_with_hidden():
    """
    测试 gat
    :return:
    """

    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    in_features = 2
    out_features = 4

    gat = GAT(in_features=in_features,
              out_features=out_features,
              dropout=0.,
              alpha=0.1,
              num_heads=3,
              hidden_size=3)

    nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
                          [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]],
                         dtype=torch.float)

    adj = torch.tensor(
        [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]],
        dtype=torch.long)

    output_nodes: torch.Tensor = gat(nodes=nodes, adj=adj)

    expect_size = (nodes.size(0), nodes.size(1), out_features)
    ASSERT.assertEqual(expect_size, output_nodes.size())

    expect = torch.tensor([[[-1.3835, -1.4764, -1.2033, -1.5113],
                            [-1.3316, -1.5785, -1.1564, -1.5368],
                            [-1.3475, -1.5467, -1.1706, -1.5279]],
                           [[-1.3388, -1.6693, -1.4427, -1.1610],
                            [-1.4288, -1.6525, -1.6607, -0.9707],
                            [-1.4320, -1.4422, -1.6465, -1.1025]]])

    ASSERT.assertTrue(
        tensor_util.is_tensor_equal(expect, output_nodes, epsilon=1e-4))
示例#5
0
def test_graph_attention_layer():
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    in_features = 2
    out_features = 4

    gat_layer = GraphAttentionLayer(in_features=in_features,
                                    out_features=out_features,
                                    dropout=0.0,
                                    alpha=0.1)

    nodes = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
                          [[0.7, 0.8], [0.9, 0.10], [0.11, 0.12]]],
                         dtype=torch.float)

    adj = torch.tensor(
        [[[0, 1, 0], [1, 0, 0], [0, 0, 0]], [[0, 1, 1], [1, 0, 1], [1, 1, 0]]],
        dtype=torch.long)

    outputs: torch.Tensor = gat_layer(input=nodes, adj=adj)

    expect_size = (nodes.size(0), nodes.size(1), out_features)

    ASSERT.assertEqual(expect_size, outputs.size())

    # 下面的 expect 是从原论文中测试得到的结果,直接拿来用
    expect = torch.tensor([[[0.2831, 0.3588, -0.5131, -0.2058],
                            [0.1606, 0.1292, -0.2264, -0.0951],
                            [0.2831, 0.3588, -0.5131, -0.2058]],
                           [[-0.0748, 0.5025, -0.3840, -0.1192],
                            [0.2959, 0.4624, -0.6123, -0.2405],
                            [0.1505, 0.8668, -0.8609, -0.3059]]],
                          dtype=torch.float)

    ASSERT.assertTrue(
        tensor_util.is_tensor_equal(expect, outputs, epsilon=1e-4))
示例#6
0
def test_flat_model_collate(lattice_ner_demo_dataset,
                            character_pretrained_embedding_loader,
                            gaz_pretrained_embedding_loader):
    """
    测试 flat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    characters = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    character_vocabulary = Vocabulary(tokens=characters,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
    character_vocabulary = PretrainedVocabulary(
        vocabulary=character_vocabulary,
        pretrained_word_embedding_loader=character_pretrained_embedding_loader)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    flat_vocabulary = FlatPretrainedVocabulary(
        character_pretrained_vocabulary=character_vocabulary,
        gaz_word_pretrained_vocabulary=gaz_vocabulary)

    flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary,
                                          gazetter=gazetter,
                                          label_vocabulary=label_vocabulary)

    model_inputs = flat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    sentence = "陈元呼吁加强国际合作推动世界经济发展"

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0)

    expect_tokens = [character
                     for character in sentence] + expect_squeeze_gaz_words_0

    tokens = metadata_0["tokens"]

    ASSERT.assertListEqual(expect_tokens, tokens)

    character_pos_begin = [index for index in range(len(sentence))]
    character_pos_end = [index for index in range(len(sentence))]

    squeeze_gaz_words_begin = list()
    squeeze_gaz_words_end = list()

    for squeeze_gaz_word in squeeze_gaz_words_0:
        index = sentence.find(squeeze_gaz_word)

        squeeze_gaz_words_begin.append(index)
        squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1)

    pos_begin = model_inputs.model_inputs["pos_begin"][0]
    pos_end = model_inputs.model_inputs["pos_end"][0]

    expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin
    expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin))
    expect_pos_begin = torch.tensor(expect_pos_begin)

    expect_pos_end = character_pos_end + squeeze_gaz_words_end
    expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end))
    expect_pos_end = torch.tensor(expect_pos_end)

    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin))
    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end))

    expect_character_length = len(sentence)
    expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0)

    character_length = model_inputs.model_inputs["sequence_length"][0]
    squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][
        0]

    ASSERT.assertEqual(expect_character_length, character_length.item())
    ASSERT.assertEqual(expect_squeeze_gaz_word_length,
                       squeeze_word_length.item())
示例#7
0
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset,
                                  gaz_pretrained_embedding_loader):
    """
    测试 bilstm gat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    bilstm_gat_model_collate = BiLstmGATModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = bilstm_gat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    t_graph_0 = model_inputs.model_inputs["t_graph"][0]
    c_graph_0 = model_inputs.model_inputs["c_graph"][0]
    l_graph_0 = model_inputs.model_inputs["l_graph"][0]

    expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0))

    expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0))

    expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0))

    gaz_words_indices = model_inputs.model_inputs["gaz_words"]

    ASSERT.assertEqual((2, 11), gaz_words_indices.size())

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0)

    expect_squeeze_gaz_words_indices_0 = torch.tensor(
        [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0],
        dtype=torch.long)

    ASSERT.assertTrue(
        is_tensor_equal(expect_squeeze_gaz_words_indices_0,
                        gaz_words_indices[0]))