Beispiel #1
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")
        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                           copy_initial_weights=False,
        #                           track_higher_grads=False) as (fpn, diffopt):
        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        ### GET SUPPORT SET REPRESENTATIONS ###
        self.logger.debug("----------------- SUPPORT SET ----------------- ")
        representations, all_labels = self.get_representations(support_set[:1])
        representations_merged = torch.cat(representations)
        class_means, unique_labels = model_utils.get_class_means(
            representations_merged, all_labels)
        self._examples_seen += len(representations_merged)
        self.logger.debug(
            f"Examples seen increased by {len(representations_merged)}")

        ### UPDATE MEMORY ###
        if do_memory_update:
            memory_update = self.memory.update(class_means,
                                               unique_labels,
                                               logger=self.logger)
            updated_memory_representations = memory_update[
                "new_class_representations"]
            self.log_discounts(memory_update["class_discount"], unique_labels)
        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = expand_class_representations(
                self.memory.class_representations, class_means, unique_labels)
        elif self.config.learner.prototypes == "memory":
            prototypes = updated_memory_representations
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
        # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear)
        weight = 2 * prototypes  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1)

        self.logger.debug("----------------- QUERY SET  ----------------- ")
        ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
        # Outer loop
        if query_set is not None:
            for text, labels in query_set:
                labels = torch.tensor(labels).to(self.device)
                query_representations = self.forward(text,
                                                     labels)["representation"]

                # distance query representations to prototypes (BATCH X N_PROTOTYPES)
                # distances = euclidean_dist(query_representations, prototypes)
                # logits = - distances
                logits = query_representations @ weight.T + bias
                loss = self.loss_fn(logits, labels)
                # log_probability = F.log_softmax(-distances, dim=1)
                # loss is negation of the log probability, index using the labels for each observation
                # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean()
                self.meta_optimizer.zero_grad()
                loss.backward()
                self.meta_optimizer.step()

                predictions = model_utils.make_prediction(logits.detach())
                # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions])
                # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(),
                #                         [round(z, 3) for z in x[2].tolist()]),
                #                         list(zip(labels, predictions, distances)))))
                self.logger.debug(
                    f"Unique Labels: {unique_labels.tolist()}\n"
                    # f"Labels, Indices, Predictions, Distances:\n{to_print}\n"
                    f"Loss:\n{loss.item()}\n"
                    f"Predictions:\n{predictions}\n")
                self.update_query_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)
                self.logger.debug(f"Examples seen increased by {len(text)}")

            # Meta optimizer step
            # self.meta_optimizer.step()
            # self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")
Beispiel #2
0
    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []

        def add_none(iterator):
            yield None
            for x in iterator:
                yield x

        shifted_dataloader = add_none(train_dataloader)
        # prototypes = self.memory.class_representations
        for i, (support_set, (query_text, query_labels,
                              datasets)) in enumerate(
                                  zip(shifted_dataloader, train_dataloader)):
            query_labels = torch.tensor(query_labels).to(self.device)
            # happens on the first one
            # prototypes = self.memory.class_representations
            if support_set is None:
                prototypes = self.memory.class_representations
            else:
                support_text, support_labels, _ = support_set
                support_labels = torch.tensor(support_labels).to(self.device)
                support_representations = self.forward(
                    support_text, support_labels)["representation"]
                support_class_means, unique_labels = model_utils.get_class_means(
                    support_representations, support_labels)
                memory_update = self.memory.update(support_class_means,
                                                   unique_labels,
                                                   logger=self.logger)
                updated_memory_representations = memory_update[
                    "new_class_representations"]
                self.log_discounts(memory_update["class_discount"],
                                   unique_labels,
                                   few_shot_examples_seen=(i + 1) *
                                   self.config.testing.few_shot_batch_size)
                prototypes = updated_memory_representations
                if self.config.learner.few_shot_detach_prototypes:
                    prototypes = prototypes.detach()
            weight = 2 * prototypes
            bias = -(prototypes**2).sum(dim=1)
            query_representations = self.forward(
                query_text, query_labels)["representation"]
            logits = query_representations @ weight.T + bias

            loss = self.loss_fn(logits, query_labels)

            self.meta_optimizer.zero_grad()
            loss.backward()
            self.meta_optimizer.step()

            predictions = model_utils.make_prediction(logits.detach())
            all_predictions.extend(predictions.tolist())
            all_labels.extend(query_labels.tolist())
            dataset_results = self.evaluate(dataloader=eval_dataloader)
            self.log_few_shot(all_predictions,
                              all_labels,
                              datasets,
                              dataset_results,
                              increment_counters,
                              query_text,
                              i,
                              split=split)
            if (i * self.config.testing.few_shot_batch_size
                ) % self.mini_batch_size == 0 and i > 0:
                all_predictions, all_labels = [], []
        self.few_shot_end()
Beispiel #3
0
    def average_accuracy(self, datasets, split, train_datasets=None):
        results = {}
        accuracies, precisions, recalls, f1s = [], [], [], []

        if self.config.testing.n_samples_before_average_evaluate is None:
            self.config.testing.n_samples_before_average_evaluate = 80
        training_amounts = (0, self.config.testing.n_samples_before_average_evaluate)
        metrics_name = split + "_evaluation"
        for training_amount in training_amounts:
            self.save_checkpoint(file_name=TEMP_CHECKPOINT, save_optimizer_state=True, delete_previous=False)
            metrics_key_average = "average" + f"_{training_amount}" * (training_amount > 0)
            # this has already been evaluated, no need to do it again
            # if metrics_key_average in self.metrics[metrics_name]:
            #     continue
        # Before evaluating average accuracy, allow model to see n examples again
            if train_datasets is not None and training_amount > 0:
                train_dataloader = iter(DataLoader(ConcatDataset(train_datasets.values()), shuffle=True,
                                            batch_size=self.mini_batch_size))
                n_batches = max(training_amount // self.mini_batch_size, 1)
                self.metrics["n_samples_before_average_evaluate"] = n_batches * self.mini_batch_size
                self.logger.info(f"Before evaluating average accuracy, train on {n_batches * self.mini_batch_size}"
                                " samples from all datasets.")
                if self.type in BASE_METALEARNERS:
                    support_set, _ = self.get_support_set(train_dataloader, n_updates=n_batches)
                    self.inner_optimizer.zero_grad()
                    for text, labels in support_set:
                        labels = torch.tensor(labels).to(self.device)
                        # labels = labels.to(self.device)
                        output = self.forward(text, labels)
                        loss = self.loss_fn(output["logits"], labels)
                        loss.backward()
                        self.inner_optimizer.step()
                elif self.type == "memory_protomaml":
                    self.inner_optimizer.zero_grad()
                    support_set, _ = self.get_support_set(train_dataloader, n_updates=n_batches)
                    # ### GET SUPPORT SET REPRESENTATIONS ###
                    # representations, all_labels = self.get_representations(support_set)
                    # representations_merged = torch.cat(representations)
                    # class_means, unique_labels = self.get_class_means(representations_merged, all_labels)
                    # ### UPDATE MEMORY ###
                    # self.update_memory(class_means, unique_labels)
                    ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
                    ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
                    prototypes = self.memory.class_representations
                    self.init_prototypical_classifier(prototypes)
                    for text, labels in support_set:
                        labels = torch.tensor(labels).to(self.device)
                        # labels = labels.to(self.device)
                        output = self.forward(text, labels)
                        loss = self.loss_fn(output["logits"], labels)
                        loss.backward()
                        self.inner_optimizer.step()
                elif self.type == "prototypical":
                    # support_set, _ = self.get_support_set(train_dataloader, n_updates=n_batches)
                    def add_none(iterator):
                        yield None
                        for x in iterator:
                            yield x
                    shifted_dataloader = add_none(train_dataloader)
                    prototypes = self.memory.class_representations
                    for i, (support_set, (query_text, query_labels, _)) in enumerate(zip(shifted_dataloader, train_dataloader)):
                        query_labels = torch.tensor(query_labels).to(self.device)
                        # happens on the first one
                        if support_set is None:
                            prototypes = self.memory.class_representations
                        else:
                            support_text, support_labels, _ = support_set
                            support_labels = torch.tensor(support_labels).to(self.device)
                            support_representations = self.forward(support_text, support_labels)["representation"]
                            support_class_means, unique_labels = model_utils.get_class_means(support_representations, support_labels)
                            updated_memory_representations = self.memory.update(support_class_means, unique_labels, logger=self.logger)
                            prototypes = updated_memory_representations["new_class_representations"]
                        weight = 2 * prototypes
                        bias = - (prototypes ** 2).sum(dim=1)
                        query_representations = self.forward(query_text, query_labels)["representation"]
                        logits = query_representations @ weight.T + bias
                        loss = self.loss_fn(logits, query_labels)

                        self.meta_optimizer.zero_grad()
                        loss.backward()
                        self.meta_optimizer.step()
                        if i == n_batches - 1:
                            break
                else:
                    for _ in range(n_batches):
                        text, labels, _ = next(train_dataloader)
                        self.training_step(text, labels)
            for dataset_name, dataset in datasets.items():
                self.logger.info("Testing on {}".format(dataset_name))
                if self.config.testing.average_validation_size is not None:
                    dataset = dataset.sample(min(len(dataset), self.config.testing.average_validation_size))
                test_dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)
                dataset_results = self.evaluate(dataloader=test_dataloader)
                accuracies.append(dataset_results["accuracy"])
                precisions.append(dataset_results["precision"])
                recalls.append(dataset_results["recall"])
                f1s.append(dataset_results["f1"])
                results[dataset_name] = dataset_results

            mean_results = {
                "accuracy": np.mean(accuracies),
                "precision": np.mean(precisions),
                "recall": np.mean(recalls),
                "f1": np.mean(f1s)
            }
            self.logger.info("Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                        "F1 score = {:.4f}".format(
                            mean_results["accuracy"], mean_results["precision"], mean_results["recall"],
                            mean_results["f1"]
                        ))
            self.metrics[metrics_name]["individual" + f"_{training_amount}" * (training_amount > 0)] = results
            self.metrics[metrics_name][metrics_key_average] = mean_results
            if self.config.wandb:
                wandb.log({
                    split + "_testing_average_accuracy" + f"_{training_amount}" * (training_amount > 0): mean_results["accuracy"]
                })
            if train_datasets is not None and self.config.testing.n_samples_before_average_evaluate > 0:
                self.load_checkpoint(TEMP_CHECKPOINT, load_optimizer_state=True)
                # delete temp checkpoint
                (self.checkpoint_dir / TEMP_CHECKPOINT).unlink()