def main(args): # Initialize Emmental config = parse_args_to_config(args) emmental.init(log_dir=config["meta_config"]["log_path"], config=config) # Log configuration into files cmd_msg = " ".join(sys.argv) logger.info(f"COMMAND: {cmd_msg}") write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg) logger.info(f"Config: {emmental.Meta.config}") write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config) # Create dataloaders dataloaders = get_dataloaders(args) # Assign transforms to dataloaders aug_dataloaders = [] if args.augment_policy: for idx in range(len(dataloaders)): if dataloaders[idx].split in args.train_split: dataloaders[idx].dataset.transform_cls = Augmentation( args=args) config["learner_config"]["task_scheduler_config"][ "task_scheduler"] = AugScheduler(augment_k=args.augment_k, enlarge=args.augment_enlarge) emmental.Meta.config["learner_config"]["task_scheduler_config"][ "task_scheduler"] = config["learner_config"]["task_scheduler_config"][ "task_scheduler"] # Create tasks model = EmmentalModel(name=f"{args.task}_task") model.add_task(create_task(args)) # Set cudnn benchmark cudnn.benchmark = True # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) if args.train: emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders + aug_dataloaders) # Remove all extra augmentation policy for idx in range(len(dataloaders)): dataloaders[idx].dataset.transform_cls = None scores = model.score(dataloaders) # Save metrics and models logger.info(f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores) model.save(f"{emmental.Meta.log_path}/last_model.pth")
def load_best_model(self, model: EmmentalModel) -> EmmentalModel: """Load the best model from the checkpoint. Args: model: The current model. Returns: The best model load from the checkpoint. """ if list(self.checkpoint_metric.keys())[0] not in self.best_metric_dict: logger.info("No best model found, use the original model.") else: # Load the best model of checkpoint_metric metric = list(self.checkpoint_metric.keys())[0] best_model_path = (f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth") model.load(best_model_path, verbose=False) logger.info(f"Loading the best model from {best_model_path}.") return model
def predict_and_write(task_name, path, data_dir, submit_subdir, batch_size): bert_model_name, max_seq_len = extract_from_cmd(path) msg = ( f"Using {bert_model_name} and max_sequence_len={max_seq_len} for task " f"{task_name}") logger.info(msg) # Build model task = build_model[task_name](bert_model_name) model = EmmentalModel(name=f"SuperGLUE_{task_name}", tasks=[task]) try: model.load(path) except UnboundLocalError: msg = ( "Failed to load state dict; confirm that your model was saved with " "a command such as 'torch.save(model.state_dict(), PATH)'") logging.error(msg) raise # Build dataloaders dataloaders = get_dataloaders( data_dir, task_name=task_name, splits=["val", "test"], # TODO: replace with ['split'] and update below max_data_samples=None, max_sequence_length=max_seq_len, tokenizer_name=bert_model_name, batch_size=batch_size, uid="uids", ) # TEMP: Sanity check val performance logger.info(f"Valid score: {model.score(dataloaders[0])}") # TEMP filename = f"{task_name}.jsonl" filepath = os.path.join(submit_subdir, filename) make_submission_file(model, dataloaders[-1], task_name, filepath)
def main(args): # Ensure that global state is fresh Meta.reset() # Initialize Emmental config = parse_arg_to_config(args) emmental.init(config["meta_config"]["log_path"], config=config) # Save command line argument into file cmd_msg = " ".join(sys.argv) logger.info(f"COMMAND: {cmd_msg}") write_to_file(Meta.log_path, "cmd.txt", cmd_msg) # Save Emmental config into file logger.info(f"Config: {Meta.config}") write_to_file(Meta.log_path, "config.txt", Meta.config) Meta.config["learner_config"]["global_evaluation_metric_dict"] = { f"model/SuperGLUE/{split}/score": partial(superglue_scorer, split=split) for split in ["val"] } # Construct dataloaders and tasks and load slices dataloaders = [] tasks = [] for task_name in args.task: task_dataloaders = get_dataloaders( data_dir=args.data_dir, task_name=task_name, splits=["train", "val", "test"], max_sequence_length=args.max_sequence_length, max_data_samples=args.max_data_samples, tokenizer_name=args.bert_model, batch_size=args.batch_size, augment=args.augmentations, ) task = models.model[task_name]( args.bert_model, last_hidden_dropout_prob=args.last_hidden_dropout_prob) if args.slices: logger.info("Initializing task-specific slices") slice_func_dict = slicing.slice_func_dict[task_name] # Include general purpose slices if args.general_slices: logger.info("Including general slices") slice_func_dict.update(slicing.slice_func_dict["general"]) task_dataloaders = slicing.add_slice_labels( task_name, task_dataloaders, slice_func_dict) slice_tasks = slicing.add_slice_tasks(task_name, task, slice_func_dict, args.slice_hidden_dim) tasks.extend(slice_tasks) else: tasks.append(task) dataloaders.extend(task_dataloaders) # Build Emmental model model = EmmentalModel(name="SuperGLUE", tasks=tasks) # Load pretrained model if necessary if Meta.config["model_config"]["model_path"]: model.load(Meta.config["model_config"]["model_path"]) # Training if args.train: emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) # If model is slice-aware, slice scores will be calculated from slice heads # If model is not slice-aware, manually calculate performance on slices if not args.slices: slice_func_dict = {} slice_keys = args.task if args.general_slices: slice_keys.append("general") for k in slice_keys: slice_func_dict.update(slicing.slice_func_dict[k]) scores = slicing.score_slices(model, dataloaders, args.task, slice_func_dict) else: scores = model.score(dataloaders) # Save metrics into file logger.info(f"Metrics: {scores}") write_to_file(Meta.log_path, "metrics.txt", scores) # Save best metrics into file if args.train: logger.info( f"Best metrics: " f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}" ) write_to_file( Meta.log_path, "best_metrics.txt", emmental_learner.logging_manager.checkpointer.best_metric_dict, ) # Save submission file for task_name in args.task: dataloaders = [d for d in dataloaders if d.split == "test"] assert len(dataloaders) == 1 filepath = os.path.join(Meta.log_path, f"{task_name}.jsonl") make_submission_file(model, dataloaders[0], task_name, filepath)
def main(args): # Initialize Emmental config = parse_args_to_config(args) emmental.init(log_dir=config["meta_config"]["log_path"], config=config) # Log configuration into files cmd_msg = " ".join(sys.argv) logger.info(f"COMMAND: {cmd_msg}") write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg) logger.info(f"Config: {emmental.Meta.config}") write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config) # Create dataloaders dataloaders = get_dataloaders(args) config["learner_config"]["task_scheduler_config"][ "task_scheduler"] = AugScheduler(augment_k=args.augment_k, enlarge=args.augment_enlarge) emmental.Meta.config["learner_config"]["task_scheduler_config"][ "task_scheduler"] = config["learner_config"]["task_scheduler_config"][ "task_scheduler"] # Specify parameter group for Adam BERT def grouped_parameters(model): no_decay = ["bias", "LayerNorm.weight"] return [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": emmental.Meta.config["learner_config"]["optimizer_config"] ["l2"], }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] emmental.Meta.config["learner_config"]["optimizer_config"][ "parameters"] = grouped_parameters # Create tasks model = EmmentalModel(name=f"{args.task}_task") model.add_task(create_task(args)) # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) if args.train: emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) # Remove all extra augmentation policy for idx in range(len(dataloaders)): dataloaders[idx].dataset.transform_cls = None dataloaders[idx].dataset.k = 1 scores = model.score(dataloaders) # Save metrics and models logger.info(f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores) model.save(f"{emmental.Meta.log_path}/last_model.pth")
for task_name in args.task: dataloaders += create_dataloaders( task_name, datasets[task_name], args.batch_size, emb_layer.word2id ) tasks = { task_name: create_task( task_name, args, datasets[task_name]["nclasses"], emb_layer ) for task_name in args.task } model = EmmentalModel(name="TC_task") if Meta.config["model_config"]["model_path"]: model.load(Meta.config["model_config"]["model_path"]) else: for task_name, task in tasks.items(): model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) scores = model.score(dataloaders) logger.info(f"Metrics: {scores}") write_to_json_file(f"{Meta.log_path}/metrics.txt", scores) if args.checkpointing: logger.info( f"Best metrics: " f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
def test_model(caplog): """Unit test of model.""" caplog.set_level(logging.INFO) dirpath = "temp_test_model" Meta.reset() emmental.init(dirpath) def ce_loss(module_name, immediate_output_dict, Y, active): return F.cross_entropy(immediate_output_dict[module_name][0][active], (Y.view(-1))[active]) def output(module_name, immediate_output_dict): return F.softmax(immediate_output_dict[module_name][0], dim=1) task1 = EmmentalTask( name="task_1", module_pool=nn.ModuleDict({ "m1": nn.Linear(10, 10, bias=False), "m2": nn.Linear(10, 2, bias=False) }), task_flow=[ { "name": "m1", "module": "m1", "inputs": [("_input_", "data")] }, { "name": "m2", "module": "m2", "inputs": [("m1", 0)] }, ], loss_func=partial(ce_loss, "m2"), output_func=partial(output, "m2"), scorer=Scorer(metrics=["accuracy"]), ) new_task1 = EmmentalTask( name="task_1", module_pool=nn.ModuleDict({ "m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False) }), task_flow=[ { "name": "m1", "module": "m1", "inputs": [("_input_", "data")] }, { "name": "m2", "module": "m2", "inputs": [("m1", 0)] }, ], loss_func=partial(ce_loss, "m2"), output_func=partial(output, "m2"), scorer=Scorer(metrics=["accuracy"]), ) task2 = EmmentalTask( name="task_2", module_pool=nn.ModuleDict({ "m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False) }), task_flow=[ { "name": "m1", "module": "m1", "inputs": [("_input_", "data")] }, { "name": "m2", "module": "m2", "inputs": [("m1", 0)] }, ], loss_func=partial(ce_loss, "m2"), output_func=partial(output, "m2"), scorer=Scorer(metrics=["accuracy"]), ) config = {"model_config": {"dataparallel": False}} emmental.Meta.update_config(config) model = EmmentalModel(name="test", tasks=task1) assert repr(model) == "EmmentalModel(name=test)" assert model.name == "test" assert model.task_names == set(["task_1"]) assert model.module_pool["m1"].weight.data.size() == (10, 10) assert model.module_pool["m2"].weight.data.size() == (2, 10) model.update_task(new_task1) assert model.module_pool["m1"].weight.data.size() == (5, 10) assert model.module_pool["m2"].weight.data.size() == (2, 5) model.update_task(task2) assert model.task_names == set(["task_1"]) model.add_task(task2) assert model.task_names == set(["task_1", "task_2"]) model.remove_task("task_1") assert model.task_names == set(["task_2"]) model.remove_task("task_1") assert model.task_names == set(["task_2"]) model.save(f"{dirpath}/saved_model.pth") model.load(f"{dirpath}/saved_model.pth") # Test add_tasks model = EmmentalModel(name="test") model.add_tasks([task1, task2]) assert model.task_names == set(["task_1", "task_2"]) shutil.rmtree(dirpath)
], "weight_decay": 0.0, }, ] emmental.Meta.config["learner_config"]["optimizer_config"][ "parameters"] = grouped_parameters # Create tasks model = EmmentalModel(name="TACRED_task") model.add_task(create_task(args)) # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) if args.train: emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) # Remove all extra augmentation policy for idx in range(len(dataloaders)): dataloaders[idx].dataset.transform_cls = None scores = model.score(dataloaders) # Save metrics and models logger.info(f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
class BootlegAnnotator(object): """BootlegAnnotator class: convenient wrapper of preprocessing and model eval to allow for annotating single sentences at a time for quick experimentation, e.g. in notebooks. Args: config: model config (default None) device: model device, -1 for CPU (default None) max_alias_len: maximum alias length (default 6) cand_map: alias candidate map (default None) threshold: probability threshold (default 0.0) cache_dir: cache directory (default None) model_name: model name (default None) verbose: verbose boolean (default False) """ def __init__( self, config=None, device=None, max_alias_len=6, cand_map=None, threshold=0.0, cache_dir=None, model_name=None, verbose=False, ): self.max_alias_len = ( max_alias_len # minimum probability of prediction to return mention ) self.verbose = verbose self.threshold = threshold if not cache_dir: self.cache_dir = get_default_cache() self.model_path = self.cache_dir / "models" self.data_path = self.cache_dir / "data" else: self.cache_dir = Path(cache_dir) self.model_path = self.cache_dir / "models" self.data_path = self.cache_dir / "data" if not model_name: model_name = "bootleg_uncased" assert model_name in { "bootleg_cased", "bootleg_cased_mini", "bootleg_uncased", "bootleg_uncased_mini", }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, " f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.") if not config: self.cache_dir.mkdir(parents=True, exist_ok=True) self.model_path.mkdir(parents=True, exist_ok=True) self.data_path.mkdir(parents=True, exist_ok=True) create_sources(self.model_path, self.data_path, model_name) self.config = create_config(self.model_path, self.data_path, model_name) else: if "emmental" in config: config = parse_boot_and_emm_args(config) self.config = config # Ensure some of the critical annotator args are the correct type self.config.data_config.max_aliases = int( self.config.data_config.max_aliases) self.config.run_config.eval_batch_size = int( self.config.run_config.eval_batch_size) self.config.data_config.max_seq_len = int( self.config.data_config.max_seq_len) self.config.data_config.train_in_candidates = bool( self.config.data_config.train_in_candidates) if not device: device = 0 if torch.cuda.is_available() else -1 if self.verbose: self.config.run_config.log_level = "DEBUG" else: self.config.run_config.log_level = "INFO" self.torch_device = (torch.device(device) if device != -1 else torch.device("cpu")) self.config.model_config.device = device log_level = logging.getLevelName( self.config["run_config"]["log_level"].upper()) emmental.init( log_dir=self.config["meta_config"]["log_path"], config=self.config, use_exact_log_path=self.config["meta_config"] ["use_exact_log_path"], level=log_level, ) logger.debug("Reading entity database") self.entity_db = EntitySymbols.load_from_cache( os.path.join( self.config.data_config.entity_dir, self.config.data_config.entity_map_dir, ), alias_cand_map_file=self.config.data_config.alias_cand_map, alias_idx_file=self.config.data_config.alias_idx_map, ) logger.debug("Reading word tokenizers") self.tokenizer = BertTokenizer.from_pretrained( self.config.data_config.word_embedding.bert_model, do_lower_case=True if "uncased" in self.config.data_config.word_embedding.bert_model else False, cache_dir=self.config.data_config.word_embedding.cache_dir, ) # Create tasks tasks = [NED_TASK] if self.config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks} # Create tasks self.model = EmmentalModel(name="Bootleg") self.model.add_task(ned_task.create_task(self.config, self.entity_db)) if TYPE_PRED_TASK in tasks: self.model.add_task( type_pred_task.create_task(self.config, self.entity_db)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(self.model) logger.debug("Loading model") # Load the best model from the pretrained model assert ( self.config["model_config"]["model_path"] is not None ), f"Must have a model to load in the model_path for the BootlegAnnotator" self.model.load(self.config["model_config"]["model_path"]) self.model.eval() if cand_map is None: alias_map = self.entity_db.get_alias2qids() else: logger.debug(f"Loading candidate map") alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, verbose) logger.debug("Reading in alias table") self.alias2cands = AliasEntityTable( data_config=self.config.data_config, entity_symbols=self.entity_db) # get batch_on_the_fly embeddings self.batch_on_the_fly_embs = get_dataloader_embeddings( self.config, self.entity_db) def extract_mentions(self, text, label_func): """Wrapper function for mention extraction. Args: text: text to extract mentions from label_func: function that performs extraction (input is (text, alias trie, max alias length) -> output is list of found aliases and found spans Returns: JSON object of sentence to be used in eval """ found_aliases, found_spans = label_func(text, self.all_aliases_trie, self.max_alias_len) return { "sentence": text, "aliases": found_aliases, "spans": found_spans, # we don't know the true QID "qids": ["Q-1" for i in range(len(found_aliases))], "gold": [True for i in range(len(found_aliases))], } def set_threshold(self, value): """Sets threshold. Args: value: threshold value Returns: """ self.threshold = value def label_mentions(self, text_list, label_func=find_aliases_in_sentence_tag): """Extracts mentions and runs disambiguation. Args: text_list: list of text to disambiguate (or single sentence) label_func: mention extraction funciton (optional) Returns: Dict of * ``qids``: final predicted QIDs, * ``probs``: final predicted probs, * ``titles``: final predicted titles, * ``cands``: all entity canddiates, * ``cand_probs``: probabilities of all candidates, * ``spans``: final extracted word spans, * ``aliases``: final extracted aliases, """ if type(text_list) is str: text_list = [text_list] else: assert (type(text_list) is list and len(text_list) > 0 and type(text_list[0]) is str ), f"We only accept inputs of strings and lists of strings" ebs = int(self.config.run_config.eval_batch_size) self.config.data_config.max_aliases = int( self.config.data_config.max_aliases) total_start_exs = 0 total_final_exs = 0 dropped_by_thresh = 0 final_char_spans = [] batch_example_aliases = [] batch_example_aliases_locs_start = [] batch_example_aliases_locs_end = [] batch_example_alias_list_pos = [] batch_example_true_entities = [] batch_word_indices = [] batch_spans_arr = [] batch_aliases_arr = [] batch_idx_unq = [] batch_subsplit_idx = [] for idx_unq, text in tqdm( enumerate(text_list), desc="Prepping data", total=len(text_list), disable=not self.verbose, ): sample = self.extract_mentions(text, label_func) total_start_exs += len(sample["aliases"]) char_spans = self.get_char_spans(sample["spans"], text) final_char_spans.append(char_spans) ( idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs, ) = sentence_utils.split_sentence( max_aliases=self.config.data_config.max_aliases, phrase=sample["sentence"], spans=sample["spans"], aliases=sample["aliases"], aliases_seen_by_model=list(range(len(sample["aliases"]))), seq_len=self.config.data_config.max_seq_len, is_bert=True, tokenizer=self.tokenizer, ) aliases_arr = [[sample["aliases"][idx] for idx in idxs] for idxs in idxs_arr] old_spans_arr = [[sample["spans"][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample["qids"][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [ self.tokenizer.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr ] # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): # ==================================================== # GENERATE MODEL INPUTS # ==================================================== aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx] assert ( len(aliases_to_predict_arr) >= 0 ), f"There are no aliases to predict for an example. This should not happen at this point." assert ( len(aliases_arr[sub_idx]) <= self.config.data_config.max_aliases ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases." example_aliases = np.ones( self.config.data_config.max_aliases) * PAD_ID example_aliases_locs_start = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_aliases_locs_end = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_alias_list_pos = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_true_entities = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) for mention_idx, alias in enumerate(aliases_arr[sub_idx]): span_start_idx, span_end_idx = spans_arr[sub_idx][ mention_idx] # generate indexes into alias table. alias_trie_idx = self.entity_db.get_alias_idx(alias) alias_qids = np.array(self.entity_db.get_qid_cands(alias)) if not qids_arr[sub_idx][mention_idx] in alias_qids: # assert not data_args.train_in_candidates if not self.config.data_config.train_in_candidates: # set class label to be "not in candidate set" true_entity_idx = 0 else: true_entity_idx = -2 else: # Here we are getting the correct class label for training. # Our training is "which of the max_entities entity candidates is the right one # (class labels 1 to max_entities) or is it none of these (class label 0)". # + (not discard_noncandidate_entities) is to ensure label 0 is # reserved for "not in candidate set" class true_entity_idx = np.nonzero( alias_qids == qids_arr[sub_idx][mention_idx] )[0][0] + ( not self.config.data_config.train_in_candidates) example_aliases[mention_idx] = alias_trie_idx example_aliases_locs_start[mention_idx] = span_start_idx # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx. example_aliases_locs_end[mention_idx] = span_end_idx - 1 example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][ mention_idx] # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence # and need to only predict subsets if mention_idx in aliases_to_predict_arr: example_true_entities[mention_idx] = true_entity_idx # get word indices word_indices = word_indices_arr[sub_idx] batch_example_aliases.append(example_aliases) batch_example_aliases_locs_start.append( example_aliases_locs_start) batch_example_aliases_locs_end.append(example_aliases_locs_end) batch_example_alias_list_pos.append(example_alias_list_pos) batch_example_true_entities.append(example_true_entities) batch_word_indices.append(word_indices) batch_aliases_arr.append(aliases_arr[sub_idx]) # Add the orginal sample spans because spans_arr is w.r.t BERT subword token batch_spans_arr.append(old_spans_arr[sub_idx]) batch_idx_unq.append(idx_unq) batch_subsplit_idx.append(sub_idx) batch_example_aliases = torch.tensor(batch_example_aliases).long() batch_example_aliases_locs_start = torch.tensor( batch_example_aliases_locs_start, device=self.torch_device) batch_example_aliases_locs_end = torch.tensor( batch_example_aliases_locs_end, device=self.torch_device) batch_example_true_entities = torch.tensor(batch_example_true_entities, device=self.torch_device) batch_word_indices = torch.tensor(batch_word_indices, device=self.torch_device) final_pred_cands = [[] for _ in range(len(text_list))] final_all_cands = [[] for _ in range(len(text_list))] final_cand_probs = [[] for _ in range(len(text_list))] final_pred_probs = [[] for _ in range(len(text_list))] final_titles = [[] for _ in range(len(text_list))] final_spans = [[] for _ in range(len(text_list))] final_aliases = [[] for _ in range(len(text_list))] for b_i in tqdm( range(0, batch_example_aliases.shape[0], ebs), desc="Evaluating model", disable=not self.verbose, ): start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs] end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs] word_indices = batch_word_indices[b_i:b_i + ebs] alias_indices = batch_example_aliases[b_i:b_i + ebs] x_dict = self.get_forward_batch(start_span_idx, end_span_idx, word_indices, alias_indices) x_dict["guid"] = torch.arange(b_i, b_i + ebs, device=self.torch_device) (uid_bdict, _, prob_bdict, _) = self.model( # type: ignore uids=x_dict["guid"], X_dict=x_dict, Y_dict=None, task_to_label_dict=self.task_to_label_dict, return_action_outputs=False, ) # ==================================================== # EVALUATE MODEL OUTPUTS # ==================================================== # recover predictions probs = prob_bdict[NED_TASK] max_probs = probs.max(2) max_probs_indices = probs.argmax(2) for ex_i in range(probs.shape[0]): idx_unq = batch_idx_unq[b_i + ex_i] entity_cands = eval_utils.map_aliases_to_candidates( self.config.data_config.train_in_candidates, self.config.data_config.max_aliases, self.entity_db.get_alias2qids(), batch_aliases_arr[b_i + ex_i], ) # batch size is 1 so we can reshape probs_ex = probs[ex_i].reshape( self.config.data_config.max_aliases, probs.shape[2]) for alias_idx, true_entity_pos_idx in enumerate( batch_example_true_entities[b_i + ex_i]): if true_entity_pos_idx != PAD_ID: pred_idx = max_probs_indices[ex_i][alias_idx] pred_prob = max_probs[ex_i][alias_idx].item() all_cands = entity_cands[alias_idx] pred_qid = all_cands[pred_idx] if pred_prob > self.threshold: final_all_cands[idx_unq].append(all_cands) final_cand_probs[idx_unq].append( probs_ex[alias_idx]) final_pred_cands[idx_unq].append(pred_qid) final_pred_probs[idx_unq].append(pred_prob) final_aliases[idx_unq].append( batch_aliases_arr[b_i + ex_i][alias_idx]) final_spans[idx_unq].append( batch_spans_arr[b_i + ex_i][alias_idx]) final_titles[idx_unq].append( self.entity_db.get_title(pred_qid) if pred_qid != "NC" else "NC") total_final_exs += 1 else: dropped_by_thresh += 1 assert total_final_exs + dropped_by_thresh == total_start_exs, ( f"Something went wrong and we have predicted fewer mentions than extracted. " f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}" ) res_dict = { "qids": final_pred_cands, "probs": final_pred_probs, "titles": final_titles, "cands": final_all_cands, "cand_probs": final_cand_probs, "spans": final_spans, "aliases": final_aliases, } return res_dict def get_forward_batch(self, start_span_idx, end_span_idx, token_ids, alias_idx): """Preps the forward batch for disambiguation. Args: start_span_idx: start span tensor end_span_idx: end span tensor token_ids: word token tensor alias_idx: alias index used for extracting candidate eids Returns: X_dict used in Emmental """ entity_cand_eid = self.alias2cands(alias_idx).long() entity_cand_eid_mask = entity_cand_eid == -1 entity_cand_eid_noneg = torch.where( entity_cand_eid >= 0, entity_cand_eid, (torch.ones_like(entity_cand_eid, dtype=torch.long) * (self.entity_db.num_entities_with_pad_and_nocand - 1)), ) kg_prepped_embs = {} for emb_key in self.batch_on_the_fly_embs: kg_adj = self.batch_on_the_fly_embs[emb_key]["kg_adj"] prep_func = self.batch_on_the_fly_embs[emb_key][ "kg_adj_process_func"] batch_prep = [] for j in range(entity_cand_eid_noneg.shape[0]): batch_prep.append( prep_func(entity_cand_eid_noneg[j].cpu(), kg_adj).reshape(1, -1)) kg_prepped_embs[emb_key] = torch.tensor(batch_prep, device=self.torch_device) X_dict = { "guids": [], "start_span_idx": start_span_idx, "end_span_idx": end_span_idx, "token_ids": token_ids, "entity_cand_eid": entity_cand_eid_noneg, "entity_cand_eid_mask": entity_cand_eid_mask, "batch_on_the_fly_kg_adj": kg_prepped_embs, } return X_dict def get_char_spans(self, spans, text): """Helper function to get character spans instead of default word spans. Args: spans: word spans text: text Returns: character spans """ query_toks = text.split() char_spans = [] for span in spans: space_btwn_toks = (len(" ".join(query_toks[0:span[0] + 1])) - len(" ".join(query_toks[0:span[0]])) - len(query_toks[span[0]])) char_b = len(" ".join(query_toks[0:span[0]])) + space_btwn_toks char_e = char_b + len(" ".join(query_toks[span[0]:span[1]])) char_spans.append([char_b, char_e]) return char_spans
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file