コード例 #1
0
ファイル: test_entity.py プロジェクト: syyunn/bootleg
class EntitySymbolTest(unittest.TestCase):
    def setUp(self):
        entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
        self.entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")

    def test_load_entites_keep_noncandidate(self):
        truealias2qids = {
                        'alias1': [["Q1", 10.0], ["Q4", 6]],
                        'multi word alias2': [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
                        'alias3': [["Q1", 30.0]],
                        'alias4': [["Q4", 20], ["Q3", 15.0], ["Q2", 1]]
                        }

        trueqid2title = {
                        'Q1': "alias1",
                        'Q2': "multi alias2",
                        'Q3': "word alias3",
                        'Q4': "nonalias4"
                        }

        # the non-candidate class is included in entity_dump
        trueqid2eid = {
                 'Q1': 1,
                 'Q2': 2,
                 'Q3': 3,
                 'Q4': 4
                 }
        self.assertEqual(self.entity_symbols.max_candidates, 3)
        self.assertEqual(self.entity_symbols.max_alias_len, 3)
        self.assertDictEqual(self.entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(self.entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(self.entity_symbols._qid2eid, trueqid2eid)

    def test_getters(self):
        self.assertEqual(self.entity_symbols.get_qid(1), 'Q1')
        self.assertSetEqual(set(self.entity_symbols.get_all_aliases()),
                            {'alias1', 'multi word alias2', 'alias3', 'alias4'})
        self.assertEqual(self.entity_symbols.get_eid('Q3'), 3)
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1'), ['Q1', 'Q4'])
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1', max_cand_pad=True), ['Q1', 'Q4', '-1'])
        self.assertListEqual(self.entity_symbols.get_eid_cands('alias1', max_cand_pad=True), [1, 4, -1])
        self.assertEqual(self.entity_symbols.get_title('Q1'), 'alias1')
コード例 #2
0
ファイル: test_entity.py プロジェクト: lorr1/bootleg
    def test_getters(self):
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        trueqid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
        )

        self.assertEqual(entity_symbols.get_qid(1), "Q1")
        self.assertSetEqual(
            set(entity_symbols.get_all_aliases()),
            {"alias1", "multi word alias2", "alias3", "alias4"},
        )
        self.assertEqual(entity_symbols.get_eid("Q3"), 3)
        self.assertListEqual(entity_symbols.get_qid_cands("alias1"), ["Q1", "Q4"])
        self.assertListEqual(
            entity_symbols.get_qid_cands("alias1", max_cand_pad=True),
            ["Q1", "Q4", "-1"],
        )
        self.assertListEqual(
            entity_symbols.get_eid_cands("alias1", max_cand_pad=True), [1, 4, -1]
        )
        self.assertEqual(entity_symbols.get_title("Q1"), "alias1")
        self.assertEqual(entity_symbols.get_alias_idx("alias1"), 0)
        self.assertEqual(entity_symbols.get_alias_from_idx(1), "alias3")
        self.assertEqual(entity_symbols.alias_exists("alias3"), True)
        self.assertEqual(entity_symbols.alias_exists("alias5"), False)