Beispiel #1
0
    def test_load_regularization(self):
        regularization_data = [
            ["typeid", "regularization"],
            [0, 0.4],
            [1, 0.8],
            [5, 0.2],
            [6, 0.1],
            [7, 0.6],
            [8, 0.7],
        ]
        with open(self.regularization_csv, "w") as out_f:
            for li in regularization_data:
                out_f.write(f"{','.join(map(str, li))}\n")

        num_types_with_pad_and_unk = 11
        type2row_dict = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
        typeid2reg = TypeEmb.load_regularization_mapping(
            self.args.data_config,
            num_types_with_pad_and_unk,
            type2row_dict,
            self.regularization_csv,
        )
        typeid2reg_gold = torch.tensor(
            [0.0, 0.4, 0.8, 0.0, 0.0, 0.0, 0.2, 0.1, 0.6, 0.7, 0.0]
        )
        assert torch.equal(typeid2reg_gold.float(), typeid2reg.float())
Beispiel #2
0
    def test_build_type_table_too_many_types(self):
        type_data = {"Q1": [1, 2, 3], "Q2": [4, 5, 6], "Q3": [], "Q4": [7, 8, 9]}
        type_vocab = {
            "T1": 1,
            "T2": 2,
            "T3": 3,
            "T4": 4,
            "T5": 5,
            "T6": 6,
            "T7": 7,
            "T8": 8,
            "T9": 9,
        }
        utils.dump_json_file(self.type_file, type_data)
        utils.dump_json_file(self.type_vocab_file, type_vocab)

        true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long()
        true_type2row = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
        pred_type_table, type2row, max_labels = TypeEmb.build_type_table(
            self.type_file,
            self.type_vocab_file,
            max_types=1,
            entity_symbols=self.entity_symbols,
        )
        assert torch.equal(pred_type_table, true_type_table)
        self.assertDictEqual(true_type2row, type2row)
        # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10
        assert max_labels == 10
Beispiel #3
0
 def test_build_type_table_too_many_types(self):
     true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long()
     true_type2row = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
     pred_type_table, type2row, max_labels = TypeEmb.build_type_table(
         self.type_labels, max_types=1, entity_symbols=self.entity_symbols)
     print(true_type_table, pred_type_table)
     assert torch.equal(pred_type_table, true_type_table)
     self.assertDictEqual(true_type2row, type2row)
     # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10
     assert max_labels == 10
Beispiel #4
0
 def test_build_type_table(self):
     true_type_table = torch.tensor([
         [0,0,0],
         [1,2,3],
         [4,5,6],
         [0,0,0],
         [7,8,9],
         [0,0,0]
     ]).long()
     true_type2row = np.array([1,2,3,4,5,6,7,8,9])
     pred_type_table, type2row, max_labels = TypeEmb.build_type_table(self.type_labels, max_types=3,
         entity_symbols=self.entity_symbols)
     assert torch.equal(pred_type_table, true_type_table)
     np.testing.assert_array_equal(true_type2row, type2row)
     # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10
     assert max_labels == 10