Exemple #1
0
    def update_classes_info(self,
                            classes_in_split_to_idx_map=None,
                            is_seen_class_indicator=None):
        """
        Sets the parameters used to determine if early stopping should occur.


        :param classes_in_split_to_idx_map: A dictionary that maps classes names to their corresponding index, as
                                            defined in the dataset being used.

        :param is_seen_class_indicator    : A dictionary that indicates whether a class belongs to the 'seen' set or
                                            the 'unseen' set.


        :return: Nothing
        """

        # Get the correct load_path for the sub-F1 metrics.
        if (self._optional_seen_unseen != ''):
            load_path = None if self._load_path is None else join_path(
                get_path_components(self._load_path)[:-1])
            load_path = None if self._load_path is None else join_path(
                [load_path] + [
                    self._optional_seen_unseen + '_' + self._name +
                    '_sub_metrics' + os.sep
                ])
        else:
            load_path = None if self._load_path is None else join_path(
                [self._load_path] + [self._name + '_sub_metrics' + os.sep])

        class_overlap = False
        for new_class in classes_in_split_to_idx_map:
            if (new_class not in self._classes_in_split_to_idx_map):
                class_file_name = re.sub('/', '-', re.sub(' ', '_', new_class))
                class_load_path = None if self._load_path is None else join_path(
                    [load_path] + [class_file_name + '_'])

                self._classes_in_split_to_idx_map[
                    new_class] = classes_in_split_to_idx_map[new_class]
                self._classes_individual_F1_scores[new_class] = F1(
                    self._evaluated_on_batches, self._max_num_batches,
                    self._accumulate_n_batch_grads, self._training_batch_sizes,
                    class_load_path, self._load_epoch_or_best,
                    len(self._batch_history_x) + 1, new_class)

                # self._is_seen_class_indicator[new_class] = is_seen_class_indicator[new_class]

            else:
                class_overlap = True

        self._idx_to_classes_in_split_map = [
            _class for _class, idx in sorted(
                self._classes_in_split_to_idx_map.items(), key=lambda x: x[1])
        ]

        if (class_overlap):
            print_warning(
                "When adding new classes to the 'MacroF1' metric, some of these new classes were already "
                + "being tracked. The new setting WILL BE IGNORED.")
    def __init__(self, short_name, evaluated_on_batches, max_num_batches,
                 accumulate_n_batch_grads, training_batch_sizes,
                 lower_is_better, load_path, load_epoch_or_best,
                 starting_batch_num):
        """
        Instantiates a Metric (or child of Metric) object.TODO: Missing parameters


        :param short_name          : The short name by which the metric is identified.

        :param evaluated_on_batches: Determines whether the metric is evaluated for minibatches. Used for plotting.

        :param max_num_batches     : The maximum number of batches in an epoch. Used to aggregate epoch information. See
                                     specific uses in child classes.

        :param lower_is_better     : Identifies whether a lower value of the metric indicates that the model performs
                                     better.

        :param load_path           : The path that leads to a saved state of an instance of this metric.

        :param load_epoch_or_best  : This parameter identifies whether to load the state associated with a specific
                                     epoch or the state of the epoch in which the model performed the best.

        :param starting_batch_num  : Some metrics can be added to the MetricsManager after the model has already been
                                     trained for a few epochs. This parameter identifies at which batch num the metric
                                     was started.
        """

        # Names.
        self._name = self.__class__.__name__
        self._short_name = short_name

        # Load previously saved metric state.
        self._load_path = load_path
        self._load_epoch_or_best = load_epoch_or_best

        extra_name = ''
        if (load_path is not None):
            if (load_path[-1] != os.sep):
                path_components = get_path_components(load_path)
                extra_name = path_components[-1]
                load_path = join_path(path_components[:-1]) + os.sep
        self._metric_state = trainer_helpers.load_checkpoint_state(
            load_path, extra_name + self._name, load_epoch_or_best)

        # Evaluation history.
        self._batch_history = [] if self._metric_state is None else self._metric_state[
            'batch_history']
        self._batch_history_x = [] if self._metric_state is None else self._metric_state[
            'batch_history_x']
        self._epoch_history = [] if self._metric_state is None else self._metric_state[
            'epoch_history']
        self._epoch_history_x = [] if self._metric_state is None else self._metric_state[
            'epoch_history_x']

        # Used to (only) evaluate the model's performance, so that no state of the singular evaluation is kept.
        self._eval_only_history = None

        # Batch information.
        self._evaluated_on_batches = evaluated_on_batches
        self._max_num_batches = max_num_batches
        self._accumulate_n_batch_grads = accumulate_n_batch_grads
        self._training_batch_sizes = training_batch_sizes
        if (training_batch_sizes[0] is None):
            self._last_acc_batch_start_num = None
            self._normal_acc_batch_size = None
            self._last_acc_batch_size = None
        else:
            max_b = max_num_batches
            acc = accumulate_n_batch_grads
            self._last_acc_batch_start_num = max_b - max_b % acc + (
                1 if max_b % acc >= 1 else -acc + 1)
            self._normal_acc_batch_size = acc * training_batch_sizes[0]
            self._last_acc_batch_size = (max_b - self._last_acc_batch_start_num
                                         ) * training_batch_sizes[0]
            self._last_acc_batch_size += training_batch_sizes[1]
        self._starting_batch_num = starting_batch_num  #TODO: Should load from memory for existing metrics

        # Early stopping and save best parameters
        self._lower_is_better = lower_is_better
        self._early_stop_tolerance = None
        self._patience = None

        # Variables used to determine the length of the information to be printed.
        self._print_info_max_length = 0

        # Class information parameters. Used in metrics such as F1, MacroF1 and HarmonicMacroF1.
        self._classes_in_split_to_idx_map = {}
        self._idx_to_classes_in_split_map = []
        self._is_seen_class_indicator = {}
def load_checkpoint_state(model_to_load_path,
                          file_type,
                          checkpoint_or_best,
                          device=None):
    """
    load_check_point_state() allows loading a previously saved state of any element of the trainer, be it the
    ModelEvaluationModule, the MetricsManager or any Metric. The element to be loaded is defined by the input
    parameters.


    :param model_to_load_path  : This parameter is the path that leads to the directory of the element to be loaded.

    :param element_to_load_type: This parameter identifies the type of element to be loaded.

    :param checkpoint_or_best  : This parameter identifies the saved state's specific epoch to load.


    :return: Returns the corresponding state_dict or None, if no valid model_to_load_path and checkpoint_or_best was
             provided.
    """

    loaded_state = None

    if (model_to_load_path is not None and checkpoint_or_best is not None
            and isinstance(checkpoint_or_best, int)):

        # Concatenates the path of the specific element to load with the general path of where models are saved.
        load_path = model_to_load_path

        # Returns the names of the existing files in the load_path directory
        existing_files = os.listdir(load_path)

        # Loads the state associated with the best model
        if (checkpoint_or_best == -1 and ".bst"
                in {existing_file[-4:]
                    for existing_file in existing_files}):
            path = load_path
            path += [
                file for file in existing_files
                if (file.endswith(".bst") and file.startswith(file_type))
            ][0]
            loaded_state = torch.load(path, map_location=device)

        # Loads the state associated with a specific checkpoint
        elif (checkpoint_or_best >= 0
              and os.path.exists(load_path + file_type + "." +
                                 str(checkpoint_or_best) + ".ckp")):
            loaded_state = torch.load(load_path + file_type + "." +
                                      str(checkpoint_or_best) + ".ckp",
                                      map_location=device)

        # The desired epoch has not yet been computed. Issue and error.
        else:
            split_path = get_path_components(load_path)
            if (split_path[-1] in list(_names.METRICS) + ["Loss"]):
                # This means that the state of a metric was meant to be loaded.
                raise ValueError("Tried to load " + split_path[-1] + "'s' (" +
                                 split_path[-2] + ") state dict for " +
                                 "invalid or inexisting epoch (epoch num: " +
                                 str(checkpoint_or_best) + ").")
            else:
                # This means that the state of either the MEM or the MMa was meant to be loaded.
                raise ValueError(
                    "Tried to load " + split_path[-1] +
                    "'s state dict for invalid or inexisting epoch " +
                    "(epoch num: " + str(checkpoint_or_best) + ").")

    return loaded_state
Exemple #4
0
    def __init__(self, file_name, setting, batch_size_xs, labels_to_load, num_workers_xs=0,
                 eval_all_descriptions=True, dataset_debug=None, full_prob_interp=False):
        """
        Instantiates a UW_RE_UVA dataset object for the specific split associated with 'file_name'.
        TODO: Missing variables.


        :param file_name            : The (almost complete) path to the data, in storage, pertaining this dataset.

        :param batch_size_xs        : The number of data instances (sentences x) (and respective labels) to load.

        :param labels_to_load       : A list that indicates what kind of labels are meant to be loaded by this dataset.

        :param num_workers_xs       : The number of sub-processes loading data instances from storage.

        :param eval_all_descriptions: Determines whether to evaluate against all relation descriptions or not.
        """

        # Input variables
        self._file_name             = file_name
        self._batch_size_xs         = batch_size_xs
        self._labels_to_load        = labels_to_load
        self._num_workers_xs        = num_workers_xs
        self._eval_all_descriptions = eval_all_descriptions


        # TODO: Always load the supervised labels, as they are necessary for evaluation in any of the splits.
        # TODO: Instead change the labels_to_load to refer uniquely to which unsupervised labels are meant to be loaded.

        # Instantiate the sub-dataset that will load the data instances and the corresponding labels.
        self._x_sentences_dataset = UW_RE_UVA_xs(file_name, setting, labels_to_load, dataset_debug, full_prob_interp)

        # Determine if this pertains a train split. If so we sample one relation description per relation in the split.
        self._train_split = self._x_sentences_dataset.train_split

        if (self._batch_size_xs == -1 or self._batch_size_xs > self._x_sentences_dataset.num_instances):
            if (self._batch_size_xs > self._x_sentences_dataset.num_instances):
                print_warning("Split: " + self._x_sentences_dataset.split + " | Requested Batch Size is bigger than " +
                              "available number of instances. Setting Batch Size to be equal to the available " +
                              "number of instances.")
            self._batch_size_xs = self._x_sentences_dataset.num_instances
        self._x_sentences_dataloader = DataLoader(self._x_sentences_dataset, batch_size=self._batch_size_xs,
                                                  shuffle=self._train_split, num_workers=num_workers_xs,
                                                  collate_fn=PadCollate(self._x_sentences_dataset.valid_labels))
        self._x_sentences_dataloader_iter = iter(self._x_sentences_dataloader)

        # We create a generator that allows us to sample from the x_sentences_dataset indefinitely.
        def get_subdataset_batch(dl):
            while (True):
                for batch in dl:
                    yield batch
        self._infinite_x_sentences_dl = get_subdataset_batch(self._x_sentences_dataloader)

        # Determine the correct length of this meta-dataset
        num_full_x_batches = len(self._x_sentences_dataset) // self._batch_size_xs
        equal_size_last_x_batch = len(self._x_sentences_dataset) % self._batch_size_xs == 0
        self._len_this_meta_dataset = num_full_x_batches + (0 if equal_size_last_x_batch else 1)

        # Load the list of relations involved in this split (in the file there is one relation per line)
        self._relations_in_split = self._x_sentences_dataset.relations_in_split

        # Load the map between relations and the indices of relation_descriptions
        path_components = get_path_components(file_name)
        with open(join_path(path_components[:-6] + ["ELMO_rel_descs_idxs.map"]), 'rb') as f:
            self._relation_to_idxs_rel_description = pickle.load(f)

        # The map above conveys always the same information for a specific instance of a dataset split, so:
        self._descs_idxs = [self._relation_to_idxs_rel_description[rel] for rel in self._relations_in_split]

        # Remove duplicate indices (for relation descriptions that might describe more than one relation).
        self._descs_idxs_all = []
        for idxs_set in self._descs_idxs:
            for idx in idxs_set:
                if (idx not in self._descs_idxs_all):
                    self._descs_idxs_all.append(idx)

        # Indicates which indices are meant to be aggregated together, when comparing against multiple relation
        # descriptions for the same relation.
        self._aggregate = []
        for idxs_set in self._descs_idxs:
            map_to_output_idxs = []
            for idx in idxs_set:
                map_to_output_idxs.append(self._descs_idxs_all.index(idx))
            self._aggregate.append(torch.LongTensor(map_to_output_idxs))
Exemple #5
0
    def __init__(self, file_name, setting, labels_to_load, dataset_debug=None, full_prob_interp=False):
        """
        Instantiates a UW_RE_UVA_xs dataset object for the specific split associated with 'file_name'.


        :param file_name     : The (almost complete) path to the data, in storage, pertaining this dataset.

        :param labels_to_load: A list that indicates what kind of labels are meant to be loaded by this dataset.
        """

        self._file_name        = file_name
        self._labels_to_load   = labels_to_load
        self._full_prob_interp = full_prob_interp

        path_components = get_path_components(file_name)
        debug_file = "DEBUG_" if path_components[-4] == 'DEBUG' else ""
        self._sentences_hdf5_file = join_path(path_components[:-4] + [debug_file + "sentences_to_ELMo.hdf5"])
        self._setting = setting
        self._split = path_components[-1]


        # Catch undesired configurations from the start.
        if (self._setting != 'N' and self._setting[0] != 'G' and full_prob_interp):
            raise RuntimeError("A full probabilistic interpretation is desired, but the chosen setting (" +
                               self._setting + ') is incompatible with such specification. Compatible settings are: ' +
                               "Normal (N); Generalised Any-Shot settings (GZS-(O/C), GFS-1, GFS-2, GFS-5, GFS-10).")


        # print('\n\n\n')
        DEBUG_EXPERIMENT = dataset_debug
        # print('DEBUG_EXPERIMENT:', dataset_debug)
        if (DEBUG_EXPERIMENT is not None and self._split == 'test'):
            self._file_name = join_path(get_path_components(self._file_name)[:-1] + ['val'])
            # print('FILE NAME:', self._file_name)

        # Get the instances indices (on the main hdf5 file) and determine the number of instances in this split.
        with open(self._file_name + ".idxs", 'rb') as f:
            loaded_instances = pickle.load(f)
            if (self._split != 'train'):
                # print(self._split, ' ORIGINAL INSTANCE INDICES', loaded_instances)
                if (DEBUG_EXPERIMENT == 'instances'):
                    self._instances_indices = loaded_instances[(0 if self._split == 'val' else 1)::2]
                elif (DEBUG_EXPERIMENT == 'classes'):
                    seen_classes = set()
                    seen_instances = []
                    seen_labels = []
                    if (self._setting[0] == 'G'):
                        with open(join_path(path_components[:-1] + ["train_relations.txt"]), 'r',
                                  encoding='utf-8') as f:
                            for line in f:
                                seen_classes.add(line.strip())
                    with open(self._file_name + '.lbs', 'rb') as f:
                        temp_labels = pickle.load(f)
                    count_class_elements = {}
                    for inst_num, instance in enumerate(loaded_instances):
                        if (temp_labels[inst_num] not in seen_classes):
                            if (temp_labels[inst_num] not in count_class_elements):
                                count_class_elements[temp_labels[inst_num]] = [instance]
                            else:
                                count_class_elements[temp_labels[inst_num]].append(instance)
                        else:
                            seen_instances.append(instance)
                            seen_labels.append(temp_labels[inst_num])
                    sorted_by_class_num_insts = sorted(count_class_elements.items(), key=lambda kv: len(kv[1]), reverse=True)
                    classes_and_instances = sorted_by_class_num_insts[(1 if self._split == 'val' else 0)::2]

                    self._labels = [_class for _class_set, inst_nums in classes_and_instances
                                               for _class in [_class_set] * len(inst_nums)] + seen_labels
                    self._instances_indices = [inst_num for _, inst_nums in classes_and_instances
                                                            for inst_num in inst_nums] + seen_instances

                elif (DEBUG_EXPERIMENT is None):
                    self._instances_indices = loaded_instances
                else:
                    raise ValueError('Wrong DEBUG_EXPERIMENT value.')
            else:
                self._instances_indices = loaded_instances
            self._num_instances = len(self._instances_indices)

        self._valid_labels = []

        # Load the corresponding supervised labels (text based label, i.e. the actual relation name).
        if ('supervised_lbls' in labels_to_load):
            with open(self._file_name + '.lbs', 'rb') as f:
                loaded_labels = pickle.load(f)
                if (self._split != 'train'):
                    # print(self._split, ' ORIGINAL LABELS', loaded_labels)
                    if (DEBUG_EXPERIMENT == 'instances'):
                        self._labels = loaded_labels[(0 if self._split == 'val' else 1)::2]
                    elif (DEBUG_EXPERIMENT == 'classes'):
                        pass
                    elif (DEBUG_EXPERIMENT is None):
                        self._labels = loaded_labels
                    else:
                        raise ValueError('Wrong DEBUG_EXPERIMENT value.')
                else:
                    self._labels = loaded_labels
            self._valid_labels.append('supervised_lbls')

            # TODO: remove this when done debugging weird results on test set.
            set_labels = set(self._labels)

            # Load the list of classes involved in this split. For some settings (like Few-Shot settings or Generalised
            # Any-Shot settings) the classes of another split might be present in this specific split.
            self._relations_in_split = []
            self._relation_in_split_to_idx_map = {}
            self._is_seen_class_indicator = {}

            self._exclusive_split_relations = []
            with open(self._file_name + "_relations.txt", 'r', encoding='utf-8') as f:
                r_num = 0
                for line in f:
                # for r_num, line in enumerate(f):
                    relation = line.strip()
                    if (relation in set_labels):
                        self._exclusive_split_relations.append(relation)
                        self._relations_in_split.append(relation)
                        self._relation_in_split_to_idx_map[relation] = r_num
                        r_num += 1

            if (self._setting[0] == 'G' or self._setting[0] == 'F' or self._setting == 'ZS-C'):
                # If this is a Generalised setting and a validation or test split, get the train classes.
                if (self._setting[0] == 'G' and self._split != 'train'):
                    for relation in self._exclusive_split_relations:
                        self._is_seen_class_indicator[relation] = False

                    self._relations_from_another_split = []
                    with open(join_path(path_components[:-1] + ["train_relations.txt"]), 'r', encoding='utf-8') as f:
                        for r_num, line in enumerate(f):
                            relation = line.strip()
                            self._relations_from_another_split.append(relation)
                            self._relations_in_split.append(relation)
                            self._relation_in_split_to_idx_map[relation] = r_num + len(self._exclusive_split_relations)
                            self._is_seen_class_indicator[relation] = True

                # If this is a Few-Shot setting and a train split, get the validation and test classes.
                # Also, if this is a (Generalised) Zero-Shot Closed setting, get the validation and test classes.
                if ((self._setting[0] == 'F' or self._setting[1] == 'F' or
                     self._setting == 'ZS-C' or self._setting == 'GZS-C') and self._split == 'train'):
                    for relation in self._exclusive_split_relations:
                        self._is_seen_class_indicator[relation] = True

                    # TODO: This assumes there's a validation split, which might not be the case.
                    self._relations_from_another_split = []
                    for split in (['val', 'test'] if DEBUG_EXPERIMENT is None else ['val']):
                        with open(join_path(path_components[:-1] + [split + "_relations.txt"]), 'r', encoding='utf-8') as f:
                            for r_num, line in enumerate(f):
                                relation = line.strip()
                                self._relation_in_split_to_idx_map[relation] = len(self._relations_from_another_split) + len(self._exclusive_split_relations)
                                self._relations_from_another_split.append(relation)
                                self._relations_in_split.append(relation)
                                self._is_seen_class_indicator[relation] = False


        # If this pertains a train split, load the unsupervised sentence labels, if they have been requested.
        # TODO: maybe some of these checks can be done when reading labels_to_load with argparse.
        if ('u_sentence_lbls' in labels_to_load and self._split == 'train'):
            with open(self._file_name + '.ulbs', 'rb') as f:
                self._data_as_idxs = pickle.load(f)
            self._valid_labels.append('u_sentence_lbls')