def train_on_batch(self, batch): self.model.train() x = torch.from_numpy(np.array(batch[0])).to(self.device_name) y = torch.from_numpy(np.array(batch[1])).to(self.device_name) logits = self.model(x) pruned_logits = prune_logits(logits, self.output_mask) cls_loss = F.cross_entropy(pruned_logits, y) cls_acc = compute_accuracy(pruned_logits, y) total_loss = cls_loss if self.config.hp.lowres_training.loss_coef > 0 or self.config.hp.lowres_training.logits_matching_loss_coef > 0: x_lowres = self.transform_em_sample(x, no_grad=False) logits_lowres = self.model(x_lowres) pruned_logits_lowres = prune_logits(logits_lowres, self.output_mask) cls_loss_lowres = F.cross_entropy(pruned_logits_lowres, y) cls_acc_lowres = compute_accuracy(pruned_logits_lowres, y) self.writer.add_scalar('train/cls_loss_lowres', cls_loss_lowres.item(), self.num_iters_done) self.writer.add_scalar('train/cls_acc_lowres', cls_acc_lowres.item(), self.num_iters_done) if self.config.hp.lowres_training.loss_coef > 0: total_loss += self.config.hp.lowres_training.loss_coef * cls_loss if self.config.hp.lowres_training.logits_matching_loss_coef > 0: logits_matching_loss = F.mse_loss(logits, logits_lowres) total_loss += self.config.hp.lowres_training.logits_matching_loss_coef * logits_matching_loss self.writer.add_scalar('train/logits_matching_loss', logits_matching_loss.item(), self.num_iters_done) if self.task_idx > 0: rehearsal_loss, rehearsal_acc = self.compute_rehearsal_loss() total_loss += self.config.hp.memory.loss_coef * rehearsal_loss self.writer.add_scalar('train/rehearsal_loss', rehearsal_loss.item(), self.num_iters_done) self.writer.add_scalar('train/rehearsal_acc', rehearsal_acc.item(), self.num_iters_done) self.optim.zero_grad() total_loss.backward() self.optim.step() self.writer.add_scalar('train/cls_loss', cls_loss.item(), self.num_iters_done) self.writer.add_scalar('train/cls_acc', cls_acc.item(), self.num_iters_done)
def compute_scores(self, dataset: str = 'val'): self.model.eval() if dataset == 'val': # GZSL metrics logits = self.run_inference(self.val_dataloader, scope=self.val_scope) preds = logits.argmax(dim=1).numpy() guessed = (preds == self.val_labels) seen_acc = guessed[self.val_pseudo_seen_idx].mean() unseen_acc = guessed[self.val_pseudo_unseen_idx].mean() harmonic = 2 * (seen_acc * unseen_acc) / (seen_acc + unseen_acc) # ZSL zsl_logits = prune_logits(logits, self.pseudo_unseen_mask) zsl_preds = zsl_logits.argmax(dim=1).numpy() zsl_acc = (zsl_preds == self.val_labels )[self.val_pseudo_unseen_idx].mean() # AUSUC if self.config.get('logging.compute_ausuc'): ausuc = compute_ausuc(logits, self.val_labels, self.train_seen_mask) * 0.01 else: ausuc = 0 elif dataset == 'test': # GZSL metrics logits = self.run_inference(self.test_dataloader, scope='all') preds = logits.argmax(dim=1).numpy() guessed = (preds == self.test_labels) seen_acc = guessed[self.test_seen_idx].mean() unseen_acc = guessed[self.test_unseen_idx].mean() harmonic = 2 * (seen_acc * unseen_acc) / (seen_acc + unseen_acc) # ZSL zsl_logits = prune_logits(logits, self.unseen_mask) zsl_preds = zsl_logits.argmax(dim=1).numpy() zsl_acc = ( zsl_preds == self.test_labels)[self.test_unseen_idx].mean() # AUSUC if self.config.get('logging.compute_ausuc'): ausuc = compute_ausuc(logits, self.test_labels, self.seen_mask) * 0.01 else: ausuc = 0 else: raise ValueError(f"Wrong dataset for GZSL scores: {dataset}") return 100 * np.array([seen_acc, unseen_acc, harmonic, zsl_acc, ausuc])
def compute_rehearsal_loss(self): x, y = self.sample_from_memory(self.config.hp.memory.batch_size) pruned_logits = prune_logits(self.model(x), self.learned_classes_mask) cls_loss = F.cross_entropy(pruned_logits, y) cls_acc = compute_accuracy(pruned_logits, y) return cls_loss, cls_acc
def train_on_batch(self, batch): self.model.train() x = torch.from_numpy(np.array(batch[0])).to(self.device_name) y = torch.from_numpy(np.array(batch[1])).to(self.device_name) logits = self.model(x) pruned_logits = prune_logits(logits, self.output_mask) cls_loss = F.cross_entropy(pruned_logits, y) cls_acc = compute_accuracy(pruned_logits, y) total_loss = cls_loss if self.task_idx > 0: rehearsal_loss, rehearsal_acc = self.compute_rehearsal_loss() total_loss += self.config.hp.memory.loss_coef * rehearsal_loss self.writer.add_scalar('train/rehearsal_loss', rehearsal_loss.item(), self.num_iters_done) self.writer.add_scalar('train/rehearsal_acc', rehearsal_acc.item(), self.num_iters_done) self.optim.zero_grad() total_loss.backward() self.optim.step() self.writer.add_scalar('train/cls_loss', cls_loss.item(), self.num_iters_done) self.writer.add_scalar('train/cls_acc', cls_acc.item(), self.num_iters_done)
def train_on_batch(self, batch): self.model.train() x = torch.from_numpy(np.array(batch[0])).to(self.device_name) y = torch.from_numpy(np.array(batch[1])).to(self.device_name) logits = prune_logits(self.model(x), self.output_mask) loss = self.criterion(logits, y) self.optim.zero_grad() loss.backward() if self.config.hp.get('clip_grad.value', float('inf')) < float('inf'): grad_norm = clip_grad_norm_(self.model.parameters(), self.config.hp.clip_grad.value) self.writer.add_scalar('cls/grad_norm', grad_norm, self.num_iters_done) self.optim.step() self.writer.add_scalar('loss', loss.item(), self.num_iters_done)
def compute_grad(model: nn.Module, criterion: nn.Module, dataloader: DataLoader, output_mask: np.ndarray, elementwise_grad_norm: str) -> Tensor: """ Computes gradient of the given loss across the dataset :param model: :param dataloader: :return: """ num_samples = 0 num_params = sum(p.numel() for p in model.parameters()) device = get_module_device(model) grad = torch.zeros(num_params).to(device) for x, y in dataloader: x = torch.from_numpy(np.array(x)).to(device) y = torch.tensor(y).to(device) logits = model(x) pruned_logits = prune_logits(logits, output_mask) loss = criterion(pruned_logits, y) model.zero_grad() loss.backward() curr_grad = torch.cat( [get_grad(p).view(-1) for p in model.parameters()]) if elementwise_grad_norm == 'square': curr_grad = curr_grad.pow(2) elif elementwise_grad_norm == 'abs': curr_grad = curr_grad.abs() else: raise NotImplementedError( f'Unknown elementwise grad norm: {elementwise_grad_norm}') grad += curr_grad num_samples += len(x) return grad / num_samples
def compute_pruned_predictions(self, x: Tensor, output_mask: np.ndarray) -> Tensor: return prune_logits(self.forward(x), output_mask)