예제 #1
0
def main():
    args = parse_args()

    # compute the max alias len in alias2qids
    with open(args.alias2qids) as f:
        alias2qids = ujson.load(f)

    with open(args.qid2title) as f:
        qid2title = ujson.load(f)

    for alias in alias2qids:
        assert (
            alias.lower() == alias
        ), f"bootleg assumes lowercase aliases in alias candidate maps: {alias}"
        # ensure only max_candidates per alias
        qids = sorted(alias2qids[alias],
                      key=lambda x: (x[1], x[0]),
                      reverse=True)
        alias2qids[alias] = qids[:args.max_candidates]

    entity_mappings = EntitySymbols(
        max_candidates=args.max_candidates,
        alias2qids=alias2qids,
        qid2title=qid2title,
        alias_cand_map_file=args.alias_cand_map_file,
    )

    entity_mappings.save(os.path.join(args.entity_dir, args.entity_map_dir))
    print("entity mappings exported.")
예제 #2
0
    def test_create_entities(self):
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        trueqid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        # the non-candidate class is included in entity_dump
        trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4}
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        truealiastrie = {"multi word alias2": 0, "alias1": 1, "alias3": 2, "alias4": 3}

        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
        )
        tri_as_dict = {}
        for k in entity_symbols._alias_trie:
            tri_as_dict[k] = entity_symbols._alias_trie[k]

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(tri_as_dict, truealiastrie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertIsNone(entity_symbols._qid2aliases)

        # Test load from dump
        temp_save_dir = "test/data/entity_loader_test"
        entity_symbols.save(temp_save_dir)
        entity_symbols = EntitySymbols.load_from_cache(temp_save_dir)

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(tri_as_dict, truealiastrie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertIsNone(entity_symbols._qid2aliases)
        shutil.rmtree(temp_save_dir)

        # Test edit mode
        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
            edit_mode=True,
        )
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }

        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
예제 #3
0
def main():
    gl_start = time.time()
    multiprocessing.set_start_method("spawn")
    args = get_arg_parser().parse_args()
    print(json.dumps(vars(args), indent=4))
    utils.ensure_dir(args.data_dir)

    out_dir = os.path.join(args.data_dir, args.out_subdir)
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir, exist_ok=True)
    # Reading in files
    in_files_train = glob.glob(os.path.join(args.data_dir, "*.jsonl"))
    in_files_cand = glob.glob(
        os.path.join(args.contextual_cand_data, "*.jsonl"))
    assert len(in_files_train
               ) > 0, f"We didn't find any train files at {args.data_dir}"
    assert (
        len(in_files_cand) > 0
    ), f"We didn't find any contextual files at {args.contextual_cand_data}"
    in_files = []
    for file in in_files_train:
        file_name = os.path.basename(file)
        tag = os.path.splitext(file_name)[0]
        is_train = "train" in tag
        if is_train:
            print(
                f"{file_name} is a training dataset...will be processed as such"
            )
        pair = None
        for f in in_files_cand:
            if tag in f:
                pair = f
                break
        assert pair is not None, f"{file_name} name, {tag} tag"
        out_file = os.path.join(out_dir, file_name)
        in_files.append([file, pair, out_file, is_train])
    final_cand_map = {}
    max_cands = 0
    for pair in in_files:
        print(
            f"Reading in {pair[0]} with cand maps {pair[1]} and dumping to {pair[2]}"
        )
        new_alias2qids = merge_data(
            args.processes,
            args.train_in_candidates,
            args.keep_orig,
            args.max_candidates,
            pair,
            args.entity_dump,
        )
        for al in new_alias2qids:
            assert al not in final_cand_map, f"{al} is already in final_cand_map"
            final_cand_map[al] = new_alias2qids[al]
            max_cands = max(max_cands, len(final_cand_map[al]))

    print(f"Buidling new entity symbols")
    entity_dump = EntitySymbols.load_from_cache(load_dir=args.entity_dump)
    entity_dump_new = EntitySymbols(
        max_candidates=max_cands,
        alias2qids=final_cand_map,
        qid2title=entity_dump.get_qid2title(),
    )
    out_dir = os.path.join(out_dir, "entity_db/entity_mappings")
    entity_dump_new.save(out_dir)
    print(f"Finished in {time.time() - gl_start}s")