コード例 #1
0
    def train_on_batch(self, batch):
        self.model.train()

        x = torch.from_numpy(np.array(batch[0])).to(self.device_name)
        y = torch.from_numpy(np.array(batch[1])).to(self.device_name)

        logits = self.model(x)
        pruned_logits = prune_logits(logits, self.output_mask)

        cls_loss = F.cross_entropy(pruned_logits, y)
        cls_acc = compute_accuracy(pruned_logits, y)

        total_loss = cls_loss

        if self.config.hp.lowres_training.loss_coef > 0 or self.config.hp.lowres_training.logits_matching_loss_coef > 0:
            x_lowres = self.transform_em_sample(x, no_grad=False)
            logits_lowres = self.model(x_lowres)
            pruned_logits_lowres = prune_logits(logits_lowres,
                                                self.output_mask)
            cls_loss_lowres = F.cross_entropy(pruned_logits_lowres, y)
            cls_acc_lowres = compute_accuracy(pruned_logits_lowres, y)

            self.writer.add_scalar('train/cls_loss_lowres',
                                   cls_loss_lowres.item(), self.num_iters_done)
            self.writer.add_scalar('train/cls_acc_lowres',
                                   cls_acc_lowres.item(), self.num_iters_done)

        if self.config.hp.lowres_training.loss_coef > 0:
            total_loss += self.config.hp.lowres_training.loss_coef * cls_loss

        if self.config.hp.lowres_training.logits_matching_loss_coef > 0:
            logits_matching_loss = F.mse_loss(logits, logits_lowres)
            total_loss += self.config.hp.lowres_training.logits_matching_loss_coef * logits_matching_loss

            self.writer.add_scalar('train/logits_matching_loss',
                                   logits_matching_loss.item(),
                                   self.num_iters_done)

        if self.task_idx > 0:
            rehearsal_loss, rehearsal_acc = self.compute_rehearsal_loss()
            total_loss += self.config.hp.memory.loss_coef * rehearsal_loss

            self.writer.add_scalar('train/rehearsal_loss',
                                   rehearsal_loss.item(), self.num_iters_done)
            self.writer.add_scalar('train/rehearsal_acc', rehearsal_acc.item(),
                                   self.num_iters_done)

        self.optim.zero_grad()
        total_loss.backward()
        self.optim.step()

        self.writer.add_scalar('train/cls_loss', cls_loss.item(),
                               self.num_iters_done)
        self.writer.add_scalar('train/cls_acc', cls_acc.item(),
                               self.num_iters_done)
コード例 #2
0
    def compute_scores(self, dataset: str = 'val'):
        self.model.eval()

        if dataset == 'val':
            # GZSL metrics
            logits = self.run_inference(self.val_dataloader,
                                        scope=self.val_scope)
            preds = logits.argmax(dim=1).numpy()
            guessed = (preds == self.val_labels)
            seen_acc = guessed[self.val_pseudo_seen_idx].mean()
            unseen_acc = guessed[self.val_pseudo_unseen_idx].mean()
            harmonic = 2 * (seen_acc * unseen_acc) / (seen_acc + unseen_acc)

            # ZSL
            zsl_logits = prune_logits(logits, self.pseudo_unseen_mask)
            zsl_preds = zsl_logits.argmax(dim=1).numpy()
            zsl_acc = (zsl_preds == self.val_labels
                       )[self.val_pseudo_unseen_idx].mean()

            # AUSUC
            if self.config.get('logging.compute_ausuc'):
                ausuc = compute_ausuc(logits, self.val_labels,
                                      self.train_seen_mask) * 0.01
            else:
                ausuc = 0
        elif dataset == 'test':
            # GZSL metrics
            logits = self.run_inference(self.test_dataloader, scope='all')
            preds = logits.argmax(dim=1).numpy()
            guessed = (preds == self.test_labels)
            seen_acc = guessed[self.test_seen_idx].mean()
            unseen_acc = guessed[self.test_unseen_idx].mean()
            harmonic = 2 * (seen_acc * unseen_acc) / (seen_acc + unseen_acc)

            # ZSL
            zsl_logits = prune_logits(logits, self.unseen_mask)
            zsl_preds = zsl_logits.argmax(dim=1).numpy()
            zsl_acc = (
                zsl_preds == self.test_labels)[self.test_unseen_idx].mean()

            # AUSUC
            if self.config.get('logging.compute_ausuc'):
                ausuc = compute_ausuc(logits, self.test_labels,
                                      self.seen_mask) * 0.01
            else:
                ausuc = 0
        else:
            raise ValueError(f"Wrong dataset for GZSL scores: {dataset}")

        return 100 * np.array([seen_acc, unseen_acc, harmonic, zsl_acc, ausuc])
コード例 #3
0
    def compute_rehearsal_loss(self):
        x, y = self.sample_from_memory(self.config.hp.memory.batch_size)
        pruned_logits = prune_logits(self.model(x), self.learned_classes_mask)
        cls_loss = F.cross_entropy(pruned_logits, y)
        cls_acc = compute_accuracy(pruned_logits, y)

        return cls_loss, cls_acc
コード例 #4
0
    def train_on_batch(self, batch):
        self.model.train()

        x = torch.from_numpy(np.array(batch[0])).to(self.device_name)
        y = torch.from_numpy(np.array(batch[1])).to(self.device_name)

        logits = self.model(x)
        pruned_logits = prune_logits(logits, self.output_mask)

        cls_loss = F.cross_entropy(pruned_logits, y)
        cls_acc = compute_accuracy(pruned_logits, y)

        total_loss = cls_loss

        if self.task_idx > 0:
            rehearsal_loss, rehearsal_acc = self.compute_rehearsal_loss()
            total_loss += self.config.hp.memory.loss_coef * rehearsal_loss

            self.writer.add_scalar('train/rehearsal_loss',
                                   rehearsal_loss.item(), self.num_iters_done)
            self.writer.add_scalar('train/rehearsal_acc', rehearsal_acc.item(),
                                   self.num_iters_done)

        self.optim.zero_grad()
        total_loss.backward()
        self.optim.step()

        self.writer.add_scalar('train/cls_loss', cls_loss.item(),
                               self.num_iters_done)
        self.writer.add_scalar('train/cls_acc', cls_acc.item(),
                               self.num_iters_done)
コード例 #5
0
    def train_on_batch(self, batch):
        self.model.train()

        x = torch.from_numpy(np.array(batch[0])).to(self.device_name)
        y = torch.from_numpy(np.array(batch[1])).to(self.device_name)

        logits = prune_logits(self.model(x), self.output_mask)
        loss = self.criterion(logits, y)

        self.optim.zero_grad()
        loss.backward()
        if self.config.hp.get('clip_grad.value', float('inf')) < float('inf'):
            grad_norm = clip_grad_norm_(self.model.parameters(),
                                        self.config.hp.clip_grad.value)
            self.writer.add_scalar('cls/grad_norm', grad_norm,
                                   self.num_iters_done)
        self.optim.step()

        self.writer.add_scalar('loss', loss.item(), self.num_iters_done)
コード例 #6
0
def compute_grad(model: nn.Module, criterion: nn.Module,
                 dataloader: DataLoader, output_mask: np.ndarray,
                 elementwise_grad_norm: str) -> Tensor:
    """
    Computes gradient of the given loss across the dataset

    :param model:
    :param dataloader:
    :return:
    """
    num_samples = 0
    num_params = sum(p.numel() for p in model.parameters())
    device = get_module_device(model)
    grad = torch.zeros(num_params).to(device)

    for x, y in dataloader:
        x = torch.from_numpy(np.array(x)).to(device)
        y = torch.tensor(y).to(device)
        logits = model(x)
        pruned_logits = prune_logits(logits, output_mask)
        loss = criterion(pruned_logits, y)

        model.zero_grad()
        loss.backward()
        curr_grad = torch.cat(
            [get_grad(p).view(-1) for p in model.parameters()])

        if elementwise_grad_norm == 'square':
            curr_grad = curr_grad.pow(2)
        elif elementwise_grad_norm == 'abs':
            curr_grad = curr_grad.abs()
        else:
            raise NotImplementedError(
                f'Unknown elementwise grad norm: {elementwise_grad_norm}')

        grad += curr_grad
        num_samples += len(x)

    return grad / num_samples
コード例 #7
0
 def compute_pruned_predictions(self, x: Tensor,
                                output_mask: np.ndarray) -> Tensor:
     return prune_logits(self.forward(x), output_mask)