예제 #1
0
    def evaluate(self, dataloader, update_memory=False):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        self.logger.info("Starting evaluation...")
        for i, (text, labels, datasets) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            with torch.no_grad():
                logits, key_logits = self.forward(text,
                                                  labels,
                                                  update_memory=update_memory)
                loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            if i % 20 == 0:
                self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.info(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))

        return results
예제 #2
0
    def evaluate(self, dataloader):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            # if i % 20 == 0:
            #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        metrics = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       metrics["accuracy"],
                                       metrics["precision"], metrics["recall"],
                                       metrics["f1"]))

        return {
            "accuracy": metrics["accuracy"],
            "precision": metrics["precision"],
            "recall": metrics["recall"],
            "f1": metrics["f1"]
        }
예제 #3
0
 def log_few_shot(self, all_predictions, all_labels, datasets, dataset_results, increment_counters, text,
                  few_shot_batch, split="test"):
     """Few shot preparation code that isn't specific to any learner"""
     metrics_entry = split + "_evaluation"
     test_results = {
         "examples_seen": few_shot_batch * self.config.testing.few_shot_batch_size,
         "examples_seen_total": self.examples_seen(),
         "accuracy": dataset_results["accuracy"],
         "task": datasets[0]
     }
     if (few_shot_batch * self.config.testing.few_shot_batch_size) % self.mini_batch_size == 0 and few_shot_batch > 0:
         online_metrics = model_utils.calculate_metrics(all_predictions, all_labels)
         train_results = {
             "examples_seen": few_shot_batch * self.config.testing.few_shot_batch_size,
             "examples_seen_total": self.examples_seen(),
             "accuracy": online_metrics["accuracy"],
             "task": datasets[0]  # assume whole batch is from same task
         }
         self.metrics[metrics_entry]["few_shot_training"][-1].append(train_results)
         if increment_counters:
             self.metrics["online"].append({
                 "accuracy": online_metrics["accuracy"],
                 "examples_seen": self.examples_seen(),
                 "task": datasets[0]
             })
     if increment_counters:
         self._examples_seen += len(text)
     self.metrics[metrics_entry]["few_shot"][-1].append(test_results)
     self.write_metrics()
     if self.config.wandb:
         # replace with new name
         test_results = test_results.copy()
         test_results[f"few_shot_{split}_accuracy_{self.few_shot_counter}"] = test_results.pop("accuracy")
         wandb.log(test_results)
예제 #4
0
    def log(self):
        """Log results during training to console and optionally other outputs

        Parameters
        ---
        metrics: dict mapping metric names to their values
        """
        metrics = model_utils.calculate_metrics(self.tracker["predictions"],
                                                self.tracker["labels"])
        self.logger.info(
            "Iteration {} - Metrics: Loss = {:.4f}, key loss: {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(self.current_iter + 1,
                                       np.mean(self.tracker["losses"]),
                                       np.mean(self.tracker["key_losses"]),
                                       metrics["accuracy"],
                                       metrics["precision"], metrics["recall"],
                                       metrics["f1"]))
        if self.config.wandb:
            wandb.log({
                "accuracy": metrics["accuracy"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1": metrics["f1"],
                "loss": np.mean(self.tracker["losses"]),
                "key_loss": np.mean(self.tracker["key_losses"]),
                "examples_seen": self.examples_seen()
            })
        self.reset_tracker()
예제 #5
0
    def training(self, datasets, **kwargs):
        # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])}
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        samples_per_task = self.config.learner.samples_per_task
        order = self.config.task_order if self.config.task_order is not None else datasets["order"]
        n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task
        dataset = get_continuum(train_datasets, order=order, n_samples=n_samples)
        dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)

        for text, labels, datasets in dataloader:
            output = self.training_step(text, labels)
            predictions = model_utils.make_prediction(output["logits"].detach())
            # for logging
            key_predictions = [
                model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"]
            ]
            # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}")

            self.update_tracker(output, predictions, key_predictions, labels)
            online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist())
            self.metrics["online"].append({
                "accuracy": online_metrics["accuracy"],
                "examples_seen": self.examples_seen(),
                "task": datasets[0]  # assumes whole batch is from same task
            })
            if self.current_iter % self.log_freq == 0:
                self.log()
                self.write_metrics()
            if self.current_iter % self.validate_freq == 0:
                self.validate(val_datasets, n_samples=self.config.training.n_validation_samples)
            self.current_iter += 1
예제 #6
0
    def evaluate(self, dataloader, prediction_network=None):
        # if self.config.learner.evaluation_support_set:
        #     support_set = []
        #     for _ in range(self.config.learner.updates):
        #         text, labels = self.memory.read_batch(batch_size=self.mini_batch_size)
        #         support_set.append((text, labels))

        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                         copy_initial_weights=False,
        #                         track_higher_grads=False) as (fpn, diffopt):
        #     if self.config.learner.evaluation_support_set:
        #         self.set_train()
        #         support_prediction_network = fpn
        #         # Inner loop
        #         task_predictions, task_labels = [], []
        #         support_loss = []
        #         for text, labels in support_set:
        #             labels = torch.tensor(labels).to(self.device)
        #             # labels = labels.to(self.device)
        #             output = self.forward(text, labels, fpn)
        #             loss = self.loss_fn(output["logits"], labels)
        #             diffopt.step(loss)

        #             pred = model_utils.make_prediction(output["logits"].detach())
        #             support_loss.append(loss.item())
        #             task_predictions.extend(pred.tolist())
        #             task_labels.extend(labels.tolist())
        #         results = model_utils.calculate_metrics(task_predictions, task_labels)
        #         self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
        #                     "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"],
        #                     results["precision"], results["recall"], results["f1"]))
        #         self.set_eval()
        #     else:
        #         support_prediction_network = self.pn
        #     if prediction_network is None:
        #         prediction_network = support_prediction_network

        self.set_eval()
        prototypes = self.memory.class_representations
        weight = 2 * prototypes
        bias = - (prototypes ** 2).sum(dim=1)
        all_losses, all_predictions, all_labels = [], [], []
        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            representations = self.forward(text, labels)["representation"]
            logits = representations @ weight.T + bias
            # labels = labels.to(self.device)
            loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"],
                    results["precision"], results["recall"], results["f1"]))
        return results
예제 #7
0
    def meta_training_log(self):
        """Logs data during training for meta learners."""
        if len(self.tracker["support_loss"]) > 0:
            support_loss = np.mean(self.tracker["support_loss"])
        else:
            support_loss = np.nan
        query_loss = np.mean(self.tracker["query_loss"])
        if len(self.tracker["support_predictions"]) > 0:
            support_metrics = model_utils.calculate_metrics(self.tracker["support_predictions"], self.tracker["support_labels"])
        else:
            support_metrics = collections.defaultdict(lambda: np.nan)
        query_metrics = model_utils.calculate_metrics(self.tracker["query_predictions"], self.tracker["query_labels"])

        self.logger.debug(
            f"Episode {self.current_iter + 1} Support set: Loss = {support_loss:.4f}, "
            f"accuracy = {support_metrics['accuracy']:.4f}, precision = {support_metrics['precision']:.4f}, "
            f"recall = {support_metrics['recall']:.4f}, F1 score = {support_metrics['f1']:.4f}"
        )
        self.logger.debug(
            f"Episode {self.current_iter + 1} -- Examples seen: {self.examples_seen()} -- Query set: Loss = {query_loss:.4f}, "
            f"accuracy = {query_metrics['accuracy']:.4f}, precision = {query_metrics['precision']:.4f}, "
            f"recall = {query_metrics['recall']:.4f}, F1 score = {query_metrics['f1']:.4f}"
        )
        if self.config.wandb:
            wandb.log({
                "support_accuracy": support_metrics['accuracy'],
                "support_precision": support_metrics['precision'],
                "support_recall": support_metrics['recall'],
                "support_f1": support_metrics['f1'],
                "support_loss": support_loss,
                "query_accuracy": query_metrics['accuracy'],
                "query_precision": query_metrics['precision'],
                "query_recall": query_metrics['recall'],
                "query_f1": query_metrics['f1'],
                "query_loss": query_loss,
                "examples_seen": self.examples_seen()
            })
        self.reset_tracker()
예제 #8
0
    def log(self):
        """Log results during training to console and optionally other outputs

        Parameters
        ---
        metrics: dict mapping metric names to their values
        """
        loss = np.mean(self.tracker["losses"])
        key_losses = [np.mean(key_losses) for key_losses in self.tracker["key_losses"]]
        reconstruction_errors = [np.mean(reconstruction_errors) for reconstruction_errors in self.tracker["reconstruction_errors"]]
        metrics = model_utils.calculate_metrics(self.tracker["predictions"], self.tracker["labels"])
        key_metrics = [
            model_utils.calculate_metrics(key_predictions, self.tracker["labels"])
            for key_predictions in self.tracker["key_predictions"]
        ]
        key_accuracy_str = [f'{km["accuracy"]:.4f}' for km in key_metrics]
        self.logger.info(
            f"Iteration {self.current_iter + 1} - Task = {self.metrics[-1]['task']} - Metrics: Loss = {loss:.4f}, "
            f"key loss = {[f'{key_loss:.4f}' for key_loss in key_losses]}, "
            f"reconstruction error = {[f'{reconstruction_error:.4f}' for reconstruction_error in reconstruction_errors]}, "
            f"accuracy = {metrics['accuracy']:.4f} - "
            f"key accuracy = {key_accuracy_str}"
        )
        if self.config.wandb:
            log = {
                "accuracy": metrics["accuracy"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1": metrics["f1"],
                "loss": loss,
                "examples_seen": self.examples_seen()
            }
            for i, dim in enumerate(self.key_dim):
                log[f"key_accuracy_encoder_{i}_dim_{dim}"] = key_metrics[i]["accuracy"]
                log[f"key_loss_encoder_{i}_dim_{dim}"] = key_losses[i]
                log[f"reconstruction_error_encoder_{i}_dim_{dim}"] = reconstruction_errors[i]
            wandb.log(log)
        self.reset_tracker()
예제 #9
0
    def train(self,
              dataloader=None,
              datasets=None,
              dataset_name=None,
              max_samples=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        replay_freq, replay_steps = self.replay_parameters(metalearner=False)

        episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
        for _ in range(self.n_epochs):
            for text, labels, datasets in dataloader:
                output = self.training_step(text, labels)
                task = datasets[0]

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_tracker(output, predictions, labels)

                metrics = model_utils.calculate_metrics(
                    self.tracker["predictions"], self.tracker["labels"])
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task
                }
                self.metrics["online"].append(online_metrics)
                if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                if self.current_iter % self.log_freq == 0:
                    self.log()
                    self.write_metrics()
                if self.current_iter % self.validate_freq == 0:
                    self.validate(
                        val_datasets,
                        n_samples=self.config.training.n_validation_samples)
                if self.replay_rate != 0 and (self.current_iter +
                                              1) % replay_freq == 0:
                    self.replay_training_step(replay_steps,
                                              episode_samples_seen,
                                              max_samples)
                self.memory.write_batch(text, labels)
                self._examples_seen += len(text)
                episode_samples_seen += len(text)
                self.current_iter += 1
                if max_samples is not None and episode_samples_seen >= max_samples:
                    break
예제 #10
0
 def log(self):
     metrics = model_utils.calculate_metrics(self.tracker["predictions"],
                                             self.tracker["labels"])
     self.logger.info(
         "Iteration {} - Metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
         "F1 score = {:.4f}".format(self.current_iter + 1,
                                    np.mean(self.tracker["losses"]),
                                    metrics["accuracy"],
                                    metrics["precision"], metrics["recall"],
                                    metrics["f1"]))
     if self.config.wandb:
         wandb.log({
             "accuracy": metrics["accuracy"],
             "precision": metrics["precision"],
             "recall": metrics["recall"],
             "f1": metrics["f1"],
             "loss": np.mean(self.tracker["losses"]),
             "examples_seen": self.examples_seen()
         })
     self.reset_tracker()
예제 #11
0
 def write_log(self, all_predictions, all_labels, all_losses, data_length):
     acc, prec, rec, f1 = model_utils.calculate_metrics(
         all_predictions, all_labels)
     time_per_iteration, estimated_time_left = self.time_metrics(
         data_length)
     self.logger.info(
         "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
         "F1 score = {:.4f}".format(self.current_iter + 1, data_length,
                                    (self.current_iter + 1) / data_length *
                                    100,
                                    time_per_iteration, estimated_time_left,
                                    np.mean(all_losses), acc, prec, rec,
                                    f1))
     if self.config.wandb:
         n_examples_seen = (self.current_iter + 1) * self.mini_batch_size
         wandb.log({
             "accuracy": acc,
             "precision": prec,
             "recall": rec,
             "f1": f1,
             "loss": np.mean(all_losses),
             "examples_seen": n_examples_seen
         })
예제 #12
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()
        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            # Inner loop
            for text, labels in support_set:
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text, labels, fpn)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)
                self.memory.write_batch(text, labels)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()
예제 #13
0
    def evaluate(self, dataloader, prediction_network=None):
        if self.config.learner.evaluation_support_set:
            support_set = []
            for _ in range(self.config.learner.updates):
                text, labels = self.memory.read_batch(
                    batch_size=self.mini_batch_size)
                support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            if self.config.learner.evaluation_support_set:
                self.set_train()
                support_prediction_network = fpn
                # Inner loop
                task_predictions, task_labels = [], []
                support_loss = []
                for text, labels in support_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    diffopt.step(loss)

                    pred = model_utils.make_prediction(
                        output["logits"].detach())
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(labels.tolist())
                results = model_utils.calculate_metrics(
                    task_predictions, task_labels)
                self.logger.info(
                    "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(support_loss),
                                               results["accuracy"],
                                               results["precision"],
                                               results["recall"],
                                               results["f1"]))
                self.set_eval()
            else:
                support_prediction_network = self.pn
            if prediction_network is None:
                prediction_network = support_prediction_network

            self.set_eval()
            all_losses, all_predictions, all_labels = [], [], []
            for i, (text, labels, datasets) in enumerate(dataloader):
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text,
                                      labels,
                                      prediction_network,
                                      no_grad=True)
                loss = self.loss_fn(output["logits"], labels)
                loss = loss.item()
                pred = model_utils.make_prediction(output["logits"].detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                # if i % 20 == 0:
                #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))
        return results
예제 #14
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")

        ### GET SUPPORT SET REPRESENTATIONS ###
        with torch.no_grad():
            representations, all_labels = self.get_representations(
                support_set[:1])
            representations_merged = torch.cat(representations)
            class_means, unique_labels = self.get_class_means(
                representations_merged, all_labels)

        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        if do_memory_update:
            ### UPDATE MEMORY ###
            updated_memory_representations = self.memory.update(
                class_means, unique_labels, logger=self.logger)

        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = class_means.detach()
        elif self.config.learner.prototypes == "memory":
            prototypes = self.memory.class_representations  # doesn't track prototype gradients
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
            self.init_prototypical_classifier(prototypes,
                                              linear_module=fpn.linear)

            self.logger.debug(
                "----------------- SUPPORT SET ----------------- ")
            ### TRAIN LINEAR CLASSIFIER ON SUPPORT SET ###
            # Inner loop
            for i, (text, labels) in enumerate(support_set):
                self.logger.debug(
                    f"----------------- {i}th Update ----------------- ")
                labels = torch.tensor(labels).to(self.device)
                # if i == 0:
                #     output = {
                #         "representation": representations[0],
                #         "logits": fpn(representations[0], out_from="linear")
                #     }
                # else:
                output = self.forward(text, labels, fpn)

                # for logging purposes
                prototype_distances = (output["representation"].unsqueeze(1) -
                                       prototypes).norm(dim=-1)
                closest_dists, closest_classes = prototype_distances.topk(
                    3, largest=False)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 2) for z in x[2].tolist()]),
                            list(zip(labels, closest_classes,
                                     closest_dists)))))
                self.logger.debug(
                    f"True labels, closest prototypes, and distances:\n{to_print}"
                )

                topk = output["logits"].topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) before update:\n{to_print}"
                )

                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                # see how much linear classifier has changed
                with torch.no_grad():
                    topk = fpn(output["representation"],
                               out_from="linear").topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) after update:\n{to_print}"
                )

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            self.logger.debug(
                "----------------- QUERY SET  ----------------- ")
            ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, prediction_network=fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    topk = output['logits'].topk(5, dim=1)
                    to_print = pprint.pformat(
                        list(
                            map(
                                lambda x: (x[0].item(), x[1].tolist(
                                ), [round(z, 3) for z in x[2].tolist()]),
                                list(zip(labels, topk[1], topk[0])))))
                    self.logger.debug(
                        f"(label, topk_classes, topk_logits):\n{to_print}")

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")
예제 #15
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")
        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                           copy_initial_weights=False,
        #                           track_higher_grads=False) as (fpn, diffopt):
        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        ### GET SUPPORT SET REPRESENTATIONS ###
        self.logger.debug("----------------- SUPPORT SET ----------------- ")
        representations, all_labels = self.get_representations(support_set[:1])
        representations_merged = torch.cat(representations)
        class_means, unique_labels = model_utils.get_class_means(
            representations_merged, all_labels)
        self._examples_seen += len(representations_merged)
        self.logger.debug(
            f"Examples seen increased by {len(representations_merged)}")

        ### UPDATE MEMORY ###
        if do_memory_update:
            memory_update = self.memory.update(class_means,
                                               unique_labels,
                                               logger=self.logger)
            updated_memory_representations = memory_update[
                "new_class_representations"]
            self.log_discounts(memory_update["class_discount"], unique_labels)
        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = expand_class_representations(
                self.memory.class_representations, class_means, unique_labels)
        elif self.config.learner.prototypes == "memory":
            prototypes = updated_memory_representations
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
        # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear)
        weight = 2 * prototypes  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1)

        self.logger.debug("----------------- QUERY SET  ----------------- ")
        ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
        # Outer loop
        if query_set is not None:
            for text, labels in query_set:
                labels = torch.tensor(labels).to(self.device)
                query_representations = self.forward(text,
                                                     labels)["representation"]

                # distance query representations to prototypes (BATCH X N_PROTOTYPES)
                # distances = euclidean_dist(query_representations, prototypes)
                # logits = - distances
                logits = query_representations @ weight.T + bias
                loss = self.loss_fn(logits, labels)
                # log_probability = F.log_softmax(-distances, dim=1)
                # loss is negation of the log probability, index using the labels for each observation
                # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean()
                self.meta_optimizer.zero_grad()
                loss.backward()
                self.meta_optimizer.step()

                predictions = model_utils.make_prediction(logits.detach())
                # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions])
                # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(),
                #                         [round(z, 3) for z in x[2].tolist()]),
                #                         list(zip(labels, predictions, distances)))))
                self.logger.debug(
                    f"Unique Labels: {unique_labels.tolist()}\n"
                    # f"Labels, Indices, Predictions, Distances:\n{to_print}\n"
                    f"Loss:\n{loss.item()}\n"
                    f"Predictions:\n{predictions}\n")
                self.update_query_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)
                self.logger.debug(f"Examples seen increased by {len(text)}")

            # Meta optimizer step
            # self.meta_optimizer.step()
            # self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")
예제 #16
0
    def train(self, dataloader=None, datasets=None, data_length=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        if data_length is None:
            data_length = len(dataloader) * self.n_epochs

        all_losses, all_predictions, all_labels = [], [], []

        for text, labels, tasks in dataloader:
            self._examples_seen += len(text)
            self.model.train()
            # assumes all data in batch is from same task
            self.current_task = self.relearning_task if tasks[
                0] == self.relearning_task else OTHER_TASKS
            loss, predictions = self._train_batch(text, labels)
            all_losses.append(loss)
            all_predictions.extend(predictions)
            all_labels.extend(labels.tolist())

            if self.current_iter % self.log_freq == 0:
                acc, prec, rec, f1 = model_utils.calculate_metrics(
                    all_predictions, all_labels)
                time_per_iteration, estimated_time_left = self.time_metrics(
                    data_length)
                self.logger.info(
                    "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(
                        self.current_iter + 1, data_length,
                        (self.current_iter + 1) / data_length * 100,
                        time_per_iteration, estimated_time_left,
                        np.mean(all_losses), acc, prec, rec, f1))
                if self.config.wandb:
                    wandb.log({
                        "accuracy": acc,
                        "precision": prec,
                        "recall": rec,
                        "f1": f1,
                        "loss": np.mean(all_losses),
                        "examples_seen": self.examples_seen(),
                        "task": self.current_task
                    })
                all_losses, all_predictions, all_labels = [], [], []
                self.start_time = time.time()
            if self.current_iter % self.validate_freq == 0:
                # only evaluate relearning task when training on relearning task
                validation_datasets = self.relearning_task_dataset if self.current_task == self.relearning_task else val_datasets

                validate_results = self.validate(
                    validation_datasets,
                    n_samples=self.config.training.n_validation_samples,
                    log=False)
                self.write_results(validate_results)
                relearning_task_performance = validate_results[
                    self.relearning_task]["accuracy"]
                if not self.first_encounter:
                    # TODO: make this a weighted average as well
                    relearning_task_relative_performance = self.relative_performance(
                        performance=relearning_task_performance,
                        task=self.relearning_task)
                    self.logger.info((
                        f"Examples seen: {self.examples_seen()} -- Relative performance of task '{self.relearning_task}':"
                        +
                        f"{relearning_task_relative_performance}. Thresholds: {self.relative_performance_threshold_lower}"
                        f"-{self.relative_performance_threshold_upper}"))
                    if self.config.wandb:
                        wandb.log({
                            "relative_performance":
                            relearning_task_relative_performance,
                            "examples_seen": self.examples_seen()
                        })
                if self.current_task == self.relearning_task:
                    self.logger.debug(
                        f"first encounter: {self.first_encounter}")
                    # relearning stops either when either one of two things happen:
                    # the relearning task is first encountered and it is saturated (doesn't improve)
                    if ((self.first_encounter and self.learning_saturated(
                            task=self.relearning_task,
                            n_samples_slope=self.n_samples_slope,
                            patience=self.saturated_patience,
                            threshold=self.saturated_threshold,
                            smooth_alpha=self.smooth_alpha)) or
                        (
                            # the relearning task is re-encountered and relative performance reaches some threshold
                            not self.first_encounter
                            and relearning_task_relative_performance >=
                            self.relative_performance_threshold_upper)):
                        # write metrics, reset, and train the other tasks
                        self.write_relearning_metrics()
                        self.logger.info(
                            f"Task {self.current_task} saturated at iteration {self.current_iter}"
                        )
                        # each list element in performance refers to one consecutive learning event of the relearning task
                        self.metrics[
                            self.relearning_task]["performance"].append([])
                        self.not_improving = 0
                        if self.first_encounter:
                            self.logger.info(
                                f"-----------FIRST ENCOUNTER RELEARNING TASK '{self.relearning_task}' FINISHED.----------\n"
                            )
                        self.first_encounter = False
                        self.logger.info("TRAINING ON OTHER TASKS")
                        self.train(dataloader=self.dataloaders[OTHER_TASKS],
                                   datasets=datasets)
                else:
                    # calculate relative performance relearning task
                    # if it reaches some threshold, train on relearning task again
                    # TODO: make performance measure attribute of relearner
                    # TODO: use moving average for relative performance check
                    # TODO: allow different task ordering
                    # TODO: measure forgetting
                    # TODO: look at adaptive mini batch size => smaller batch size when re encountering
                    if relearning_task_relative_performance <= self.relative_performance_threshold_lower:
                        self.logger.info(
                            f"Relative performance on relearning task {self.relearning_task} below threshold. Evaluating relearning.."
                        )
                        # this needs to be done because we want a fresh list of performances when we start
                        # training the relearning task again. The first item in this list is simply the zero
                        # shot performance after the relative performance threshold is reached
                        # this means that every odd list in the relearning_task metrics is when training on the relearning task
                        relearning_task_performance = self.metrics[
                            self.relearning_task]["performance"]
                        relearning_task_performance.append([])
                        # copy the last entry of the performance while training on the other tasks to the new list
                        relearning_task_performance[-1].append(
                            relearning_task_performance[-2][-1])
                        self.train(
                            dataloader=self.dataloaders[self.relearning_task],
                            datasets=datasets)
                with open(self.results_dir / METRICS_FILE, "w") as f:
                    json.dump(self.metrics, f)

            self.time_checkpoint()
            self.current_iter += 1