def training(self, datasets, **kwargs): train_datasets = datasets_dict(datasets["train"], datasets["order"]) val_datasets = datasets_dict(datasets["val"], datasets["order"]) self.relearning_task_dataset = { self.relearning_task: val_datasets[self.relearning_task] } self.dataloaders = { self.relearning_task: data.DataLoader(train_datasets[self.relearning_task], batch_size=self.mini_batch_size, shuffle=True), # for now, pi;e all other tasks on one stack OTHER_TASKS: data.DataLoader(data.ConcatDataset([ dataset for task, dataset in train_datasets.items() if task != self.relearning_task ]), batch_size=self.mini_batch_size, shuffle=True) } self.metrics[self.relearning_task]["performance"].append([]) # write performance of initial encounter (before training) to metrics self.metrics[self.relearning_task]["performance"][0].append( self.validate(self.relearning_task_dataset, log=False, n_samples=self.config.training.n_validation_samples)[ self.relearning_task]) self.metrics[ self.relearning_task]["performance"][0][0]["examples_seen"] = 0 # first encounter relearning task self.train(dataloader=self.dataloaders[self.relearning_task], datasets=datasets)
def prepare_data(self, datasets): """Deal with making data ready for consumption. Parameters --- datasets: Dict[str, List of dataset names] Returns: tuple: """ # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])} train_datasets = datasets_dict(datasets["train"], datasets["order"]) val_datasets = datasets_dict(datasets["test"], datasets["order"]) eval_dataset = val_datasets[self.config.testing.eval_dataset] # split into training and testing point, assumes there is no meaningful difference in dataset order eval_train_dataset = eval_dataset.new(0, self.config.testing.n_samples) eval_eval_dataset = eval_dataset.new(self.config.testing.n_samples, -1) # sample a subset so validation doesn't take too long eval_eval_dataset = eval_eval_dataset.sample(min(self.config.testing.few_shot_validation_size, len(eval_dataset))) if self.config.data.alternating_order: order, n_samples = alternating_order(train_datasets, tasks=self.config.data.alternating_tasks, n_samples_per_switch=self.config.data.alternating_n_samples_per_switch, relative_frequencies=self.config.data.alternating_relative_frequencies) else: n_samples, order = n_samples_order(self.config.learner.samples_per_task, self.config.task_order, datasets["order"]) datas = get_continuum(train_datasets, order=order, n_samples=n_samples, eval_dataset=self.config.testing.eval_dataset, merge=False) # for logging extra things self.extra_dataloader = iter(DataLoader(ConcatDataset(train_datasets.values()), batch_size=self.mini_batch_size, shuffle=True)) return datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset
def testing(self, datasets, order): """ Evaluate the learner after training. Parameters --- datasets: Dict[str, List[Dataset]] Test datasets. order: List[str] Specifies order of encountered datasets """ self.logger.info("Testing..") train_datasets = datasets_dict(datasets["train"], order) for split in ("test", "val"): self.logger.info(f"Validating on split {split}") eval_datasets = datasets[split] eval_datasets = datasets_dict(eval_datasets, order) self.set_eval() if self.config.testing.average_accuracy: self.logger.info("Getting average accuracy") self.average_accuracy(eval_datasets, split=split, train_datasets=train_datasets) if self.config.testing.few_shot: # split into training and testing point, assumes there is no meaningful difference in dataset order dataset = eval_datasets[self.config.testing.eval_dataset] train_dataset = dataset.new(0, self.config.testing.n_samples) eval_dataset = dataset.new(self.config.testing.n_samples, -1) # sample a subset so validation doesn't take too long eval_dataset = eval_dataset.sample(min(self.config.testing.few_shot_validation_size, len(eval_dataset))) self.logger.info(f"Few shot eval dataset size: {len(eval_dataset)}") self.few_shot_testing(train_dataset=train_dataset, eval_dataset=eval_dataset, split=split)
def training(self, datasets, **kwargs): # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])} train_datasets = datasets_dict(datasets["train"], datasets["order"]) val_datasets = datasets_dict(datasets["val"], datasets["order"]) samples_per_task = self.config.learner.samples_per_task order = self.config.task_order if self.config.task_order is not None else datasets["order"] n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task dataset = get_continuum(train_datasets, order=order, n_samples=n_samples) dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False) for text, labels, datasets in dataloader: output = self.training_step(text, labels) predictions = model_utils.make_prediction(output["logits"].detach()) # for logging key_predictions = [ model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"] ] # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}") self.update_tracker(output, predictions, key_predictions, labels) online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist()) self.metrics["online"].append({ "accuracy": online_metrics["accuracy"], "examples_seen": self.examples_seen(), "task": datasets[0] # assumes whole batch is from same task }) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate(val_datasets, n_samples=self.config.training.n_validation_samples) self.current_iter += 1
def train(self, dataloader=None, datasets=None, dataset_name=None, max_samples=None): val_datasets = datasets_dict(datasets["val"], datasets["order"]) replay_freq, replay_steps = self.replay_parameters(metalearner=False) episode_samples_seen = 0 # have to keep track of per-task samples seen as we might use replay as well for _ in range(self.n_epochs): for text, labels, datasets in dataloader: output = self.training_step(text, labels) task = datasets[0] predictions = model_utils.make_prediction( output["logits"].detach()) self.update_tracker(output, predictions, labels) metrics = model_utils.calculate_metrics( self.tracker["predictions"], self.tracker["labels"]) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task } self.metrics["online"].append(online_metrics) if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate( val_datasets, n_samples=self.config.training.n_validation_samples) if self.replay_rate != 0 and (self.current_iter + 1) % replay_freq == 0: self.replay_training_step(replay_steps, episode_samples_seen, max_samples) self.memory.write_batch(text, labels) self._examples_seen += len(text) episode_samples_seen += len(text) self.current_iter += 1 if max_samples is not None and episode_samples_seen >= max_samples: break
def train(self, dataloader=None, datasets=None, data_length=None): val_datasets = datasets_dict(datasets["val"], datasets["order"]) if data_length is None: data_length = len(dataloader) * self.n_epochs all_losses, all_predictions, all_labels = [], [], [] for text, labels, tasks in dataloader: self._examples_seen += len(text) self.model.train() # assumes all data in batch is from same task self.current_task = self.relearning_task if tasks[ 0] == self.relearning_task else OTHER_TASKS loss, predictions = self._train_batch(text, labels) all_losses.append(loss) all_predictions.extend(predictions) all_labels.extend(labels.tolist()) if self.current_iter % self.log_freq == 0: acc, prec, rec, f1 = model_utils.calculate_metrics( all_predictions, all_labels) time_per_iteration, estimated_time_left = self.time_metrics( data_length) self.logger.info( "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format( self.current_iter + 1, data_length, (self.current_iter + 1) / data_length * 100, time_per_iteration, estimated_time_left, np.mean(all_losses), acc, prec, rec, f1)) if self.config.wandb: wandb.log({ "accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "loss": np.mean(all_losses), "examples_seen": self.examples_seen(), "task": self.current_task }) all_losses, all_predictions, all_labels = [], [], [] self.start_time = time.time() if self.current_iter % self.validate_freq == 0: # only evaluate relearning task when training on relearning task validation_datasets = self.relearning_task_dataset if self.current_task == self.relearning_task else val_datasets validate_results = self.validate( validation_datasets, n_samples=self.config.training.n_validation_samples, log=False) self.write_results(validate_results) relearning_task_performance = validate_results[ self.relearning_task]["accuracy"] if not self.first_encounter: # TODO: make this a weighted average as well relearning_task_relative_performance = self.relative_performance( performance=relearning_task_performance, task=self.relearning_task) self.logger.info(( f"Examples seen: {self.examples_seen()} -- Relative performance of task '{self.relearning_task}':" + f"{relearning_task_relative_performance}. Thresholds: {self.relative_performance_threshold_lower}" f"-{self.relative_performance_threshold_upper}")) if self.config.wandb: wandb.log({ "relative_performance": relearning_task_relative_performance, "examples_seen": self.examples_seen() }) if self.current_task == self.relearning_task: self.logger.debug( f"first encounter: {self.first_encounter}") # relearning stops either when either one of two things happen: # the relearning task is first encountered and it is saturated (doesn't improve) if ((self.first_encounter and self.learning_saturated( task=self.relearning_task, n_samples_slope=self.n_samples_slope, patience=self.saturated_patience, threshold=self.saturated_threshold, smooth_alpha=self.smooth_alpha)) or ( # the relearning task is re-encountered and relative performance reaches some threshold not self.first_encounter and relearning_task_relative_performance >= self.relative_performance_threshold_upper)): # write metrics, reset, and train the other tasks self.write_relearning_metrics() self.logger.info( f"Task {self.current_task} saturated at iteration {self.current_iter}" ) # each list element in performance refers to one consecutive learning event of the relearning task self.metrics[ self.relearning_task]["performance"].append([]) self.not_improving = 0 if self.first_encounter: self.logger.info( f"-----------FIRST ENCOUNTER RELEARNING TASK '{self.relearning_task}' FINISHED.----------\n" ) self.first_encounter = False self.logger.info("TRAINING ON OTHER TASKS") self.train(dataloader=self.dataloaders[OTHER_TASKS], datasets=datasets) else: # calculate relative performance relearning task # if it reaches some threshold, train on relearning task again # TODO: make performance measure attribute of relearner # TODO: use moving average for relative performance check # TODO: allow different task ordering # TODO: measure forgetting # TODO: look at adaptive mini batch size => smaller batch size when re encountering if relearning_task_relative_performance <= self.relative_performance_threshold_lower: self.logger.info( f"Relative performance on relearning task {self.relearning_task} below threshold. Evaluating relearning.." ) # this needs to be done because we want a fresh list of performances when we start # training the relearning task again. The first item in this list is simply the zero # shot performance after the relative performance threshold is reached # this means that every odd list in the relearning_task metrics is when training on the relearning task relearning_task_performance = self.metrics[ self.relearning_task]["performance"] relearning_task_performance.append([]) # copy the last entry of the performance while training on the other tasks to the new list relearning_task_performance[-1].append( relearning_task_performance[-2][-1]) self.train( dataloader=self.dataloaders[self.relearning_task], datasets=datasets) with open(self.results_dir / METRICS_FILE, "w") as f: json.dump(self.metrics, f) self.time_checkpoint() self.current_iter += 1