Пример #1
0
    def training(self, datasets, **kwargs):
        # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])}
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        samples_per_task = self.config.learner.samples_per_task
        order = self.config.task_order if self.config.task_order is not None else datasets["order"]
        n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task
        dataset = get_continuum(train_datasets, order=order, n_samples=n_samples)
        dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)

        for text, labels, datasets in dataloader:
            output = self.training_step(text, labels)
            predictions = model_utils.make_prediction(output["logits"].detach())
            # for logging
            key_predictions = [
                model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"]
            ]
            # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}")

            self.update_tracker(output, predictions, key_predictions, labels)
            online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist())
            self.metrics["online"].append({
                "accuracy": online_metrics["accuracy"],
                "examples_seen": self.examples_seen(),
                "task": datasets[0]  # assumes whole batch is from same task
            })
            if self.current_iter % self.log_freq == 0:
                self.log()
                self.write_metrics()
            if self.current_iter % self.validate_freq == 0:
                self.validate(val_datasets, n_samples=self.config.training.n_validation_samples)
            self.current_iter += 1
Пример #2
0
    def evaluate(self, dataloader):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            # if i % 20 == 0:
            #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        metrics = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       metrics["accuracy"],
                                       metrics["precision"], metrics["recall"],
                                       metrics["f1"]))

        return {
            "accuracy": metrics["accuracy"],
            "precision": metrics["precision"],
            "recall": metrics["recall"],
            "f1": metrics["f1"]
        }
Пример #3
0
    def train(self, dataloaders):
        self.model.train()
        dataloader = dataloaders["train"]
        data_length = len(dataloader) * self.n_epochs

        for epoch in range(self.n_epochs):
            all_losses, all_predictions, all_labels = [], [], []

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.model.encode_text(text)
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)

                self.update_parameters(loss, mini_batch_size=len(labels))

                loss = loss.item()
                pred = model_utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                self.memory.write_batch(text, labels)

                if self.current_iter % self.log_freq == 0:
                    self.write_log(all_predictions,
                                   all_labels,
                                   all_losses,
                                   data_length=data_length)
                    self.start_time = time.time()  # time from last log
                    all_losses, all_predictions, all_labels = [], [], []
                # if self.current_iter % self.config.training.save_freq == 0:
                self.time_checkpoint()
                self.current_iter += 1
            self.current_epoch += 1
Пример #4
0
    def evaluate(self, dataloader, update_memory=False):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        self.logger.info("Starting evaluation...")
        for i, (text, labels, datasets) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            with torch.no_grad():
                logits, key_logits = self.forward(text,
                                                  labels,
                                                  update_memory=update_memory)
                loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            if i % 20 == 0:
                self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.info(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))

        return results
Пример #5
0
    def evaluate(self, dataloader, prediction_network=None):
        # if self.config.learner.evaluation_support_set:
        #     support_set = []
        #     for _ in range(self.config.learner.updates):
        #         text, labels = self.memory.read_batch(batch_size=self.mini_batch_size)
        #         support_set.append((text, labels))

        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                         copy_initial_weights=False,
        #                         track_higher_grads=False) as (fpn, diffopt):
        #     if self.config.learner.evaluation_support_set:
        #         self.set_train()
        #         support_prediction_network = fpn
        #         # Inner loop
        #         task_predictions, task_labels = [], []
        #         support_loss = []
        #         for text, labels in support_set:
        #             labels = torch.tensor(labels).to(self.device)
        #             # labels = labels.to(self.device)
        #             output = self.forward(text, labels, fpn)
        #             loss = self.loss_fn(output["logits"], labels)
        #             diffopt.step(loss)

        #             pred = model_utils.make_prediction(output["logits"].detach())
        #             support_loss.append(loss.item())
        #             task_predictions.extend(pred.tolist())
        #             task_labels.extend(labels.tolist())
        #         results = model_utils.calculate_metrics(task_predictions, task_labels)
        #         self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
        #                     "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"],
        #                     results["precision"], results["recall"], results["f1"]))
        #         self.set_eval()
        #     else:
        #         support_prediction_network = self.pn
        #     if prediction_network is None:
        #         prediction_network = support_prediction_network

        self.set_eval()
        prototypes = self.memory.class_representations
        weight = 2 * prototypes
        bias = - (prototypes ** 2).sum(dim=1)
        all_losses, all_predictions, all_labels = [], [], []
        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            representations = self.forward(text, labels)["representation"]
            logits = representations @ weight.T + bias
            # labels = labels.to(self.device)
            loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"],
                    results["precision"], results["recall"], results["f1"]))
        return results
Пример #6
0
    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []
        with higher.innerloop_ctx(self.pln,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpln, diffopt):
            self.pln.train()
            self.rln.eval()
            # Inner loop
            for i, (text, labels, datasets) in enumerate(train_dataloader):
                labels = torch.tensor(labels).to(self.device)
                output = self.forward(text, labels, fpln)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                all_predictions.extend(predictions.tolist())
                all_labels.extend(labels.tolist())
                dataset_results = self.evaluate(dataloader=eval_dataloader,
                                                prediction_network=fpln)
                self.log_few_shot(all_predictions,
                                  all_labels,
                                  datasets,
                                  dataset_results,
                                  increment_counters,
                                  text,
                                  i,
                                  split=split)
                if (i * self.config.testing.few_shot_batch_size
                    ) % self.mini_batch_size == 0 and i > 0:
                    all_predictions, all_labels = [], []
        self.few_shot_end()
Пример #7
0
 def _train_batch(self, text, labels):
     labels = torch.tensor(labels).to(self.device)
     input_dict = self.model.encode_text(text)
     output = self.model(input_dict)
     loss = self.loss_fn(output, labels)
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     loss = loss.item()
     self.logger.debug(f"Loss: {loss}")
     pred = model_utils.make_prediction(output.detach())
     return loss, pred.tolist()
Пример #8
0
    def train(self,
              dataloader=None,
              datasets=None,
              dataset_name=None,
              max_samples=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        replay_freq, replay_steps = self.replay_parameters(metalearner=False)

        episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
        for _ in range(self.n_epochs):
            for text, labels, datasets in dataloader:
                output = self.training_step(text, labels)
                task = datasets[0]

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_tracker(output, predictions, labels)

                metrics = model_utils.calculate_metrics(
                    self.tracker["predictions"], self.tracker["labels"])
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task
                }
                self.metrics["online"].append(online_metrics)
                if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                if self.current_iter % self.log_freq == 0:
                    self.log()
                    self.write_metrics()
                if self.current_iter % self.validate_freq == 0:
                    self.validate(
                        val_datasets,
                        n_samples=self.config.training.n_validation_samples)
                if self.replay_rate != 0 and (self.current_iter +
                                              1) % replay_freq == 0:
                    self.replay_training_step(replay_steps,
                                              episode_samples_seen,
                                              max_samples)
                self.memory.write_batch(text, labels)
                self._examples_seen += len(text)
                episode_samples_seen += len(text)
                self.current_iter += 1
                if max_samples is not None and episode_samples_seen >= max_samples:
                    break
Пример #9
0
    def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        split: str, one of {"val", "test"}.
            Which data split is used. For logging purposes.
        """
        self.logger.info(f"few shot testing on dataset {self.config.testing.eval_dataset} "
                         f"with {len(train_dataset)} samples")
        # whenever we do few shot evaluation, we reset the learning to before the evaluation started
        train_dataloader, eval_dataloader = self.few_shot_preparation(train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []

        for i, (text, labels, datasets) in enumerate(train_dataloader):
            output = self.training_step(text, labels)
            predictions = model_utils.make_prediction(output["logits"].detach())
            all_predictions.extend(predictions.tolist())
            all_labels.extend(labels.tolist())

            dataset_results = self.evaluate(dataloader=eval_dataloader)
            self.log_few_shot(all_predictions, all_labels, datasets, dataset_results, increment_counters,
                              text, i, split=split)
            if (i * self.config.testing.few_shot_batch_size) % self.mini_batch_size == 0 and i > 0:
                all_predictions, all_labels = [], []
        self.few_shot_end()
Пример #10
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()
        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            # Inner loop
            for text, labels in support_set:
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text, labels, fpn)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)
                self.memory.write_batch(text, labels)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()
Пример #11
0
    def evaluate(self, dataloader, prediction_network=None):
        if self.config.learner.evaluation_support_set:
            support_set = []
            for _ in range(self.config.learner.updates):
                text, labels = self.memory.read_batch(
                    batch_size=self.mini_batch_size)
                support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            if self.config.learner.evaluation_support_set:
                self.set_train()
                support_prediction_network = fpn
                # Inner loop
                task_predictions, task_labels = [], []
                support_loss = []
                for text, labels in support_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    diffopt.step(loss)

                    pred = model_utils.make_prediction(
                        output["logits"].detach())
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(labels.tolist())
                results = model_utils.calculate_metrics(
                    task_predictions, task_labels)
                self.logger.info(
                    "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(support_loss),
                                               results["accuracy"],
                                               results["precision"],
                                               results["recall"],
                                               results["f1"]))
                self.set_eval()
            else:
                support_prediction_network = self.pn
            if prediction_network is None:
                prediction_network = support_prediction_network

            self.set_eval()
            all_losses, all_predictions, all_labels = [], [], []
            for i, (text, labels, datasets) in enumerate(dataloader):
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text,
                                      labels,
                                      prediction_network,
                                      no_grad=True)
                loss = self.loss_fn(output["logits"], labels)
                loss = loss.item()
                pred = model_utils.make_prediction(output["logits"].detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                # if i % 20 == 0:
                #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))
        return results
Пример #12
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")

        ### GET SUPPORT SET REPRESENTATIONS ###
        with torch.no_grad():
            representations, all_labels = self.get_representations(
                support_set[:1])
            representations_merged = torch.cat(representations)
            class_means, unique_labels = self.get_class_means(
                representations_merged, all_labels)

        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        if do_memory_update:
            ### UPDATE MEMORY ###
            updated_memory_representations = self.memory.update(
                class_means, unique_labels, logger=self.logger)

        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = class_means.detach()
        elif self.config.learner.prototypes == "memory":
            prototypes = self.memory.class_representations  # doesn't track prototype gradients
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
            self.init_prototypical_classifier(prototypes,
                                              linear_module=fpn.linear)

            self.logger.debug(
                "----------------- SUPPORT SET ----------------- ")
            ### TRAIN LINEAR CLASSIFIER ON SUPPORT SET ###
            # Inner loop
            for i, (text, labels) in enumerate(support_set):
                self.logger.debug(
                    f"----------------- {i}th Update ----------------- ")
                labels = torch.tensor(labels).to(self.device)
                # if i == 0:
                #     output = {
                #         "representation": representations[0],
                #         "logits": fpn(representations[0], out_from="linear")
                #     }
                # else:
                output = self.forward(text, labels, fpn)

                # for logging purposes
                prototype_distances = (output["representation"].unsqueeze(1) -
                                       prototypes).norm(dim=-1)
                closest_dists, closest_classes = prototype_distances.topk(
                    3, largest=False)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 2) for z in x[2].tolist()]),
                            list(zip(labels, closest_classes,
                                     closest_dists)))))
                self.logger.debug(
                    f"True labels, closest prototypes, and distances:\n{to_print}"
                )

                topk = output["logits"].topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) before update:\n{to_print}"
                )

                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                # see how much linear classifier has changed
                with torch.no_grad():
                    topk = fpn(output["representation"],
                               out_from="linear").topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) after update:\n{to_print}"
                )

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            self.logger.debug(
                "----------------- QUERY SET  ----------------- ")
            ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, prediction_network=fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    topk = output['logits'].topk(5, dim=1)
                    to_print = pprint.pformat(
                        list(
                            map(
                                lambda x: (x[0].item(), x[1].tolist(
                                ), [round(z, 3) for z in x[2].tolist()]),
                                list(zip(labels, topk[1], topk[0])))))
                    self.logger.debug(
                        f"(label, topk_classes, topk_logits):\n{to_print}")

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")
Пример #13
0
    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []

        def add_none(iterator):
            yield None
            for x in iterator:
                yield x

        shifted_dataloader = add_none(train_dataloader)
        # prototypes = self.memory.class_representations
        for i, (support_set, (query_text, query_labels,
                              datasets)) in enumerate(
                                  zip(shifted_dataloader, train_dataloader)):
            query_labels = torch.tensor(query_labels).to(self.device)
            # happens on the first one
            # prototypes = self.memory.class_representations
            if support_set is None:
                prototypes = self.memory.class_representations
            else:
                support_text, support_labels, _ = support_set
                support_labels = torch.tensor(support_labels).to(self.device)
                support_representations = self.forward(
                    support_text, support_labels)["representation"]
                support_class_means, unique_labels = model_utils.get_class_means(
                    support_representations, support_labels)
                memory_update = self.memory.update(support_class_means,
                                                   unique_labels,
                                                   logger=self.logger)
                updated_memory_representations = memory_update[
                    "new_class_representations"]
                self.log_discounts(memory_update["class_discount"],
                                   unique_labels,
                                   few_shot_examples_seen=(i + 1) *
                                   self.config.testing.few_shot_batch_size)
                prototypes = updated_memory_representations
                if self.config.learner.few_shot_detach_prototypes:
                    prototypes = prototypes.detach()
            weight = 2 * prototypes
            bias = -(prototypes**2).sum(dim=1)
            query_representations = self.forward(
                query_text, query_labels)["representation"]
            logits = query_representations @ weight.T + bias

            loss = self.loss_fn(logits, query_labels)

            self.meta_optimizer.zero_grad()
            loss.backward()
            self.meta_optimizer.step()

            predictions = model_utils.make_prediction(logits.detach())
            all_predictions.extend(predictions.tolist())
            all_labels.extend(query_labels.tolist())
            dataset_results = self.evaluate(dataloader=eval_dataloader)
            self.log_few_shot(all_predictions,
                              all_labels,
                              datasets,
                              dataset_results,
                              increment_counters,
                              query_text,
                              i,
                              split=split)
            if (i * self.config.testing.few_shot_batch_size
                ) % self.mini_batch_size == 0 and i > 0:
                all_predictions, all_labels = [], []
        self.few_shot_end()
Пример #14
0
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")
        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                           copy_initial_weights=False,
        #                           track_higher_grads=False) as (fpn, diffopt):
        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        ### GET SUPPORT SET REPRESENTATIONS ###
        self.logger.debug("----------------- SUPPORT SET ----------------- ")
        representations, all_labels = self.get_representations(support_set[:1])
        representations_merged = torch.cat(representations)
        class_means, unique_labels = model_utils.get_class_means(
            representations_merged, all_labels)
        self._examples_seen += len(representations_merged)
        self.logger.debug(
            f"Examples seen increased by {len(representations_merged)}")

        ### UPDATE MEMORY ###
        if do_memory_update:
            memory_update = self.memory.update(class_means,
                                               unique_labels,
                                               logger=self.logger)
            updated_memory_representations = memory_update[
                "new_class_representations"]
            self.log_discounts(memory_update["class_discount"], unique_labels)
        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = expand_class_representations(
                self.memory.class_representations, class_means, unique_labels)
        elif self.config.learner.prototypes == "memory":
            prototypes = updated_memory_representations
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
        # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear)
        weight = 2 * prototypes  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1)

        self.logger.debug("----------------- QUERY SET  ----------------- ")
        ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
        # Outer loop
        if query_set is not None:
            for text, labels in query_set:
                labels = torch.tensor(labels).to(self.device)
                query_representations = self.forward(text,
                                                     labels)["representation"]

                # distance query representations to prototypes (BATCH X N_PROTOTYPES)
                # distances = euclidean_dist(query_representations, prototypes)
                # logits = - distances
                logits = query_representations @ weight.T + bias
                loss = self.loss_fn(logits, labels)
                # log_probability = F.log_softmax(-distances, dim=1)
                # loss is negation of the log probability, index using the labels for each observation
                # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean()
                self.meta_optimizer.zero_grad()
                loss.backward()
                self.meta_optimizer.step()

                predictions = model_utils.make_prediction(logits.detach())
                # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions])
                # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(),
                #                         [round(z, 3) for z in x[2].tolist()]),
                #                         list(zip(labels, predictions, distances)))))
                self.logger.debug(
                    f"Unique Labels: {unique_labels.tolist()}\n"
                    # f"Labels, Indices, Predictions, Distances:\n{to_print}\n"
                    f"Loss:\n{loss.item()}\n"
                    f"Predictions:\n{predictions}\n")
                self.update_query_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)
                self.logger.debug(f"Examples seen increased by {len(text)}")

            # Meta optimizer step
            # self.meta_optimizer.step()
            # self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")