예제 #1
0
 def _trainer_specific_loss(self, unlab_img: Tensor, unlab_gt: Tensor,
                            **kwargs) -> Tensor:
     unlab_img = unlab_img.to(self._device)
     unlab_img_tf, _ = self.affine_transform(unlab_img)
     all_preds = self._model(torch.cat([unlab_img, unlab_img_tf], dim=0))
     unlabel_pred, unlabel_pred_tf = torch.chunk(all_preds, 2)
     assert simplex(unlabel_pred) and simplex(unlabel_pred_tf)
     reg = self.kl_criterion(unlabel_pred_tf, unlabel_pred.detach())
     entropy = self.entropy_entropy(unlabel_pred)
     reg += entropy * 0.1
     self.METERINTERFACE["uda_reg"].add(reg.item())
     self.METERINTERFACE["unl_acc"].add(unlabel_pred.max(1)[1], unlab_gt)
     self.METERINTERFACE["entropy"].add(entropy.item())
     return reg
예제 #2
0
 def _trainer_specific_loss(self, unlab_img: Tensor, **kwargs) -> Tensor:
     unlab_img = unlab_img.to(self._device)
     unlabeled_preds = self._model(unlab_img)
     assert simplex(unlabeled_preds, 1)
     marginal = unlabeled_preds.mean(0)
     lagrangian = (self.prior * (marginal * self.mu.detach() + 1 +
                                 (-self.mu.detach()).log())).sum()
     centropy = self.entropy(unlabeled_preds)
     self.METERINTERFACE["centropy"].add(centropy.item())
     lagrangian += centropy * 0.1
     return lagrangian
예제 #3
0
    def _update_mu(self, unlab_img: Tensor):
        self.mu_optim.zero_grad()
        unlab_img = unlab_img.to(self._device)
        unlabeled_preds = self._model(unlab_img).detach()
        assert simplex(unlabeled_preds, 1)
        marginal = unlabeled_preds.mean(0)
        # to increase the lagrangian..
        lagrangian = (-1 * (self.prior * (marginal * self.mu + 1 +
                                          (-self.mu).log())).sum())
        lagrangian.backward()
        self.mu_optim.step()

        self.METERINTERFACE["residual"].add(self.mu.grad.abs().sum().item())
        # to quantify:
        marginal_loss = self.kl_criterion(marginal.unsqueeze(0),
                                          self.prior.unsqueeze(0),
                                          disable_assert=True)
        self.METERINTERFACE["marginal"].add(marginal_loss.item())
예제 #4
0
    def _trainer_specific_loss(self, unlab_img: Tensor, *args,
                               **kwargs) -> Tensor:
        unlab_img = unlab_img.to(self._device)
        unlabeled_preds = self._model(unlab_img)
        assert simplex(unlabeled_preds, 1)
        marginal = unlabeled_preds.mean(0)
        if not self.inverse_kl:
            marginal_loss = self.kl_criterion(marginal.unsqueeze(0),
                                              self.prior.unsqueeze(0))
        else:
            marginal_loss = self.kl_criterion(self.prior.unsqueeze(0),
                                              marginal.unsqueeze(0),
                                              disable_assert=True)

        self.METERINTERFACE["marginal"].add(marginal_loss.item())
        centropy = self.entropy(unlabeled_preds)
        marginal_loss += centropy * 0.1
        self.METERINTERFACE["centropy"].add(centropy.item())
        return marginal_loss
예제 #5
0
 def __init__(
     self,
     model: Model,
     labeled_loader: DataLoader,
     unlabeled_loader: DataLoader,
     val_loader: DataLoader,
     max_epoch: int = 100,
     save_dir: str = "base",
     checkpoint_path: str = None,
     device="cpu",
     config: dict = None,
     max_iter: int = 100,
     prior: Tensor = None,
     inverse_kl=False,
     **kwargs,
 ) -> None:
     """
     :param prior: the predefined prior, must provide as a tensor
     :param inverse_kl:
     :param kwargs:
     """
     super().__init__(
         model,
         labeled_loader,
         unlabeled_loader,
         val_loader,
         max_epoch,
         save_dir,
         checkpoint_path,
         device,
         config,
         max_iter,
         **kwargs,
     )
     assert isinstance(prior, Tensor), prior
     assert simplex(prior, 0), f"`prior` provided must be simplex."
     self.prior = prior.to(self._device)
     self.entropy = Entropy()
     self.inverse_kl = inverse_kl