def update_classes_info(self, classes_in_split_to_idx_map=None, is_seen_class_indicator=None): """ Sets the parameters used to determine if early stopping should occur. :param classes_in_split_to_idx_map: A dictionary that maps classes names to their corresponding index, as defined in the dataset being used. :param is_seen_class_indicator : A dictionary that indicates whether a class belongs to the 'seen' set or the 'unseen' set. :return: Nothing """ # Get the correct load_path for the sub-F1 metrics. if (self._optional_seen_unseen != ''): load_path = None if self._load_path is None else join_path( get_path_components(self._load_path)[:-1]) load_path = None if self._load_path is None else join_path( [load_path] + [ self._optional_seen_unseen + '_' + self._name + '_sub_metrics' + os.sep ]) else: load_path = None if self._load_path is None else join_path( [self._load_path] + [self._name + '_sub_metrics' + os.sep]) class_overlap = False for new_class in classes_in_split_to_idx_map: if (new_class not in self._classes_in_split_to_idx_map): class_file_name = re.sub('/', '-', re.sub(' ', '_', new_class)) class_load_path = None if self._load_path is None else join_path( [load_path] + [class_file_name + '_']) self._classes_in_split_to_idx_map[ new_class] = classes_in_split_to_idx_map[new_class] self._classes_individual_F1_scores[new_class] = F1( self._evaluated_on_batches, self._max_num_batches, self._accumulate_n_batch_grads, self._training_batch_sizes, class_load_path, self._load_epoch_or_best, len(self._batch_history_x) + 1, new_class) # self._is_seen_class_indicator[new_class] = is_seen_class_indicator[new_class] else: class_overlap = True self._idx_to_classes_in_split_map = [ _class for _class, idx in sorted( self._classes_in_split_to_idx_map.items(), key=lambda x: x[1]) ] if (class_overlap): print_warning( "When adding new classes to the 'MacroF1' metric, some of these new classes were already " + "being tracked. The new setting WILL BE IGNORED.")
def __init__(self, short_name, evaluated_on_batches, max_num_batches, accumulate_n_batch_grads, training_batch_sizes, lower_is_better, load_path, load_epoch_or_best, starting_batch_num): """ Instantiates a Metric (or child of Metric) object.TODO: Missing parameters :param short_name : The short name by which the metric is identified. :param evaluated_on_batches: Determines whether the metric is evaluated for minibatches. Used for plotting. :param max_num_batches : The maximum number of batches in an epoch. Used to aggregate epoch information. See specific uses in child classes. :param lower_is_better : Identifies whether a lower value of the metric indicates that the model performs better. :param load_path : The path that leads to a saved state of an instance of this metric. :param load_epoch_or_best : This parameter identifies whether to load the state associated with a specific epoch or the state of the epoch in which the model performed the best. :param starting_batch_num : Some metrics can be added to the MetricsManager after the model has already been trained for a few epochs. This parameter identifies at which batch num the metric was started. """ # Names. self._name = self.__class__.__name__ self._short_name = short_name # Load previously saved metric state. self._load_path = load_path self._load_epoch_or_best = load_epoch_or_best extra_name = '' if (load_path is not None): if (load_path[-1] != os.sep): path_components = get_path_components(load_path) extra_name = path_components[-1] load_path = join_path(path_components[:-1]) + os.sep self._metric_state = trainer_helpers.load_checkpoint_state( load_path, extra_name + self._name, load_epoch_or_best) # Evaluation history. self._batch_history = [] if self._metric_state is None else self._metric_state[ 'batch_history'] self._batch_history_x = [] if self._metric_state is None else self._metric_state[ 'batch_history_x'] self._epoch_history = [] if self._metric_state is None else self._metric_state[ 'epoch_history'] self._epoch_history_x = [] if self._metric_state is None else self._metric_state[ 'epoch_history_x'] # Used to (only) evaluate the model's performance, so that no state of the singular evaluation is kept. self._eval_only_history = None # Batch information. self._evaluated_on_batches = evaluated_on_batches self._max_num_batches = max_num_batches self._accumulate_n_batch_grads = accumulate_n_batch_grads self._training_batch_sizes = training_batch_sizes if (training_batch_sizes[0] is None): self._last_acc_batch_start_num = None self._normal_acc_batch_size = None self._last_acc_batch_size = None else: max_b = max_num_batches acc = accumulate_n_batch_grads self._last_acc_batch_start_num = max_b - max_b % acc + ( 1 if max_b % acc >= 1 else -acc + 1) self._normal_acc_batch_size = acc * training_batch_sizes[0] self._last_acc_batch_size = (max_b - self._last_acc_batch_start_num ) * training_batch_sizes[0] self._last_acc_batch_size += training_batch_sizes[1] self._starting_batch_num = starting_batch_num #TODO: Should load from memory for existing metrics # Early stopping and save best parameters self._lower_is_better = lower_is_better self._early_stop_tolerance = None self._patience = None # Variables used to determine the length of the information to be printed. self._print_info_max_length = 0 # Class information parameters. Used in metrics such as F1, MacroF1 and HarmonicMacroF1. self._classes_in_split_to_idx_map = {} self._idx_to_classes_in_split_map = [] self._is_seen_class_indicator = {}
def load_checkpoint_state(model_to_load_path, file_type, checkpoint_or_best, device=None): """ load_check_point_state() allows loading a previously saved state of any element of the trainer, be it the ModelEvaluationModule, the MetricsManager or any Metric. The element to be loaded is defined by the input parameters. :param model_to_load_path : This parameter is the path that leads to the directory of the element to be loaded. :param element_to_load_type: This parameter identifies the type of element to be loaded. :param checkpoint_or_best : This parameter identifies the saved state's specific epoch to load. :return: Returns the corresponding state_dict or None, if no valid model_to_load_path and checkpoint_or_best was provided. """ loaded_state = None if (model_to_load_path is not None and checkpoint_or_best is not None and isinstance(checkpoint_or_best, int)): # Concatenates the path of the specific element to load with the general path of where models are saved. load_path = model_to_load_path # Returns the names of the existing files in the load_path directory existing_files = os.listdir(load_path) # Loads the state associated with the best model if (checkpoint_or_best == -1 and ".bst" in {existing_file[-4:] for existing_file in existing_files}): path = load_path path += [ file for file in existing_files if (file.endswith(".bst") and file.startswith(file_type)) ][0] loaded_state = torch.load(path, map_location=device) # Loads the state associated with a specific checkpoint elif (checkpoint_or_best >= 0 and os.path.exists(load_path + file_type + "." + str(checkpoint_or_best) + ".ckp")): loaded_state = torch.load(load_path + file_type + "." + str(checkpoint_or_best) + ".ckp", map_location=device) # The desired epoch has not yet been computed. Issue and error. else: split_path = get_path_components(load_path) if (split_path[-1] in list(_names.METRICS) + ["Loss"]): # This means that the state of a metric was meant to be loaded. raise ValueError("Tried to load " + split_path[-1] + "'s' (" + split_path[-2] + ") state dict for " + "invalid or inexisting epoch (epoch num: " + str(checkpoint_or_best) + ").") else: # This means that the state of either the MEM or the MMa was meant to be loaded. raise ValueError( "Tried to load " + split_path[-1] + "'s state dict for invalid or inexisting epoch " + "(epoch num: " + str(checkpoint_or_best) + ").") return loaded_state
def __init__(self, file_name, setting, batch_size_xs, labels_to_load, num_workers_xs=0, eval_all_descriptions=True, dataset_debug=None, full_prob_interp=False): """ Instantiates a UW_RE_UVA dataset object for the specific split associated with 'file_name'. TODO: Missing variables. :param file_name : The (almost complete) path to the data, in storage, pertaining this dataset. :param batch_size_xs : The number of data instances (sentences x) (and respective labels) to load. :param labels_to_load : A list that indicates what kind of labels are meant to be loaded by this dataset. :param num_workers_xs : The number of sub-processes loading data instances from storage. :param eval_all_descriptions: Determines whether to evaluate against all relation descriptions or not. """ # Input variables self._file_name = file_name self._batch_size_xs = batch_size_xs self._labels_to_load = labels_to_load self._num_workers_xs = num_workers_xs self._eval_all_descriptions = eval_all_descriptions # TODO: Always load the supervised labels, as they are necessary for evaluation in any of the splits. # TODO: Instead change the labels_to_load to refer uniquely to which unsupervised labels are meant to be loaded. # Instantiate the sub-dataset that will load the data instances and the corresponding labels. self._x_sentences_dataset = UW_RE_UVA_xs(file_name, setting, labels_to_load, dataset_debug, full_prob_interp) # Determine if this pertains a train split. If so we sample one relation description per relation in the split. self._train_split = self._x_sentences_dataset.train_split if (self._batch_size_xs == -1 or self._batch_size_xs > self._x_sentences_dataset.num_instances): if (self._batch_size_xs > self._x_sentences_dataset.num_instances): print_warning("Split: " + self._x_sentences_dataset.split + " | Requested Batch Size is bigger than " + "available number of instances. Setting Batch Size to be equal to the available " + "number of instances.") self._batch_size_xs = self._x_sentences_dataset.num_instances self._x_sentences_dataloader = DataLoader(self._x_sentences_dataset, batch_size=self._batch_size_xs, shuffle=self._train_split, num_workers=num_workers_xs, collate_fn=PadCollate(self._x_sentences_dataset.valid_labels)) self._x_sentences_dataloader_iter = iter(self._x_sentences_dataloader) # We create a generator that allows us to sample from the x_sentences_dataset indefinitely. def get_subdataset_batch(dl): while (True): for batch in dl: yield batch self._infinite_x_sentences_dl = get_subdataset_batch(self._x_sentences_dataloader) # Determine the correct length of this meta-dataset num_full_x_batches = len(self._x_sentences_dataset) // self._batch_size_xs equal_size_last_x_batch = len(self._x_sentences_dataset) % self._batch_size_xs == 0 self._len_this_meta_dataset = num_full_x_batches + (0 if equal_size_last_x_batch else 1) # Load the list of relations involved in this split (in the file there is one relation per line) self._relations_in_split = self._x_sentences_dataset.relations_in_split # Load the map between relations and the indices of relation_descriptions path_components = get_path_components(file_name) with open(join_path(path_components[:-6] + ["ELMO_rel_descs_idxs.map"]), 'rb') as f: self._relation_to_idxs_rel_description = pickle.load(f) # The map above conveys always the same information for a specific instance of a dataset split, so: self._descs_idxs = [self._relation_to_idxs_rel_description[rel] for rel in self._relations_in_split] # Remove duplicate indices (for relation descriptions that might describe more than one relation). self._descs_idxs_all = [] for idxs_set in self._descs_idxs: for idx in idxs_set: if (idx not in self._descs_idxs_all): self._descs_idxs_all.append(idx) # Indicates which indices are meant to be aggregated together, when comparing against multiple relation # descriptions for the same relation. self._aggregate = [] for idxs_set in self._descs_idxs: map_to_output_idxs = [] for idx in idxs_set: map_to_output_idxs.append(self._descs_idxs_all.index(idx)) self._aggregate.append(torch.LongTensor(map_to_output_idxs))
def __init__(self, file_name, setting, labels_to_load, dataset_debug=None, full_prob_interp=False): """ Instantiates a UW_RE_UVA_xs dataset object for the specific split associated with 'file_name'. :param file_name : The (almost complete) path to the data, in storage, pertaining this dataset. :param labels_to_load: A list that indicates what kind of labels are meant to be loaded by this dataset. """ self._file_name = file_name self._labels_to_load = labels_to_load self._full_prob_interp = full_prob_interp path_components = get_path_components(file_name) debug_file = "DEBUG_" if path_components[-4] == 'DEBUG' else "" self._sentences_hdf5_file = join_path(path_components[:-4] + [debug_file + "sentences_to_ELMo.hdf5"]) self._setting = setting self._split = path_components[-1] # Catch undesired configurations from the start. if (self._setting != 'N' and self._setting[0] != 'G' and full_prob_interp): raise RuntimeError("A full probabilistic interpretation is desired, but the chosen setting (" + self._setting + ') is incompatible with such specification. Compatible settings are: ' + "Normal (N); Generalised Any-Shot settings (GZS-(O/C), GFS-1, GFS-2, GFS-5, GFS-10).") # print('\n\n\n') DEBUG_EXPERIMENT = dataset_debug # print('DEBUG_EXPERIMENT:', dataset_debug) if (DEBUG_EXPERIMENT is not None and self._split == 'test'): self._file_name = join_path(get_path_components(self._file_name)[:-1] + ['val']) # print('FILE NAME:', self._file_name) # Get the instances indices (on the main hdf5 file) and determine the number of instances in this split. with open(self._file_name + ".idxs", 'rb') as f: loaded_instances = pickle.load(f) if (self._split != 'train'): # print(self._split, ' ORIGINAL INSTANCE INDICES', loaded_instances) if (DEBUG_EXPERIMENT == 'instances'): self._instances_indices = loaded_instances[(0 if self._split == 'val' else 1)::2] elif (DEBUG_EXPERIMENT == 'classes'): seen_classes = set() seen_instances = [] seen_labels = [] if (self._setting[0] == 'G'): with open(join_path(path_components[:-1] + ["train_relations.txt"]), 'r', encoding='utf-8') as f: for line in f: seen_classes.add(line.strip()) with open(self._file_name + '.lbs', 'rb') as f: temp_labels = pickle.load(f) count_class_elements = {} for inst_num, instance in enumerate(loaded_instances): if (temp_labels[inst_num] not in seen_classes): if (temp_labels[inst_num] not in count_class_elements): count_class_elements[temp_labels[inst_num]] = [instance] else: count_class_elements[temp_labels[inst_num]].append(instance) else: seen_instances.append(instance) seen_labels.append(temp_labels[inst_num]) sorted_by_class_num_insts = sorted(count_class_elements.items(), key=lambda kv: len(kv[1]), reverse=True) classes_and_instances = sorted_by_class_num_insts[(1 if self._split == 'val' else 0)::2] self._labels = [_class for _class_set, inst_nums in classes_and_instances for _class in [_class_set] * len(inst_nums)] + seen_labels self._instances_indices = [inst_num for _, inst_nums in classes_and_instances for inst_num in inst_nums] + seen_instances elif (DEBUG_EXPERIMENT is None): self._instances_indices = loaded_instances else: raise ValueError('Wrong DEBUG_EXPERIMENT value.') else: self._instances_indices = loaded_instances self._num_instances = len(self._instances_indices) self._valid_labels = [] # Load the corresponding supervised labels (text based label, i.e. the actual relation name). if ('supervised_lbls' in labels_to_load): with open(self._file_name + '.lbs', 'rb') as f: loaded_labels = pickle.load(f) if (self._split != 'train'): # print(self._split, ' ORIGINAL LABELS', loaded_labels) if (DEBUG_EXPERIMENT == 'instances'): self._labels = loaded_labels[(0 if self._split == 'val' else 1)::2] elif (DEBUG_EXPERIMENT == 'classes'): pass elif (DEBUG_EXPERIMENT is None): self._labels = loaded_labels else: raise ValueError('Wrong DEBUG_EXPERIMENT value.') else: self._labels = loaded_labels self._valid_labels.append('supervised_lbls') # TODO: remove this when done debugging weird results on test set. set_labels = set(self._labels) # Load the list of classes involved in this split. For some settings (like Few-Shot settings or Generalised # Any-Shot settings) the classes of another split might be present in this specific split. self._relations_in_split = [] self._relation_in_split_to_idx_map = {} self._is_seen_class_indicator = {} self._exclusive_split_relations = [] with open(self._file_name + "_relations.txt", 'r', encoding='utf-8') as f: r_num = 0 for line in f: # for r_num, line in enumerate(f): relation = line.strip() if (relation in set_labels): self._exclusive_split_relations.append(relation) self._relations_in_split.append(relation) self._relation_in_split_to_idx_map[relation] = r_num r_num += 1 if (self._setting[0] == 'G' or self._setting[0] == 'F' or self._setting == 'ZS-C'): # If this is a Generalised setting and a validation or test split, get the train classes. if (self._setting[0] == 'G' and self._split != 'train'): for relation in self._exclusive_split_relations: self._is_seen_class_indicator[relation] = False self._relations_from_another_split = [] with open(join_path(path_components[:-1] + ["train_relations.txt"]), 'r', encoding='utf-8') as f: for r_num, line in enumerate(f): relation = line.strip() self._relations_from_another_split.append(relation) self._relations_in_split.append(relation) self._relation_in_split_to_idx_map[relation] = r_num + len(self._exclusive_split_relations) self._is_seen_class_indicator[relation] = True # If this is a Few-Shot setting and a train split, get the validation and test classes. # Also, if this is a (Generalised) Zero-Shot Closed setting, get the validation and test classes. if ((self._setting[0] == 'F' or self._setting[1] == 'F' or self._setting == 'ZS-C' or self._setting == 'GZS-C') and self._split == 'train'): for relation in self._exclusive_split_relations: self._is_seen_class_indicator[relation] = True # TODO: This assumes there's a validation split, which might not be the case. self._relations_from_another_split = [] for split in (['val', 'test'] if DEBUG_EXPERIMENT is None else ['val']): with open(join_path(path_components[:-1] + [split + "_relations.txt"]), 'r', encoding='utf-8') as f: for r_num, line in enumerate(f): relation = line.strip() self._relation_in_split_to_idx_map[relation] = len(self._relations_from_another_split) + len(self._exclusive_split_relations) self._relations_from_another_split.append(relation) self._relations_in_split.append(relation) self._is_seen_class_indicator[relation] = False # If this pertains a train split, load the unsupervised sentence labels, if they have been requested. # TODO: maybe some of these checks can be done when reading labels_to_load with argparse. if ('u_sentence_lbls' in labels_to_load and self._split == 'train'): with open(self._file_name + '.ulbs', 'rb') as f: self._data_as_idxs = pickle.load(f) self._valid_labels.append('u_sentence_lbls')