コード例 #1
0
def main():
    args = parse_args()
    print(json.dumps(args, indent=4))
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    state_dict, model_state_dict = load_statedict(args.init_checkpoint)
    assert ENTITY_EMB_KEY in model_state_dict
    print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}")
    entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file)
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    print(f"Filtering embeddings")
    model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict)
    # Generate the new old2new_eid weight vector to save in model_state_dict
    oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand)
    # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry
    oldeid2topkeid[-1] = new_num_topk_entities+2-1
    for qid, new_eid in qid2topk_eid.items():
        old_eid = entity_db.get_eid(qid)
        oldeid2topkeid[old_eid] = new_eid
    assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
    assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row"
    model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid
    state_dict["model"] = model_state_dict
    print(f"Saving model at {args.save_checkpoint}")
    torch.save(state_dict, args.save_checkpoint)
    print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}")
    utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
コード例 #2
0
ファイル: test_entity.py プロジェクト: syyunn/bootleg
class EntitySymbolTest(unittest.TestCase):
    def setUp(self):
        entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
        self.entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")

    def test_load_entites_keep_noncandidate(self):
        truealias2qids = {
                        'alias1': [["Q1", 10.0], ["Q4", 6]],
                        'multi word alias2': [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
                        'alias3': [["Q1", 30.0]],
                        'alias4': [["Q4", 20], ["Q3", 15.0], ["Q2", 1]]
                        }

        trueqid2title = {
                        'Q1': "alias1",
                        'Q2': "multi alias2",
                        'Q3': "word alias3",
                        'Q4': "nonalias4"
                        }

        # the non-candidate class is included in entity_dump
        trueqid2eid = {
                 'Q1': 1,
                 'Q2': 2,
                 'Q3': 3,
                 'Q4': 4
                 }
        self.assertEqual(self.entity_symbols.max_candidates, 3)
        self.assertEqual(self.entity_symbols.max_alias_len, 3)
        self.assertDictEqual(self.entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(self.entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(self.entity_symbols._qid2eid, trueqid2eid)

    def test_getters(self):
        self.assertEqual(self.entity_symbols.get_qid(1), 'Q1')
        self.assertSetEqual(set(self.entity_symbols.get_all_aliases()),
                            {'alias1', 'multi word alias2', 'alias3', 'alias4'})
        self.assertEqual(self.entity_symbols.get_eid('Q3'), 3)
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1'), ['Q1', 'Q4'])
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1', max_cand_pad=True), ['Q1', 'Q4', '-1'])
        self.assertListEqual(self.entity_symbols.get_eid_cands('alias1', max_cand_pad=True), [1, 4, -1])
        self.assertEqual(self.entity_symbols.get_title('Q1'), 'alias1')
コード例 #3
0
ファイル: test_entity.py プロジェクト: lorr1/bootleg
    def test_getters(self):
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        trueqid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
        )

        self.assertEqual(entity_symbols.get_qid(1), "Q1")
        self.assertSetEqual(
            set(entity_symbols.get_all_aliases()),
            {"alias1", "multi word alias2", "alias3", "alias4"},
        )
        self.assertEqual(entity_symbols.get_eid("Q3"), 3)
        self.assertListEqual(entity_symbols.get_qid_cands("alias1"), ["Q1", "Q4"])
        self.assertListEqual(
            entity_symbols.get_qid_cands("alias1", max_cand_pad=True),
            ["Q1", "Q4", "-1"],
        )
        self.assertListEqual(
            entity_symbols.get_eid_cands("alias1", max_cand_pad=True), [1, 4, -1]
        )
        self.assertEqual(entity_symbols.get_title("Q1"), "alias1")
        self.assertEqual(entity_symbols.get_alias_idx("alias1"), 0)
        self.assertEqual(entity_symbols.get_alias_from_idx(1), "alias3")
        self.assertEqual(entity_symbols.alias_exists("alias3"), True)
        self.assertEqual(entity_symbols.alias_exists("alias5"), False)