示例#1
0
    def test_construction_single(self):
        entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings"
        entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")
        fmt_types = {"trie1": "qid_cand_with_score"}
        max_values = {"trie1": 3}
        input_dicts = {"trie1": entity_symbols.get_alias2qids()}
        record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(),
                                 fmt_types=fmt_types, max_values=max_values)
        truealias2qids = {
                        'alias1': [["Q1", 10], ["Q4", 6]],
                        'multi word alias2': [["Q2", 5], ["Q1", 3], ["Q4", 2]],
                        'alias3': [["Q1", 30]],
                        'alias4': [["Q4", 20], ["Q3", 15], ["Q2", 1]]
                        }

        for al in truealias2qids:
            self.assertEqual(truealias2qids[al], record_trie.get_value("trie1", al))

        for al in truealias2qids:
            self.assertEqual([x[0] for x in truealias2qids[al]], record_trie.get_value("trie1", al, getter=lambda x: x[0]))

        for al in truealias2qids:
            self.assertEqual([x[1] for x in truealias2qids[al]], record_trie.get_value("trie1", al, getter=lambda x: x[1]))

        self.assertEqual(set(truealias2qids.keys()), set(record_trie.get_keys("trie1")))
示例#2
0
def main():
    args = parse_args()
    print(json.dumps(args, indent=4))
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    state_dict, model_state_dict = load_statedict(args.init_checkpoint)
    assert ENTITY_EMB_KEY in model_state_dict
    print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}")
    entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file)
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    print(f"Filtering embeddings")
    model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict)
    # Generate the new old2new_eid weight vector to save in model_state_dict
    oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand)
    # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry
    oldeid2topkeid[-1] = new_num_topk_entities+2-1
    for qid, new_eid in qid2topk_eid.items():
        old_eid = entity_db.get_eid(qid)
        oldeid2topkeid[old_eid] = new_eid
    assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
    assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row"
    model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid
    state_dict["model"] = model_state_dict
    print(f"Saving model at {args.save_checkpoint}")
    torch.save(state_dict, args.save_checkpoint)
    print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}")
    utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
示例#3
0
def main():
    args = parse_args()

    # compute the max alias len in alias2qids
    with open(args.alias2qids) as f:
        alias2qids = ujson.load(f)

    with open(args.qid2title) as f:
        qid2title = ujson.load(f)

    max_alias_len = -1
    for alias in alias2qids:
        assert alias.lower(
        ) == alias, f'bootleg assumes lowercase aliases in alias candidate maps: {alias}'
        # ensure only max_candidates per alias
        qids = sorted(alias2qids[alias],
                      key=lambda x: (x[1], x[0]),
                      reverse=True)
        alias2qids[alias] = qids[:args.max_candidates]
        # keep track of the maximum number of words in an alias
        max_alias_len = max(max_alias_len, len(alias.split()))

    entity_mappings = EntitySymbols(
        max_candidates=args.max_candidates,
        max_alias_len=max_alias_len,
        alias2qids=alias2qids,
        qid2title=qid2title,
        alias_cand_map_file=args.alias_cand_map_file)

    entity_mappings.dump(os.path.join(args.entity_dir, args.entity_map_dir))
    print('entity mappings exported.')
def main():
    args = parse_args()

    # compute the max alias len in alias2qids
    with open(args.alias2qids) as f:
        alias2qids = ujson.load(f)

    with open(args.qid2title) as f:
        qid2title = ujson.load(f)

    for alias in alias2qids:
        assert (
            alias.lower() == alias
        ), f"bootleg assumes lowercase aliases in alias candidate maps: {alias}"
        # ensure only max_candidates per alias
        qids = sorted(alias2qids[alias],
                      key=lambda x: (x[1], x[0]),
                      reverse=True)
        alias2qids[alias] = qids[:args.max_candidates]

    entity_mappings = EntitySymbols(
        max_candidates=args.max_candidates,
        alias2qids=alias2qids,
        qid2title=qid2title,
        alias_cand_map_file=args.alias_cand_map_file,
    )

    entity_mappings.save(os.path.join(args.entity_dir, args.entity_map_dir))
    print("entity mappings exported.")
示例#5
0
    def test_reidentify_entity(self):
        alias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }
        qid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }
        max_candidates = 3

        entity_symbols = EntitySymbols(
            max_candidates=max_candidates,
            alias2qids=alias2qids,
            qid2title=qid2title,
            edit_mode=True,
        )
        entity_symbols.reidentify_entity("Q1", "Q7")
        trueqid2aliases = {
            "Q7": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }
        truealias2qids = {
            "alias1": [["Q7", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q7", 3], ["Q4", 2]],
            "alias3": [["Q7", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }
        trueqid2eid = {"Q7": 1, "Q2": 2, "Q3": 3, "Q4": 4}
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}
        truemax_eid = 4
        truenum_entities = 4
        truenum_entities_with_pad_and_nocand = 6
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(
            entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()}
        )
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)
        self.assertEqual(entity_symbols.max_eid, truemax_eid)
        self.assertEqual(entity_symbols.num_entities, truenum_entities)
        self.assertEqual(
            entity_symbols.num_entities_with_pad_and_nocand,
            truenum_entities_with_pad_and_nocand,
        )
示例#6
0
    def __init__(self,
                 config_args,
                 device='cuda',
                 max_alias_len=6,
                 cand_map=None,
                 threshold=0.0):
        self.args = config_args
        self.device = device
        self.entity_db = EntitySymbols(
            os.path.join(self.args.data_config.entity_dir,
                         self.args.data_config.entity_map_dir),
            alias_cand_map_file=self.args.data_config.alias_cand_map)
        self.word_db = data_utils.load_wordsymbols(self.args.data_config,
                                                   is_writer=True,
                                                   distributed=False)
        self.model = self._load_model()
        self.max_alias_len = max_alias_len
        if cand_map is None:
            alias_map = self.entity_db._alias2qids
        else:
            alias_map = ujson.load(open(cand_map))
        self.all_aliases_trie = get_all_aliases(alias_map,
                                                logger=logging.getLogger())
        self.alias_table = AliasEntityTable(args=self.args,
                                            entity_symbols=self.entity_db)

        # minimum probability of prediction to return mention
        self.threshold = threshold

        # get batch_on_the_fly embeddings _and_ the batch_prep embeddings
        self.batch_on_the_fly_embs = {}
        for i, emb in enumerate(self.args.data_config.ent_embeddings):
            if 'batch_prep' in emb and emb['batch_prep'] is True:
                self.args.data_config.ent_embeddings[i][
                    'batch_on_the_fly'] = True
                del self.args.data_config.ent_embeddings[i]['batch_prep']
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=self.args,
                                         emb_args=emb['args'],
                                         entity_symbols=self.entity_db,
                                         model_device=None,
                                         word_symbols=None)
                except AttributeError as e:
                    print(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
示例#7
0
def main():
    args = parse_args()
    logging.info(json.dumps(args, indent=4))
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(save_dir,
                        train_file,
                        entity_symbols,
                        args.lower,
                        args.strip,
                        num_workers=args.num_workers)
    if not args.no_types:
        type_symbols = TypeSymbols(
            entity_symbols=entity_symbols,
            emb_dir=args.emb_dir,
            max_types=args.max_types,
            emb_file="hyena_type_emb.pkl",
            type_vocab_file="hyena_type_graph.vocab.pkl",
            type_file="hyena_types.txt")
        compute_type_occurrences(save_dir, "orig", entity_symbols,
                                 type_symbols.qid2typenames, train_file)
示例#8
0
    def test_filter_embs(self):
        entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings"
        entity_db = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")
        num_topk_entities = 2
        old2new_eid = {0:0,-1:-1,2:1,3:2}
        qid2topk_eid = {"Q1":2,"Q2":1,"Q3":2,"Q4":2}
        toes_eid = 2
        state_dict = {}
        state_dict["emb_layer.entity_embs.learned.learned_entity_embedding.weight"] = torch.Tensor([
            [1.0,2,3,4,5],
            [2,2,2,2,2],
            [3,3,3,3,3],
            [4,4,4,4,4],
            [5,5,5,5,5],
            [0,0,0,0,0]
        ])

        gold_state_dict = {}
        gold_state_dict["emb_layer.entity_embs.learned.learned_entity_embedding.weight"] = torch.Tensor([
            [1.0,2,3,4,5],
            [3,3,3,3,3],
            [4,4,4,4,4],
            [0,0,0,0,0]
        ])

        new_state_dict = filter_embs(num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, state_dict)
        for k in gold_state_dict:
            assert k in new_state_dict
            assert torch.equal(new_state_dict[k], gold_state_dict[k])
示例#9
0
 def setUp(self):
     self.args = parser_utils.get_full_config("test/run_args/test_model_training.json")
     train_utils.setup_train_heads_and_eval_slices(self.args)
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbols(os.path.join(
         self.args.data_config.entity_dir, self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map)
     slices = WikiSlices(
         args=self.args,
         use_weak_label=False,
         input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"),
         dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name(
             data_args=self.args.data_config, use_weak_label=True, split_name="slice_train")),
         is_writer=True,
         distributed=self.args.run_config.distributed,
         dataset_is_eval=False
     )
     self.data = WikiDataset(
         args=self.args,
         use_weak_label=False,
         input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"),
         dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name(
             data_args=self.args.data_config, use_weak_label=False, split_name="train")),
         is_writer=True,
         distributed=self.args.run_config.distributed,
         word_symbols=self.word_symbols,
         entity_symbols=self.entity_symbols,
         slice_dataset=slices,
         dataset_is_eval=False
     )
     self.trainer = Trainer(self.args, self.entity_symbols, self.word_symbols)
示例#10
0
    def test_load_and_save(self):
        entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings"
        entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")
        fmt_types = {"trie1": "qid_cand_with_score"}
        max_values = {"trie1": 3}
        input_dicts = {"trie1": entity_symbols.get_alias2qids()}
        record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(),
                                 fmt_types=fmt_types, max_values=max_values)

        record_trie.dump(save_dir=os.path.join(entity_dump_dir, "record_trie"))
        record_trie_loaded = RecordTrieCollection(load_dir=os.path.join(entity_dump_dir, "record_trie"))

        self.assertEqual(record_trie._fmt_types, record_trie_loaded._fmt_types)
        self.assertEqual(record_trie._max_values, record_trie_loaded._max_values)
        self.assertEqual(record_trie._stoi, record_trie_loaded._stoi)
        np.testing.assert_array_equal(record_trie._itos, record_trie_loaded._itos)
        self.assertEqual(record_trie._record_tris, record_trie_loaded._record_tris)
示例#11
0
    def load_from_jsonl(
        cls,
        profile_file,
        max_candidates=30,
        max_types=10,
        max_kg_connections=100,
        edit_mode=False,
    ):
        """Loads an entity profile from the raw jsonl file. Each line is a JSON
        object with entity metadata.

        Example object::

            {
                "entity_id": "C000",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            }

        Args:
            profile_file: file where jsonl data lives
            max_candidates: maximum entity candidates
            max_types: maximum types per entity
            max_kg_connections: maximum KG connections per entity
            edit_mode: edit mode

        Returns: entity profile object
        """
        qid2title, alias2qids, type_systems, qid2relations = cls._read_profile_file(
            profile_file
        )
        entity_symbols = EntitySymbols(
            alias2qids=alias2qids,
            qid2title=qid2title,
            max_candidates=max_candidates,
            edit_mode=edit_mode,
        )

        all_type_symbols = {
            ty_name: TypeSymbols(
                qid2typenames=type_map, max_types=max_types, edit_mode=edit_mode
            )
            for ty_name, type_map in type_systems.items()
        }
        kg_symbols = KGSymbols(
            qid2relations, max_connections=max_kg_connections, edit_mode=edit_mode
        )
        return cls(entity_symbols, all_type_symbols, kg_symbols, edit_mode)
示例#12
0
    def test_construction_double(self):
        truealias2qids = {
                        'alias1': [["Q1", 10], ["Q4", 6]],
                        'multi word alias2': [["Q2", 5], ["Q1", 3], ["Q4", 2]],
                        'alias3': [["Q1", 30]],
                        'alias4': [["Q4", 20], ["Q3", 15], ["Q2", 1]]
                        }
        truealias2typeids = {
                        'Q1': [1, 2, 3],
                        'Q2': [8],
                        'Q3': [4],
                        'Q4': [2,4]
                        }
        truerelations = {
                        'Q1': set(["Q2", "Q3"]),
                        'Q2': set(["Q1"]),
                        'Q3': set(["Q2"]),
                        'Q4': set(["Q1", "Q2", "Q3"])
                        }
        entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings"
        entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")
        fmt_types = {"trie1": "qid_cand_with_score", "trie2": "type_ids", "trie3": "kg_relations"}
        max_values = {"trie1": 3, "trie2": 3, "trie3": 3}
        input_dicts = {"trie1": entity_symbols.get_alias2qids(), "trie2": truealias2typeids, "trie3": truerelations}
        record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(),
                                 fmt_types=fmt_types, max_values=max_values)

        for al in truealias2qids:
            self.assertEqual(truealias2qids[al], record_trie.get_value("trie1", al))

        for qid in truealias2typeids:
            self.assertEqual(truealias2typeids[qid], record_trie.get_value("trie2", qid))

        for qid in truerelations:
            self.assertEqual(truerelations[qid], record_trie.get_value("trie3", qid))
示例#13
0
 def setUp(self):
     entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
     self.entity_symbols = EntitySymbols(entity_dump_dir, alias_cand_map_file="alias2qids.json")
     self.config = {
         'data_config':
             {'train_in_candidates': False,
              'entity_dir': 'test/data/entity_loader/entity_data',
              'entity_prep_dir': 'prep',
              'alias_cand_map': 'alias2qids.json',
              'max_aliases': 3,
              'data_dir': 'test/data/entity_loader',
              'overwrite_preprocessed_data': True},
         'run_config':
             {'distributed': False}
         }
示例#14
0
 def setUp(self):
     entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
     self.entity_symbols = EntitySymbols.load_from_cache(
         entity_dump_dir, alias_cand_map_file="alias2qids.json"
     )
     self.config = {
         "data_config": {
             "train_in_candidates": False,
             "entity_dir": "test/data/entity_loader/entity_data",
             "entity_prep_dir": "prep",
             "alias_cand_map": "alias2qids.json",
             "max_aliases": 3,
             "data_dir": "test/data/entity_loader",
             "overwrite_preprocessed_data": True,
         },
         "run_config": {"distributed": False},
     }
示例#15
0
    def test_filter_qids(self):
        entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings"
        entity_db = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")
        qid2count = {"Q1":10,"Q2":20,"Q3":2,"Q4":4}
        perc_emb_drop = 0.8

        gold_qid2topk_eid = {"Q1":2,"Q2":1,"Q3":2,"Q4":2}
        gold_old2new_eid = {0:0,-1:-1,2:1,3:2}
        gold_new_toes_eid = 2
        gold_num_topk_entities = 2

        qid2topk_eid, old2new_eid, new_toes_eid, num_topk_entities = filter_qids(perc_emb_drop, entity_db, qid2count)
        self.assertEqual(gold_qid2topk_eid, qid2topk_eid)
        self.assertEqual(gold_old2new_eid, old2new_eid)
        self.assertEqual(gold_new_toes_eid, new_toes_eid)
        self.assertEqual(gold_num_topk_entities, num_topk_entities)
示例#16
0
def write_data_labels_initializer(merged_entity_emb_file, merged_storage_type,
                                  sent_idx_map_file, train_in_candidates,
                                  dump_embs, data_config):
    global filt_emb_data_global
    filt_emb_data_global = np.memmap(merged_entity_emb_file,
                                     dtype=merged_storage_type,
                                     mode="r+")
    global sent_idx_map_global
    sent_idx_map_global = utils.load_single_item_trie(sent_idx_map_file)
    global train_in_candidates_global
    train_in_candidates_global = train_in_candidates
    global dump_embs_global
    dump_embs_global = dump_embs
    global entity_dump_global
    entity_dump_global = EntitySymbols(
        load_dir=os.path.join(data_config.entity_dir,
                              data_config.entity_map_dir),
        alias_cand_map_file=data_config.alias_cand_map)
示例#17
0
def main():
    args = parse_args()
    logging.info(json.dumps(vars(args), indent=4))
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(
        save_dir,
        train_file,
        entity_symbols,
        args.lower,
        args.strip,
        num_workers=args.num_workers,
    )
def main():
    args = parse_args()
    print(ujson.dumps(vars(args), indent=4))
    entity_symbols = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, args.entity_map_dir),
        alias_cand_map_file=args.alias_cand_map,
        alias_idx_file=args.alias_idx_map,
    )
    print("DO LOWERCASE IS", "uncased" in args.bert_model)
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case="uncased" in args.bert_model,
        cache_dir=args.word_model_cache,
    )
    model = BertModel.from_pretrained(
        args.bert_model,
        cache_dir=args.word_model_cache,
        output_attentions=False,
        output_hidden_states=False,
    )
    if not args.cpu:
        model = model.to("cuda")
    model.eval()

    entity2avgtitle = build_title_table(
        args.cpu, args.batch_size, model, tokenizer, entity_symbols
    )
    save_fold = os.path.dirname(args.save_file)
    if len(save_fold) > 0:
        if not os.path.exists(save_fold):
            os.makedirs(save_fold)
    if args.output_method == "pt":
        save_obj = (entity_symbols.get_qid2eid(), entity2avgtitle)
        torch.save(obj=save_obj, f=args.save_file)
    else:
        res = {}
        for qid in tqdm(entity_symbols.get_all_qids(), desc="Building final json"):
            eid = entity_symbols.get_eid(qid)
            res[qid] = entity2avgtitle[eid].tolist()
        with open(args.save_file, "w") as out_f:
            ujson.dump(res, out_f)
    print(f"Done!")
示例#19
0
def main():
    args = parse_args()
    print(ujson.dumps(vars(args), indent=4))
    num_processes = min(args.processes, int(0.8 * multiprocessing.cpu_count()))
    print("Loading entity symbols")
    entity_symbols = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, args.entity_map_dir),
        alias_cand_map_file=args.alias_cand_map,
        alias_idx_file=args.alias_idx_map,
    )

    in_file = os.path.join(args.data_dir, args.train_file)
    print(f"Getting slice counts from {in_file}")
    qid_cnts = get_counts(num_processes, in_file)
    with open(os.path.join(args.data_dir, "qid_cnts_train.json"), "w") as out_f:
        ujson.dump(qid_cnts, out_f)
    df = build_reg_csv(qid_cnts, entity_symbols)

    df.to_csv(args.out_file, index=False)
    print(f"Saved file to {args.out_file}")
示例#20
0
 def setUp(self):
     """ENTITY SYMBOLS
      {
        "multi word alias2":[["Q2",5.0],["Q1",3.0],["Q4",2.0]],
        "alias1":[["Q1",10.0],["Q4",6.0]],
        "alias3":[["Q1",30.0]],
        "alias4":[["Q4",20.0],["Q3",15.0],["Q2",1.0]]
      }
      EMBEDDINGS
      {
          "key": "learned",
          "freeze": false,
          "load_class": "LearnedEntityEmb",
          "args":
          {
            "learned_embedding_size": 10,
          }
      },
      {
         "key": "learned_type",
         "load_class": "LearnedTypeEmb",
         "freeze": false,
         "args": {
             "type_labels": "type_pred_mapping.json",
             "max_types": 1,
             "type_dim": 5,
             "merge_func": "addattn",
             "attn_hidden_size": 5
         }
     }
     """
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_model_training.json")
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
示例#21
0
class EntitySymbolTest(unittest.TestCase):
    def setUp(self):
        entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
        self.entity_symbols = EntitySymbols(load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json")

    def test_load_entites_keep_noncandidate(self):
        truealias2qids = {
                        'alias1': [["Q1", 10.0], ["Q4", 6]],
                        'multi word alias2': [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
                        'alias3': [["Q1", 30.0]],
                        'alias4': [["Q4", 20], ["Q3", 15.0], ["Q2", 1]]
                        }

        trueqid2title = {
                        'Q1': "alias1",
                        'Q2': "multi alias2",
                        'Q3': "word alias3",
                        'Q4': "nonalias4"
                        }

        # the non-candidate class is included in entity_dump
        trueqid2eid = {
                 'Q1': 1,
                 'Q2': 2,
                 'Q3': 3,
                 'Q4': 4
                 }
        self.assertEqual(self.entity_symbols.max_candidates, 3)
        self.assertEqual(self.entity_symbols.max_alias_len, 3)
        self.assertDictEqual(self.entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(self.entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(self.entity_symbols._qid2eid, trueqid2eid)

    def test_getters(self):
        self.assertEqual(self.entity_symbols.get_qid(1), 'Q1')
        self.assertSetEqual(set(self.entity_symbols.get_all_aliases()),
                            {'alias1', 'multi word alias2', 'alias3', 'alias4'})
        self.assertEqual(self.entity_symbols.get_eid('Q3'), 3)
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1'), ['Q1', 'Q4'])
        self.assertListEqual(self.entity_symbols.get_qid_cands('alias1', max_cand_pad=True), ['Q1', 'Q4', '-1'])
        self.assertListEqual(self.entity_symbols.get_eid_cands('alias1', max_cand_pad=True), [1, 4, -1])
        self.assertEqual(self.entity_symbols.get_title('Q1'), 'alias1')
示例#22
0
 def setUp(self):
     # tests that the sampling is done correctly on indices
     # load data from directory
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_type_data.json")
     self.tokenizer = BertTokenizer.from_pretrained(
         "bert-base-cased",
         do_lower_case=False,
         cache_dir="test/data/emb_data/pretrained_bert_models",
     )
     self.is_bert = True
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     self.temp_file_name = "test/data/data_loader/test_data.jsonl"
     self.guid_dtype = lambda max_aliases: np.dtype([
         ("sent_idx", "i8", 1),
         ("subsent_idx", "i8", 1),
         ("alias_orig_list_pos", "i8", max_aliases),
     ])
示例#23
0
class Annotator():
    """
    Annotator class: convenient wrapper of preprocessing and model eval to allow for
    annotating single sentences at a time for quick experimentation, e.g. in notebooks.
    """
    def __init__(self,
                 config_args,
                 device='cuda',
                 max_alias_len=6,
                 cand_map=None,
                 threshold=0.0):
        self.args = config_args
        self.device = device
        self.entity_db = EntitySymbols(
            os.path.join(self.args.data_config.entity_dir,
                         self.args.data_config.entity_map_dir),
            alias_cand_map_file=self.args.data_config.alias_cand_map)
        self.word_db = data_utils.load_wordsymbols(self.args.data_config,
                                                   is_writer=True,
                                                   distributed=False)
        self.model = self._load_model()
        self.max_alias_len = max_alias_len
        if cand_map is None:
            alias_map = self.entity_db._alias2qids
        else:
            alias_map = ujson.load(open(cand_map))
        self.all_aliases_trie = get_all_aliases(alias_map,
                                                logger=logging.getLogger())
        self.alias_table = AliasEntityTable(args=self.args,
                                            entity_symbols=self.entity_db)

        # minimum probability of prediction to return mention
        self.threshold = threshold

        # get batch_on_the_fly embeddings _and_ the batch_prep embeddings
        self.batch_on_the_fly_embs = {}
        for i, emb in enumerate(self.args.data_config.ent_embeddings):
            if 'batch_prep' in emb and emb['batch_prep'] is True:
                self.args.data_config.ent_embeddings[i][
                    'batch_on_the_fly'] = True
                del self.args.data_config.ent_embeddings[i]['batch_prep']
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=self.args,
                                         emb_args=emb['args'],
                                         entity_symbols=self.entity_db,
                                         model_device=None,
                                         word_symbols=None)
                except AttributeError as e:
                    print(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)

    def _load_model(self):
        model_state_dict = torch.load(
            self.args.run_config.init_checkpoint,
            map_location=lambda storage, loc: storage)['model']
        model = Model(args=self.args,
                      model_device=self.device,
                      entity_symbols=self.entity_db,
                      word_symbols=self.word_db).to(self.device)
        # remove distributed naming if it exists
        if not self.args.run_config.distributed:
            new_state_dict = OrderedDict()
            for k, v in model_state_dict.items():
                if 'module.' in k and k[:len('module.')] == 'module.':
                    name = k[len('module.'):]
                    new_state_dict[name] = v
            # we renamed all layers due to distributed naming
            if len(new_state_dict) == len(model_state_dict):
                model_state_dict = new_state_dict
            else:
                assert len(new_state_dict) == 0
        model.load_state_dict(model_state_dict, strict=True)
        # must set model in eval mode
        model.eval()
        return model

    def extract_mentions(self, text):
        found_aliases, found_spans = find_aliases_in_sentence_tag(
            text, self.all_aliases_trie, self.max_alias_len)
        return {
            'sentence': text,
            'aliases': found_aliases,
            'spans': found_spans,
            # we don't know the true QID
            'qids': ['Q-1' for i in range(len(found_aliases))],
            'gold': [True for i in range(len(found_aliases))]
        }

    def set_threshold(self, value):
        self.threshold = value

    def label_mentions(self, text):
        sample = self.extract_mentions(text)
        idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr = sentence_utils.split_sentence(
            max_aliases=self.args.data_config.max_aliases,
            phrase=sample['sentence'],
            spans=sample['spans'],
            aliases=sample['aliases'],
            aliases_seen_by_model=[i for i in range(len(sample['aliases']))],
            seq_len=self.args.data_config.max_word_token_len,
            word_symbols=self.word_db)
        aliases_arr = [[sample['aliases'][idx] for idx in idxs]
                       for idxs in idxs_arr]
        qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr]
        word_indices_arr = [
            self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr
        ]

        if len(idxs_arr) > 1:
            #TODO: support sentences that overflow due to long sequence length or too many mentions
            raise ValueError(
                'Overflowing sentences not currently supported in Annotator')

        # iterate over each sample in the split
        for sub_idx in range(len(idxs_arr)):
            example_aliases = np.ones(self.args.data_config.max_aliases,
                                      dtype=np.int) * PAD_ID
            example_true_entities = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_start = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_end = np.ones(
                self.args.data_config.max_aliases) * PAD_ID

            aliases = aliases_arr[sub_idx]
            for mention_idx, alias in enumerate(aliases):
                # get aliases
                alias_trie_idx = self.entity_db.get_alias_idx(alias)
                alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                example_aliases[mention_idx] = alias_trie_idx

                # alias_idx_pair
                span_idx = spans_arr[sub_idx][mention_idx]
                span_start_idx, span_end_idx = span_idx
                example_aliases_locs_start[mention_idx] = span_start_idx
                example_aliases_locs_end[mention_idx] = span_end_idx

            # get word indices
            word_indices = word_indices_arr[sub_idx]

            # entity indices from alias table (these are the candidates)
            entity_indices = self.alias_table(example_aliases)

            # all CPU embs have to retrieved on the fly
            batch_on_the_fly_data = {}
            for emb_name, emb in self.batch_on_the_fly_embs.items():
                batch_on_the_fly_data[emb_name] = torch.tensor(
                    emb.batch_prep(example_aliases, entity_indices),
                    device=self.device)

            outs, entity_pack, _ = self.model(
                alias_idx_pair_sent=[
                    torch.tensor(example_aliases_locs_start,
                                 device=self.device).unsqueeze(0),
                    torch.tensor(example_aliases_locs_end,
                                 device=self.device).unsqueeze(0)
                ],
                word_indices=torch.tensor([word_indices], device=self.device),
                alias_indices=torch.tensor(example_aliases,
                                           device=self.device).unsqueeze(0),
                entity_indices=torch.tensor(entity_indices,
                                            device=self.device).unsqueeze(0),
                batch_prepped_data={},
                batch_on_the_fly_data=batch_on_the_fly_data)

            entity_cands = eval_utils.map_aliases_to_candidates(
                self.args.data_config.train_in_candidates, self.entity_db,
                aliases)
            # recover predictions
            probs = torch.exp(
                eval_utils.masked_class_logsoftmax(
                    pred=outs[DISAMBIG][FINAL_LOSS],
                    mask=~entity_pack.mask,
                    dim=2))
            max_probs, max_probs_indices = probs.max(2)

            pred_cands = []
            pred_probs = []
            titles = []
            for alias_idx in range(len(aliases)):
                pred_idx = max_probs_indices[0][alias_idx]
                pred_prob = max_probs[0][alias_idx].item()
                pred_qid = entity_cands[alias_idx][pred_idx]
                if pred_prob > self.threshold:
                    pred_cands.append(pred_qid)
                    pred_probs.append(pred_prob)
                    titles.append(
                        self.entity_db.
                        get_title(pred_qid) if pred_qid != 'NC' else 'NC')

            return pred_cands, pred_probs, titles
示例#24
0
def create_task(args, entity_symbols=None, slice_datasets=None):
    """Creates a type prediction task.

    Args:
        args: args
        entity_symbols: entity symbols
        slice_datasets: slice datasets used in scorer (default None)

    Returns: EmmentalTask for type prediction
    """
    if entity_symbols is None:
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir=os.path.join(
                args.data_config.entity_dir, args.data_config.entity_map_dir
            ),
            alias_cand_map_file=args.data_config.alias_cand_map,
            alias_idx_file=args.data_config.alias_idx_map,
        )

    # Create sentence encoder
    bert_model = BertEncoder(
        args.data_config.word_embedding, output_size=args.model_config.hidden_size
    )

    # Create type prediction module
    # Add 1 for pad type
    type_prediction = TypePred(
        args.model_config.hidden_size,
        args.data_config.type_prediction.dim,
        args.data_config.type_prediction.num_types + 1,
        embedding_utils.get_max_candidates(entity_symbols, args.data_config),
    )

    # Create scorer
    sliced_scorer = BootlegSlicedScorer(
        args.data_config.train_in_candidates, slice_datasets
    )

    # Create module pool
    # BERT model will be shared across tasks as long as the name matches
    module_pool = nn.ModuleDict(
        {BERT_MODEL_NAME: bert_model, "type_prediction": type_prediction}
    )

    # Create task flow
    task_flow = [
        {
            "name": BERT_MODEL_NAME,
            "module": BERT_MODEL_NAME,
            "inputs": [
                ("_input_", "entity_cand_eid"),
                ("_input_", "token_ids"),
            ],  # We pass the entity_cand_eids to BERT in case of embeddings that require word information
        },
        {
            "name": "type_prediction",
            "module": "type_prediction",  # output: embedding_dict, batch_type_pred
            "inputs": [
                (BERT_MODEL_NAME, 0),  # sentence embedding
                ("_input_", "start_span_idx"),
            ],
        },
    ]

    return EmmentalTask(
        name=TYPE_PRED_TASK,
        module_pool=module_pool,
        task_flow=task_flow,
        loss_func=partial(type_loss, "type_prediction"),
        output_func=partial(type_output, "type_prediction"),
        require_prob_for_eval=False,
        require_pred_for_eval=True,
        scorer=Scorer(
            customize_metric_funcs={
                f"{TYPE_PRED_TASK}_scorer": sliced_scorer.type_pred_score
            }
        ),
    )
示例#25
0
    def load_from_cache(
        cls,
        load_dir,
        edit_mode=False,
        verbose=False,
        no_kg=False,
        no_type=False,
        type_systems_to_load=None,
    ):
        """Loaded a pre-saved profile.

        Args:
            load_dir: load directory
            edit_mode: edit mode flag, default False
            verbose: verbose flag, default False
            no_kg: load kg or not flag, default False
            no_type: load types or not flag, default False. If True, this will ignore type_systems_to_load.
            type_systems_to_load: list of type systems to load, default is None which means all types systems

        Returns: entity profile object
        """
        # Check type system input
        load_dir = Path(load_dir)
        type_subfolder = load_dir / TYPE_SUBFOLDER
        if type_systems_to_load is not None:
            if not isinstance(type_systems_to_load, list):
                raise ValueError(
                    f"`type_systems` must be a list of subfolders in {type_subfolder}"
                )
            for sys in type_systems_to_load:
                if sys not in list([p.name for p in type_subfolder.iterdir()]):
                    raise ValueError(
                        f"`type_systems` must be a list of subfolders in {type_subfolder}. {sys} is not one."
                    )

        if verbose:
            print("Loading Entity Symbols")
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir / ENTITY_SUBFOLDER,
            edit_mode=edit_mode,
            verbose=verbose,
        )
        if no_type:
            print(
                f"Not loading type information. We will act as if there is no types associated with any entity "
                f"and will not modify the types in any way, even if calling `add`."
            )
        type_sys_dict = {}
        for fold in type_subfolder.iterdir():
            if ((not no_type) and (type_systems_to_load is None
                                   or fold.name in type_systems_to_load)
                    and (fold.is_dir())):
                if verbose:
                    print(f"Loading Type Symbols from {fold}")
                type_sys_dict[fold.name] = TypeSymbols.load_from_cache(
                    type_subfolder / fold.name,
                    edit_mode=edit_mode,
                    verbose=verbose,
                )
        if verbose:
            print(f"Loading KG Symbols")
        if no_kg:
            print(
                f"Not loading KG information. We will act as if there is not KG connections between entities. "
                f"We will not modify the KG information in any way, even if calling `add`."
            )
        kg_symbols = None
        if not no_kg:
            kg_symbols = KGSymbols.load_from_cache(
                load_dir / KG_SUBFOLDER,
                edit_mode=edit_mode,
                verbose=verbose,
            )
        return cls(entity_symbols, type_sys_dict, kg_symbols, edit_mode,
                   verbose)
示例#26
0
文件: run.py 项目: paper2code/bootleg
def model_eval(args, mode, is_writer, logger, world_size=1, rank=0):
    assert args.run_config.init_checkpoint != "", "You can't have an empty model file to do eval"
    # this is in main but call again in case eval is called directly
    train_utils.setup_train_heads_and_eval_slices(args)
    train_utils.setup_run_folders(args, mode)

    word_symbols = data_utils.load_wordsymbols(
        args.data_config, is_writer, distributed=args.run_config.distributed)
    logger.info(f"Loading entity_symbols...")
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_config.entity_dir,
                              args.data_config.entity_map_dir),
        alias_cand_map_file=args.data_config.alias_cand_map)
    logger.info(
        f"Loaded entity_symbols with {entity_symbols.num_entities} entities.")
    eval_slice_names = args.run_config.eval_slices
    test_dataset_collection = {}
    test_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.test_dataset, is_writer, dataset_is_eval=True)
    test_dataset = data_utils.create_dataset(args,
                                             args.data_config.test_dataset,
                                             is_writer,
                                             word_symbols,
                                             entity_symbols,
                                             slice_dataset=test_slice_dataset,
                                             dataset_is_eval=True)
    test_dataloader, test_sampler = data_utils.create_dataloader(
        args,
        test_dataset,
        eval_slice_dataset=test_slice_dataset,
        batch_size=args.run_config.eval_batch_size)
    dataset_collection = DatasetCollection(args.data_config.test_dataset,
                                           args.data_config.test_dataset.file,
                                           test_dataset, test_dataloader,
                                           test_slice_dataset, test_sampler)
    test_dataset_collection[
        args.data_config.test_dataset.file] = dataset_collection

    trainer = Trainer(args,
                      entity_symbols,
                      word_symbols,
                      resume_model_file=args.run_config.init_checkpoint,
                      eval_slice_names=eval_slice_names,
                      model_eval=True)

    # Run evaluation numbers without dumping predictions (quick, batched)
    if mode == 'eval':
        status_reporter = StatusReporter(args,
                                         logger,
                                         is_writer,
                                         max_epochs=None,
                                         total_steps_per_epoch=None,
                                         is_eval=True)
        # results are written to json file
        for test_data_file in test_dataset_collection:
            logger.info(
                f"************************RUNNING EVAL {test_data_file}************************"
            )
            test_dataloader = test_dataset_collection[
                test_data_file].data_loader
            # True is for if the batch is test or not, None is for the global step
            eval_utils.run_batched_eval(args=args,
                                        is_test=True,
                                        global_step=None,
                                        logger=logger,
                                        trainer=trainer,
                                        dataloader=test_dataloader,
                                        status_reporter=status_reporter,
                                        file=test_data_file)

    elif mode == 'dump_preds' or mode == 'dump_embs':
        # get predictions and optionally dump the corresponding contextual entity embeddings
        # TODO: support dumping ids for other embeddings as well (static entity embeddings, type embeddings, relation embeddings)
        # TODO: remove collection abstraction
        for test_data_file in test_dataset_collection:
            logger.info(
                f"************************DUMPING PREDICTIONS FOR {test_data_file}************************"
            )
            test_dataloader = test_dataset_collection[
                test_data_file].data_loader
            pred_file, emb_file = eval_utils.run_dump_preds(
                args=args,
                entity_symbols=entity_symbols,
                test_data_file=test_data_file,
                logger=logger,
                trainer=trainer,
                dataloader=test_dataloader,
                dump_embs=(mode == 'dump_embs'))
            return pred_file, emb_file
    return
示例#27
0
文件: run.py 项目: paper2code/bootleg
def train(args, is_writer, logger, world_size, rank):
    # This is main but call again in case train is called directly
    train_utils.setup_train_heads_and_eval_slices(args)
    train_utils.setup_run_folders(args, "train")

    # Load word symbols (like tokenizers) and entity symbols (aka entity profiles)
    word_symbols = data_utils.load_wordsymbols(
        args.data_config, is_writer, distributed=args.run_config.distributed)
    logger.info(f"Loading entity_symbols...")
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_config.entity_dir,
                              args.data_config.entity_map_dir),
        alias_cand_map_file=args.data_config.alias_cand_map)
    logger.info(
        f"Loaded entity_symbols with {entity_symbols.num_entities} entities.")
    # Get train dataset
    train_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.train_dataset, is_writer, dataset_is_eval=False)
    train_dataset = data_utils.create_dataset(
        args,
        args.data_config.train_dataset,
        is_writer,
        word_symbols,
        entity_symbols,
        slice_dataset=train_slice_dataset,
        dataset_is_eval=False)
    train_dataloader, train_sampler = data_utils.create_dataloader(
        args,
        train_dataset,
        eval_slice_dataset=None,
        world_size=world_size,
        rank=rank,
        batch_size=args.train_config.batch_size)

    # Repeat for dev
    dev_dataset_collection = {}
    dev_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.dev_dataset, is_writer, dataset_is_eval=True)
    dev_dataset = data_utils.create_dataset(args,
                                            args.data_config.dev_dataset,
                                            is_writer,
                                            word_symbols,
                                            entity_symbols,
                                            slice_dataset=dev_slice_dataset,
                                            dataset_is_eval=True)
    dev_dataloader, dev_sampler = data_utils.create_dataloader(
        args,
        dev_dataset,
        eval_slice_dataset=dev_slice_dataset,
        batch_size=args.run_config.eval_batch_size)
    dataset_collection = DatasetCollection(args.data_config.dev_dataset,
                                           args.data_config.dev_dataset.file,
                                           dev_dataset, dev_dataloader,
                                           dev_slice_dataset, dev_sampler)
    dev_dataset_collection[
        args.data_config.dev_dataset.file] = dataset_collection

    eval_slice_names = args.run_config.eval_slices

    total_steps_per_epoch = len(train_dataloader)
    # Create trainer---model, optimizer, and scorer
    trainer = Trainer(args,
                      entity_symbols,
                      word_symbols,
                      total_steps_per_epoch=total_steps_per_epoch,
                      eval_slice_names=eval_slice_names,
                      resume_model_file=args.run_config.init_checkpoint)

    # Set up epochs and intervals for saving and evaluating
    max_epochs = int(args.run_config.max_epochs)
    eval_steps = int(args.run_config.eval_steps)
    log_steps = int(args.run_config.log_steps)
    save_steps = max(int(args.run_config.save_every_k_eval * eval_steps), 1)
    logger.info(
        f"Eval steps {eval_steps}, Log steps {log_steps}, Save steps {save_steps}, Total training examples per epoch {len(train_dataset)}"
    )
    status_reporter = StatusReporter(args,
                                     logger,
                                     is_writer,
                                     max_epochs,
                                     total_steps_per_epoch,
                                     is_eval=False)
    global_step = 0
    for epoch in range(trainer.start_epoch, trainer.start_epoch + max_epochs):
        # this is to fix having to save/restore the RNG state for checkpointing
        torch.manual_seed(args.train_config.seed + epoch)
        np.random.seed(args.train_config.seed + epoch)
        if args.run_config.distributed:
            # for determinism across runs https://github.com/pytorch/examples/issues/501
            train_sampler.set_epoch(epoch)

        start_time_load = time.time()
        for i, batch in enumerate(train_dataloader):
            load_time = time.time() - start_time_load
            start_time = time.time()
            _, loss_pack, _, _ = trainer.update(batch)
            # Log progress
            if (global_step + 1) % log_steps == 0:
                duration = time.time() - start_time
                status_reporter.step_status(epoch=epoch,
                                            step=global_step,
                                            loss_pack=loss_pack,
                                            time=duration,
                                            load_time=load_time,
                                            lr=trainer.get_lr())
            # Save model
            if (global_step + 1) % save_steps == 0 and is_writer:
                logger.info("Saving model...")
                trainer.save(save_dir=train_utils.get_save_folder(
                    args.run_config),
                             epoch=epoch,
                             step=global_step,
                             step_in_batch=i,
                             suffix=args.run_config.model_suffix)
            # Run evaluation
            if (global_step + 1) % eval_steps == 0:
                eval_utils.run_eval_all_dev_sets(args, global_step,
                                                 dev_dataset_collection,
                                                 logger, status_reporter,
                                                 trainer)
            if args.run_config.distributed:
                dist.barrier()
            global_step += 1
            # Time loading new batch
            start_time_load = time.time()
        ######### END OF EPOCH
        if is_writer:
            logger.info(f"Saving model end of epoch {epoch}...")
            trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                         epoch=epoch,
                         step=global_step,
                         step_in_batch=i,
                         end_of_epoch=True,
                         suffix=args.run_config.model_suffix)
        # Always run eval when saving -- if this coincided with eval_step, then don't need to rerun eval
        if (global_step + 1) % eval_steps != 0:
            eval_utils.run_eval_all_dev_sets(args, global_step,
                                             dev_dataset_collection, logger,
                                             status_reporter, trainer)
    if is_writer:
        logger.info("Saving model...")
        trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                     epoch=epoch,
                     step=global_step,
                     step_in_batch=i,
                     end_of_epoch=True,
                     last_epoch=True,
                     suffix=args.run_config.model_suffix)
    if args.run_config.distributed:
        dist.barrier()
示例#28
0
    def test_add_entity(self):
        alias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }
        qid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }
        max_candidates = 3

        entity_symbols = EntitySymbols(
            max_candidates=max_candidates,
            alias2qids=alias2qids,
            qid2title=qid2title,
            edit_mode=True,
        )

        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }
        trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4}
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}
        truemax_eid = 4
        truemax_alid = 3
        truenum_entities = 4
        truenum_entities_with_pad_and_nocand = 6
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(
            entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()}
        )
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)
        self.assertEqual(entity_symbols.max_eid, truemax_eid)
        self.assertEqual(entity_symbols.max_alid, truemax_alid)
        self.assertEqual(entity_symbols.num_entities, truenum_entities)
        self.assertEqual(
            entity_symbols.num_entities_with_pad_and_nocand,
            truenum_entities_with_pad_and_nocand,
        )

        # Add entity
        entity_symbols.add_entity(
            "Q5", [["multi word alias2", 1.5], ["alias5", 20.0]], "Snake"
        )
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "alias4"},
            "Q5": {"multi word alias2", "alias5"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [
                ["Q2", 5.0],
                ["Q1", 3],
                ["Q5", 1.5],
            ],  # adding new entity-mention pair - we override scores to add it. Hence Q4 is removed
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
            "alias5": [["Q5", 20]],
        }
        trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4, "Q5": 5}
        truealias2id = {
            "alias1": 0,
            "alias3": 1,
            "alias4": 2,
            "multi word alias2": 3,
            "alias5": 4,
        }
        trueid2alias = {
            0: "alias1",
            1: "alias3",
            2: "alias4",
            3: "multi word alias2",
            4: "alias5",
        }
        truemax_eid = 5
        truemax_alid = 4
        truenum_entities = 5
        truenum_entities_with_pad_and_nocand = 7
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(
            entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()}
        )
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)
        self.assertEqual(entity_symbols.max_eid, truemax_eid)
        self.assertEqual(entity_symbols.max_alid, truemax_alid)
        self.assertEqual(entity_symbols.num_entities, truenum_entities)
        self.assertEqual(
            entity_symbols.num_entities_with_pad_and_nocand,
            truenum_entities_with_pad_and_nocand,
        )
示例#29
0
    def test_add_remove_mention(self):
        alias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        qid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        max_candidates = 3

        # the non-candidate class is included in entity_dump
        trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4}
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        truealiastrie = {"multi word alias2": 0, "alias1": 1, "alias3": 2, "alias4": 3}
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        entity_symbols = EntitySymbols(
            max_candidates=max_candidates,
            alias2qids=alias2qids,
            qid2title=qid2title,
        )
        tri_as_dict = {}
        for k in entity_symbols._alias_trie:
            tri_as_dict[k] = entity_symbols._alias_trie[k]

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(entity_symbols._qid2title, qid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(tri_as_dict, truealiastrie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertIsNone(entity_symbols._qid2aliases)

        # Check if fails with edit_mode = False
        with self.assertRaises(AttributeError) as context:
            entity_symbols.add_mention("Q2", "alias3", 31.0)
        assert type(context.exception) is AttributeError

        entity_symbols = EntitySymbols(
            max_candidates=max_candidates,
            alias2qids=alias2qids,
            qid2title=qid2title,
            edit_mode=True,
        )

        # Check nothing changes if pair doesn't exist
        entity_symbols.remove_mention("Q3", "alias1")

        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2title, qid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)

        # ADD Q2 ALIAS 3
        entity_symbols.add_mention("Q2", "alias3", 31.0)
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4", "alias3"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # ADD Q1 ALIAS 4
        entity_symbols.add_mention("Q1", "alias4", 31.0)
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3", "alias4"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
            "alias4": [["Q1", 31.0], ["Q4", 20], ["Q3", 15.0]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # REMOVE Q3 ALIAS 4
        entity_symbols.remove_mention("Q3", "alias4")
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3", "alias4"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": set(),
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
            "alias4": [["Q1", 31.0], ["Q4", 20]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # REMOVE Q4 ALIAS 4
        entity_symbols.remove_mention("Q4", "alias4")
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3", "alias4"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": set(),
            "Q4": {"alias1", "multi word alias2"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
            "alias4": [["Q1", 31.0]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # REMOVE Q1 ALIAS 4
        entity_symbols.remove_mention("Q1", "alias4")
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": set(),
            "Q4": {"alias1", "multi word alias2"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3}
        trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # ADD Q1 BLIAS 0
        entity_symbols.add_mention("Q1", "blias0", 11)
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3", "blias0"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": set(),
            "Q4": {"alias1", "multi word alias2"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q2", 31.0], ["Q1", 30.0]],
            "blias0": [["Q1", 11.0]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3, "blias0": 4}
        trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2", 4: "blias0"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 4)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)

        # SET SCORE Q2 ALIAS3
        # Check if fails not a pair
        with self.assertRaises(ValueError) as context:
            entity_symbols.set_score("Q2", "alias1", 2)
        assert type(context.exception) is ValueError

        entity_symbols.set_score("Q2", "alias3", 2)
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3", "blias0"},
            "Q2": {"multi word alias2", "alias3"},
            "Q3": set(),
            "Q4": {"alias1", "multi word alias2"},
        }
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0], ["Q2", 2]],
            "blias0": [["Q1", 11.0]],
        }
        truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3, "blias0": 4}
        trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2", 4: "blias0"}

        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 4)
        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertIsNone(entity_symbols._alias_trie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertDictEqual(entity_symbols._id2alias, trueid2alias)
示例#30
0
    def test_getters(self):
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        trueqid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
        )

        self.assertEqual(entity_symbols.get_qid(1), "Q1")
        self.assertSetEqual(
            set(entity_symbols.get_all_aliases()),
            {"alias1", "multi word alias2", "alias3", "alias4"},
        )
        self.assertEqual(entity_symbols.get_eid("Q3"), 3)
        self.assertListEqual(entity_symbols.get_qid_cands("alias1"), ["Q1", "Q4"])
        self.assertListEqual(
            entity_symbols.get_qid_cands("alias1", max_cand_pad=True),
            ["Q1", "Q4", "-1"],
        )
        self.assertListEqual(
            entity_symbols.get_eid_cands("alias1", max_cand_pad=True), [1, 4, -1]
        )
        self.assertEqual(entity_symbols.get_title("Q1"), "alias1")
        self.assertEqual(entity_symbols.get_alias_idx("alias1"), 0)
        self.assertEqual(entity_symbols.get_alias_from_idx(1), "alias3")
        self.assertEqual(entity_symbols.alias_exists("alias3"), True)
        self.assertEqual(entity_symbols.alias_exists("alias5"), False)