def evaluate(self, dataloader, update_memory=False): self.set_eval() all_losses, all_predictions, all_labels = [], [], [] self.logger.info("Starting evaluation...") for i, (text, labels, datasets) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) with torch.no_grad(): logits, key_logits = self.forward(text, labels, update_memory=update_memory) loss = self.loss_fn(logits, labels) loss = loss.item() pred = model_utils.make_prediction(logits.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) if i % 20 == 0: self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.info( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def evaluate(self, dataloader): self.set_eval() all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, _) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) with torch.no_grad(): output = self.model(input_dict) loss = self.loss_fn(output, labels) loss = loss.item() pred = model_utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) # if i % 20 == 0: # self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") metrics = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"])) return { "accuracy": metrics["accuracy"], "precision": metrics["precision"], "recall": metrics["recall"], "f1": metrics["f1"] }
def log_few_shot(self, all_predictions, all_labels, datasets, dataset_results, increment_counters, text, few_shot_batch, split="test"): """Few shot preparation code that isn't specific to any learner""" metrics_entry = split + "_evaluation" test_results = { "examples_seen": few_shot_batch * self.config.testing.few_shot_batch_size, "examples_seen_total": self.examples_seen(), "accuracy": dataset_results["accuracy"], "task": datasets[0] } if (few_shot_batch * self.config.testing.few_shot_batch_size) % self.mini_batch_size == 0 and few_shot_batch > 0: online_metrics = model_utils.calculate_metrics(all_predictions, all_labels) train_results = { "examples_seen": few_shot_batch * self.config.testing.few_shot_batch_size, "examples_seen_total": self.examples_seen(), "accuracy": online_metrics["accuracy"], "task": datasets[0] # assume whole batch is from same task } self.metrics[metrics_entry]["few_shot_training"][-1].append(train_results) if increment_counters: self.metrics["online"].append({ "accuracy": online_metrics["accuracy"], "examples_seen": self.examples_seen(), "task": datasets[0] }) if increment_counters: self._examples_seen += len(text) self.metrics[metrics_entry]["few_shot"][-1].append(test_results) self.write_metrics() if self.config.wandb: # replace with new name test_results = test_results.copy() test_results[f"few_shot_{split}_accuracy_{self.few_shot_counter}"] = test_results.pop("accuracy") wandb.log(test_results)
def log(self): """Log results during training to console and optionally other outputs Parameters --- metrics: dict mapping metric names to their values """ metrics = model_utils.calculate_metrics(self.tracker["predictions"], self.tracker["labels"]) self.logger.info( "Iteration {} - Metrics: Loss = {:.4f}, key loss: {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(self.current_iter + 1, np.mean(self.tracker["losses"]), np.mean(self.tracker["key_losses"]), metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"])) if self.config.wandb: wandb.log({ "accuracy": metrics["accuracy"], "precision": metrics["precision"], "recall": metrics["recall"], "f1": metrics["f1"], "loss": np.mean(self.tracker["losses"]), "key_loss": np.mean(self.tracker["key_losses"]), "examples_seen": self.examples_seen() }) self.reset_tracker()
def training(self, datasets, **kwargs): # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])} train_datasets = datasets_dict(datasets["train"], datasets["order"]) val_datasets = datasets_dict(datasets["val"], datasets["order"]) samples_per_task = self.config.learner.samples_per_task order = self.config.task_order if self.config.task_order is not None else datasets["order"] n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task dataset = get_continuum(train_datasets, order=order, n_samples=n_samples) dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False) for text, labels, datasets in dataloader: output = self.training_step(text, labels) predictions = model_utils.make_prediction(output["logits"].detach()) # for logging key_predictions = [ model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"] ] # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}") self.update_tracker(output, predictions, key_predictions, labels) online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist()) self.metrics["online"].append({ "accuracy": online_metrics["accuracy"], "examples_seen": self.examples_seen(), "task": datasets[0] # assumes whole batch is from same task }) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate(val_datasets, n_samples=self.config.training.n_validation_samples) self.current_iter += 1
def evaluate(self, dataloader, prediction_network=None): # if self.config.learner.evaluation_support_set: # support_set = [] # for _ in range(self.config.learner.updates): # text, labels = self.memory.read_batch(batch_size=self.mini_batch_size) # support_set.append((text, labels)) # with higher.innerloop_ctx(self.pn, self.inner_optimizer, # copy_initial_weights=False, # track_higher_grads=False) as (fpn, diffopt): # if self.config.learner.evaluation_support_set: # self.set_train() # support_prediction_network = fpn # # Inner loop # task_predictions, task_labels = [], [] # support_loss = [] # for text, labels in support_set: # labels = torch.tensor(labels).to(self.device) # # labels = labels.to(self.device) # output = self.forward(text, labels, fpn) # loss = self.loss_fn(output["logits"], labels) # diffopt.step(loss) # pred = model_utils.make_prediction(output["logits"].detach()) # support_loss.append(loss.item()) # task_predictions.extend(pred.tolist()) # task_labels.extend(labels.tolist()) # results = model_utils.calculate_metrics(task_predictions, task_labels) # self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " # "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"], # results["precision"], results["recall"], results["f1"])) # self.set_eval() # else: # support_prediction_network = self.pn # if prediction_network is None: # prediction_network = support_prediction_network self.set_eval() prototypes = self.memory.class_representations weight = 2 * prototypes bias = - (prototypes ** 2).sum(dim=1) all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, _) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) representations = self.forward(text, labels)["representation"] logits = representations @ weight.T + bias # labels = labels.to(self.device) loss = self.loss_fn(logits, labels) loss = loss.item() pred = model_utils.make_prediction(logits.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def meta_training_log(self): """Logs data during training for meta learners.""" if len(self.tracker["support_loss"]) > 0: support_loss = np.mean(self.tracker["support_loss"]) else: support_loss = np.nan query_loss = np.mean(self.tracker["query_loss"]) if len(self.tracker["support_predictions"]) > 0: support_metrics = model_utils.calculate_metrics(self.tracker["support_predictions"], self.tracker["support_labels"]) else: support_metrics = collections.defaultdict(lambda: np.nan) query_metrics = model_utils.calculate_metrics(self.tracker["query_predictions"], self.tracker["query_labels"]) self.logger.debug( f"Episode {self.current_iter + 1} Support set: Loss = {support_loss:.4f}, " f"accuracy = {support_metrics['accuracy']:.4f}, precision = {support_metrics['precision']:.4f}, " f"recall = {support_metrics['recall']:.4f}, F1 score = {support_metrics['f1']:.4f}" ) self.logger.debug( f"Episode {self.current_iter + 1} -- Examples seen: {self.examples_seen()} -- Query set: Loss = {query_loss:.4f}, " f"accuracy = {query_metrics['accuracy']:.4f}, precision = {query_metrics['precision']:.4f}, " f"recall = {query_metrics['recall']:.4f}, F1 score = {query_metrics['f1']:.4f}" ) if self.config.wandb: wandb.log({ "support_accuracy": support_metrics['accuracy'], "support_precision": support_metrics['precision'], "support_recall": support_metrics['recall'], "support_f1": support_metrics['f1'], "support_loss": support_loss, "query_accuracy": query_metrics['accuracy'], "query_precision": query_metrics['precision'], "query_recall": query_metrics['recall'], "query_f1": query_metrics['f1'], "query_loss": query_loss, "examples_seen": self.examples_seen() }) self.reset_tracker()
def log(self): """Log results during training to console and optionally other outputs Parameters --- metrics: dict mapping metric names to their values """ loss = np.mean(self.tracker["losses"]) key_losses = [np.mean(key_losses) for key_losses in self.tracker["key_losses"]] reconstruction_errors = [np.mean(reconstruction_errors) for reconstruction_errors in self.tracker["reconstruction_errors"]] metrics = model_utils.calculate_metrics(self.tracker["predictions"], self.tracker["labels"]) key_metrics = [ model_utils.calculate_metrics(key_predictions, self.tracker["labels"]) for key_predictions in self.tracker["key_predictions"] ] key_accuracy_str = [f'{km["accuracy"]:.4f}' for km in key_metrics] self.logger.info( f"Iteration {self.current_iter + 1} - Task = {self.metrics[-1]['task']} - Metrics: Loss = {loss:.4f}, " f"key loss = {[f'{key_loss:.4f}' for key_loss in key_losses]}, " f"reconstruction error = {[f'{reconstruction_error:.4f}' for reconstruction_error in reconstruction_errors]}, " f"accuracy = {metrics['accuracy']:.4f} - " f"key accuracy = {key_accuracy_str}" ) if self.config.wandb: log = { "accuracy": metrics["accuracy"], "precision": metrics["precision"], "recall": metrics["recall"], "f1": metrics["f1"], "loss": loss, "examples_seen": self.examples_seen() } for i, dim in enumerate(self.key_dim): log[f"key_accuracy_encoder_{i}_dim_{dim}"] = key_metrics[i]["accuracy"] log[f"key_loss_encoder_{i}_dim_{dim}"] = key_losses[i] log[f"reconstruction_error_encoder_{i}_dim_{dim}"] = reconstruction_errors[i] wandb.log(log) self.reset_tracker()
def train(self, dataloader=None, datasets=None, dataset_name=None, max_samples=None): val_datasets = datasets_dict(datasets["val"], datasets["order"]) replay_freq, replay_steps = self.replay_parameters(metalearner=False) episode_samples_seen = 0 # have to keep track of per-task samples seen as we might use replay as well for _ in range(self.n_epochs): for text, labels, datasets in dataloader: output = self.training_step(text, labels) task = datasets[0] predictions = model_utils.make_prediction( output["logits"].detach()) self.update_tracker(output, predictions, labels) metrics = model_utils.calculate_metrics( self.tracker["predictions"], self.tracker["labels"]) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task } self.metrics["online"].append(online_metrics) if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate( val_datasets, n_samples=self.config.training.n_validation_samples) if self.replay_rate != 0 and (self.current_iter + 1) % replay_freq == 0: self.replay_training_step(replay_steps, episode_samples_seen, max_samples) self.memory.write_batch(text, labels) self._examples_seen += len(text) episode_samples_seen += len(text) self.current_iter += 1 if max_samples is not None and episode_samples_seen >= max_samples: break
def log(self): metrics = model_utils.calculate_metrics(self.tracker["predictions"], self.tracker["labels"]) self.logger.info( "Iteration {} - Metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(self.current_iter + 1, np.mean(self.tracker["losses"]), metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"])) if self.config.wandb: wandb.log({ "accuracy": metrics["accuracy"], "precision": metrics["precision"], "recall": metrics["recall"], "f1": metrics["f1"], "loss": np.mean(self.tracker["losses"]), "examples_seen": self.examples_seen() }) self.reset_tracker()
def write_log(self, all_predictions, all_labels, all_losses, data_length): acc, prec, rec, f1 = model_utils.calculate_metrics( all_predictions, all_labels) time_per_iteration, estimated_time_left = self.time_metrics( data_length) self.logger.info( "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(self.current_iter + 1, data_length, (self.current_iter + 1) / data_length * 100, time_per_iteration, estimated_time_left, np.mean(all_losses), acc, prec, rec, f1)) if self.config.wandb: n_examples_seen = (self.current_iter + 1) * self.mini_batch_size wandb.log({ "accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "loss": np.mean(all_losses), "examples_seen": n_examples_seen })
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop for text, labels in support_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) self.memory.write_batch(text, labels) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_support_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) self.update_meta_gradients(loss, fpn) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad()
def evaluate(self, dataloader, prediction_network=None): if self.config.learner.evaluation_support_set: support_set = [] for _ in range(self.config.learner.updates): text, labels = self.memory.read_batch( batch_size=self.mini_batch_size) support_set.append((text, labels)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): if self.config.learner.evaluation_support_set: self.set_train() support_prediction_network = fpn # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, labels in support_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) pred = model_utils.make_prediction( output["logits"].detach()) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(labels.tolist()) results = model_utils.calculate_metrics( task_predictions, task_labels) self.logger.info( "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"], results["precision"], results["recall"], results["f1"])) self.set_eval() else: support_prediction_network = self.pn if prediction_network is None: prediction_network = support_prediction_network self.set_eval() all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, datasets) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, prediction_network, no_grad=True) loss = self.loss_fn(output["logits"], labels) loss = loss.item() pred = model_utils.make_prediction(output["logits"].detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) # if i % 20 == 0: # self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP -------------------") ### GET SUPPORT SET REPRESENTATIONS ### with torch.no_grad(): representations, all_labels = self.get_representations( support_set[:1]) representations_merged = torch.cat(representations) class_means, unique_labels = self.get_class_means( representations_merged, all_labels) do_memory_update = self.config.learner.prototype_update_freq > 0 and \ (self.current_iter % self.config.learner.prototype_update_freq) == 0 if do_memory_update: ### UPDATE MEMORY ### updated_memory_representations = self.memory.update( class_means, unique_labels, logger=self.logger) ### DETERMINE WHAT'S SEEN AS PROTOTYPE ### if self.config.learner.prototypes == "class_means": prototypes = class_means.detach() elif self.config.learner.prototypes == "memory": prototypes = self.memory.class_representations # doesn't track prototype gradients else: raise AssertionError( "Prototype type not in {'class_means', 'memory'}, fix config file." ) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ### self.init_prototypical_classifier(prototypes, linear_module=fpn.linear) self.logger.debug( "----------------- SUPPORT SET ----------------- ") ### TRAIN LINEAR CLASSIFIER ON SUPPORT SET ### # Inner loop for i, (text, labels) in enumerate(support_set): self.logger.debug( f"----------------- {i}th Update ----------------- ") labels = torch.tensor(labels).to(self.device) # if i == 0: # output = { # "representation": representations[0], # "logits": fpn(representations[0], out_from="linear") # } # else: output = self.forward(text, labels, fpn) # for logging purposes prototype_distances = (output["representation"].unsqueeze(1) - prototypes).norm(dim=-1) closest_dists, closest_classes = prototype_distances.topk( 3, largest=False) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 2) for z in x[2].tolist()]), list(zip(labels, closest_classes, closest_dists))))) self.logger.debug( f"True labels, closest prototypes, and distances:\n{to_print}" ) topk = output["logits"].topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits) before update:\n{to_print}" ) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) # see how much linear classifier has changed with torch.no_grad(): topk = fpn(output["representation"], out_from="linear").topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits) after update:\n{to_print}" ) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_support_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) self.logger.debug( "----------------- QUERY SET ----------------- ") ### EVALUATE ON QUERY SET AND UPDATE ENCODER ### # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, prediction_network=fpn) loss = self.loss_fn(output["logits"], labels) self.update_meta_gradients(loss, fpn) topk = output['logits'].topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist( ), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits):\n{to_print}") predictions = model_utils.make_prediction( output["logits"].detach()) self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP END -------------------")
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP -------------------") # with higher.innerloop_ctx(self.pn, self.inner_optimizer, # copy_initial_weights=False, # track_higher_grads=False) as (fpn, diffopt): do_memory_update = self.config.learner.prototype_update_freq > 0 and \ (self.current_iter % self.config.learner.prototype_update_freq) == 0 ### GET SUPPORT SET REPRESENTATIONS ### self.logger.debug("----------------- SUPPORT SET ----------------- ") representations, all_labels = self.get_representations(support_set[:1]) representations_merged = torch.cat(representations) class_means, unique_labels = model_utils.get_class_means( representations_merged, all_labels) self._examples_seen += len(representations_merged) self.logger.debug( f"Examples seen increased by {len(representations_merged)}") ### UPDATE MEMORY ### if do_memory_update: memory_update = self.memory.update(class_means, unique_labels, logger=self.logger) updated_memory_representations = memory_update[ "new_class_representations"] self.log_discounts(memory_update["class_discount"], unique_labels) ### DETERMINE WHAT'S SEEN AS PROTOTYPE ### if self.config.learner.prototypes == "class_means": prototypes = expand_class_representations( self.memory.class_representations, class_means, unique_labels) elif self.config.learner.prototypes == "memory": prototypes = updated_memory_representations else: raise AssertionError( "Prototype type not in {'class_means', 'memory'}, fix config file." ) ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ### # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear) weight = 2 * prototypes # divide by number of dimensions, otherwise blows up bias = -(prototypes**2).sum(dim=1) self.logger.debug("----------------- QUERY SET ----------------- ") ### EVALUATE ON QUERY SET AND UPDATE ENCODER ### # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) query_representations = self.forward(text, labels)["representation"] # distance query representations to prototypes (BATCH X N_PROTOTYPES) # distances = euclidean_dist(query_representations, prototypes) # logits = - distances logits = query_representations @ weight.T + bias loss = self.loss_fn(logits, labels) # log_probability = F.log_softmax(-distances, dim=1) # loss is negation of the log probability, index using the labels for each observation # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean() self.meta_optimizer.zero_grad() loss.backward() self.meta_optimizer.step() predictions = model_utils.make_prediction(logits.detach()) # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions]) # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(), # [round(z, 3) for z in x[2].tolist()]), # list(zip(labels, predictions, distances))))) self.logger.debug( f"Unique Labels: {unique_labels.tolist()}\n" # f"Labels, Indices, Predictions, Distances:\n{to_print}\n" f"Loss:\n{loss.item()}\n" f"Predictions:\n{predictions}\n") self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) self.logger.debug(f"Examples seen increased by {len(text)}") # Meta optimizer step # self.meta_optimizer.step() # self.meta_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP END -------------------")
def train(self, dataloader=None, datasets=None, data_length=None): val_datasets = datasets_dict(datasets["val"], datasets["order"]) if data_length is None: data_length = len(dataloader) * self.n_epochs all_losses, all_predictions, all_labels = [], [], [] for text, labels, tasks in dataloader: self._examples_seen += len(text) self.model.train() # assumes all data in batch is from same task self.current_task = self.relearning_task if tasks[ 0] == self.relearning_task else OTHER_TASKS loss, predictions = self._train_batch(text, labels) all_losses.append(loss) all_predictions.extend(predictions) all_labels.extend(labels.tolist()) if self.current_iter % self.log_freq == 0: acc, prec, rec, f1 = model_utils.calculate_metrics( all_predictions, all_labels) time_per_iteration, estimated_time_left = self.time_metrics( data_length) self.logger.info( "Iteration {}/{} ({:.2f}%) -- {:.3f} (sec/it) -- Time Left: {}\nMetrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format( self.current_iter + 1, data_length, (self.current_iter + 1) / data_length * 100, time_per_iteration, estimated_time_left, np.mean(all_losses), acc, prec, rec, f1)) if self.config.wandb: wandb.log({ "accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "loss": np.mean(all_losses), "examples_seen": self.examples_seen(), "task": self.current_task }) all_losses, all_predictions, all_labels = [], [], [] self.start_time = time.time() if self.current_iter % self.validate_freq == 0: # only evaluate relearning task when training on relearning task validation_datasets = self.relearning_task_dataset if self.current_task == self.relearning_task else val_datasets validate_results = self.validate( validation_datasets, n_samples=self.config.training.n_validation_samples, log=False) self.write_results(validate_results) relearning_task_performance = validate_results[ self.relearning_task]["accuracy"] if not self.first_encounter: # TODO: make this a weighted average as well relearning_task_relative_performance = self.relative_performance( performance=relearning_task_performance, task=self.relearning_task) self.logger.info(( f"Examples seen: {self.examples_seen()} -- Relative performance of task '{self.relearning_task}':" + f"{relearning_task_relative_performance}. Thresholds: {self.relative_performance_threshold_lower}" f"-{self.relative_performance_threshold_upper}")) if self.config.wandb: wandb.log({ "relative_performance": relearning_task_relative_performance, "examples_seen": self.examples_seen() }) if self.current_task == self.relearning_task: self.logger.debug( f"first encounter: {self.first_encounter}") # relearning stops either when either one of two things happen: # the relearning task is first encountered and it is saturated (doesn't improve) if ((self.first_encounter and self.learning_saturated( task=self.relearning_task, n_samples_slope=self.n_samples_slope, patience=self.saturated_patience, threshold=self.saturated_threshold, smooth_alpha=self.smooth_alpha)) or ( # the relearning task is re-encountered and relative performance reaches some threshold not self.first_encounter and relearning_task_relative_performance >= self.relative_performance_threshold_upper)): # write metrics, reset, and train the other tasks self.write_relearning_metrics() self.logger.info( f"Task {self.current_task} saturated at iteration {self.current_iter}" ) # each list element in performance refers to one consecutive learning event of the relearning task self.metrics[ self.relearning_task]["performance"].append([]) self.not_improving = 0 if self.first_encounter: self.logger.info( f"-----------FIRST ENCOUNTER RELEARNING TASK '{self.relearning_task}' FINISHED.----------\n" ) self.first_encounter = False self.logger.info("TRAINING ON OTHER TASKS") self.train(dataloader=self.dataloaders[OTHER_TASKS], datasets=datasets) else: # calculate relative performance relearning task # if it reaches some threshold, train on relearning task again # TODO: make performance measure attribute of relearner # TODO: use moving average for relative performance check # TODO: allow different task ordering # TODO: measure forgetting # TODO: look at adaptive mini batch size => smaller batch size when re encountering if relearning_task_relative_performance <= self.relative_performance_threshold_lower: self.logger.info( f"Relative performance on relearning task {self.relearning_task} below threshold. Evaluating relearning.." ) # this needs to be done because we want a fresh list of performances when we start # training the relearning task again. The first item in this list is simply the zero # shot performance after the relative performance threshold is reached # this means that every odd list in the relearning_task metrics is when training on the relearning task relearning_task_performance = self.metrics[ self.relearning_task]["performance"] relearning_task_performance.append([]) # copy the last entry of the performance while training on the other tasks to the new list relearning_task_performance[-1].append( relearning_task_performance[-2][-1]) self.train( dataloader=self.dataloaders[self.relearning_task], datasets=datasets) with open(self.results_dir / METRICS_FILE, "w") as f: json.dump(self.metrics, f) self.time_checkpoint() self.current_iter += 1