Esempio n. 1
0
class Baseline(Learner):
    def __init__(self, config, **kwargs):
        """
        Baseline models: sequential and multitask setup.
        """
        super().__init__(config, **kwargs)
        self.lr = config.learner.lr
        self.type = config.learner.type
        self.n_epochs = config.training.epochs
        self.log_freq = config.training.log_freq
        self.model = TransformerClsModel(model_name=config.learner.model_name,
                                         n_classes=config.data.n_classes,
                                         max_length=config.data.max_length,
                                         device=self.device)
        self.logger.info("Loaded {} as model".format(
            self.model.__class__.__name__))
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)

    def training(self, datasets, **kwargs):
        datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset = self.prepare_data(
            datasets)
        if self.config.learner.multitask:
            data = ConcatDataset([
                ClassificationDataset(d.dataset.name,
                                      d.dataset.data.iloc[d.indices])
                for d in datas
            ])
            train_dataloader = DataLoader(data,
                                          batch_size=self.mini_batch_size,
                                          shuffle=True)
            self.train(dataloader=train_dataloader, datasets=datasets)
            return
        for i, (data, dataset_name,
                n_sample) in enumerate(zip(datas, order, n_samples)):
            self.logger.info(
                f"Observing dataset {dataset_name} for {n_sample} samples. "
                f"Evaluation={dataset_name=='evaluation'}")
            if dataset_name == "evaluation":
                self.few_shot_testing(train_dataset=eval_train_dataset,
                                      eval_dataset=eval_eval_dataset,
                                      split="test",
                                      increment_counters=False)
            else:
                train_dataloader = DataLoader(data,
                                              batch_size=self.mini_batch_size,
                                              shuffle=False)
                self.train(dataloader=train_dataloader,
                           datasets=datasets,
                           dataset_name=dataset_name,
                           max_samples=n_sample)
            if i == 0:
                self.metrics["eval_task_first_encounter_evaluation"] = \
                    self.evaluate(DataLoader(eval_dataset, batch_size=self.mini_batch_size))["accuracy"]
            if dataset_name == self.config.testing.eval_dataset:
                self.eval_task_first_encounter = False

    def train(self,
              dataloader=None,
              datasets=None,
              dataset_name=None,
              max_samples=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        replay_freq, replay_steps = self.replay_parameters(metalearner=False)

        episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
        for _ in range(self.n_epochs):
            for text, labels, datasets in dataloader:
                output = self.training_step(text, labels)
                task = datasets[0]

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_tracker(output, predictions, labels)

                metrics = model_utils.calculate_metrics(
                    self.tracker["predictions"], self.tracker["labels"])
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task
                }
                self.metrics["online"].append(online_metrics)
                if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                if self.current_iter % self.log_freq == 0:
                    self.log()
                    self.write_metrics()
                if self.current_iter % self.validate_freq == 0:
                    self.validate(
                        val_datasets,
                        n_samples=self.config.training.n_validation_samples)
                if self.replay_rate != 0 and (self.current_iter +
                                              1) % replay_freq == 0:
                    self.replay_training_step(replay_steps,
                                              episode_samples_seen,
                                              max_samples)
                self.memory.write_batch(text, labels)
                self._examples_seen += len(text)
                episode_samples_seen += len(text)
                self.current_iter += 1
                if max_samples is not None and episode_samples_seen >= max_samples:
                    break

    def training_step(self, text, labels):
        self.set_train()
        labels = torch.tensor(labels).to(self.device)

        input_dict = self.model.encode_text(text)
        logits = self.model(input_dict)
        loss = self.loss_fn(logits, labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = loss.item()

        return {"logits": logits, "loss": loss}

    def replay_training_step(self, replay_steps, episode_samples_seen,
                             max_samples):
        self.optimizer.zero_grad()
        for _ in range(replay_steps):
            text, labels = self.memory.read_batch(
                batch_size=self.mini_batch_size)
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            output = self.model(input_dict)
            loss = self.loss_fn(output, labels)
            loss.backward()
            self._examples_seen += len(text)
            self.metrics["replay_samples_seen"] += len(text)
            episode_samples_seen += len(text)
            if max_samples is not None and episode_samples_seen >= max_samples:
                break

        params = [p for p in self.model.parameters() if p.requires_grad]
        torch.nn.utils.clip_grad_norm(params,
                                      self.config.learner.clip_grad_norm)
        self.optimizer.step()

    def log(self):
        metrics = model_utils.calculate_metrics(self.tracker["predictions"],
                                                self.tracker["labels"])
        self.logger.info(
            "Iteration {} - Metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(self.current_iter + 1,
                                       np.mean(self.tracker["losses"]),
                                       metrics["accuracy"],
                                       metrics["precision"], metrics["recall"],
                                       metrics["f1"]))
        if self.config.wandb:
            wandb.log({
                "accuracy": metrics["accuracy"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1": metrics["f1"],
                "loss": np.mean(self.tracker["losses"]),
                "examples_seen": self.examples_seen()
            })
        self.reset_tracker()

    def evaluate(self, dataloader):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            # if i % 20 == 0:
            #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        metrics = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       metrics["accuracy"],
                                       metrics["precision"], metrics["recall"],
                                       metrics["f1"]))

        return {
            "accuracy": metrics["accuracy"],
            "precision": metrics["precision"],
            "recall": metrics["recall"],
            "f1": metrics["f1"]
        }
Esempio n. 2
0
class ANML(Learner):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.inner_lr = config.learner.inner_lr
        self.meta_lr = config.learner.meta_lr
        self.write_prob = config.learner.write_prob
        self.mini_batch_size = config.training.batch_size

        self.nm = TransformerNeuromodulator(
            model_name=config.learner.model_name, device=self.device)
        self.pn = TransformerClsModel(model_name=config.learner.model_name,
                                      n_classes=config.data.n_classes,
                                      max_length=config.data.max_length,
                                      device=self.device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        self.loss_fn = nn.CrossEntropyLoss()

        self.logger.info("Loaded {} as NM".format(self.nm.__class__.__name__))
        self.logger.info("Loaded {} as PN".format(self.pn.__class__.__name__))

        meta_params = [p for p in self.nm.parameters() if p.requires_grad] + \
                      [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)

    def training(self, datasets, **kwargs):
        replay_freq, replay_steps = self.replay_parameters()
        self.logger.info("Replay frequency: {}".format(replay_freq))
        self.logger.info("Replay steps: {}".format(replay_steps))

        datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset = self.prepare_data(
            datasets)
        for i, (data, dataset_name,
                n_sample) in enumerate(zip(datas, order, n_samples)):
            self.logger.info(
                f"Observing dataset {dataset_name} for {n_sample} samples. "
                f"Evaluation={dataset_name=='evaluation'}")
            if dataset_name == "evaluation":
                self.few_shot_testing(train_dataset=eval_train_dataset,
                                      eval_dataset=eval_eval_dataset,
                                      split="test",
                                      increment_counters=False)
            else:
                train_dataloader = iter(
                    DataLoader(data,
                               batch_size=self.mini_batch_size,
                               shuffle=False))
                self.episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
                # iterate over episodes
                while True:
                    self.set_train()
                    support_set, task = self.get_support_set(
                        train_dataloader, n_sample)
                    # TODO: return flag that indicates whether the query set is from the memory. Don't include this in the online accuracy calc
                    query_set = self.get_query_set(train_dataloader,
                                                   replay_freq, replay_steps,
                                                   n_sample)
                    if support_set is None or query_set is None:
                        break

                    self.training_step(support_set, query_set, task=task)

                    self.meta_training_log()
                    self.write_metrics()
                    self.current_iter += 1
                    if self.episode_samples_seen >= n_sample:
                        break
            if i == 0:
                self.metrics["eval_task_first_encounter_evaluation"] = \
                    self.evaluate(DataLoader(eval_dataset, batch_size=self.mini_batch_size))["accuracy"]
            if dataset_name == self.config.testing.eval_dataset:
                self.eval_task_first_encounter = False

    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()
        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            # Inner loop
            for text, labels in support_set:
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text, labels, fpn)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)
                self.memory.write_batch(text, labels)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()

    def forward(self, text, labels, prediction_network=None, no_grad=False):
        if prediction_network is None:
            prediction_network = self.pn
        if no_grad:
            with torch.no_grad():
                input_dict = self.pn.encode_text(text)
                representation = prediction_network(input_dict,
                                                    out_from="transformers")
                modulation = self.nm(input_dict)
                logits = prediction_network(representation * modulation,
                                            out_from="linear")
        else:
            input_dict = self.pn.encode_text(text)
            representation = prediction_network(input_dict,
                                                out_from="transformers")
            modulation = self.nm(input_dict)
            modulated_representation = representation * modulation
            logits = prediction_network(modulated_representation,
                                        out_from="linear")

        return {"logits": logits}

    def update_meta_gradients(self, loss, fpn):
        # NM meta gradients
        nm_params = [p for p in self.nm.parameters() if p.requires_grad]
        meta_nm_grads = torch.autograd.grad(loss,
                                            nm_params,
                                            retain_graph=True,
                                            allow_unused=True)
        for param, meta_grad in zip(nm_params, meta_nm_grads):
            if meta_grad is not None:
                if param.grad is not None:
                    param.grad += meta_grad.detach()
                else:
                    param.grad = meta_grad.detach()

        # PN meta gradients
        pn_params = [p for p in fpn.parameters() if p.requires_grad]
        meta_pn_grads = torch.autograd.grad(loss, pn_params, allow_unused=True)
        pn_params = [p for p in self.pn.parameters() if p.requires_grad]
        for param, meta_grad in zip(pn_params, meta_pn_grads):
            if meta_grad is not None:
                if param.grad is not None:
                    param.grad += meta_grad.detach()
                else:
                    param.grad = meta_grad.detach()

    def update_support_tracker(self, loss, pred, labels):
        self.tracker["support_loss"].append(loss.item())
        self.tracker["support_predictions"].extend(pred.tolist())
        self.tracker["support_labels"].extend(labels.tolist())

    def update_query_tracker(self, loss, pred, labels):
        self.tracker["query_loss"].append(loss.item())
        self.tracker["query_predictions"].extend(pred.tolist())
        self.tracker["query_labels"].extend(labels.tolist())

    def reset_tracker(self):
        self.tracker = {
            "support_loss": [],
            "support_predictions": [],
            "support_labels": [],
            "query_loss": [],
            "query_predictions": [],
            "query_labels": []
        }

    def evaluate(self, dataloader, prediction_network=None):
        if self.config.learner.evaluation_support_set:
            support_set = []
            for _ in range(self.config.learner.updates):
                text, labels = self.memory.read_batch(
                    batch_size=self.mini_batch_size)
                support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            if self.config.learner.evaluation_support_set:
                self.set_train()
                support_prediction_network = fpn
                # Inner loop
                task_predictions, task_labels = [], []
                support_loss = []
                for text, labels in support_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    diffopt.step(loss)

                    pred = model_utils.make_prediction(
                        output["logits"].detach())
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(labels.tolist())
                results = model_utils.calculate_metrics(
                    task_predictions, task_labels)
                self.logger.info(
                    "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(support_loss),
                                               results["accuracy"],
                                               results["precision"],
                                               results["recall"],
                                               results["f1"]))
                self.set_eval()
            else:
                support_prediction_network = self.pn
            if prediction_network is None:
                prediction_network = support_prediction_network

            self.set_eval()
            all_losses, all_predictions, all_labels = [], [], []
            for i, (text, labels, datasets) in enumerate(dataloader):
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text,
                                      labels,
                                      prediction_network,
                                      no_grad=True)
                loss = self.loss_fn(output["logits"], labels)
                loss = loss.item()
                pred = model_utils.make_prediction(output["logits"].detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                # if i % 20 == 0:
                #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))
        return results

    def model_state(self):
        return {"nm": self.nm.state_dict(), "pn": self.pn.state_dict()}

    def optimizer_state(self):
        return self.meta_optimizer.state_dict()

    def load_model_state(self, checkpoint):
        self.nm.load_state_dict(checkpoint["model_state"]["nm"])
        self.pn.load_state_dict(checkpoint["model_state"]["pn"])

    def load_optimizer_state(self, checkpoint):
        self.meta_optimizer.load_state_dict(checkpoint["optimizer"])

    def save_other_state_information(self, state):
        """Any learner specific state information is added here"""
        state["memory"] = self.memory
        return state

    def load_other_state_information(self, checkpoint):
        self.memory = checkpoint["memory"]

    def set_eval(self):
        self.nm.eval()
        self.pn.eval()

    def set_train(self):
        self.nm.train()
        self.pn.train()

    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []
        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            self.pn.train()
            # Inner loop
            for i, (text, labels, datasets) in enumerate(train_dataloader):
                labels = torch.tensor(labels).to(self.device)
                output = self.forward(text, labels, fpn)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                all_predictions.extend(predictions.tolist())
                all_labels.extend(labels.tolist())
                dataset_results = self.evaluate(dataloader=eval_dataloader,
                                                prediction_network=fpn)
                self.log_few_shot(all_predictions,
                                  all_labels,
                                  datasets,
                                  dataset_results,
                                  increment_counters,
                                  text,
                                  i,
                                  split=split)
                if (i * self.config.testing.few_shot_batch_size
                    ) % self.mini_batch_size == 0 and i > 0:
                    all_predictions, all_labels = [], []
        self.few_shot_end()
Esempio n. 3
0
class MemoryProtomaml(Learner):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.inner_lr = config.learner.inner_lr
        self.meta_lr = config.learner.meta_lr
        self.mini_batch_size = config.training.batch_size

        self.pn = TransformerClsModel(model_name=config.learner.model_name,
                                      n_classes=config.data.n_classes,
                                      max_length=config.data.max_length,
                                      device=self.device)

        # self.encoder = TransformerRLN(model_name=config.learner.model_name,
        #                               max_length=config.data.max_length,
        #                               device=self.device)
        # self.classifier = nn.Linear(TRANSFORMER_HDIM, config.data.n_classes).to(self.device)
        #         self.memory = MemoryStore(memory_size=config.learner.memory_size, key_dim=TRANSFORMER_HDIM, device=self.device)
        self.memory = ClassMemoryStore(
            key_dim=TRANSFORMER_HDIM,
            device=self.device,
            class_discount=config.learner.class_discount,
            n_classes=config.data.n_classes,
            discount_method=config.learner.class_discount_method)
        self.loss_fn = nn.CrossEntropyLoss()

        # self.logger.info("Loaded {} as encoder".format(self.encoder.__class__.__name__))
        # meta_params = [p for p in self.encoder.parameters() if p.requires_grad]
        # self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        # inner_params = [p for p in self.classifier.parameters() if p.requires_grad]
        # self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)

        meta_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
        #TODO: remove below line
        self.episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well

    def training(self, datasets, **kwargs):
        representations_log = []
        replay_freq, replay_steps = self.replay_parameters()
        self.logger.info("Replay frequency: {}".format(replay_freq))
        self.logger.info("Replay steps: {}".format(replay_steps))

        datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset = self.prepare_data(
            datasets)
        for i, (data, dataset_name,
                n_sample) in enumerate(zip(datas, order, n_samples)):
            self.logger.info(
                f"Observing dataset {dataset_name} for {n_sample} samples. "
                f"Evaluation={dataset_name=='evaluation'}")
            if dataset_name == "evaluation":
                self.few_shot_testing(train_dataset=eval_train_dataset,
                                      eval_dataset=eval_eval_dataset,
                                      split="test",
                                      increment_counters=False)
            else:
                train_dataloader = iter(
                    DataLoader(data,
                               batch_size=self.mini_batch_size,
                               shuffle=False))
                self.episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
                # iterate over episodes
                while True:
                    self.set_train()
                    support_set, task = self.get_support_set(
                        train_dataloader, n_sample)
                    # TODO: return flag that indicates whether the query set is from the memory. Don't include this in the online accuracy calc
                    query_set = self.get_query_set(train_dataloader,
                                                   replay_freq,
                                                   replay_steps,
                                                   n_sample,
                                                   write_memory=False)
                    if support_set is None or query_set is None:
                        break
                    self.training_step(support_set, query_set, task=task)

                    if self.current_iter % 5 == 0:
                        class_representations = self.memory.class_representations
                        extra_text, extra_labels, datasets = next(
                            self.extra_dataloader)
                        with torch.no_grad():
                            extra_representations = self.forward(
                                extra_text, extra_labels,
                                no_grad=True)["representation"]
                            query_text, query_labels = query_set[0]
                            query_representations = self.forward(
                                query_text, query_labels,
                                no_grad=True)["representation"]
                            extra_dist, extra_dist_normalized, extra_unique_labels = model_utils.class_dists(
                                extra_representations, extra_labels,
                                class_representations)
                            query_dist, query_dist_normalized, query_unique_labels = model_utils.class_dists(
                                query_representations, query_labels,
                                class_representations)
                            class_representation_distances = model_utils.euclidean_dist(
                                class_representations, class_representations)
                            representations_log.append({
                                "query_dist":
                                query_dist.tolist(),
                                "query_dist_normalized":
                                query_dist_normalized.tolist(),
                                "query_labels":
                                query_labels.tolist(),
                                "query_unique_labels":
                                query_unique_labels.tolist(),
                                "extra_dist":
                                extra_dist.tolist(),
                                "extra_dist_normalized":
                                extra_dist_normalized.tolist(),
                                "extra_labels":
                                extra_labels.tolist(),
                                "extra_unique_labels":
                                extra_unique_labels.tolist(),
                                "class_representation_distances":
                                class_representation_distances.tolist(),
                                "class_tsne":
                                TSNE().fit_transform(
                                    class_representations.cpu()).tolist(),
                                "current_iter":
                                self.current_iter,
                                "examples_seen":
                                self.examples_seen()
                            })
                            if self.current_iter % 100 == 0:
                                with open(
                                        self.representations_dir /
                                        f"classDists_{self.current_iter}.json",
                                        "w") as f:
                                    json.dump(representations_log, f)
                                representations_log = []

                    self.meta_training_log()
                    self.write_metrics()
                    self.current_iter += 1
                    if self.episode_samples_seen >= n_sample:
                        break
            if i == 0:
                self.metrics["eval_task_first_encounter_evaluation"] = \
                    self.evaluate(DataLoader(eval_dataset, batch_size=self.mini_batch_size))["accuracy"]
            if dataset_name == self.config.testing.eval_dataset:
                self.eval_task_first_encounter = False

    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")

        ### GET SUPPORT SET REPRESENTATIONS ###
        with torch.no_grad():
            representations, all_labels = self.get_representations(
                support_set[:1])
            representations_merged = torch.cat(representations)
            class_means, unique_labels = self.get_class_means(
                representations_merged, all_labels)

        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        if do_memory_update:
            ### UPDATE MEMORY ###
            updated_memory_representations = self.memory.update(
                class_means, unique_labels, logger=self.logger)

        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = class_means.detach()
        elif self.config.learner.prototypes == "memory":
            prototypes = self.memory.class_representations  # doesn't track prototype gradients
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
            self.init_prototypical_classifier(prototypes,
                                              linear_module=fpn.linear)

            self.logger.debug(
                "----------------- SUPPORT SET ----------------- ")
            ### TRAIN LINEAR CLASSIFIER ON SUPPORT SET ###
            # Inner loop
            for i, (text, labels) in enumerate(support_set):
                self.logger.debug(
                    f"----------------- {i}th Update ----------------- ")
                labels = torch.tensor(labels).to(self.device)
                # if i == 0:
                #     output = {
                #         "representation": representations[0],
                #         "logits": fpn(representations[0], out_from="linear")
                #     }
                # else:
                output = self.forward(text, labels, fpn)

                # for logging purposes
                prototype_distances = (output["representation"].unsqueeze(1) -
                                       prototypes).norm(dim=-1)
                closest_dists, closest_classes = prototype_distances.topk(
                    3, largest=False)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 2) for z in x[2].tolist()]),
                            list(zip(labels, closest_classes,
                                     closest_dists)))))
                self.logger.debug(
                    f"True labels, closest prototypes, and distances:\n{to_print}"
                )

                topk = output["logits"].topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) before update:\n{to_print}"
                )

                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                # see how much linear classifier has changed
                with torch.no_grad():
                    topk = fpn(output["representation"],
                               out_from="linear").topk(5, dim=1)
                to_print = pprint.pformat(
                    list(
                        map(
                            lambda x: (x[0].item(), x[1].tolist(),
                                       [round(z, 3) for z in x[2].tolist()]),
                            list(zip(labels, topk[1], topk[0])))))
                self.logger.debug(
                    f"(label, topk_classes, topk_logits) after update:\n{to_print}"
                )

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_support_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)

            self.logger.debug(
                "----------------- QUERY SET  ----------------- ")
            ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
            # Outer loop
            if query_set is not None:
                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, prediction_network=fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    self.update_meta_gradients(loss, fpn)

                    topk = output['logits'].topk(5, dim=1)
                    to_print = pprint.pformat(
                        list(
                            map(
                                lambda x: (x[0].item(), x[1].tolist(
                                ), [round(z, 3) for z in x[2].tolist()]),
                                list(zip(labels, topk[1], topk[0])))))
                    self.logger.debug(
                        f"(label, topk_classes, topk_logits):\n{to_print}")

                    predictions = model_utils.make_prediction(
                        output["logits"].detach())
                    self.update_query_tracker(loss, predictions, labels)
                    metrics = model_utils.calculate_metrics(
                        predictions.tolist(), labels.tolist())
                    online_metrics = {
                        "accuracy": metrics["accuracy"],
                        "examples_seen": self.examples_seen(),
                        "task": task if task is not None else "none"
                    }
                    self.metrics["online"].append(online_metrics)
                    if task is not None and task == self.config.testing.eval_dataset and \
                        self.eval_task_first_encounter:
                        self.metrics["eval_task_first_encounter"].append(
                            online_metrics)
                    self._examples_seen += len(text)

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")

    def get_representations(self, support_set, prediction_network=None):
        representations = []
        all_labels = []
        for text, labels in support_set:
            labels = torch.tensor(labels).to(self.device)
            all_labels.extend(labels.tolist())
            # labels = labels.to(self.device)
            output = self.forward(text,
                                  labels,
                                  prediction_network=prediction_network,
                                  update_memory=False)
            representations.append(output["representation"])
        return representations, all_labels

    def forward(self,
                text,
                labels,
                prediction_network=None,
                update_memory=False,
                no_grad=False):
        if prediction_network is None:
            prediction_network = self.pn
        input_dict = self.pn.encode_text(text)
        context_manager = torch.no_grad() if no_grad else nullcontext()
        with context_manager:
            representation = prediction_network(input_dict,
                                                out_from="transformers")
            logits = prediction_network(representation, out_from="linear")

        if update_memory:
            self.memory.add_entry(embeddings=representation.detach(),
                                  labels=labels,
                                  query_result=None)
        return {"representation": representation, "logits": logits}

    def update_memory(self, class_means, unique_labels):
        to_update = unique_labels
        # selection of old class representations here
        old_class_representations = self.memory.class_representations[
            to_update]
        # if old class representations haven't been given values yet, don't bias towards 0 by exponential update
        if (old_class_representations == 0).bool().all():
            new_class_representations = class_means
        else:
            # memory update rule here
            new_class_representations = (
                1 - self.config.learner.class_discount
            ) * old_class_representations + self.config.learner.class_discount * class_means
        self.logger.debug(
            f"Updating class representations for classes {unique_labels}.\n"
            f"Distance old class representations and class means: {[round(z, 2) for z in (old_class_representations - class_means).norm(dim=1).tolist()]}\n"
            f"Distance old and new class representations: {[round(z, 2) for z in (new_class_representations - old_class_representations).norm(dim=1).tolist()]}"
        )
        # update memory
        self.memory.class_representations[
            to_update] = new_class_representations.detach()

    def init_prototypical_classifier(self, prototypes, linear_module=None):
        if linear_module is None:
            linear_module = self.pn.linear
        weight = 2 * prototypes  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1)
        # otherwise the bias of the classes observed in the support set is always smaller than
        # not observed ones, which favors the unobserved ones. However, it is expected that labels
        # in the support set are more likely to be in the query set.
        bias_unchanged = bias == 0
        bias[bias_unchanged] = bias.min()
        self.logger.info(
            f"Prototype is zero vector for classes {bias_unchanged.nonzero(as_tuple=True)[0].tolist()}. "
            f"Setting their bias entries to the minimum of the uninitialized bias vector."
        )
        # prototypical-equivalent network initialization
        linear_module.weight.data = weight
        linear_module.bias.data = bias
        self.logger.info(f"Classifier bias initialized to {bias}.")

        # a = mmaml.classifier.weight
        # # https://stackoverflow.com/questions/61279403/gradient-flow-through-torch-nn-parameter
        # # a = torch.nn.Parameter(torch.ones((10,)), requires_grad=True)
        # b = a[:] # silly hack to convert in a raw tensor including the computation graph
        # # b.retain_grad() # Otherwise backward pass will not store the gradient since it is not a leaf
        # it is necessary to do it this way to retain the gradient information on the classifier parameters
        # https://discuss.pytorch.org/t/non-leaf-variables-as-a-modules-parameters/65775
        # del self.classifier.weight
        # self.classifier.weight = 2 * prototypes
        # del self.classifier.bias
        # self.classifier.bias = bias
        # weight_copy = self.classifier.weight[:]
        # bias_copy = self.classifier.bias[:]

    def update_meta_gradients(self, loss, fpn):
        # PN meta gradients
        pn_params = [p for p in fpn.parameters() if p.requires_grad]
        meta_pn_grads = torch.autograd.grad(loss, pn_params, allow_unused=True)
        pn_params = [p for p in self.pn.parameters() if p.requires_grad]
        for param, meta_grad in zip(pn_params, meta_pn_grads):
            if meta_grad is not None:
                if param.grad is not None:
                    param.grad += meta_grad.detach()
                else:
                    param.grad = meta_grad.detach()

    def update_support_tracker(self, loss, pred, labels):
        self.tracker["support_loss"].append(loss.item())
        self.tracker["support_predictions"].extend(pred.tolist())
        self.tracker["support_labels"].extend(labels.tolist())

    def update_query_tracker(self, loss, pred, labels):
        self.tracker["query_loss"].append(loss.item())
        self.tracker["query_predictions"].extend(pred.tolist())
        self.tracker["query_labels"].extend(labels.tolist())

    def reset_tracker(self):
        self.tracker = {
            "support_loss": [],
            "support_predictions": [],
            "support_labels": [],
            "query_loss": [],
            "query_predictions": [],
            "query_labels": []
        }

    def evaluate(self, dataloader, prediction_network=None):
        if self.config.learner.evaluation_support_set:
            support_set = []
            for _ in range(self.config.learner.updates):
                text, labels = self.memory.read_batch(
                    batch_size=self.mini_batch_size)
                support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            if self.config.learner.evaluation_support_set:
                self.set_train()
                support_prediction_network = fpn
                # Inner loop
                task_predictions, task_labels = [], []
                support_loss = []
                for text, labels in support_set:
                    labels = torch.tensor(labels).to(self.device)
                    # labels = labels.to(self.device)
                    output = self.forward(text, labels, fpn)
                    loss = self.loss_fn(output["logits"], labels)
                    diffopt.step(loss)

                    pred = model_utils.make_prediction(
                        output["logits"].detach())
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(labels.tolist())
                results = model_utils.calculate_metrics(
                    task_predictions, task_labels)
                self.logger.info(
                    "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(support_loss),
                                               results["accuracy"],
                                               results["precision"],
                                               results["recall"],
                                               results["f1"]))
                self.set_eval()
            else:
                support_prediction_network = self.pn
            if prediction_network is None:
                prediction_network = support_prediction_network

            self.set_eval()
            all_losses, all_predictions, all_labels = [], [], []
            for i, (text, labels, _) in enumerate(dataloader):
                labels = torch.tensor(labels).to(self.device)
                # labels = labels.to(self.device)
                output = self.forward(text,
                                      labels,
                                      prediction_network,
                                      no_grad=True)
                loss = self.loss_fn(output["logits"], labels)
                loss = loss.item()
                pred = model_utils.make_prediction(output["logits"].detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                # if i % 20 == 0:
                #     self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))
        return results

    def model_state(self):
        return {"pn": self.pn.state_dict()}

    def optimizer_state(self):
        return self.meta_optimizer.state_dict()

    def load_model_state(self, checkpoint):
        self.pn.load_state_dict(checkpoint["model_state"]["pn"])

    def load_optimizer_state(self, checkpoint):
        self.meta_optimizer.load_state_dict(checkpoint["optimizer"])

    def save_other_state_information(self, state):
        """Any learner specific state information is added here"""
        state["memory"] = self.memory
        return state

    def load_other_state_information(self, checkpoint):
        self.memory = checkpoint["memory"]

    def set_eval(self):
        self.pn.eval()

    def set_train(self):
        self.pn.train()

    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []

        self.init_prototypical_classifier(
            prototypes=self.memory.class_representations)
        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):
            # Inner loop
            for i, (text, labels, datasets) in enumerate(train_dataloader):
                self.set_train()
                labels = torch.tensor(labels).to(self.device)
                output = self.forward(text, labels, fpn)
                prototype_distances = (output['representation'] -
                                       self.memory.class_representations).norm(
                                           dim=1)
                class_distances = list(
                    map(
                        lambda x: (x[0].item(), x[1].item()),
                        list(
                            zip(torch.arange(len(prototype_distances)),
                                prototype_distances))))
                self.logger.info(
                    f"Ground truth: {labels} -- Prototype distances: {class_distances}"
                )
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                all_predictions.extend(predictions.tolist())
                all_labels.extend(labels.tolist())
                dataset_results = self.evaluate(dataloader=eval_dataloader,
                                                prediction_network=fpn)
                self.log_few_shot(all_predictions,
                                  all_labels,
                                  datasets,
                                  dataset_results,
                                  increment_counters,
                                  text,
                                  i,
                                  split=split)
                if (i * self.config.testing.few_shot_batch_size
                    ) % self.mini_batch_size == 0 and i > 0:
                    all_predictions, all_labels = [], []
        self.few_shot_end()

    def get_class_means(self, embeddings, labels):
        """Return class means and unique labels given neighbors.
        
        Parameters
        ---
        embeddings: Tensor, shape (batch_size, embed_size)
        labels: iterable of labels for each embedding
            
        Returns
        ---
        Tuple[List[Tensor], List[Tensor]]:
            class means and unique labels
        """
        class_means = []
        unique_labels = torch.tensor(labels).unique()
        for label_ in unique_labels:
            label_ixs = (label_ == torch.tensor(labels)).nonzero(
                as_tuple=False).flatten()
            same_class_embeddings = embeddings[label_ixs]
            class_means.append(same_class_embeddings.mean(axis=0))
        return torch.stack(class_means), unique_labels
Esempio n. 4
0
class PrototypicalNetwork(Learner):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.inner_lr = config.learner.inner_lr
        self.meta_lr = config.learner.meta_lr
        self.mini_batch_size = config.training.batch_size

        self.pn = TransformerClsModel(model_name=config.learner.model_name,
                                      n_classes=config.data.n_classes,
                                      max_length=config.data.max_length,
                                      device=self.device)
        if config.wandb:
            wandb.watch(self.pn, log='all')

        self.memory = ClassMemoryStore(
            key_dim=TRANSFORMER_HDIM,
            device=self.device,
            class_discount=config.learner.class_discount,
            n_classes=config.data.n_classes,
            discount_method=config.learner.class_discount_method)
        self.loss_fn = nn.CrossEntropyLoss()

        meta_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
        #TODO: remove below line
        self.episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well

    def training(self, datasets, **kwargs):
        representations_log = []
        replay_freq, replay_steps = self.replay_parameters()
        self.logger.info("Replay frequency: {}".format(replay_freq))
        self.logger.info("Replay steps: {}".format(replay_steps))

        datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset = self.prepare_data(
            datasets)
        for i, (data, dataset_name,
                n_sample) in enumerate(zip(datas, order, n_samples)):
            self.logger.info(
                f"Observing dataset {dataset_name} for {n_sample} samples. "
                f"Evaluation={dataset_name=='evaluation'}")
            if dataset_name == "evaluation" and self.config.testing.few_shot:
                self.few_shot_testing(train_dataset=eval_train_dataset,
                                      eval_dataset=eval_eval_dataset,
                                      split="test",
                                      increment_counters=False)
            else:
                train_dataloader = iter(
                    DataLoader(data,
                               batch_size=self.mini_batch_size,
                               shuffle=False))
                self.episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
                # iterate over episodes
                while True:
                    self.set_train()
                    support_set, task = self.get_support_set(
                        train_dataloader, n_sample)
                    # TODO: return flag that indicates whether the query set is from the memory. Don't include this in the online accuracy calc
                    query_set = self.get_query_set(train_dataloader,
                                                   replay_freq,
                                                   replay_steps,
                                                   n_sample,
                                                   write_memory=False)
                    if support_set is None or query_set is None:
                        break
                    self.training_step(support_set, query_set, task=task)

                    if self.current_iter % 5 == 0:
                        class_representations = self.memory.class_representations
                        extra_text, extra_labels, datasets = next(
                            self.extra_dataloader)
                        with torch.no_grad():
                            extra_representations = self.forward(
                                extra_text, extra_labels,
                                no_grad=True)["representation"]
                            query_text, query_labels = query_set[0]
                            query_representations = self.forward(
                                query_text, query_labels,
                                no_grad=True)["representation"]
                            extra_dist, extra_dist_normalized, extra_unique_labels = model_utils.class_dists(
                                extra_representations, extra_labels,
                                class_representations)
                            query_dist, query_dist_normalized, query_unique_labels = model_utils.class_dists(
                                query_representations, query_labels,
                                class_representations)
                            class_representation_distances = model_utils.euclidean_dist(
                                class_representations, class_representations)
                            representations_log.append({
                                "query_dist":
                                query_dist.tolist(),
                                "query_dist_normalized":
                                query_dist_normalized.tolist(),
                                "query_labels":
                                query_labels.tolist(),
                                "query_unique_labels":
                                query_unique_labels.tolist(),
                                "extra_dist":
                                extra_dist.tolist(),
                                "extra_dist_normalized":
                                extra_dist_normalized.tolist(),
                                "extra_labels":
                                extra_labels.tolist(),
                                "extra_unique_labels":
                                extra_unique_labels.tolist(),
                                "class_representation_distances":
                                class_representation_distances.tolist(),
                                "class_tsne":
                                TSNE().fit_transform(
                                    class_representations.cpu()).tolist(),
                                "current_iter":
                                self.current_iter,
                                "examples_seen":
                                self.examples_seen()
                            })
                            if self.current_iter % 100 == 0:
                                with open(
                                        self.representations_dir /
                                        f"classDists_{self.current_iter}.json",
                                        "w") as f:
                                    json.dump(representations_log, f)
                                representations_log = []

                    self.meta_training_log()
                    self.write_metrics()
                    self.current_iter += 1
                    if self.episode_samples_seen >= n_sample:
                        break
            if i == 0:
                self.metrics["eval_task_first_encounter_evaluation"] = \
                    self.evaluate(DataLoader(eval_dataset, batch_size=self.mini_batch_size))["accuracy"]
                # self.save_checkpoint("first_task_learned.pt", save_optimizer_state=True)
            if dataset_name == self.config.testing.eval_dataset:
                self.eval_task_first_encounter = False

    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug(
            "-------------------- TRAINING STEP  -------------------")
        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                           copy_initial_weights=False,
        #                           track_higher_grads=False) as (fpn, diffopt):
        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        ### GET SUPPORT SET REPRESENTATIONS ###
        self.logger.debug("----------------- SUPPORT SET ----------------- ")
        representations, all_labels = self.get_representations(support_set[:1])
        representations_merged = torch.cat(representations)
        class_means, unique_labels = model_utils.get_class_means(
            representations_merged, all_labels)
        self._examples_seen += len(representations_merged)
        self.logger.debug(
            f"Examples seen increased by {len(representations_merged)}")

        ### UPDATE MEMORY ###
        if do_memory_update:
            memory_update = self.memory.update(class_means,
                                               unique_labels,
                                               logger=self.logger)
            updated_memory_representations = memory_update[
                "new_class_representations"]
            self.log_discounts(memory_update["class_discount"], unique_labels)
        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = expand_class_representations(
                self.memory.class_representations, class_means, unique_labels)
        elif self.config.learner.prototypes == "memory":
            prototypes = updated_memory_representations
        else:
            raise AssertionError(
                "Prototype type not in {'class_means', 'memory'}, fix config file."
            )

        ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
        # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear)
        weight = 2 * prototypes  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1)

        self.logger.debug("----------------- QUERY SET  ----------------- ")
        ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
        # Outer loop
        if query_set is not None:
            for text, labels in query_set:
                labels = torch.tensor(labels).to(self.device)
                query_representations = self.forward(text,
                                                     labels)["representation"]

                # distance query representations to prototypes (BATCH X N_PROTOTYPES)
                # distances = euclidean_dist(query_representations, prototypes)
                # logits = - distances
                logits = query_representations @ weight.T + bias
                loss = self.loss_fn(logits, labels)
                # log_probability = F.log_softmax(-distances, dim=1)
                # loss is negation of the log probability, index using the labels for each observation
                # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean()
                self.meta_optimizer.zero_grad()
                loss.backward()
                self.meta_optimizer.step()

                predictions = model_utils.make_prediction(logits.detach())
                # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions])
                # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(),
                #                         [round(z, 3) for z in x[2].tolist()]),
                #                         list(zip(labels, predictions, distances)))))
                self.logger.debug(
                    f"Unique Labels: {unique_labels.tolist()}\n"
                    # f"Labels, Indices, Predictions, Distances:\n{to_print}\n"
                    f"Loss:\n{loss.item()}\n"
                    f"Predictions:\n{predictions}\n")
                self.update_query_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(
                    predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                self._examples_seen += len(text)
                self.logger.debug(f"Examples seen increased by {len(text)}")

            # Meta optimizer step
            # self.meta_optimizer.step()
            # self.meta_optimizer.zero_grad()
        self.logger.debug(
            "-------------------- TRAINING STEP END  -------------------")

    def get_representations(self, support_set, prediction_network=None):
        """
        Parameters
        ---
        support_set: List[Tuple[batch text, batch labels]]
        prediction network: pytorch module 

        Returns
        ---
        Tuple[List[Tensor], List[Int]] where the first result is the hidden representation and the second
        the labels.
        """
        representations = []
        all_labels = []
        for text, labels in support_set:
            labels = torch.tensor(labels).to(self.device)
            all_labels.extend(labels.tolist())
            # labels = labels.to(self.device)
            output = self.forward(text,
                                  labels,
                                  prediction_network=prediction_network)
            representations.append(output["representation"])
        return representations, all_labels

    def forward(self, text, labels, prediction_network=None, no_grad=False):
        if prediction_network is None:
            prediction_network = self.pn
        input_dict = self.pn.encode_text(text)
        context_manager = torch.no_grad() if no_grad else nullcontext()
        with context_manager:
            representation = prediction_network(input_dict,
                                                out_from="transformers")
            logits = prediction_network(representation, out_from="linear")
        return {"representation": representation, "logits": logits}

    def update_memory(self, class_means, unique_labels):
        to_update = unique_labels
        # selection of old class representations here
        old_class_representations = self.memory.class_representations[
            to_update]
        # if old class representations haven't been given values yet, don't bias towards 0 by exponential update
        if (old_class_representations == 0).bool().all():
            new_class_representations = class_means
        else:
            # memory update rule here
            new_class_representations = (
                1 - self.config.learner.class_discount
            ) * old_class_representations + self.config.learner.class_discount * class_means
        self.logger.debug(
            f"Updating class representations for classes {unique_labels}.\n"
            f"Distance old class representations and class means: {[round(z, 2) for z in (old_class_representations - class_means).norm(dim=1).tolist()]}\n"
            f"Distance old and new class representations: {[round(z, 2) for z in (new_class_representations - old_class_representations).norm(dim=1).tolist()]}"
        )
        # for returning new class representations while keeping gradients intact
        result = torch.clone(self.memory.class_representations)
        result[to_update] = new_class_representations
        # update memory
        self.memory.class_representations[
            to_update] = new_class_representations.detach()

        return result

    def init_prototypical_classifier(self, prototypes, linear_module=None):
        if linear_module is None:
            linear_module = self.pn.linear
        weight = 2 * prototypes / TRANSFORMER_HDIM  # divide by number of dimensions, otherwise blows up
        bias = -(prototypes**2).sum(dim=1) / TRANSFORMER_HDIM
        # otherwise the bias of the classes observed in the support set is always smaller than
        # not observed ones, which favors the unobserved ones. However, it is expected that labels
        # in the support set are more likely to be in the query set.
        bias_unchanged = bias == 0
        bias[bias_unchanged] = bias.min()
        self.logger.info(
            f"Prototype is zero vector for classes {bias_unchanged.nonzero(as_tuple=True)[0].tolist()}. "
            f"Setting their bias entries to the minimum of the uninitialized bias vector."
        )
        # prototypical-equivalent network initialization
        linear_module.weight.data = weight
        linear_module.bias.data = bias
        self.logger.info(f"Classifier bias initialized to {bias}.")

        # a = mmaml.classifier.weight
        # # https://stackoverflow.com/questions/61279403/gradient-flow-through-torch-nn-parameter
        # # a = torch.nn.Parameter(torch.ones((10,)), requires_grad=True)
        # b = a[:] # silly hack to convert in a raw tensor including the computation graph
        # # b.retain_grad() # Otherwise backward pass will not store the gradient since it is not a leaf
        # it is necessary to do it this way to retain the gradient information on the classifier parameters
        # https://discuss.pytorch.org/t/non-leaf-variables-as-a-modules-parameters/65775
        # del self.classifier.weight
        # self.classifier.weight = 2 * prototypes
        # del self.classifier.bias
        # self.classifier.bias = bias
        # weight_copy = self.classifier.weight[:]
        # bias_copy = self.classifier.bias[:]

    def update_meta_gradients(self, loss, fpn):
        # PN meta gradients
        pn_params = [p for p in fpn.parameters() if p.requires_grad]
        meta_pn_grads = torch.autograd.grad(loss, pn_params, allow_unused=True)
        pn_params = [p for p in self.pn.parameters() if p.requires_grad]
        for param, meta_grad in zip(pn_params, meta_pn_grads):
            if meta_grad is not None:
                if param.grad is not None:
                    param.grad += meta_grad.detach()
                else:
                    param.grad = meta_grad.detach()

    def update_support_tracker(self, loss, pred, labels):
        self.tracker["support_loss"].append(loss.item())
        self.tracker["support_predictions"].extend(pred.tolist())
        self.tracker["support_labels"].extend(labels.tolist())

    def update_query_tracker(self, loss, pred, labels):
        self.tracker["query_loss"].append(loss.item())
        self.tracker["query_predictions"].extend(pred.tolist())
        self.tracker["query_labels"].extend(labels.tolist())

    def reset_tracker(self):
        self.tracker = {
            "support_loss": [],
            "support_predictions": [],
            "support_labels": [],
            "query_loss": [],
            "query_predictions": [],
            "query_labels": []
        }

    def evaluate(self, dataloader, prediction_network=None):
        # if self.config.learner.evaluation_support_set:
        #     support_set = []
        #     for _ in range(self.config.learner.updates):
        #         text, labels = self.memory.read_batch(batch_size=self.mini_batch_size)
        #         support_set.append((text, labels))

        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                         copy_initial_weights=False,
        #                         track_higher_grads=False) as (fpn, diffopt):
        #     if self.config.learner.evaluation_support_set:
        #         self.set_train()
        #         support_prediction_network = fpn
        #         # Inner loop
        #         task_predictions, task_labels = [], []
        #         support_loss = []
        #         for text, labels in support_set:
        #             labels = torch.tensor(labels).to(self.device)
        #             # labels = labels.to(self.device)
        #             output = self.forward(text, labels, fpn)
        #             loss = self.loss_fn(output["logits"], labels)
        #             diffopt.step(loss)

        #             pred = model_utils.make_prediction(output["logits"].detach())
        #             support_loss.append(loss.item())
        #             task_predictions.extend(pred.tolist())
        #             task_labels.extend(labels.tolist())
        #         results = model_utils.calculate_metrics(task_predictions, task_labels)
        #         self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
        #                     "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"],
        #                     results["precision"], results["recall"], results["f1"]))
        #         self.set_eval()
        #     else:
        #         support_prediction_network = self.pn
        #     if prediction_network is None:
        #         prediction_network = support_prediction_network

        self.set_eval()
        prototypes = self.memory.class_representations
        weight = 2 * prototypes
        bias = -(prototypes**2).sum(dim=1)
        all_losses, all_predictions, all_labels = [], [], []
        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            representations = self.forward(text, labels)["representation"]
            logits = representations @ weight.T + bias
            # labels = labels.to(self.device)
            loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses),
                                       results["accuracy"],
                                       results["precision"], results["recall"],
                                       results["f1"]))
        return results

    def model_state(self):
        return {"pn": self.pn.state_dict()}

    def optimizer_state(self):
        return self.meta_optimizer.state_dict()

    def load_model_state(self, checkpoint):
        self.pn.load_state_dict(checkpoint["model_state"]["pn"])

    def load_optimizer_state(self, checkpoint):
        self.meta_optimizer.load_state_dict(checkpoint["optimizer"])

    def save_other_state_information(self, state):
        """Any learner specific state information is added here"""
        state["memory"] = self.memory
        return state

    def load_other_state_information(self, checkpoint):
        self.memory = checkpoint["memory"]

    def set_eval(self):
        self.pn.eval()

    def set_train(self):
        self.pn.train()

    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []

        def add_none(iterator):
            yield None
            for x in iterator:
                yield x

        shifted_dataloader = add_none(train_dataloader)
        # prototypes = self.memory.class_representations
        for i, (support_set, (query_text, query_labels,
                              datasets)) in enumerate(
                                  zip(shifted_dataloader, train_dataloader)):
            query_labels = torch.tensor(query_labels).to(self.device)
            # happens on the first one
            # prototypes = self.memory.class_representations
            if support_set is None:
                prototypes = self.memory.class_representations
            else:
                support_text, support_labels, _ = support_set
                support_labels = torch.tensor(support_labels).to(self.device)
                support_representations = self.forward(
                    support_text, support_labels)["representation"]
                support_class_means, unique_labels = model_utils.get_class_means(
                    support_representations, support_labels)
                memory_update = self.memory.update(support_class_means,
                                                   unique_labels,
                                                   logger=self.logger)
                updated_memory_representations = memory_update[
                    "new_class_representations"]
                self.log_discounts(memory_update["class_discount"],
                                   unique_labels,
                                   few_shot_examples_seen=(i + 1) *
                                   self.config.testing.few_shot_batch_size)
                prototypes = updated_memory_representations
                if self.config.learner.few_shot_detach_prototypes:
                    prototypes = prototypes.detach()
            weight = 2 * prototypes
            bias = -(prototypes**2).sum(dim=1)
            query_representations = self.forward(
                query_text, query_labels)["representation"]
            logits = query_representations @ weight.T + bias

            loss = self.loss_fn(logits, query_labels)

            self.meta_optimizer.zero_grad()
            loss.backward()
            self.meta_optimizer.step()

            predictions = model_utils.make_prediction(logits.detach())
            all_predictions.extend(predictions.tolist())
            all_labels.extend(query_labels.tolist())
            dataset_results = self.evaluate(dataloader=eval_dataloader)
            self.log_few_shot(all_predictions,
                              all_labels,
                              datasets,
                              dataset_results,
                              increment_counters,
                              query_text,
                              i,
                              split=split)
            if (i * self.config.testing.few_shot_batch_size
                ) % self.mini_batch_size == 0 and i > 0:
                all_predictions, all_labels = [], []
        self.few_shot_end()

    def log_discounts(self,
                      class_discount,
                      unique_labels,
                      few_shot_examples_seen=None):
        prefix = f"few_shot_{self.few_shot_counter}_" if few_shot_examples_seen is not None else ""
        discounts = {prefix + "class_discount": {}}
        if not isinstance(class_discount, float) and not isinstance(
                class_discount, int):
            for l, discount in zip(unique_labels, class_discount):
                discounts[prefix +
                          "class_discount"][f"Class {l}"] = discount.item()
        else:
            for l in unique_labels:
                discounts[prefix + "class_discount"][f"Class {l}"] = float(
                    class_discount)
        for l in range(self.config.data.n_classes):
            if l not in unique_labels:
                discounts[prefix + "class_discount"][f"Class {l}"] = 0
        discounts[
            "examples_seen"] = few_shot_examples_seen if few_shot_examples_seen is not None else self.examples_seen(
            )
        if "class_discount" not in self.metrics:
            self.metrics["class_discount"] = []
        self.metrics["class_discount"].append(discounts)
        if few_shot_examples_seen is not None:
            self.logger.debug("Logging class discounts")
            self.logger.debug(f"Examples seen: {discounts['examples_seen']}")
        if self.config.wandb:
            wandb.log(discounts)
Esempio n. 5
0
class Relearner(Learner):
    def __init__(self, config, **kwargs):
        """
        Baseline models: sequential and multitask setup.
        """
        super().__init__(config, **kwargs)
        self.lr = config.learner.lr
        self.n_epochs = config.training.epochs
        self.model = TransformerClsModel(model_name=config.learner.model_name,
                                         n_classes=config.data.n_classes,
                                         max_length=config.data.max_length,
                                         device=self.device)
        self.logger.info("Loaded {} as model".format(
            self.model.__class__.__name__))
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)
        # assume for now that we only look at one task at a time
        self.relearning_task = config.learner.relearning_task
        # increments each time the relearning task is observed
        self.relearning_iter = 0
        self.relative_performance_threshold_lower = self.config.learner.relative_performance_threshold_lower
        self.relative_performance_threshold_upper = self.config.learner.relative_performance_threshold_upper

        # TODO: set through config
        # Parameter deciding at which percentage of performance relearning is evaluated
        self.relearning_evaluation_alphas = (0.75, 0.8, 0.85, 0.9, 0.95)
        self.smooth_alpha = self.config.learner.smooth_alpha
        self.first_encounter = True
        self.n_samples_slope = self.config.learner.n_samples_slope
        self.saturated_patience = self.config.learner.n_samples_saturated_patience
        self.saturated_threshold = self.config.learner.saturated_threshold

        # patience counter for saturation check relearning task
        self.not_improving = 0
        # initialize metrics
        for task in (self.relearning_task, OTHER_TASKS):
            if "performance" not in self.metrics[task]:
                self.metrics[task]["performance"] = []

    def training(self, datasets, **kwargs):
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        self.relearning_task_dataset = {
            self.relearning_task: val_datasets[self.relearning_task]
        }

        self.dataloaders = {
            self.relearning_task:
            data.DataLoader(train_datasets[self.relearning_task],
                            batch_size=self.mini_batch_size,
                            shuffle=True),
            # for now, pi;e all other tasks on one stack
            OTHER_TASKS:
            data.DataLoader(data.ConcatDataset([
                dataset for task, dataset in train_datasets.items()
                if task != self.relearning_task
            ]),
                            batch_size=self.mini_batch_size,
                            shuffle=True)
        }
        self.metrics[self.relearning_task]["performance"].append([])
        # write performance of initial encounter (before training) to metrics
        self.metrics[self.relearning_task]["performance"][0].append(
            self.validate(self.relearning_task_dataset,
                          log=False,
                          n_samples=self.config.training.n_validation_samples)[
                              self.relearning_task])
        self.metrics[
            self.relearning_task]["performance"][0][0]["examples_seen"] = 0
        # first encounter relearning task
        self.train(dataloader=self.dataloaders[self.relearning_task],
                   datasets=datasets)

    def train(self, dataloader=None, datasets=None, data_length=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        if data_length is None:
            data_length = len(dataloader) * self.n_epochs

        all_losses, all_predictions, all_labels = [], [], []

        for text, labels, tasks in dataloader:
            self._examples_seen += len(text)
            self.model.train()
            # assumes all data in batch is from same task
            self.current_task = self.relearning_task if tasks[
                0] == self.relearning_task else OTHER_TASKS
            loss, predictions = self._train_batch(text, labels)
            all_losses.append(loss)
            all_predictions.extend(predictions)
            all_labels.extend(labels.tolist())

            if self.current_iter % self.log_freq == 0:
                acc, prec, rec, f1 = model_utils.calculate_metrics(
                    all_predictions, all_labels)
                time_per_iteration, estimated_time_left = self.time_metrics(
                    data_length)
                self.logger.info(
                    "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(
                        self.current_iter + 1, data_length,
                        (self.current_iter + 1) / data_length * 100,
                        time_per_iteration, estimated_time_left,
                        np.mean(all_losses), acc, prec, rec, f1))
                if self.config.wandb:
                    wandb.log({
                        "accuracy": acc,
                        "precision": prec,
                        "recall": rec,
                        "f1": f1,
                        "loss": np.mean(all_losses),
                        "examples_seen": self.examples_seen(),
                        "task": self.current_task
                    })
                all_losses, all_predictions, all_labels = [], [], []
                self.start_time = time.time()
            if self.current_iter % self.validate_freq == 0:
                # only evaluate relearning task when training on relearning task
                validation_datasets = self.relearning_task_dataset if self.current_task == self.relearning_task else val_datasets

                validate_results = self.validate(
                    validation_datasets,
                    n_samples=self.config.training.n_validation_samples,
                    log=False)
                self.write_results(validate_results)
                relearning_task_performance = validate_results[
                    self.relearning_task]["accuracy"]
                if not self.first_encounter:
                    # TODO: make this a weighted average as well
                    relearning_task_relative_performance = self.relative_performance(
                        performance=relearning_task_performance,
                        task=self.relearning_task)
                    self.logger.info((
                        f"Examples seen: {self.examples_seen()} -- Relative performance of task '{self.relearning_task}':"
                        +
                        f"{relearning_task_relative_performance}. Thresholds: {self.relative_performance_threshold_lower}"
                        f"-{self.relative_performance_threshold_upper}"))
                    if self.config.wandb:
                        wandb.log({
                            "relative_performance":
                            relearning_task_relative_performance,
                            "examples_seen": self.examples_seen()
                        })
                if self.current_task == self.relearning_task:
                    self.logger.debug(
                        f"first encounter: {self.first_encounter}")
                    # relearning stops either when either one of two things happen:
                    # the relearning task is first encountered and it is saturated (doesn't improve)
                    if ((self.first_encounter and self.learning_saturated(
                            task=self.relearning_task,
                            n_samples_slope=self.n_samples_slope,
                            patience=self.saturated_patience,
                            threshold=self.saturated_threshold,
                            smooth_alpha=self.smooth_alpha)) or
                        (
                            # the relearning task is re-encountered and relative performance reaches some threshold
                            not self.first_encounter
                            and relearning_task_relative_performance >=
                            self.relative_performance_threshold_upper)):
                        # write metrics, reset, and train the other tasks
                        self.write_relearning_metrics()
                        self.logger.info(
                            f"Task {self.current_task} saturated at iteration {self.current_iter}"
                        )
                        # each list element in performance refers to one consecutive learning event of the relearning task
                        self.metrics[
                            self.relearning_task]["performance"].append([])
                        self.not_improving = 0
                        if self.first_encounter:
                            self.logger.info(
                                f"-----------FIRST ENCOUNTER RELEARNING TASK '{self.relearning_task}' FINISHED.----------\n"
                            )
                        self.first_encounter = False
                        self.logger.info("TRAINING ON OTHER TASKS")
                        self.train(dataloader=self.dataloaders[OTHER_TASKS],
                                   datasets=datasets)
                else:
                    # calculate relative performance relearning task
                    # if it reaches some threshold, train on relearning task again
                    # TODO: make performance measure attribute of relearner
                    # TODO: use moving average for relative performance check
                    # TODO: allow different task ordering
                    # TODO: measure forgetting
                    # TODO: look at adaptive mini batch size => smaller batch size when re encountering
                    if relearning_task_relative_performance <= self.relative_performance_threshold_lower:
                        self.logger.info(
                            f"Relative performance on relearning task {self.relearning_task} below threshold. Evaluating relearning.."
                        )
                        # this needs to be done because we want a fresh list of performances when we start
                        # training the relearning task again. The first item in this list is simply the zero
                        # shot performance after the relative performance threshold is reached
                        # this means that every odd list in the relearning_task metrics is when training on the relearning task
                        relearning_task_performance = self.metrics[
                            self.relearning_task]["performance"]
                        relearning_task_performance.append([])
                        # copy the last entry of the performance while training on the other tasks to the new list
                        relearning_task_performance[-1].append(
                            relearning_task_performance[-2][-1])
                        self.train(
                            dataloader=self.dataloaders[self.relearning_task],
                            datasets=datasets)
                with open(self.results_dir / METRICS_FILE, "w") as f:
                    json.dump(self.metrics, f)

            self.time_checkpoint()
            self.current_iter += 1

    def _train_batch(self, text, labels):
        labels = torch.tensor(labels).to(self.device)
        input_dict = self.model.encode_text(text)
        output = self.model(input_dict)
        loss = self.loss_fn(output, labels)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = loss.item()
        self.logger.debug(f"Loss: {loss}")
        pred = model_utils.make_prediction(output.detach())
        return loss, pred.tolist()

    def relative_performance(self, performance, task, metric="accuracy"):
        """Calculate relative performance of a task compared to first encounter.
        
        For now assumes that task is the relearning task, since it assumes `max_score` and
        `initial_score` attributes in the metrics attribute.
        """
        max_score = self.metrics[task]["max_score"]
        initial_score = self.metrics[task]["initial_score"]
        return (performance - initial_score) / (max_score - initial_score)

    def write_results(self, validate_results):
        """Write validation results to self.metrics"""
        for task in validate_results.keys():
            if "performance" not in self.metrics[task]:
                self.metrics[task]["performance"] = []
            task_performance = self.metrics[task]["performance"]
            if len(task_performance) == 0:
                task_performance.append([])
            validate_results[task]["examples_seen"] = self.examples_seen()
            task_performance[-1].append(validate_results[task])

    def learning_saturated(self,
                           task,
                           n_samples_slope=500,
                           threshold=0.0002,
                           patience=800,
                           smooth_alpha=0.3,
                           metric="accuracy"):
        # threshold = 0.0002 # equivalent to 0.1% increase per 500 samples
        """Return true if validation performance of a task is deemed not to increase anymore.
        
        This is done by checking the slope of the performance curve over some amount of samples.

        Parameters
        ---
        task: str
            Which task to consider.
        patience: int
            Measured in terms of examples seen.

        Returns
        ---
        bool: whether performance on task is saturated.
        """
        validate_freq = self.validate_freq
        batch_size = self.mini_batch_size
        n_samples_per_validate = self.mini_batch_size * self.validate_freq
        self.logger.debug(
            "-----------------START SATURATION CHECK----------------")
        self.logger.debug(f"n_samples_per_validate: {n_samples_per_validate}")

        # + 1 because we are looking at intervals
        window_size = n_samples_slope // n_samples_per_validate + 1
        self.logger.debug(f"window_size: {window_size}")

        # extract specific performance metric from last recorded task performance
        performance = [
            performance_metrics[metric]
            for performance_metrics in self.metrics[task]["performance"][-1]
        ]
        self.logger.debug(f"performances: {performance}")
        if len(performance) >= window_size:
            # use moving average to smooth out noise
            moving_average = model_utils.ewma(performance, alpha=smooth_alpha)
            self.logger.debug(f"moving average: {moving_average}")
            # measured in percent points
            slope = 100 * (moving_average[-1] - moving_average[-window_size]) / \
                          ((window_size - 1) * n_samples_per_validate)
            self.logger.debug(
                f"Iteration {self.current_iter} -- Slope {slope} -- Threshold {threshold} -- not_improving: {self.not_improving}"
            )
            if slope < threshold:
                self.not_improving += 1
            else:
                self.not_improving = 0
        self.logger.debug(
            "-----------------END SATURATION CHECK----------------")
        if self.not_improving * n_samples_per_validate >= patience:
            return True
        return False

    def write_relearning_metrics(self, metric="accuracy"):
        """Perform calculations necessary to get relearning metrics and write to self.metrics"""
        self.logger.info(
            "-----------------START RELEARNING EVALUATION----------------")

        performance = [
            performance_metrics[metric] for performance_metrics in
            self.metrics[self.relearning_task]["performance"][-1]
        ]
        moving_average = model_utils.ewma(performance, alpha=self.smooth_alpha)
        if self.first_encounter:
            max_score = max(moving_average)
            self.metrics[self.relearning_task]["relearning"] = []
            self.metrics[self.relearning_task]["max_score"] = max_score
            self.metrics[
                self.relearning_task]["initial_score"] = performance[0]
            self.logger.info(
                f"first encounter max score: {max_score} -- initial score: {performance[0]}"
            )
        else:
            # for logging purposes
            max_score = self.metrics[self.relearning_task]["max_score"]
            initial_score = self.metrics[self.relearning_task]["initial_score"]

        k_zero_log = True  # to avoid duplicate log messsages when k_alpha == 0
        relearning_metrics = {}
        for alpha in self.relearning_evaluation_alphas:
            # alpha_max_score = alpha * max_score
            self.logger.info(f"Evaluating with alpha {alpha}")
            # first index that reaches performance higher than alpha * max score
            try:
                i_alpha_save = next(
                    i for i, v in enumerate(moving_average)
                    if self.relative_performance(
                        performance=v, task=self.relearning_task) >= alpha)
            except StopIteration:
                self.logger.info(
                    f"This run didn't reach a relative performance of at least {alpha}, skipping.."
                )
                if k_zero_log:
                    relative_performances = [
                        self.relative_performance(performance=v,
                                                  task=self.relearning_task)
                        for v in moving_average
                    ]
                    summary = list(
                        zip(performance, moving_average,
                            relative_performances))
                    self.logger.info(f"Showing run statistics")
                    self.logger.info(
                        f"First encounter max score: {max_score} -- initial score: {initial_score}"
                    )
                    self.logger.info(
                        f"(Performance, Moving average, Relative performance): {summary}"
                    )
                    k_zero_log = False
                continue
            # number of examples seen in just this encounter, can be calculated as offset using the examples seen
            k_alpha = (self.metrics[self.relearning_task]["performance"][-1]
                       [i_alpha_save]["examples_seen"] -
                       self.metrics[self.relearning_task]["performance"][-1][0]
                       ["examples_seen"])
            if k_alpha == 0:
                # special value to avoid dividing by 0
                self.logger.info("warning: k_alpha equals 0, may skew results")
                if k_zero_log:
                    relative_performances = [
                        self.relative_performance(performance=v,
                                                  task=self.relearning_task)
                        for v in moving_average
                    ]
                    summary = list(
                        zip(performance, moving_average,
                            relative_performances))
                    self.logger.info(f"Showing run statistics")
                    self.logger.info(
                        f"First encounter max score: {max_score} -- initial score: {initial_score}"
                    )
                    self.logger.info(
                        f"(Performance, Moving average, Relative performance): {summary}"
                    )
                    k_zero_log = False
                learning_speed = np.NaN
            else:
                # TODO: look at this, now the accuracy could be when relative performance is higher than alpha
                alpha_performance = self.metrics[self.relearning_task][
                    "performance"][-1][i_alpha_save]["accuracy"]
                # scale 1-100
                learning_speed = 100 * (alpha_performance - performance[0]) / \
                                    k_alpha
            relearning_metrics[f"k_alpha_{alpha}"] = k_alpha
            relearning_metrics[
                f"learning_speed_alpha_{alpha}"] = learning_speed
            self.logger.info(
                f"Reached relative performance of {alpha} after k_{alpha} = {k_alpha} examples"
            )
            self.logger.info(f"learning_speed_alpha_{alpha}: {learning_speed}")
            # we have already had the initial task encounter, now we can record relearning speed
            if not self.first_encounter:
                relearning_slope_alpha = learning_speed / self.metrics[
                    self.relearning_task]["relearning"][0][
                        f"learning_speed_alpha_{alpha}"]
                relearning_sample_alpha = (
                    self.metrics[self.relearning_task]["relearning"][0]
                    [f"k_alpha_{alpha}"] / k_alpha if k_alpha != 0 else np.nan)
                relearning_metrics[
                    f"relearning_slope_alpha_{alpha}"] = relearning_slope_alpha
                relearning_metrics[
                    f"relearning_sample_alpha_{alpha}"] = relearning_sample_alpha
                self.logger.info("Relearning metrics:")
                self.logger.info(
                    f"\trelearning_slope_alpha_{alpha}: {relearning_slope_alpha}"
                )
                self.logger.info(
                    f"\trelearning_sample_alpha_{alpha}: {relearning_sample_alpha}\n"
                )

        self.metrics[self.relearning_task]["relearning"].append(
            relearning_metrics)
        if self.config.wandb:
            wandb.log({
                "examples_seen":
                self.examples_seen(),
                f"k_{alpha}":
                k_alpha,
                f"learning_speed_alpha_{alpha}":
                learning_speed,
                f"relearning_slope_alpha_{alpha}":
                relearning_slope_alpha if not self.first_encounter else None,
                f"relearning_sample_alpha_{alpha}":
                relearning_sample_alpha if not self.first_encounter else None
            })
        self.logger.info(
            "-----------------END RELEARNING EVALUATION----------------")

    def examples_seen(self):
        return self._examples_seen

    # def testing(self, datasets, order):
    #     """
    #     Parameters
    #     ---
    #     datasets: List[Dataset]
    #         Test datasets.
    #     order: List[str]
    #         Specifies order of encountered datasets
    #     """
    #     accuracies, precisions, recalls, f1s = [], [], [], []
    #     results = {}
    #     # only have one dataset if type is single
    #     if self.type == "single":
    #         train_dataset = datasets[order.index(self.config.learner.dataset)]
    #         datasets = [train_dataset]
    #     for dataset in datasets:
    #         dataset_name = dataset.__class__.__name__
    #         self.logger.info("Testing on {}".format(dataset_name))
    #         test_dataloader = data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)
    #         dataset_results = self.evaluate(dataloader=test_dataloader)
    #         accuracies.append(dataset_results["accuracy"])
    #         precisions.append(dataset_results["precision"])
    #         recalls.append(dataset_results["recall"])
    #         f1s.append(dataset_results["f1"])
    #         results[dataset_name] = dataset_results

    #     mean_results = {
    #         "accuracy": np.mean(accuracies),
    #         "precision": np.mean(precisions),
    #         "recall": np.mean(recalls),
    #         "f1": np.mean(f1s)
    #     }
    #     self.logger.info("Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
    #                 "F1 score = {:.4f}".format(
    #                     mean_results["accuracy"], mean_results["precision"], mean_results["recall"],
    #                     mean_results["f1"]
    #                 ))
    #     return results, mean_results

    def evaluate(self, dataloader, **kwargs):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for i, (text, labels, task) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            if i % 20 == 0:
                self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        acc, prec, rec, f1 = model_utils.calculate_metrics(
            all_predictions, all_labels)
        self.logger.info(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}
Esempio n. 6
0
class AGEM(Learner):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.lr = config.learner.lr
        self.n_epochs = config.training.epochs

        self.model = TransformerClsModel(model_name=config.learner.model_name,
                                         n_classes=config.data.n_classes,
                                         max_length=config.data.max_length,
                                         device=self.device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        self.logger.info("Loaded {} as model".format(
            self.model.__class__.__name__))

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)

    def training(self, datasets, **kwargs):
        train_datasets = data.ConcatDataset(datasets["train"])
        dataloaders = {
            "train":
            data.DataLoader(train_datasets,
                            batch_size=self.mini_batch_size,
                            shuffle=False,
                            collate_fn=batch_encode),
        }
        self.train(dataloaders=dataloaders)

    def train(self, dataloaders):
        self.model.train()
        dataloader = dataloaders["train"]
        data_length = len(dataloader) * self.n_epochs

        for epoch in range(self.n_epochs):
            all_losses, all_predictions, all_labels = [], [], []

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.model.encode_text(text)
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)

                self.update_parameters(loss, mini_batch_size=len(labels))

                loss = loss.item()
                pred = model_utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                self.memory.write_batch(text, labels)

                if self.current_iter % self.log_freq == 0:
                    self.write_log(all_predictions,
                                   all_labels,
                                   all_losses,
                                   data_length=data_length)
                    self.start_time = time.time()  # time from last log
                    all_losses, all_predictions, all_labels = [], [], []
                # if self.current_iter % self.config.training.save_freq == 0:
                self.time_checkpoint()
                self.current_iter += 1
            self.current_epoch += 1

    def update_parameters(self, loss, mini_batch_size):
        """Update parameters of model"""
        self.optimizer.zero_grad()

        params = [p for p in self.model.parameters() if p.requires_grad]
        orig_grad = torch.autograd.grad(loss, params)

        replay_freq = self.replay_every // mini_batch_size
        replay_steps = int(self.replay_every * self.replay_rate /
                           mini_batch_size)

        if self.replay_rate != 0 and (self.current_iter +
                                      1) % replay_freq == 0:
            ref_grad_sum = None
            for _ in range(replay_steps):
                ref_text, ref_labels = self.memory.read_batch(
                    batch_size=mini_batch_size)
                ref_labels = torch.tensor(ref_labels).to(self.device)
                ref_input_dict = self.model.encode_text(ref_text)
                ref_output = self.model(ref_input_dict)
                ref_loss = self.loss_fn(ref_output, ref_labels)
                ref_grad = torch.autograd.grad(ref_loss, params)
                if ref_grad_sum is None:
                    ref_grad_sum = ref_grad
                else:
                    ref_grad_sum = [
                        x + y for (x, y) in zip(ref_grad, ref_grad_sum)
                    ]
            final_grad = self.compute_grad(orig_grad, ref_grad_sum)
        else:
            final_grad = orig_grad

        for param, grad in zip(params, final_grad):
            param.grad = grad.data
        self.optimizer.step()

    def write_log(self, all_predictions, all_labels, all_losses, data_length):
        acc, prec, rec, f1 = model_utils.calculate_metrics(
            all_predictions, all_labels)
        time_per_iteration, estimated_time_left = self.time_metrics(
            data_length)
        self.logger.info(
            "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(self.current_iter + 1, data_length,
                                       (self.current_iter + 1) / data_length *
                                       100,
                                       time_per_iteration, estimated_time_left,
                                       np.mean(all_losses), acc, prec, rec,
                                       f1))
        if self.config.wandb:
            n_examples_seen = (self.current_iter + 1) * self.mini_batch_size
            wandb.log({
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "loss": np.mean(all_losses),
                "examples_seen": n_examples_seen
            })

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.set_eval()

        for i, (text, labels) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            if i % 20 == 0:
                self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        acc, prec, rec, f1 = model_utils.calculate_metrics(
            all_predictions, all_labels)
        self.logger.info(
            "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
            "F1 score = {:.4f}".format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}

    def compute_grad(self, orig_grad, ref_grad):
        """Computes gradient according to the AGEM method"""
        with torch.no_grad():
            flat_orig_grad = torch.cat([torch.flatten(x) for x in orig_grad])
            flat_ref_grad = torch.cat([torch.flatten(x) for x in ref_grad])
            dot_product = torch.dot(flat_orig_grad, flat_ref_grad)
            if dot_product >= 0:
                return orig_grad
            proj_component = dot_product / torch.dot(flat_ref_grad,
                                                     flat_ref_grad)
            modified_grad = [
                o - proj_component * r for (o, r) in zip(orig_grad, ref_grad)
            ]
            return modified_grad