def setUp(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" self.entity_symbols = EntitySymbols.load_from_cache( entity_dump_dir, alias_cand_map_file="alias2qids.json" ) self.config = { "data_config": { "train_in_candidates": False, "entity_dir": "test/data/entity_loader/entity_data", "entity_prep_dir": "prep", "alias_cand_map": "alias2qids.json", "max_aliases": 3, "data_dir": "test/data/entity_loader", "overwrite_preprocessed_data": True, }, "run_config": {"distributed": False}, }
def main(): args = parse_args() logging.info(json.dumps(vars(args), indent=4)) entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)) train_file = os.path.join(args.data_dir, args.train_file) save_dir = os.path.join(args.save_dir, "stats") logging.info(f"Will save data to {save_dir}") utils.ensure_dir(save_dir) # compute_histograms(save_dir, entity_symbols) compute_occurrences( save_dir, train_file, entity_symbols, args.lower, args.strip, num_workers=args.num_workers, )
def main(): args = parse_args() print(ujson.dumps(vars(args), indent=4)) entity_symbols = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, args.entity_map_dir), alias_cand_map_file=args.alias_cand_map, alias_idx_file=args.alias_idx_map, ) print("DO LOWERCASE IS", "uncased" in args.bert_model) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case="uncased" in args.bert_model, cache_dir=args.word_model_cache, ) model = BertModel.from_pretrained( args.bert_model, cache_dir=args.word_model_cache, output_attentions=False, output_hidden_states=False, ) if not args.cpu: model = model.to("cuda") model.eval() entity2avgtitle = build_title_table( args.cpu, args.batch_size, model, tokenizer, entity_symbols ) save_fold = os.path.dirname(args.save_file) if len(save_fold) > 0: if not os.path.exists(save_fold): os.makedirs(save_fold) if args.output_method == "pt": save_obj = (entity_symbols.get_qid2eid(), entity2avgtitle) torch.save(obj=save_obj, f=args.save_file) else: res = {} for qid in tqdm(entity_symbols.get_all_qids(), desc="Building final json"): eid = entity_symbols.get_eid(qid) res[qid] = entity2avgtitle[eid].tolist() with open(args.save_file, "w") as out_f: ujson.dump(res, out_f) print(f"Done!")
def main(): args = parse_args() print(ujson.dumps(vars(args), indent=4)) num_processes = min(args.processes, int(0.8 * multiprocessing.cpu_count())) print("Loading entity symbols") entity_symbols = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, args.entity_map_dir), alias_cand_map_file=args.alias_cand_map, alias_idx_file=args.alias_idx_map, ) in_file = os.path.join(args.data_dir, args.train_file) print(f"Getting slice counts from {in_file}") qid_cnts = get_counts(num_processes, in_file) with open(os.path.join(args.data_dir, "qid_cnts_train.json"), "w") as out_f: ujson.dump(qid_cnts, out_f) df = build_reg_csv(qid_cnts, entity_symbols) df.to_csv(args.out_file, index=False) print(f"Saved file to {args.out_file}")
def setUp(self): """ENTITY SYMBOLS { "multi word alias2":[["Q2",5.0],["Q1",3.0],["Q4",2.0]], "alias1":[["Q1",10.0],["Q4",6.0]], "alias3":[["Q1",30.0]], "alias4":[["Q4",20.0],["Q3",15.0],["Q2",1.0]] } EMBEDDINGS { "key": "learned", "freeze": false, "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 10, } }, { "key": "learned_type", "load_class": "LearnedTypeEmb", "freeze": false, "args": { "type_labels": "type_pred_mapping.json", "max_types": 1, "type_dim": 5, "merge_func": "addattn", "attn_hidden_size": 5 } } """ self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_model_training.json") self.entity_symbols = EntitySymbols.load_from_cache( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map, ) emmental.init(log_dir="test/temp_log") if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path)
def setUp(self): # tests that the sampling is done correctly on indices # load data from directory self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_type_data.json") self.tokenizer = BertTokenizer.from_pretrained( "bert-base-cased", do_lower_case=False, cache_dir="test/data/emb_data/pretrained_bert_models", ) self.is_bert = True self.entity_symbols = EntitySymbols.load_from_cache( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map, ) self.temp_file_name = "test/data/data_loader/test_data.jsonl" self.guid_dtype = lambda max_aliases: np.dtype([ ("sent_idx", "i8", 1), ("subsent_idx", "i8", 1), ("alias_orig_list_pos", "i8", max_aliases), ])
def test_filter_qids(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" entity_db = EntitySymbols.load_from_cache( load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json" ) qid2count = {"Q1": 10, "Q2": 20, "Q3": 2, "Q4": 4} perc_emb_drop = 0.8 gold_qid2topk_eid = {"Q1": 2, "Q2": 1, "Q3": 2, "Q4": 2} gold_old2new_eid = {0: 0, -1: -1, 2: 1, 3: 2} gold_new_toes_eid = 2 gold_num_topk_entities = 2 ( qid2topk_eid, old2new_eid, new_toes_eid, num_topk_entities, ) = filter_qids(perc_emb_drop, entity_db, qid2count) self.assertEqual(gold_qid2topk_eid, qid2topk_eid) self.assertEqual(gold_old2new_eid, old2new_eid) self.assertEqual(gold_new_toes_eid, new_toes_eid) self.assertEqual(gold_num_topk_entities, num_topk_entities)
def test_create_entities(self): truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } trueqid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } # the non-candidate class is included in entity_dump trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4} truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} truealiastrie = {"multi word alias2": 0, "alias1": 1, "alias3": 2, "alias4": 3} entity_symbols = EntitySymbols( max_candidates=3, alias2qids=truealias2qids, qid2title=trueqid2title, ) tri_as_dict = {} for k in entity_symbols._alias_trie: tri_as_dict[k] = entity_symbols._alias_trie[k] self.assertEqual(entity_symbols.max_candidates, 3) self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertDictEqual(entity_symbols._qid2title, trueqid2title) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual(tri_as_dict, truealiastrie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertIsNone(entity_symbols._qid2aliases) # Test load from dump temp_save_dir = "test/data/entity_loader_test" entity_symbols.save(temp_save_dir) entity_symbols = EntitySymbols.load_from_cache(temp_save_dir) self.assertEqual(entity_symbols.max_candidates, 3) self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertDictEqual(entity_symbols._qid2title, trueqid2title) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual(tri_as_dict, truealiastrie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertIsNone(entity_symbols._qid2aliases) shutil.rmtree(temp_save_dir) # Test edit mode entity_symbols = EntitySymbols( max_candidates=3, alias2qids=truealias2qids, qid2title=trueqid2title, edit_mode=True, ) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
def create_task(args, entity_symbols=None, slice_datasets=None): """Returns an EmmentalTask for named entity disambiguation (NED). Args: args: args entity_symbols: entity symbols (default None) slice_datasets: slice datasets used in scorer (default None) Returns: EmmentalTask for NED """ if entity_symbols is None: entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(args.data_config.entity_dir, args.data_config.entity_map_dir), alias_cand_map_file=args.data_config.alias_cand_map, alias_idx_file=args.data_config.alias_idx_map, ) # Create sentence encoder bert_model = BertEncoder(args.data_config.word_embedding, output_size=args.model_config.hidden_size) # Gets the tasks that query for the individual embeddings (e.g., word, entity, type, kg) # The device dict will store which embedding modules we want on the cpu ( embedding_task_flows, # task flows for standard embeddings (e.g., kg, type, entity) embedding_module_pool, # module for standard embeddings embedding_module_device_dict, # module device dict for standard embeddings # some embeddings output indices for BERT so we handle these embeddings in our BERT layer # (see comments in get_through_bert_embedding_tasks) extra_bert_embedding_layers, embedding_payload_inputs, # the layers that are fed into the payload embedding_total_sizes, # total size of all embeddings ) = get_embedding_tasks(args, entity_symbols) # Add the extra embedding layers to BERT module for emb_obj in extra_bert_embedding_layers: bert_model.add_embedding(emb_obj) # Create the embedding payload, attention network, and prediction layer modules if args.model_config.attn_class == "BootlegM2E": embedding_payload = EmbeddingPayload(args, entity_symbols, embedding_total_sizes) attn_network = BootlegM2E(args, entity_symbols) pred_layer = PredictionLayer(args) elif args.model_config.attn_class == "Bootleg": embedding_payload = EmbeddingPayload(args, entity_symbols, embedding_total_sizes) attn_network = Bootleg(args, entity_symbols) pred_layer = PredictionLayer(args) elif args.model_config.attn_class == "BERTNED": # Baseline model embedding_payload = EmbeddingPayloadBase(args, entity_symbols, embedding_total_sizes) attn_network = BERTNED(args, entity_symbols) pred_layer = NoopPredictionLayer(args) else: raise ValueError(f"{args.model_config.attn_class} is not supported.") sliced_scorer = BootlegSlicedScorer(args.data_config.train_in_candidates, slice_datasets) # Create module pool and combine with embedding module pool module_pool = nn.ModuleDict({ BERT_MODEL_NAME: bert_model, "embedding_payload": embedding_payload, "attn_network": attn_network, PRED_LAYER: pred_layer, }) module_pool.update(embedding_module_pool) # Create task flow task_flow = [ { "name": BERT_MODEL_NAME, "module": BERT_MODEL_NAME, "inputs": [ ("_input_", "entity_cand_eid"), ("_input_", "token_ids"), ], # We pass the entity_cand_eids to BERT in case of embeddings that require word information }, *embedding_task_flows, # Add task flows to create embedding inputs { "name": "embedding_payload", "module": "embedding_payload", # outputs: embedding_tensor "inputs": [ ("_input_", "start_span_idx"), ("_input_", "end_span_idx"), *embedding_payload_inputs, # all embeddings ], }, { "name": "attn_network", "module": "attn_network", # output: predictions from layers, output entity embeddings "inputs": [ (BERT_MODEL_NAME, 0), # sentence embedding (BERT_MODEL_NAME, 1), # sentence embedding mask ("embedding_payload", 0), ("_input_", "entity_cand_eid_mask"), ("_input_", "start_span_idx"), ("_input_", "end_span_idx"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": PRED_LAYER, "module": PRED_LAYER, "inputs": [ ( "attn_network", "intermed_scores", ), # output predictions from intermediate layers from the model ( "attn_network", "ent_embs", ), # output entity embeddings (from all KG modules) ( "attn_network", "final_scores", ), # score (empty except for baseline model) ], }, ] return EmmentalTask( name=NED_TASK, module_pool=module_pool, task_flow=task_flow, loss_func=disambig_loss, output_func=disambig_output, require_prob_for_eval=False, require_pred_for_eval=True, # action_outputs are used to stitch together sentence fragments action_outputs=[ ("_input_", "sent_idx"), ("_input_", "subsent_idx"), ("_input_", "alias_orig_list_pos"), ("_input_", "for_dump_gold_cand_K_idx_train"), (PRED_LAYER, "ent_embs"), # entity embeddings ], scorer=Scorer(customize_metric_funcs={ f"{NED_TASK}_scorer": sliced_scorer.bootleg_score }), module_device=embedding_module_device_dict, )
def init_process(entity_dump_f): global ed_global ed_global = EntitySymbols.load_from_cache(load_dir=entity_dump_f)
def compress_topk_embeddings(args): assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" print( f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}" ) entity_db = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, "entity_mappings")) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") ( qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities, ) = filter_qids(args.perc_emb_drop, entity_db, qid2count) if len(args.model_path) > 0: assert (args.save_model_path is not None and len(args.save_model_path) > 0 ), f"If you give a model path, you must give a save checkpoint" print(f"Filtering embeddings") state_dict, model_state_dict = load_statedict(args.model_path) try: get_nested_item(model_state_dict, ENTITY_EMB_KEYS) except: print( f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict") raise model_state_dict = filter_embs( new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict, ) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange( 0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing # for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1 for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 - 1), "The -1 eid should still map to the last row" model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS, oldeid2topkeid) # Remove the eid2reg value as that was with the old entity id mapping try: model_state_dict = set_nested_item(model_state_dict, ENTITY_REG_KEYS, None) except: print( f"Could not remove regularization. If your model was trained with regularization mapping on " f"the learned entity embedding, this should not happen.") print(model_state_dict["module_pool"]["learned"].keys()) state_dict["model"] = model_state_dict print(f"Saving model at {args.save_model_path}") torch.save(state_dict, args.save_model_path) print( f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}" ) utils.dump_json_file( os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), qid2topk_eid, ) if args.model_config is not None: modify_config( args.model_config, args.save_model_config, args.save_model_path, os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), args.perc_emb_drop, )
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file
def test_filter_embs(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" entity_db = EntitySymbols.load_from_cache( load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json", alias_idx_file="alias2id.json", ) num_topk_entities = 2 old2new_eid = {0: 0, -1: -1, 2: 1, 3: 2} qid2topk_eid = {"Q1": 2, "Q2": 1, "Q3": 2, "Q4": 2} toes_eid = 2 state_dict = { "module_pool": { "learned": { "learned_entity_embedding.weight": torch.Tensor( [ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 0, 0, 0, 0], ] ) } } } gold_state_dict = { "module_pool": { "learned": { "learned_entity_embedding.weight": torch.Tensor( [ [1.0, 2, 3, 4, 5], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [0, 0, 0, 0, 0], ] ) } } } new_state_dict = filter_embs( num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, state_dict, ) gld = gold_state_dict nsd = new_state_dict keys_to_check = ["module_pool", "learned", "learned_entity_embedding.weight"] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k])
def load_from_cache( cls, load_dir, edit_mode=False, verbose=False, no_kg=False, no_type=False, type_systems_to_load=None, ): """Loaded a pre-saved profile. Args: load_dir: load directory edit_mode: edit mode flag, default False verbose: verbose flag, default False no_kg: load kg or not flag, default False no_type: load types or not flag, default False. If True, this will ignore type_systems_to_load. type_systems_to_load: list of type systems to load, default is None which means all types systems Returns: entity profile object """ # Check type system input load_dir = Path(load_dir) type_subfolder = load_dir / TYPE_SUBFOLDER if type_systems_to_load is not None: if not isinstance(type_systems_to_load, list): raise ValueError( f"`type_systems` must be a list of subfolders in {type_subfolder}" ) for sys in type_systems_to_load: if sys not in list([p.name for p in type_subfolder.iterdir()]): raise ValueError( f"`type_systems` must be a list of subfolders in {type_subfolder}. {sys} is not one." ) if verbose: print("Loading Entity Symbols") entity_symbols = EntitySymbols.load_from_cache( load_dir / ENTITY_SUBFOLDER, edit_mode=edit_mode, verbose=verbose, ) if no_type: print( f"Not loading type information. We will act as if there is no types associated with any entity " f"and will not modify the types in any way, even if calling `add`." ) type_sys_dict = {} for fold in type_subfolder.iterdir(): if ((not no_type) and (type_systems_to_load is None or fold.name in type_systems_to_load) and (fold.is_dir())): if verbose: print(f"Loading Type Symbols from {fold}") type_sys_dict[fold.name] = TypeSymbols.load_from_cache( type_subfolder / fold.name, edit_mode=edit_mode, verbose=verbose, ) if verbose: print(f"Loading KG Symbols") if no_kg: print( f"Not loading KG information. We will act as if there is not KG connections between entities. " f"We will not modify the KG information in any way, even if calling `add`." ) kg_symbols = None if not no_kg: kg_symbols = KGSymbols.load_from_cache( load_dir / KG_SUBFOLDER, edit_mode=edit_mode, verbose=verbose, ) return cls(entity_symbols, type_sys_dict, kg_symbols, edit_mode, verbose)
def __init__( self, config=None, device=None, max_alias_len=6, cand_map=None, threshold=0.0, cache_dir=None, model_name=None, verbose=False, ): self.max_alias_len = ( max_alias_len # minimum probability of prediction to return mention ) self.verbose = verbose self.threshold = threshold if not cache_dir: self.cache_dir = get_default_cache() self.model_path = self.cache_dir / "models" self.data_path = self.cache_dir / "data" else: self.cache_dir = Path(cache_dir) self.model_path = self.cache_dir / "models" self.data_path = self.cache_dir / "data" if not model_name: model_name = "bootleg_uncased" assert model_name in { "bootleg_cased", "bootleg_cased_mini", "bootleg_uncased", "bootleg_uncased_mini", }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, " f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.") if not config: self.cache_dir.mkdir(parents=True, exist_ok=True) self.model_path.mkdir(parents=True, exist_ok=True) self.data_path.mkdir(parents=True, exist_ok=True) create_sources(self.model_path, self.data_path, model_name) self.config = create_config(self.model_path, self.data_path, model_name) else: if "emmental" in config: config = parse_boot_and_emm_args(config) self.config = config # Ensure some of the critical annotator args are the correct type self.config.data_config.max_aliases = int( self.config.data_config.max_aliases) self.config.run_config.eval_batch_size = int( self.config.run_config.eval_batch_size) self.config.data_config.max_seq_len = int( self.config.data_config.max_seq_len) self.config.data_config.train_in_candidates = bool( self.config.data_config.train_in_candidates) if not device: device = 0 if torch.cuda.is_available() else -1 if self.verbose: self.config.run_config.log_level = "DEBUG" else: self.config.run_config.log_level = "INFO" self.torch_device = (torch.device(device) if device != -1 else torch.device("cpu")) self.config.model_config.device = device log_level = logging.getLevelName( self.config["run_config"]["log_level"].upper()) emmental.init( log_dir=self.config["meta_config"]["log_path"], config=self.config, use_exact_log_path=self.config["meta_config"] ["use_exact_log_path"], level=log_level, ) logger.debug("Reading entity database") self.entity_db = EntitySymbols.load_from_cache( os.path.join( self.config.data_config.entity_dir, self.config.data_config.entity_map_dir, ), alias_cand_map_file=self.config.data_config.alias_cand_map, alias_idx_file=self.config.data_config.alias_idx_map, ) logger.debug("Reading word tokenizers") self.tokenizer = BertTokenizer.from_pretrained( self.config.data_config.word_embedding.bert_model, do_lower_case=True if "uncased" in self.config.data_config.word_embedding.bert_model else False, cache_dir=self.config.data_config.word_embedding.cache_dir, ) # Create tasks tasks = [NED_TASK] if self.config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks} # Create tasks self.model = EmmentalModel(name="Bootleg") self.model.add_task(ned_task.create_task(self.config, self.entity_db)) if TYPE_PRED_TASK in tasks: self.model.add_task( type_pred_task.create_task(self.config, self.entity_db)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(self.model) logger.debug("Loading model") # Load the best model from the pretrained model assert ( self.config["model_config"]["model_path"] is not None ), f"Must have a model to load in the model_path for the BootlegAnnotator" self.model.load(self.config["model_config"]["model_path"]) self.model.eval() if cand_map is None: alias_map = self.entity_db.get_alias2qids() else: logger.debug(f"Loading candidate map") alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, verbose) logger.debug("Reading in alias table") self.alias2cands = AliasEntityTable( data_config=self.config.data_config, entity_symbols=self.entity_db) # get batch_on_the_fly embeddings self.batch_on_the_fly_embs = get_dataloader_embeddings( self.config, self.entity_db)
def create_task(args, entity_symbols=None, slice_datasets=None): """Creates a type prediction task. Args: args: args entity_symbols: entity symbols slice_datasets: slice datasets used in scorer (default None) Returns: EmmentalTask for type prediction """ if entity_symbols is None: entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join( args.data_config.entity_dir, args.data_config.entity_map_dir ), alias_cand_map_file=args.data_config.alias_cand_map, alias_idx_file=args.data_config.alias_idx_map, ) # Create sentence encoder bert_model = BertEncoder( args.data_config.word_embedding, output_size=args.model_config.hidden_size ) # Create type prediction module # Add 1 for pad type type_prediction = TypePred( args.model_config.hidden_size, args.data_config.type_prediction.dim, args.data_config.type_prediction.num_types + 1, embedding_utils.get_max_candidates(entity_symbols, args.data_config), ) # Create scorer sliced_scorer = BootlegSlicedScorer( args.data_config.train_in_candidates, slice_datasets ) # Create module pool # BERT model will be shared across tasks as long as the name matches module_pool = nn.ModuleDict( {BERT_MODEL_NAME: bert_model, "type_prediction": type_prediction} ) # Create task flow task_flow = [ { "name": BERT_MODEL_NAME, "module": BERT_MODEL_NAME, "inputs": [ ("_input_", "entity_cand_eid"), ("_input_", "token_ids"), ], # We pass the entity_cand_eids to BERT in case of embeddings that require word information }, { "name": "type_prediction", "module": "type_prediction", # output: embedding_dict, batch_type_pred "inputs": [ (BERT_MODEL_NAME, 0), # sentence embedding ("_input_", "start_span_idx"), ], }, ] return EmmentalTask( name=TYPE_PRED_TASK, module_pool=module_pool, task_flow=task_flow, loss_func=partial(type_loss, "type_prediction"), output_func=partial(type_output, "type_prediction"), require_prob_for_eval=False, require_pred_for_eval=True, scorer=Scorer( customize_metric_funcs={ f"{TYPE_PRED_TASK}_scorer": sliced_scorer.type_pred_score } ), )
def main(): gl_start = time.time() multiprocessing.set_start_method("spawn") args = get_arg_parser().parse_args() print(json.dumps(vars(args), indent=4)) utils.ensure_dir(args.data_dir) out_dir = os.path.join(args.data_dir, args.out_subdir) if os.path.exists(out_dir): shutil.rmtree(out_dir) os.makedirs(out_dir, exist_ok=True) # Reading in files in_files_train = glob.glob(os.path.join(args.data_dir, "*.jsonl")) in_files_cand = glob.glob( os.path.join(args.contextual_cand_data, "*.jsonl")) assert len(in_files_train ) > 0, f"We didn't find any train files at {args.data_dir}" assert ( len(in_files_cand) > 0 ), f"We didn't find any contextual files at {args.contextual_cand_data}" in_files = [] for file in in_files_train: file_name = os.path.basename(file) tag = os.path.splitext(file_name)[0] is_train = "train" in tag if is_train: print( f"{file_name} is a training dataset...will be processed as such" ) pair = None for f in in_files_cand: if tag in f: pair = f break assert pair is not None, f"{file_name} name, {tag} tag" out_file = os.path.join(out_dir, file_name) in_files.append([file, pair, out_file, is_train]) final_cand_map = {} max_cands = 0 for pair in in_files: print( f"Reading in {pair[0]} with cand maps {pair[1]} and dumping to {pair[2]}" ) new_alias2qids = merge_data( args.processes, args.train_in_candidates, args.keep_orig, args.max_candidates, pair, args.entity_dump, ) for al in new_alias2qids: assert al not in final_cand_map, f"{al} is already in final_cand_map" final_cand_map[al] = new_alias2qids[al] max_cands = max(max_cands, len(final_cand_map[al])) print(f"Buidling new entity symbols") entity_dump = EntitySymbols.load_from_cache(load_dir=args.entity_dump) entity_dump_new = EntitySymbols( max_candidates=max_cands, alias2qids=final_cand_map, qid2title=entity_dump.get_qid2title(), ) out_dir = os.path.join(out_dir, "entity_db/entity_mappings") entity_dump_new.save(out_dir) print(f"Finished in {time.time() - gl_start}s")