class EntitySymbolTest(unittest.TestCase): def setUp(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" self.entity_symbols = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") def test_load_entites_keep_noncandidate(self): truealias2qids = { 'alias1': [["Q1", 10.0], ["Q4", 6]], 'multi word alias2': [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], 'alias3': [["Q1", 30.0]], 'alias4': [["Q4", 20], ["Q3", 15.0], ["Q2", 1]] } trueqid2title = { 'Q1': "alias1", 'Q2': "multi alias2", 'Q3': "word alias3", 'Q4': "nonalias4" } # the non-candidate class is included in entity_dump trueqid2eid = { 'Q1': 1, 'Q2': 2, 'Q3': 3, 'Q4': 4 } self.assertEqual(self.entity_symbols.max_candidates, 3) self.assertEqual(self.entity_symbols.max_alias_len, 3) self.assertDictEqual(self.entity_symbols._alias2qids, truealias2qids) self.assertDictEqual(self.entity_symbols._qid2title, trueqid2title) self.assertDictEqual(self.entity_symbols._qid2eid, trueqid2eid) def test_getters(self): self.assertEqual(self.entity_symbols.get_qid(1), 'Q1') self.assertSetEqual(set(self.entity_symbols.get_all_aliases()), {'alias1', 'multi word alias2', 'alias3', 'alias4'}) self.assertEqual(self.entity_symbols.get_eid('Q3'), 3) self.assertListEqual(self.entity_symbols.get_qid_cands('alias1'), ['Q1', 'Q4']) self.assertListEqual(self.entity_symbols.get_qid_cands('alias1', max_cand_pad=True), ['Q1', 'Q4', '-1']) self.assertListEqual(self.entity_symbols.get_eid_cands('alias1', max_cand_pad=True), [1, 4, -1]) self.assertEqual(self.entity_symbols.get_title('Q1'), 'alias1')
def test_getters(self): truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } trueqid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } entity_symbols = EntitySymbols( max_candidates=3, alias2qids=truealias2qids, qid2title=trueqid2title, ) self.assertEqual(entity_symbols.get_qid(1), "Q1") self.assertSetEqual( set(entity_symbols.get_all_aliases()), {"alias1", "multi word alias2", "alias3", "alias4"}, ) self.assertEqual(entity_symbols.get_eid("Q3"), 3) self.assertListEqual(entity_symbols.get_qid_cands("alias1"), ["Q1", "Q4"]) self.assertListEqual( entity_symbols.get_qid_cands("alias1", max_cand_pad=True), ["Q1", "Q4", "-1"], ) self.assertListEqual( entity_symbols.get_eid_cands("alias1", max_cand_pad=True), [1, 4, -1] ) self.assertEqual(entity_symbols.get_title("Q1"), "alias1") self.assertEqual(entity_symbols.get_alias_idx("alias1"), 0) self.assertEqual(entity_symbols.get_alias_from_idx(1), "alias3") self.assertEqual(entity_symbols.alias_exists("alias3"), True) self.assertEqual(entity_symbols.alias_exists("alias5"), False)