def training(self, datasets, **kwargs): # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])} train_datasets = datasets_dict(datasets["train"], datasets["order"]) val_datasets = datasets_dict(datasets["val"], datasets["order"]) samples_per_task = self.config.learner.samples_per_task order = self.config.task_order if self.config.task_order is not None else datasets["order"] n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task dataset = get_continuum(train_datasets, order=order, n_samples=n_samples) dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False) for text, labels, datasets in dataloader: output = self.training_step(text, labels) predictions = model_utils.make_prediction(output["logits"].detach()) # for logging key_predictions = [ model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"] ] # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}") self.update_tracker(output, predictions, key_predictions, labels) online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist()) self.metrics["online"].append({ "accuracy": online_metrics["accuracy"], "examples_seen": self.examples_seen(), "task": datasets[0] # assumes whole batch is from same task }) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate(val_datasets, n_samples=self.config.training.n_validation_samples) self.current_iter += 1
def evaluate(self, dataloader): self.set_eval() all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, _) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) with torch.no_grad(): output = self.model(input_dict) loss = self.loss_fn(output, labels) loss = loss.item() pred = model_utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) # if i % 20 == 0: # self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") metrics = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"])) return { "accuracy": metrics["accuracy"], "precision": metrics["precision"], "recall": metrics["recall"], "f1": metrics["f1"] }
def train(self, dataloaders): self.model.train() dataloader = dataloaders["train"] data_length = len(dataloader) * self.n_epochs for epoch in range(self.n_epochs): all_losses, all_predictions, all_labels = [], [], [] for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) output = self.model(input_dict) loss = self.loss_fn(output, labels) self.update_parameters(loss, mini_batch_size=len(labels)) loss = loss.item() pred = model_utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) self.memory.write_batch(text, labels) if self.current_iter % self.log_freq == 0: self.write_log(all_predictions, all_labels, all_losses, data_length=data_length) self.start_time = time.time() # time from last log all_losses, all_predictions, all_labels = [], [], [] # if self.current_iter % self.config.training.save_freq == 0: self.time_checkpoint() self.current_iter += 1 self.current_epoch += 1
def evaluate(self, dataloader, update_memory=False): self.set_eval() all_losses, all_predictions, all_labels = [], [], [] self.logger.info("Starting evaluation...") for i, (text, labels, datasets) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) with torch.no_grad(): logits, key_logits = self.forward(text, labels, update_memory=update_memory) loss = self.loss_fn(logits, labels) loss = loss.item() pred = model_utils.make_prediction(logits.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) if i % 20 == 0: self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.info( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def evaluate(self, dataloader, prediction_network=None): # if self.config.learner.evaluation_support_set: # support_set = [] # for _ in range(self.config.learner.updates): # text, labels = self.memory.read_batch(batch_size=self.mini_batch_size) # support_set.append((text, labels)) # with higher.innerloop_ctx(self.pn, self.inner_optimizer, # copy_initial_weights=False, # track_higher_grads=False) as (fpn, diffopt): # if self.config.learner.evaluation_support_set: # self.set_train() # support_prediction_network = fpn # # Inner loop # task_predictions, task_labels = [], [] # support_loss = [] # for text, labels in support_set: # labels = torch.tensor(labels).to(self.device) # # labels = labels.to(self.device) # output = self.forward(text, labels, fpn) # loss = self.loss_fn(output["logits"], labels) # diffopt.step(loss) # pred = model_utils.make_prediction(output["logits"].detach()) # support_loss.append(loss.item()) # task_predictions.extend(pred.tolist()) # task_labels.extend(labels.tolist()) # results = model_utils.calculate_metrics(task_predictions, task_labels) # self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " # "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"], # results["precision"], results["recall"], results["f1"])) # self.set_eval() # else: # support_prediction_network = self.pn # if prediction_network is None: # prediction_network = support_prediction_network self.set_eval() prototypes = self.memory.class_representations weight = 2 * prototypes bias = - (prototypes ** 2).sum(dim=1) all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, _) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) representations = self.forward(text, labels)["representation"] logits = representations @ weight.T + bias # labels = labels.to(self.device) loss = self.loss_fn(logits, labels) loss = loss.item() pred = model_utils.make_prediction(logits.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"): """ Allow the model to train on a small amount of datapoints at a time. After every training step, evaluate on many samples that haven't been seen yet. Results are saved in learner's `metrics` attribute. Parameters --- train_dataset: Dataset Contains examples on which the model is trained before being evaluated eval_dataset: Dataset Contains examples on which the model is evaluated increment_counters: bool If True, update online metrics and current iteration counters. """ self.logger.info( f"few shot testing on dataset {self.config.testing.eval_dataset} " f"with {len(train_dataset)} samples") train_dataloader, eval_dataloader = self.few_shot_preparation( train_dataset, eval_dataset, split=split) all_predictions, all_labels = [], [] with higher.innerloop_ctx(self.pln, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpln, diffopt): self.pln.train() self.rln.eval() # Inner loop for i, (text, labels, datasets) in enumerate(train_dataloader): labels = torch.tensor(labels).to(self.device) output = self.forward(text, labels, fpln) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) predictions = model_utils.make_prediction( output["logits"].detach()) all_predictions.extend(predictions.tolist()) all_labels.extend(labels.tolist()) dataset_results = self.evaluate(dataloader=eval_dataloader, prediction_network=fpln) self.log_few_shot(all_predictions, all_labels, datasets, dataset_results, increment_counters, text, i, split=split) if (i * self.config.testing.few_shot_batch_size ) % self.mini_batch_size == 0 and i > 0: all_predictions, all_labels = [], [] self.few_shot_end()
def _train_batch(self, text, labels): labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) output = self.model(input_dict) loss = self.loss_fn(output, labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss = loss.item() self.logger.debug(f"Loss: {loss}") pred = model_utils.make_prediction(output.detach()) return loss, pred.tolist()
def train(self, dataloader=None, datasets=None, dataset_name=None, max_samples=None): val_datasets = datasets_dict(datasets["val"], datasets["order"]) replay_freq, replay_steps = self.replay_parameters(metalearner=False) episode_samples_seen = 0 # have to keep track of per-task samples seen as we might use replay as well for _ in range(self.n_epochs): for text, labels, datasets in dataloader: output = self.training_step(text, labels) task = datasets[0] predictions = model_utils.make_prediction( output["logits"].detach()) self.update_tracker(output, predictions, labels) metrics = model_utils.calculate_metrics( self.tracker["predictions"], self.tracker["labels"]) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task } self.metrics["online"].append(online_metrics) if dataset_name is not None and dataset_name == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) if self.current_iter % self.log_freq == 0: self.log() self.write_metrics() if self.current_iter % self.validate_freq == 0: self.validate( val_datasets, n_samples=self.config.training.n_validation_samples) if self.replay_rate != 0 and (self.current_iter + 1) % replay_freq == 0: self.replay_training_step(replay_steps, episode_samples_seen, max_samples) self.memory.write_batch(text, labels) self._examples_seen += len(text) episode_samples_seen += len(text) self.current_iter += 1 if max_samples is not None and episode_samples_seen >= max_samples: break
def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"): """ Allow the model to train on a small amount of datapoints at a time. After every training step, evaluate on many samples that haven't been seen yet. Results are saved in learner's `metrics` attribute. Parameters --- train_dataset: Dataset Contains examples on which the model is trained before being evaluated eval_dataset: Dataset Contains examples on which the model is evaluated increment_counters: bool If True, update online metrics and current iteration counters. split: str, one of {"val", "test"}. Which data split is used. For logging purposes. """ self.logger.info(f"few shot testing on dataset {self.config.testing.eval_dataset} " f"with {len(train_dataset)} samples") # whenever we do few shot evaluation, we reset the learning to before the evaluation started train_dataloader, eval_dataloader = self.few_shot_preparation(train_dataset, eval_dataset, split=split) all_predictions, all_labels = [], [] for i, (text, labels, datasets) in enumerate(train_dataloader): output = self.training_step(text, labels) predictions = model_utils.make_prediction(output["logits"].detach()) all_predictions.extend(predictions.tolist()) all_labels.extend(labels.tolist()) dataset_results = self.evaluate(dataloader=eval_dataloader) self.log_few_shot(all_predictions, all_labels, datasets, dataset_results, increment_counters, text, i, split=split) if (i * self.config.testing.few_shot_batch_size) % self.mini_batch_size == 0 and i > 0: all_predictions, all_labels = [], [] self.few_shot_end()
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop for text, labels in support_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) self.memory.write_batch(text, labels) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_support_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) self.update_meta_gradients(loss, fpn) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad()
def evaluate(self, dataloader, prediction_network=None): if self.config.learner.evaluation_support_set: support_set = [] for _ in range(self.config.learner.updates): text, labels = self.memory.read_batch( batch_size=self.mini_batch_size) support_set.append((text, labels)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): if self.config.learner.evaluation_support_set: self.set_train() support_prediction_network = fpn # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, labels in support_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, fpn) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) pred = model_utils.make_prediction( output["logits"].detach()) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(labels.tolist()) results = model_utils.calculate_metrics( task_predictions, task_labels) self.logger.info( "Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"], results["precision"], results["recall"], results["f1"])) self.set_eval() else: support_prediction_network = self.pn if prediction_network is None: prediction_network = support_prediction_network self.set_eval() all_losses, all_predictions, all_labels = [], [], [] for i, (text, labels, datasets) in enumerate(dataloader): labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, prediction_network, no_grad=True) loss = self.loss_fn(output["logits"], labels) loss = loss.item() pred = model_utils.make_prediction(output["logits"].detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) # if i % 20 == 0: # self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed") results = model_utils.calculate_metrics(all_predictions, all_labels) self.logger.debug( "Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, " "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"], results["precision"], results["recall"], results["f1"])) return results
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP -------------------") ### GET SUPPORT SET REPRESENTATIONS ### with torch.no_grad(): representations, all_labels = self.get_representations( support_set[:1]) representations_merged = torch.cat(representations) class_means, unique_labels = self.get_class_means( representations_merged, all_labels) do_memory_update = self.config.learner.prototype_update_freq > 0 and \ (self.current_iter % self.config.learner.prototype_update_freq) == 0 if do_memory_update: ### UPDATE MEMORY ### updated_memory_representations = self.memory.update( class_means, unique_labels, logger=self.logger) ### DETERMINE WHAT'S SEEN AS PROTOTYPE ### if self.config.learner.prototypes == "class_means": prototypes = class_means.detach() elif self.config.learner.prototypes == "memory": prototypes = self.memory.class_representations # doesn't track prototype gradients else: raise AssertionError( "Prototype type not in {'class_means', 'memory'}, fix config file." ) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ### self.init_prototypical_classifier(prototypes, linear_module=fpn.linear) self.logger.debug( "----------------- SUPPORT SET ----------------- ") ### TRAIN LINEAR CLASSIFIER ON SUPPORT SET ### # Inner loop for i, (text, labels) in enumerate(support_set): self.logger.debug( f"----------------- {i}th Update ----------------- ") labels = torch.tensor(labels).to(self.device) # if i == 0: # output = { # "representation": representations[0], # "logits": fpn(representations[0], out_from="linear") # } # else: output = self.forward(text, labels, fpn) # for logging purposes prototype_distances = (output["representation"].unsqueeze(1) - prototypes).norm(dim=-1) closest_dists, closest_classes = prototype_distances.topk( 3, largest=False) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 2) for z in x[2].tolist()]), list(zip(labels, closest_classes, closest_dists))))) self.logger.debug( f"True labels, closest prototypes, and distances:\n{to_print}" ) topk = output["logits"].topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits) before update:\n{to_print}" ) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) # see how much linear classifier has changed with torch.no_grad(): topk = fpn(output["representation"], out_from="linear").topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist(), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits) after update:\n{to_print}" ) predictions = model_utils.make_prediction( output["logits"].detach()) self.update_support_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) self.logger.debug( "----------------- QUERY SET ----------------- ") ### EVALUATE ON QUERY SET AND UPDATE ENCODER ### # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) # labels = labels.to(self.device) output = self.forward(text, labels, prediction_network=fpn) loss = self.loss_fn(output["logits"], labels) self.update_meta_gradients(loss, fpn) topk = output['logits'].topk(5, dim=1) to_print = pprint.pformat( list( map( lambda x: (x[0].item(), x[1].tolist( ), [round(z, 3) for z in x[2].tolist()]), list(zip(labels, topk[1], topk[0]))))) self.logger.debug( f"(label, topk_classes, topk_logits):\n{to_print}") predictions = model_utils.make_prediction( output["logits"].detach()) self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP END -------------------")
def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"): """ Allow the model to train on a small amount of datapoints at a time. After every training step, evaluate on many samples that haven't been seen yet. Results are saved in learner's `metrics` attribute. Parameters --- train_dataset: Dataset Contains examples on which the model is trained before being evaluated eval_dataset: Dataset Contains examples on which the model is evaluated increment_counters: bool If True, update online metrics and current iteration counters. """ self.logger.info( f"few shot testing on dataset {self.config.testing.eval_dataset} " f"with {len(train_dataset)} samples") train_dataloader, eval_dataloader = self.few_shot_preparation( train_dataset, eval_dataset, split=split) all_predictions, all_labels = [], [] def add_none(iterator): yield None for x in iterator: yield x shifted_dataloader = add_none(train_dataloader) # prototypes = self.memory.class_representations for i, (support_set, (query_text, query_labels, datasets)) in enumerate( zip(shifted_dataloader, train_dataloader)): query_labels = torch.tensor(query_labels).to(self.device) # happens on the first one # prototypes = self.memory.class_representations if support_set is None: prototypes = self.memory.class_representations else: support_text, support_labels, _ = support_set support_labels = torch.tensor(support_labels).to(self.device) support_representations = self.forward( support_text, support_labels)["representation"] support_class_means, unique_labels = model_utils.get_class_means( support_representations, support_labels) memory_update = self.memory.update(support_class_means, unique_labels, logger=self.logger) updated_memory_representations = memory_update[ "new_class_representations"] self.log_discounts(memory_update["class_discount"], unique_labels, few_shot_examples_seen=(i + 1) * self.config.testing.few_shot_batch_size) prototypes = updated_memory_representations if self.config.learner.few_shot_detach_prototypes: prototypes = prototypes.detach() weight = 2 * prototypes bias = -(prototypes**2).sum(dim=1) query_representations = self.forward( query_text, query_labels)["representation"] logits = query_representations @ weight.T + bias loss = self.loss_fn(logits, query_labels) self.meta_optimizer.zero_grad() loss.backward() self.meta_optimizer.step() predictions = model_utils.make_prediction(logits.detach()) all_predictions.extend(predictions.tolist()) all_labels.extend(query_labels.tolist()) dataset_results = self.evaluate(dataloader=eval_dataloader) self.log_few_shot(all_predictions, all_labels, datasets, dataset_results, increment_counters, query_text, i, split=split) if (i * self.config.testing.few_shot_batch_size ) % self.mini_batch_size == 0 and i > 0: all_predictions, all_labels = [], [] self.few_shot_end()
def training_step(self, support_set, query_set=None, task=None): self.inner_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP -------------------") # with higher.innerloop_ctx(self.pn, self.inner_optimizer, # copy_initial_weights=False, # track_higher_grads=False) as (fpn, diffopt): do_memory_update = self.config.learner.prototype_update_freq > 0 and \ (self.current_iter % self.config.learner.prototype_update_freq) == 0 ### GET SUPPORT SET REPRESENTATIONS ### self.logger.debug("----------------- SUPPORT SET ----------------- ") representations, all_labels = self.get_representations(support_set[:1]) representations_merged = torch.cat(representations) class_means, unique_labels = model_utils.get_class_means( representations_merged, all_labels) self._examples_seen += len(representations_merged) self.logger.debug( f"Examples seen increased by {len(representations_merged)}") ### UPDATE MEMORY ### if do_memory_update: memory_update = self.memory.update(class_means, unique_labels, logger=self.logger) updated_memory_representations = memory_update[ "new_class_representations"] self.log_discounts(memory_update["class_discount"], unique_labels) ### DETERMINE WHAT'S SEEN AS PROTOTYPE ### if self.config.learner.prototypes == "class_means": prototypes = expand_class_representations( self.memory.class_representations, class_means, unique_labels) elif self.config.learner.prototypes == "memory": prototypes = updated_memory_representations else: raise AssertionError( "Prototype type not in {'class_means', 'memory'}, fix config file." ) ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ### # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear) weight = 2 * prototypes # divide by number of dimensions, otherwise blows up bias = -(prototypes**2).sum(dim=1) self.logger.debug("----------------- QUERY SET ----------------- ") ### EVALUATE ON QUERY SET AND UPDATE ENCODER ### # Outer loop if query_set is not None: for text, labels in query_set: labels = torch.tensor(labels).to(self.device) query_representations = self.forward(text, labels)["representation"] # distance query representations to prototypes (BATCH X N_PROTOTYPES) # distances = euclidean_dist(query_representations, prototypes) # logits = - distances logits = query_representations @ weight.T + bias loss = self.loss_fn(logits, labels) # log_probability = F.log_softmax(-distances, dim=1) # loss is negation of the log probability, index using the labels for each observation # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean() self.meta_optimizer.zero_grad() loss.backward() self.meta_optimizer.step() predictions = model_utils.make_prediction(logits.detach()) # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions]) # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(), # [round(z, 3) for z in x[2].tolist()]), # list(zip(labels, predictions, distances))))) self.logger.debug( f"Unique Labels: {unique_labels.tolist()}\n" # f"Labels, Indices, Predictions, Distances:\n{to_print}\n" f"Loss:\n{loss.item()}\n" f"Predictions:\n{predictions}\n") self.update_query_tracker(loss, predictions, labels) metrics = model_utils.calculate_metrics( predictions.tolist(), labels.tolist()) online_metrics = { "accuracy": metrics["accuracy"], "examples_seen": self.examples_seen(), "task": task if task is not None else "none" } self.metrics["online"].append(online_metrics) if task is not None and task == self.config.testing.eval_dataset and \ self.eval_task_first_encounter: self.metrics["eval_task_first_encounter"].append( online_metrics) self._examples_seen += len(text) self.logger.debug(f"Examples seen increased by {len(text)}") # Meta optimizer step # self.meta_optimizer.step() # self.meta_optimizer.zero_grad() self.logger.debug( "-------------------- TRAINING STEP END -------------------")