def test_load_regularization(self): regularization_data = [ ["typeid", "regularization"], [0, 0.4], [1, 0.8], [5, 0.2], [6, 0.1], [7, 0.6], [8, 0.7], ] with open(self.regularization_csv, "w") as out_f: for li in regularization_data: out_f.write(f"{','.join(map(str, li))}\n") num_types_with_pad_and_unk = 11 type2row_dict = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9} typeid2reg = TypeEmb.load_regularization_mapping( self.args.data_config, num_types_with_pad_and_unk, type2row_dict, self.regularization_csv, ) typeid2reg_gold = torch.tensor( [0.0, 0.4, 0.8, 0.0, 0.0, 0.0, 0.2, 0.1, 0.6, 0.7, 0.0] ) assert torch.equal(typeid2reg_gold.float(), typeid2reg.float())
def test_build_type_table_too_many_types(self): type_data = {"Q1": [1, 2, 3], "Q2": [4, 5, 6], "Q3": [], "Q4": [7, 8, 9]} type_vocab = { "T1": 1, "T2": 2, "T3": 3, "T4": 4, "T5": 5, "T6": 6, "T7": 7, "T8": 8, "T9": 9, } utils.dump_json_file(self.type_file, type_data) utils.dump_json_file(self.type_vocab_file, type_vocab) true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long() true_type2row = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9} pred_type_table, type2row, max_labels = TypeEmb.build_type_table( self.type_file, self.type_vocab_file, max_types=1, entity_symbols=self.entity_symbols, ) assert torch.equal(pred_type_table, true_type_table) self.assertDictEqual(true_type2row, type2row) # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10 assert max_labels == 10
def test_build_type_table_too_many_types(self): true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long() true_type2row = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9} pred_type_table, type2row, max_labels = TypeEmb.build_type_table( self.type_labels, max_types=1, entity_symbols=self.entity_symbols) print(true_type_table, pred_type_table) assert torch.equal(pred_type_table, true_type_table) self.assertDictEqual(true_type2row, type2row) # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10 assert max_labels == 10
def test_build_type_table(self): true_type_table = torch.tensor([ [0,0,0], [1,2,3], [4,5,6], [0,0,0], [7,8,9], [0,0,0] ]).long() true_type2row = np.array([1,2,3,4,5,6,7,8,9]) pred_type_table, type2row, max_labels = TypeEmb.build_type_table(self.type_labels, max_types=3, entity_symbols=self.entity_symbols) assert torch.equal(pred_type_table, true_type_table) np.testing.assert_array_equal(true_type2row, type2row) # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10 assert max_labels == 10