Exemplo n.º 1
0
        target = self.to_onehot(target, self.smoothing)
        s_loss = self.loss_f[0](output, target)
        return output, s_loss

    def unlabeled(self,
                  input: torch.Tensor) -> (None, torch.Tensor, torch.Tensor):
        with disable_bn_stats(self.model):
            logits = self.model(input)
            u_loss = self.vat_loss(input, logits.clone().detach())
            e_loss = Categorical(logits=logits).entropy().mean()
        return None, u_loss, e_loss

    def vat_loss(self, input: torch.Tensor,
                 logits: torch.Tensor) -> torch.Tensor:
        d = normalize(input.clone().normal_())
        d.requires_grad_(True)
        pred_hat = self.model(input + self.xi * d)
        adv_loss = kl_div(logits, pred_hat)
        d_grad, = torch.autograd.grad([adv_loss], [d])
        d = normalize(d_grad)
        self.model.zero_grad()
        pred_hat = self.model(input + self.eps * d)
        return kl_div(logits, pred_hat)


if __name__ == "__main__":
    import hydra

    hydra.main('config/vat.yaml')(get_task(
        VATTrainer, [cross_entropy_with_softlabels, F.cross_entropy]))()
Exemplo n.º 2
0
import torch
from homura.modules import cross_entropy_with_softlabels
from torch.nn import functional as F

from backends.utils import SSLTrainerBase, disable_bn_stats, get_task


class PseudoLabelTrainer(SSLTrainerBase):
    def labeled(self, input: torch.Tensor,
                target: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        target = self.to_onehot(target, self.smoothing)
        output = self.model(input)
        loss = self.loss_f(output, target)
        return output, loss

    def unlabeled(self, input: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        with disable_bn_stats(self.model):
            u_output = self.model(input)
        u_loss = F.cross_entropy(u_output,
                                 u_output.argmax(dim=1),
                                 reduction='none')
        u_loss = ((u_loss > self.threshold).float() * u_loss).mean()
        return u_output, u_loss


if __name__ == "__main__":
    import hydra

    hydra.main('config/pseudo_label.yaml')(get_task(
        PseudoLabelTrainer, cross_entropy_with_softlabels))()
Exemplo n.º 3
0
import torch
from homura.modules import cross_entropy_with_softlabels, to_onehot
from torch.nn import functional as F

from backends.utils import SSLTrainerBase, disable_bn_stats, get_task


class MeanTeacherTrainer(SSLTrainerBase):
    def labeled(self, input: torch.Tensor,
                target: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        output = self.model(input)
        target = to_onehot(target, self.num_classes)
        target -= self.smoothing * (target - 1 / self.num_classes)
        loss = self.loss_f(output, target)
        return output, loss

    def unlabeled(self, input1: torch.Tensor,
                  input2: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        with disable_bn_stats(self.model):
            o1 = self.model(input1)
            with torch.no_grad():
                o2 = self.ema(input2)
        return o1, F.mse_loss(o1.softmax(dim=1), o2.softmax(dim=1))


if __name__ == "__main__":
    import hydra

    hydra.main('config/mean_teacher.yaml')(get_task(
        MeanTeacherTrainer, cross_entropy_with_softlabels))()
Exemplo n.º 4
0
        loss = self.loss_f[0](output, target)
        return output, loss

    def unlabeled(self, input: torch.Tensor):
        with torch.no_grad():
            expected = self.ema(input).softmax(dim=-1)
            input, expected = self.mixup(input, expected)
        output = self.model(input)
        loss = self.loss_f[1](output, expected)
        return output, loss

    def mixup(self, input: torch.Tensor, target: torch.Tensor):
        if not torch.is_tensor(self.beta):
            # very important for speed up
            self.beta = torch.tensor(self.beta).to(self.device)
        gamma = Beta(self.beta, self.beta).sample((input.size(0), 1, 1, 1))
        perm = torch.randperm(input.size(0))
        perm_input = input[perm]
        perm_target = target[perm]
        input.mul_(gamma).add_(perm_input.mul_(1 - gamma))
        gamma = gamma.view(-1, 1)
        target.mul_(gamma).add_(perm_target.mul_(1 - gamma))
        return input, target


if __name__ == '__main__':
    import hydra

    hydra.main('config/ict.yaml')(get_task(
        ICTTrainer, [cross_entropy_with_softlabels, mse_with_logits]))()
Exemplo n.º 5
0
from torch.nn import functional as F

from backends.utils import SSLTrainerBase, disable_bn_stats, get_task


class PseudoLabelTrainer(SSLTrainerBase):

    def labeled(self,
                input: torch.Tensor,
                target: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        target = self.to_onehot(target, self.smoothing)
        output = self.model(input)
        loss = self.loss_f(output, target)
        return output, loss

    def unlabeled(self,
                  input: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        with disable_bn_stats(self.model):
            u_output = self.model(input)
        u_loss = F.cross_entropy(u_output, u_output.argmax(dim=1), reduction='none')
        u_loss = ((u_output.softmax(dim=1) > self.threshold).any(dim=1).float() * u_loss).mean()
        return u_output, u_loss


if __name__ == "__main__":
    import hydra

    hydra.main('config/pseudo_label.yaml')(
        get_task(PseudoLabelTrainer, cross_entropy_with_softlabels)
    )()
Exemplo n.º 6
0
    def data_handle(self,
                    data: Tuple) -> Tuple:
        pillow_aug = (len(data) == 5)
        if pillow_aug:
            # Pillow augmentation
            input, target, u_x1, u_x2, _ = data
            u_x, u_y = self.sharpen((u_x1, u_x2))
        else:
            input, target, u_x, _ = data
            u_x, u_y = self.sharpen(u_x)
        # l_x, l_y, u_x, u_y
        return input, self.to_onehot(target), u_x, u_y

    def sharpen(self,
                input: torch.Tensor or Tuple) -> Tuple[torch.Tensor, torch.Tensor]:

        u_b = torch.cat(input, dim=0)
        with disable_bn_stats(self.model):
            q_b = (self.model(input[0]).softmax(dim=1) + self.model(input[1]).softmax(dim=1)) / 2
        q_b.pow_(1 / self.temperature).div_(q_b.sum(dim=1, keepdim=True))
        return u_b, q_b.repeat(2, 1)


if __name__ == "__main__":
    import hydra

    hydra.main('config/mixmatch.yaml')(
        get_task(MixmatchTrainer,
                 [cross_entropy_with_softlabels, F.mse_loss])
    )()
Exemplo n.º 7
0
from backends.utils import SSLTrainerBase, disable_bn_stats, get_task


class MeanTeacherTrainer(SSLTrainerBase):

    def labeled(self,
                input: torch.Tensor,
                target: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        output = self.model(input)
        target = to_onehot(target, self.num_classes)
        target -= self.smoothing * (target - 1 / self.num_classes)
        loss = self.loss_f(output, target)
        return output, loss

    def unlabeled(self,
                  input1: torch.Tensor,
                  input2: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        with disable_bn_stats(self.model):
            o1 = self.model(input1)
        o2 = self.ema(input2)
        return o1, F.mse_loss(o1.softmax(dim=1), o2.softmax(dim=1))


if __name__ == "__main__":
    import hydra

    hydra.main('config/mean_teacher.yaml')(
        get_task(MeanTeacherTrainer, cross_entropy_with_softlabels)
    )()