def load_regularization_mapping(cls, data_config, entity_symbols, reg_file): """Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. Args: data_config: data config qid2topk_eid: Dict from QID to eid in the entity embedding num_entities_with_pad_and_nocand: number of entities including pad and null candidate option reg_file: regularization csv file Returns: Tensor where each value is the regularization value for EID """ reg_str = os.path.splitext(os.path.basename(reg_file.replace("/", "_")))[0] prep_dir = data_utils.get_data_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"entity_regularization_mapping_{reg_str}.pt") utils.ensure_dir(os.path.dirname(prep_file)) log_rank_0_debug(logger, f"Looking for regularization mapping in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing entity regularization mapping from {prep_file}", ) start = time.time() eid2reg = torch.load(prep_file) log_rank_0_debug( logger, f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_info( logger, f"Building entity regularization mapping from {reg_file}") qid2reg = pd.read_csv(reg_file) assert ( "qid" in qid2reg.columns and "regularization" in qid2reg.columns ), f"Expected qid and regularization as the column names for {reg_file}" # default of no mask eid2reg_arr = [0.0 ] * entity_symbols.num_entities_with_pad_and_nocand for row_idx, row in qid2reg.iterrows(): if entity_symbols.qid_exists(row["qid"]): eid = entity_symbols.get_eid(row["qid"]) eid2reg_arr[eid] = row["regularization"] eid2reg = torch.tensor(eid2reg_arr) torch.save(eid2reg, prep_file) log_rank_0_debug( logger, f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.", ) return eid2reg
def unfreeze_params(self): """Unfreezes the parameters of the module. Returns: """ for name, param in self.named_parameters(): param.requires_grad = True log_rank_0_info(logger, f"Unfreezing {name}") return
def prep( cls, data_config, emb_args, entity_symbols, threshold, log_weight, ): """Preps the KG information. Args: data_config: data config emb_args: embedding args entity_symbols: entity symbols threshold: weight threshold for counting an edge log_weight: whether to take the log of the weight value after the threshold Returns: numpy sparce KG adjacency matrix, prep file """ file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join(prep_dir, f"kg_adj_file_{file_tag}.npz") utils.ensure_dir(os.path.dirname(prep_file)) if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading existing KG adj from {prep_file}") start = time.time() kg_adj = scipy.sparse.load_npz(prep_file) log_rank_0_debug( logger, f"Loaded existing KG adj in {round(time.time() - start, 2)}s") else: start = time.time() kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj) log_rank_0_info(logger, f"Building KG adj from {kg_adj_file}") kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold, log_weight) scipy.sparse.save_npz(prep_file, kg_adj) log_rank_0_debug( logger, f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.", ) return kg_adj, prep_file
def setup(config, run_config_path=None): """ Setup distributed backend and save configuration files. Args: config: config run_config_path: path for original run config Returns: """ # torch.multiprocessing.set_sharing_strategy("file_system") # spawn method must be fork to work with Meta.config torch.multiprocessing.set_start_method("fork", force=True) """ ulimit -n 500000 python3 -m torch.distributed.launch --nproc_per_node=2 bootleg/run.py --config_script ... """ log_level = logging.getLevelName(config.run_config.log_level.upper()) emmental.init( log_dir=config["meta_config"]["log_path"], config=config, use_exact_log_path=config["meta_config"]["use_exact_log_path"], local_rank=config.learner_config.local_rank, level=log_level, ) log = logging.getLogger() # Remove streaming handlers and use rich log.handlers = [ h for h in log.handlers if not type(h) is logging.StreamHandler ] log.addHandler(RichHandler()) # Set up distributed backend emmental.Meta.init_distributed_backend() cmd_msg = " ".join(sys.argv) # Log configuration into filess if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg) dump_yaml_file(f"{emmental.Meta.log_path}/parsed_config.yaml", emmental.Meta.config) # Dump the run config (does not contain defaults) if run_config_path is not None: dump_yaml_file( f"{emmental.Meta.log_path}/run_config.yaml", load_yaml_file(run_config_path), ) log_rank_0_info(logger, f"COMMAND: {cmd_msg}") log_rank_0_info( logger, f"Saving config to {emmental.Meta.log_path}/parsed_config.yaml") git_hash = "Not able to retrieve git hash" try: git_hash = subprocess.check_output([ "git", "log", "-n", "1", "--pretty=tformat:%h-%ad", "--date=short" ]).strip() except subprocess.CalledProcessError: pass log_rank_0_info(logger, f"Git Hash: {git_hash}")
def try_rmtree(rm_dir): """In the case a resource is open, rmtree will fail. This retries to rmtree after 1 second waits for 5 times. Args: rm_dir: directory to remove Returns: """ num_retries = 0 max_retries = 5 while num_retries < max_retries: try: shutil.rmtree(rm_dir) break except OSError: time.sleep(1) num_retries += 1 if num_retries >= max_retries: log_rank_0_info( logger, f"{rm_dir} was not able to be deleted. This is okay but will have to manually be removed.", )
def __init__( self, main_args, dataset, use_weak_label, entity_symbols, dataset_threads, split="train", ): global_start = time.time() log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.") spawn_method = main_args.run_config.spawn_method data_config = main_args.data_config orig_spawn = multiprocessing.get_start_method() multiprocessing.set_start_method(spawn_method, force=True) self.slice_names = data_utils.get_eval_slices(data_config.eval_slices) self.get_slice_dt = lambda max_a2p: np.dtype([ ("sent_idx", int), ("subslice_idx", int), ("alias_slice_incidence", int, (max_a2p, )), ("prob_labels", float, (max_a2p, )), ]) self.get_storage = lambda max_a2p: np.dtype( [(slice_name, self.get_slice_dt(max_a2p)) for slice_name in self.slice_names]) # Folder for all mmap saved files save_dataset_folder = data_utils.get_save_data_folder( data_config, use_weak_label, dataset) utils.ensure_dir(save_dataset_folder) # Folder for temporary output files temp_output_folder = os.path.join(data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files") utils.ensure_dir(temp_output_folder) # Input step 1 create_ex_indir = os.path.join(temp_output_folder, "create_examples_input") utils.ensure_dir(create_ex_indir) # Input step 2 create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output") utils.ensure_dir(create_ex_outdir) # Meta data saved files meta_file = os.path.join(temp_output_folder, "meta_data.json") # File for standard training data hash = hashlib.sha1(str( self.slice_names).encode("UTF-8")).hexdigest()[:10] self.save_dataset_name = os.path.join(save_dataset_folder, f"ned_slices_{hash}.bin") self.save_data_config_name = os.path.join(save_dataset_folder, "ned_slices_config.json") # ======================================================================================= # SLICE DATA # ======================================================================================= log_rank_0_debug(logger, "Loading dataset...") log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists") if data_config.overwrite_preprocessed_data or (not os.path.exists( self.save_dataset_name)): st_time = time.time() try: log_rank_0_info( logger, f"Building dataset from scratch. Saving to {save_dataset_folder}", ) create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, self.slice_names, use_weak_label, split, ) max_alias2pred = utils.load_json_file( meta_file)["max_alias2pred"] convert_examples_to_features_and_save( meta_file, dataset_threads, self.slice_names, self.save_dataset_name, self.get_storage(max_alias2pred), ) utils.dump_json_file(self.save_data_config_name, {"max_alias2pred": max_alias2pred}) log_rank_0_debug( logger, f"Finished prepping data in {time.time() - st_time}") except Exception as e: tb = traceback.TracebackException.from_exception(e) logger.error(e) logger.error("\n".join(tb.stack.format())) shutil.rmtree(save_dataset_folder, ignore_errors=True) raise log_rank_0_info( logger, f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}", ) max_alias2pred = utils.load_json_file( self.save_data_config_name)["max_alias2pred"] self.data, self.sent_to_row_id_dict = self.build_data_dict( self.save_dataset_name, self.get_storage(max_alias2pred)) assert len(self.data) > 0 assert len(self.sent_to_row_id_dict) > 0 log_rank_0_debug(logger, f"Removing temporary output files") shutil.rmtree(temp_output_folder, ignore_errors=True) # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to # be correctly passed in the collate_fn. multiprocessing.set_start_method(orig_spawn, force=True) log_rank_0_info( logger, f"Final slice data initialization time from {split} is {time.time() - global_start}s", )
def convert_examples_to_features_and_save(meta_file, dataset_threads, slice_names, save_dataset_name, storage): """Converts the prepped examples into input features and saves in memmap files. These are used in the __get_item__ method. Args: meta_file: metadata file where input file paths are saved dataset_threads: number of threads slice_names: list of slice names to evaluation on save_dataset_name: data file name to save storage: data storage type (for memmap) Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_info( logger, f"Starting to build and save features with {num_processes} threads") log_rank_0_debug(logger, f"Counting lines") total_input = utils.load_json_file(meta_file)["num_mentions"] max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"] files_and_counts = utils.load_json_file(meta_file)["files_and_counts"] # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before # being opened in r+ mode by workers memmap_file = np.memmap(save_dataset_name, dtype=storage, mode="w+", shape=(total_input, ), order="C") # Save -1 in sent_idx to check that things are loaded correctly later memmap_file[slice_names[0]]["sent_idx"][:] = -1 input_args = [] # Saves where in memap file to start writing offset = 0 for i, in_file_name in enumerate(files_and_counts.keys()): input_args.append({ "file_name": in_file_name, "in_file_lines": files_and_counts[in_file_name], "save_file_offset": offset, "ex_print_mod": int(np.ceil(total_input / 20)), "slice_names": slice_names, "max_alias2pred": max_alias2pred, }) offset += files_and_counts[in_file_name] if num_processes == 1: assert len(input_args) == 1 total_output = convert_examples_to_features_and_save_single( input_args[0], memmap_file) else: log_rank_0_debug( logger, "Initializing pool. This make take a few minutes.", ) pool = multiprocessing.Pool( processes=num_processes, initializer=convert_examples_to_features_and_save_initializer, initargs=[save_dataset_name, storage], ) total_output = 0 for res in pool.imap_unordered( convert_examples_to_features_and_save_hlp, input_args, chunksize=1): total_output += res pool.close() # Verify that sentences are unique and saved correctly mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r") all_uniq_ids = set() for i in tqdm(range(total_input), desc="Checking sentence uniqueness"): assert (mmap_file[slice_names[0]]["sent_idx"][i] != -1), f"Index {i} has -1 sent idx" uniq_id = str( f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}" ) assert (uniq_id not in all_uniq_ids ), f"Idx {uniq_id} is not unique and already in data" all_uniq_ids.add(uniq_id) log_rank_0_debug( logger, f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. " f"Total lines kept {total_output}", ) return
def create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, slice_names, use_weak_label, split, ): """Creates examples from the raw input data. Args: dataset: dataset file create_ex_indir: temporary directory where input files are stored create_ex_outdir: temporary directory to store output files from method meta_file: metadata file to save the file names/paths for the next step in prep pipeline data_config: data config dataset_threads: number of threads slice_names: list of slices to evaluate on use_weak_label: whether to use weak labeling or not split: data split Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_debug(logger, f"Counting lines") total_input = sum(1 for _ in open(dataset)) if num_processes == 1: out_file_name = os.path.join(create_ex_outdir, os.path.basename(dataset)) constants_dict = { "slice_names": slice_names, "use_weak_label": use_weak_label, "max_aliases": data_config.max_aliases, "split": split, "train_in_candidates": data_config.train_in_candidates, } files_and_counts = {} res = create_examples_single(dataset, total_input, out_file_name, constants_dict) total_output = res["total_lines"] max_alias2pred = res["max_alias2pred"] files_and_counts[res["output_filename"]] = res["total_lines"] else: log_rank_0_info( logger, f"Strating to extract examples with {num_processes} threads") log_rank_0_debug( logger, "Parallelizing with " + str(num_processes) + " threads.") chunk_input = int(np.ceil(total_input / num_processes)) log_rank_0_debug( logger, f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines", ) total_input_from_chunks, input_files_dict = utils.chunk_file( dataset, create_ex_indir, chunk_input) input_files = list(input_files_dict.keys()) input_file_lines = [input_files_dict[k] for k in input_files] output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] assert ( total_input == total_input_from_chunks ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}" log_rank_0_debug(logger, f"Done chunking files") pool = multiprocessing.Pool( processes=num_processes, initializer=create_examples_initializer, initargs=[ data_config, slice_names, use_weak_label, data_config.max_aliases, split, data_config.train_in_candidates, ], ) total_output = 0 max_alias2pred = 0 input_args = list(zip(input_files, input_file_lines, output_files)) # Store output files and counts for saving in next step files_and_counts = {} for res in pool.imap_unordered(create_examples_hlp, input_args, chunksize=1): total_output += res["total_lines"] max_alias2pred = max(max_alias2pred, res["max_alias2pred"]) files_and_counts[res["output_filename"]] = res["total_lines"] pool.close() utils.dump_json_file( meta_file, { "num_mentions": total_output, "files_and_counts": files_and_counts, "max_alias2pred": max_alias2pred, }, ) log_rank_0_debug( logger, f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. " f"Total lines kept {total_output}.", ) return
def get_dataloaders( args, tasks, splits, entity_symbols, batch_on_the_fly_kg_adj, ): """Gets the dataloaders. Args: args: main args tasks: task names splits: data splits to generate dataloaders for entity_symbols: entity symbols batch_on_the_fly_kg_adj: kg embeddings metadata for the __get_item__ method (see get_dataloader_embeddings) Returns: list of dataloaders """ task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks} is_bert = len(args.data_config.word_embedding.bert_model) > 0 tokenizer = BertTokenizer.from_pretrained( args.data_config.word_embedding.bert_model, do_lower_case=True if "uncased" in args.data_config.word_embedding.bert_model else False, cache_dir=args.data_config.word_embedding.cache_dir, ) datasets = {} for split in splits: dataset_path = os.path.join(args.data_config.data_dir, args.data_config[f"{split}_dataset"].file) datasets[split] = BootlegDataset( main_args=args, name=f"Bootleg", dataset=dataset_path, use_weak_label=args.data_config[f"{split}_dataset"].use_weak_label, tokenizer=tokenizer, entity_symbols=entity_symbols, dataset_threads=args.run_config.dataset_threads, split=split, is_bert=is_bert, batch_on_the_fly_kg_adj=batch_on_the_fly_kg_adj, ) dataloaders = [] for split, dataset in datasets.items(): if split in args.learner_config.train_split: dataset_sampler = (RandomSampler(dataset) if Meta.config["learner_config"]["local_rank"] == -1 else DistributedSampler(dataset)) else: dataset_sampler = None if Meta.config["learner_config"]["local_rank"] != -1: log_rank_0_info( logger, f"You are using distributed computing for eval. We are not using a distributed sampler. " f"Please use DataParallel and not DDP.", ) dataloaders.append( EmmentalDataLoader( task_to_label_dict=task_to_label_dict, dataset=dataset, sampler=dataset_sampler, split=split, collate_fn=bootleg_collate_fn, batch_size=args.train_config.batch_size if split in args.learner_config.train_split or args.run_config.eval_batch_size is None else args.run_config.eval_batch_size, num_workers=args.run_config.dataloader_threads, pin_memory=False, )) log_rank_0_info( logger, f"Built dataloader for {split} set with {len(dataset)} and {args.run_config.dataloader_threads} threads " f"samples (Shuffle={split in args.learner_config.train_split}, " f"Batch size={dataloaders[-1].batch_size}).", ) return dataloaders
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TypeEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "max_types", "type_dim", "type_labels", "type_vocab", "merge_func", "attn_hidden_size", "regularize_mapping", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "max_types" in emb_args), "Type embedding requires max_types to be set in args" assert ( "type_dim" in emb_args), "Type embedding requires type_dim to be set in args" assert ( "type_labels" in emb_args ), "Type embedding requires type_labels to be set in args. A Dict from QID -> TypeId or TypeName" assert ( "type_vocab" in emb_args ), "Type embedding requires type_vocab to be set in args. A Dict from TypeName -> TypeId" assert (self.cpu is False ), f"We don't support putting type embeddings on CPU right now" self.merge_func = self.average_types self.orig_dim = emb_args.type_dim self.add_attn = None # Function for merging multiple types if "merge_func" in emb_args: assert emb_args.merge_func in [ "average", "addattn", ], (f"{key}: You have set the type merge_func to be {emb_args.merge_func} but" f" that is not in the allowable list of [average, addattn]") if emb_args.merge_func == "addattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 # Softmax of types using the sentence context self.add_attn = PositionAwareAttention( input_size=self.orig_dim, attn_size=attn_hidden_size, feature_size=0) self.merge_func = self.add_attn_merge self.max_types = emb_args.max_types ( eid2typeids_table, self.type2row_dict, num_types_with_unk, self.prep_file, ) = self.prep( data_config=main_args.data_config, emb_args=emb_args, entity_symbols=entity_symbols, ) self.register_buffer("eid2typeids_table", eid2typeids_table, persistent=False) # self.eid2typeids_table.requires_grad = False self.num_types_with_pad_and_unk = num_types_with_unk + 1 # Regularization mapping goes from typeid to 2d dropout percent if "regularize_mapping" in emb_args: typeid2reg = torch.zeros(self.num_types_with_pad_and_unk) else: typeid2reg = None if not self.from_pretrained: if "regularize_mapping" in emb_args: if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: logger.warning( f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?" ) log_rank_0_info( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) typeid2reg = self.load_regularization_mapping( main_args.data_config, self.num_types_with_pad_and_unk, self.type2row_dict, emb_args.regularize_mapping, ) self.register_buffer("typeid2reg", typeid2reg) assert self.eid2typeids_table.shape[1] == emb_args.max_types, ( f"Something went wrong with loading type file." f" The given max types {emb_args.max_types} does not match that " f"of type table {self.eid2typeids_table.shape[1]}") log_rank_0_debug( logger, f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. " f"Setting merge_func to be {self.merge_func.__name__} in type emb.", )
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file