예제 #1
0
    def training(self, datasets, **kwargs):
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        self.relearning_task_dataset = {
            self.relearning_task: val_datasets[self.relearning_task]
        }

        self.dataloaders = {
            self.relearning_task:
            data.DataLoader(train_datasets[self.relearning_task],
                            batch_size=self.mini_batch_size,
                            shuffle=True),
            # for now, pi;e all other tasks on one stack
            OTHER_TASKS:
            data.DataLoader(data.ConcatDataset([
                dataset for task, dataset in train_datasets.items()
                if task != self.relearning_task
            ]),
                            batch_size=self.mini_batch_size,
                            shuffle=True)
        }
        self.metrics[self.relearning_task]["performance"].append([])
        # write performance of initial encounter (before training) to metrics
        self.metrics[self.relearning_task]["performance"][0].append(
            self.validate(self.relearning_task_dataset,
                          log=False,
                          n_samples=self.config.training.n_validation_samples)[
                              self.relearning_task])
        self.metrics[
            self.relearning_task]["performance"][0][0]["examples_seen"] = 0
        # first encounter relearning task
        self.train(dataloader=self.dataloaders[self.relearning_task],
                   datasets=datasets)
예제 #2
0
    def prepare_data(self, datasets):
        """Deal with making data ready for consumption.
        
        Parameters
        ---
        datasets: Dict[str, List of dataset names]

        Returns:
            tuple:

        """
        # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])}
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["test"], datasets["order"])
        eval_dataset = val_datasets[self.config.testing.eval_dataset]
        # split into training and testing point, assumes there is no meaningful difference in dataset order
        eval_train_dataset = eval_dataset.new(0, self.config.testing.n_samples)
        eval_eval_dataset = eval_dataset.new(self.config.testing.n_samples, -1)
        # sample a subset so validation doesn't take too long
        eval_eval_dataset = eval_eval_dataset.sample(min(self.config.testing.few_shot_validation_size, len(eval_dataset)))

        if self.config.data.alternating_order:
            order, n_samples = alternating_order(train_datasets, tasks=self.config.data.alternating_tasks,
                                                 n_samples_per_switch=self.config.data.alternating_n_samples_per_switch,
                                                 relative_frequencies=self.config.data.alternating_relative_frequencies)
        else:
            n_samples, order = n_samples_order(self.config.learner.samples_per_task, self.config.task_order, datasets["order"])
        datas = get_continuum(train_datasets, order=order, n_samples=n_samples,
                             eval_dataset=self.config.testing.eval_dataset, merge=False)
        # for logging extra things
        self.extra_dataloader = iter(DataLoader(ConcatDataset(train_datasets.values()), batch_size=self.mini_batch_size, shuffle=True))
        return datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset
예제 #3
0
    def testing(self, datasets, order):
        """
        Evaluate the learner after training.

        Parameters
        ---
        datasets: Dict[str, List[Dataset]]
            Test datasets.
        order: List[str]
            Specifies order of encountered datasets
        """
        self.logger.info("Testing..")
        train_datasets = datasets_dict(datasets["train"], order)
        for split in ("test", "val"):
            self.logger.info(f"Validating on split {split}")
            eval_datasets = datasets[split]
            eval_datasets = datasets_dict(eval_datasets, order)
            self.set_eval()

            if self.config.testing.average_accuracy:
                self.logger.info("Getting average accuracy")
                self.average_accuracy(eval_datasets, split=split, train_datasets=train_datasets)

            if self.config.testing.few_shot:
                # split into training and testing point, assumes there is no meaningful difference in dataset order
                dataset = eval_datasets[self.config.testing.eval_dataset]
                train_dataset = dataset.new(0, self.config.testing.n_samples)
                eval_dataset = dataset.new(self.config.testing.n_samples, -1)
                # sample a subset so validation doesn't take too long
                eval_dataset = eval_dataset.sample(min(self.config.testing.few_shot_validation_size, len(eval_dataset)))
                self.logger.info(f"Few shot eval dataset size: {len(eval_dataset)}")
                self.few_shot_testing(train_dataset=train_dataset, eval_dataset=eval_dataset, split=split)
예제 #4
0
    def training(self, datasets, **kwargs):
        # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])}
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        samples_per_task = self.config.learner.samples_per_task
        order = self.config.task_order if self.config.task_order is not None else datasets["order"]
        n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task
        dataset = get_continuum(train_datasets, order=order, n_samples=n_samples)
        dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)

        for text, labels, datasets in dataloader:
            output = self.training_step(text, labels)
            predictions = model_utils.make_prediction(output["logits"].detach())
            # for logging
            key_predictions = [
                model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"]
            ]
            # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}")

            self.update_tracker(output, predictions, key_predictions, labels)
            online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist())
            self.metrics["online"].append({
                "accuracy": online_metrics["accuracy"],
                "examples_seen": self.examples_seen(),
                "task": datasets[0]  # assumes whole batch is from same task
            })
            if self.current_iter % self.log_freq == 0:
                self.log()
                self.write_metrics()
            if self.current_iter % self.validate_freq == 0:
                self.validate(val_datasets, n_samples=self.config.training.n_validation_samples)
            self.current_iter += 1
예제 #5
0
    def train(self,
              dataloader=None,
              datasets=None,
              dataset_name=None,
              max_samples=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])
        replay_freq, replay_steps = self.replay_parameters(metalearner=False)

        episode_samples_seen = 0  # have to keep track of per-task samples seen as we might use replay as well
        for _ in range(self.n_epochs):
            for text, labels, datasets in dataloader:
                output = self.training_step(text, labels)
                task = datasets[0]

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                self.update_tracker(output, predictions, labels)

                metrics = model_utils.calculate_metrics(
                    self.tracker["predictions"], self.tracker["labels"])
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task
                }
                self.metrics["online"].append(online_metrics)
                if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(
                        online_metrics)
                if self.current_iter % self.log_freq == 0:
                    self.log()
                    self.write_metrics()
                if self.current_iter % self.validate_freq == 0:
                    self.validate(
                        val_datasets,
                        n_samples=self.config.training.n_validation_samples)
                if self.replay_rate != 0 and (self.current_iter +
                                              1) % replay_freq == 0:
                    self.replay_training_step(replay_steps,
                                              episode_samples_seen,
                                              max_samples)
                self.memory.write_batch(text, labels)
                self._examples_seen += len(text)
                episode_samples_seen += len(text)
                self.current_iter += 1
                if max_samples is not None and episode_samples_seen >= max_samples:
                    break
예제 #6
0
    def train(self, dataloader=None, datasets=None, data_length=None):
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        if data_length is None:
            data_length = len(dataloader) * self.n_epochs

        all_losses, all_predictions, all_labels = [], [], []

        for text, labels, tasks in dataloader:
            self._examples_seen += len(text)
            self.model.train()
            # assumes all data in batch is from same task
            self.current_task = self.relearning_task if tasks[
                0] == self.relearning_task else OTHER_TASKS
            loss, predictions = self._train_batch(text, labels)
            all_losses.append(loss)
            all_predictions.extend(predictions)
            all_labels.extend(labels.tolist())

            if self.current_iter % self.log_freq == 0:
                acc, prec, rec, f1 = model_utils.calculate_metrics(
                    all_predictions, all_labels)
                time_per_iteration, estimated_time_left = self.time_metrics(
                    data_length)
                self.logger.info(
                    "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(
                        self.current_iter + 1, data_length,
                        (self.current_iter + 1) / data_length * 100,
                        time_per_iteration, estimated_time_left,
                        np.mean(all_losses), acc, prec, rec, f1))
                if self.config.wandb:
                    wandb.log({
                        "accuracy": acc,
                        "precision": prec,
                        "recall": rec,
                        "f1": f1,
                        "loss": np.mean(all_losses),
                        "examples_seen": self.examples_seen(),
                        "task": self.current_task
                    })
                all_losses, all_predictions, all_labels = [], [], []
                self.start_time = time.time()
            if self.current_iter % self.validate_freq == 0:
                # only evaluate relearning task when training on relearning task
                validation_datasets = self.relearning_task_dataset if self.current_task == self.relearning_task else val_datasets

                validate_results = self.validate(
                    validation_datasets,
                    n_samples=self.config.training.n_validation_samples,
                    log=False)
                self.write_results(validate_results)
                relearning_task_performance = validate_results[
                    self.relearning_task]["accuracy"]
                if not self.first_encounter:
                    # TODO: make this a weighted average as well
                    relearning_task_relative_performance = self.relative_performance(
                        performance=relearning_task_performance,
                        task=self.relearning_task)
                    self.logger.info((
                        f"Examples seen: {self.examples_seen()} -- Relative performance of task '{self.relearning_task}':"
                        +
                        f"{relearning_task_relative_performance}. Thresholds: {self.relative_performance_threshold_lower}"
                        f"-{self.relative_performance_threshold_upper}"))
                    if self.config.wandb:
                        wandb.log({
                            "relative_performance":
                            relearning_task_relative_performance,
                            "examples_seen": self.examples_seen()
                        })
                if self.current_task == self.relearning_task:
                    self.logger.debug(
                        f"first encounter: {self.first_encounter}")
                    # relearning stops either when either one of two things happen:
                    # the relearning task is first encountered and it is saturated (doesn't improve)
                    if ((self.first_encounter and self.learning_saturated(
                            task=self.relearning_task,
                            n_samples_slope=self.n_samples_slope,
                            patience=self.saturated_patience,
                            threshold=self.saturated_threshold,
                            smooth_alpha=self.smooth_alpha)) or
                        (
                            # the relearning task is re-encountered and relative performance reaches some threshold
                            not self.first_encounter
                            and relearning_task_relative_performance >=
                            self.relative_performance_threshold_upper)):
                        # write metrics, reset, and train the other tasks
                        self.write_relearning_metrics()
                        self.logger.info(
                            f"Task {self.current_task} saturated at iteration {self.current_iter}"
                        )
                        # each list element in performance refers to one consecutive learning event of the relearning task
                        self.metrics[
                            self.relearning_task]["performance"].append([])
                        self.not_improving = 0
                        if self.first_encounter:
                            self.logger.info(
                                f"-----------FIRST ENCOUNTER RELEARNING TASK '{self.relearning_task}' FINISHED.----------\n"
                            )
                        self.first_encounter = False
                        self.logger.info("TRAINING ON OTHER TASKS")
                        self.train(dataloader=self.dataloaders[OTHER_TASKS],
                                   datasets=datasets)
                else:
                    # calculate relative performance relearning task
                    # if it reaches some threshold, train on relearning task again
                    # TODO: make performance measure attribute of relearner
                    # TODO: use moving average for relative performance check
                    # TODO: allow different task ordering
                    # TODO: measure forgetting
                    # TODO: look at adaptive mini batch size => smaller batch size when re encountering
                    if relearning_task_relative_performance <= self.relative_performance_threshold_lower:
                        self.logger.info(
                            f"Relative performance on relearning task {self.relearning_task} below threshold. Evaluating relearning.."
                        )
                        # this needs to be done because we want a fresh list of performances when we start
                        # training the relearning task again. The first item in this list is simply the zero
                        # shot performance after the relative performance threshold is reached
                        # this means that every odd list in the relearning_task metrics is when training on the relearning task
                        relearning_task_performance = self.metrics[
                            self.relearning_task]["performance"]
                        relearning_task_performance.append([])
                        # copy the last entry of the performance while training on the other tasks to the new list
                        relearning_task_performance[-1].append(
                            relearning_task_performance[-2][-1])
                        self.train(
                            dataloader=self.dataloaders[self.relearning_task],
                            datasets=datasets)
                with open(self.results_dir / METRICS_FILE, "w") as f:
                    json.dump(self.metrics, f)

            self.time_checkpoint()
            self.current_iter += 1