def _train_loop(
        self,
        train_loader: DataLoader,
        epoch: int,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        super()._train_loop(*args, **kwargs)
        self.model.set_mode(mode)
        assert self.model.training
        _train_loader: tqdm = tqdm_(train_loader)
        for _batch_num, images_labels_indices in enumerate(_train_loader):
            images, labels, *_ = zip(*images_labels_indices)
            tf1_images = torch.cat(tuple(
                [images[0] for _ in range(images.__len__() - 1)]),
                                   dim=0).to(self.device)
            tf2_images = torch.cat(tuple(images[1:]), dim=0).to(self.device)
            pred_tf1_simplex = self.model(tf1_images)
            pred_tf2_simplex = self.model(tf2_images)
            assert simplex(pred_tf1_simplex[0]), pred_tf1_simplex
            assert simplex(pred_tf2_simplex[0]), pred_tf2_simplex
            total_loss = self._trainer_specific_loss(tf1_images, tf2_images,
                                                     pred_tf1_simplex,
                                                     pred_tf2_simplex)
            self.model.zero_grad()
            total_loss.backward()
            self.model.step()
            report_dict = self._training_report_dict
            _train_loader.set_postfix(report_dict)

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"  Training epoch: {epoch} : {report_dict_str}")
Beispiel #2
0
    def _trainer_specific_loss(self, label_img, label_gt, unlab_img, *args,
                               **kwargs):
        super(AdaNetTrainer, self)._trainer_specific_loss(*args,
                                                          **kwargs)  # warning
        assert label_img.shape == unlab_img.shape, f"Shapes of labeled and unlabeled images should be the same," \
            f"given {label_img.shape} and {unlab_img.shape}."
        self.model.eval()
        with torch.no_grad():
            pseudo_label = self.model.torchnet(unlab_img)[0]
        self.model.train()
        mixup_img, mixup_label, mix_indice = self._mixup(
            label_img,
            class2one_hot(label_gt.unsqueeze(dim=1).unsqueeze(dim=2),
                          10).squeeze().float(), unlab_img, pseudo_label)

        pred, cls = self.model(mixup_img)
        assert simplex(pred) and simplex(cls)
        reg_loss1 = self.kl_criterion(pred, mixup_label)
        adv_loss = self.kl_criterion(cls, mix_indice)
        self.METERINTERFACE.tra_sup_mixup.add(reg_loss1.item())
        self.METERINTERFACE.tra_cls.add(adv_loss.item())
        self.METERINTERFACE.grl.add(self.grl_scheduler.value)

        # Discriminator
        return (reg_loss1 + adv_loss) * 0.1
Beispiel #3
0
    def forward(self, x_out: Tensor, x_tf_out: Tensor):
        """
        return the inverse of the MI. if the x_out == y_out, return the inverse of Entropy
        :param x_out:
        :param x_tf_out:
        :return:
        """
        assert simplex(x_out), f"x_out not normalized."
        assert simplex(x_tf_out), f"x_tf_out not normalized."
        _, k = x_out.size()
        p_i_j = compute_joint(x_out, x_tf_out)
        assert p_i_j.size() == (k, k)

        p_i = (p_i_j.sum(dim=1).view(k, 1).expand(k, k)
               )  # p_i should be the mean of the x_out
        p_j = p_i_j.sum(dim=0).view(1, k).expand(
            k, k)  # but should be same, symmetric

        # p_i = x_out.mean(0).view(k, 1).expand(k, k)
        # p_j = x_tf_out.mean(0).view(1, k).expand(k, k)
        #
        # avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway
        if self.torch_vision < "1.3.0":
            p_i_j[p_i_j < self.eps] = self.eps
            p_j[p_j < self.eps] = self.eps
            p_i[p_i < self.eps] = self.eps

        loss = -p_i_j * (torch.log(p_i_j) - self.lamb * torch.log(p_j) -
                         self.lamb * torch.log(p_i))
        loss = loss.sum()
        loss_no_lamb = -p_i_j * (torch.log(p_i_j) - torch.log(p_j) -
                                 torch.log(p_i))
        loss_no_lamb = loss_no_lamb.sum()
        return loss, loss_no_lamb
Beispiel #4
0
 def _mixup_image_pred_index(self, tf1_image, tf1_pred, tf2_image,
                             tf2_pred) -> Tuple[Tensor, Tensor, Tensor]:
     """
     There the input predictions are simplexes instead of list of simplexes
     """
     assert simplex(tf1_pred) and simplex(tf2_pred)
     mixup_img, mixup_label, mixup_index = self.mixup_module(
         tf1_image, tf1_pred, tf2_image, tf2_pred)
     return mixup_img, mixup_label, mixup_index
Beispiel #5
0
 def forward(self, x_out1: Tensor, x_out2: Tensor):
     assert simplex(x_out1) and simplex(x_out2)
     joint_distr = compute_joint(x_out1, x_out2)
     marginal = self.entropy(
         joint_distr.sum(0).unsqueeze(0)) + self.entropy(
             joint_distr.sum(1).unsqueeze(0))
     centropy = -(joint_distr *
                  (joint_distr + self.entropy._eps).log()).sum()
     mi = marginal - self.c_coef * centropy
     self._update_weights(x_out1)
     return mi * -1.0, mi * -1.0
Beispiel #6
0
    def __call__(self, x_out1: Tensor, x_out2: Tensor):
        assert simplex(x_out1) and simplex(x_out2)
        joint_distr = self.compute_joint(x_out1, x_out2)
        marginal = self.entropy(
            joint_distr.sum(0).unsqueeze(0)) + self.entropy(
                joint_distr.sum(1).unsqueeze(0))
        centropy = -(joint_distr *
                     (joint_distr + self.entropy._eps).log()).sum()

        mi = self.lamda * marginal - centropy
        # print(marginal.data, centropy.data)

        return mi * -1.0, mi * -1.0
    def _trainer_specific_loss(
        self,
        images: torch.Tensor,
        images_tf: torch.Tensor,
        pred: List[torch.Tensor],
        pred_tf: List[torch.Tensor],
    ) -> torch.Tensor:
        assert simplex(pred[0]) and pred_tf.__len__() == pred.__len__()

        # generate adversarial images: Take image without repetition
        _, adv_images, _ = VATLoss_Multihead(xi=1, eps=10, prop_eps=0.1)(
            self.model.torchnet, images[:self.train_loader.batch_size])
        adv_preds = self.model(adv_images)
        assert adv_preds[0].__len__() == self.train_loader.batch_size

        batch_loss: List[torch.Tensor] = []  # type: ignore
        for subhead in range(pred.__len__()):
            # add adv prediction to the whole prediction list
            _loss, _loss_no_lambda = self.criterion(
                torch.cat(
                    (pred[subhead],
                     pred[subhead][:self.train_loader.batch_size]),
                    dim=0,
                ),
                torch.cat((pred_tf[subhead], adv_preds[subhead]), dim=0),
            )
            batch_loss.append(_loss)
        batch_loss: torch.Tensor = sum(batch_loss) / len(
            batch_loss)  # type: ignore
        self.METERINTERFACE[f"train_mi"].add(-batch_loss.item())

        total_loss = batch_loss

        return total_loss
    def _trainer_specific_loss(
        self,
        images: torch.Tensor,
        images_tf: torch.Tensor,
        pred: List[torch.Tensor],
        pred_tf: List[torch.Tensor],
    ) -> torch.Tensor:
        assert simplex(pred[0]) and pred_tf.__len__() == pred.__len__()
        batch_loss: List[torch.Tensor] = []  # type: ignore
        for subhead in range(pred.__len__()):
            _loss, _loss_no_lambda = self.criterion(pred[subhead],
                                                    pred_tf[subhead])
            batch_loss.append(_loss)
        batch_loss: torch.Tensor = sum(batch_loss) / len(
            batch_loss)  # type: ignore
        self.METERINTERFACE[f"train_mi"].add(-batch_loss.item())

        # # vat loss:
        sat_loss = 0
        if self.sat_weight > 0:
            sat_loss, *_ = VATLoss_Multihead(xi=1, eps=10,
                                             prop_eps=0.1)(self.model.torchnet,
                                                           images)
            self.METERINTERFACE["train_sat"].add(sat_loss.item())
        total_loss = batch_loss + self.sat_weight * sat_loss

        return total_loss
Beispiel #9
0
 def __call__(self, *args, **kwargs):
     force_simplex = kwargs.pop("force_simplex", False)
     assert isinstance(force_simplex, bool), force_simplex
     torch_logits = self._torchnet(*args, **kwargs)
     if force_simplex:
         if not simplex(torch_logits, 1):
             return F.softmax(torch_logits, 1)
     return torch_logits
Beispiel #10
0
def compute_joint(x_out: Tensor, x_tf_out: Tensor) -> Tensor:
    r"""
    return joint probability
    :param x_out: p1, simplex
    :param x_tf_out: p2, simplex
    :return: joint probability
    """
    # produces variable that requires grad (since args require grad)
    assert simplex(x_out), f"x_out not normalized."
    assert simplex(x_tf_out), f"x_tf_out not normalized."

    bn, k = x_out.shape
    assert x_tf_out.size(0) == bn and x_tf_out.size(1) == k

    p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1)  # bn, k, k
    p_i_j = p_i_j.sum(dim=0)  # k, k aggregated over one batch
    p_i_j = (p_i_j + p_i_j.t()) / 2.0  # symmetric
    p_i_j /= p_i_j.sum()  # normalise

    return p_i_j
Beispiel #11
0
    def _mixup(self, label_img: torch.Tensor, label_onehot: torch.Tensor,
               unlab_img: torch.Tensor, unlabeled_pred: torch.Tensor):
        assert label_img.shape == unlab_img.shape
        assert label_img.shape.__len__() == 4
        assert one_hot(label_onehot) and simplex(unlabeled_pred)
        assert label_onehot.shape == unlabeled_pred.shape
        bn, *shape = label_img.shape
        alpha = self.beta_distr.sample((bn, )).squeeze(1).to(self.device)
        _alpha = alpha.view(bn, 1, 1, 1).repeat(1, *shape)
        assert _alpha.shape == label_img.shape
        mixup_img = label_img * _alpha + unlab_img * (1 - _alpha)
        mixup_label = label_onehot * alpha.view(bn, 1) \
                      + unlabeled_pred * (1 - alpha).view(bn, 1)
        mixup_index = torch.stack([alpha, 1 - alpha], dim=1).to(self.device)

        assert mixup_img.shape == label_img.shape
        assert mixup_label.shape == label_onehot.shape
        assert mixup_index.shape[0] == bn
        assert simplex(mixup_index)

        return mixup_img, mixup_label, mixup_index
Beispiel #12
0
    def _trainer_specific_loss(
        self,
        images: torch.Tensor,
        images_tf: torch.Tensor,
        pred: List[torch.Tensor],
        pred_tf: List[torch.Tensor],
    ) -> torch.Tensor:
        assert simplex(pred[0]) and pred_tf.__len__() == pred.__len__()
        _, adv_images, _ = VATLoss_Multihead(xi=1, eps=10,
                                             prop_eps=0.1)(self.model.torchnet,
                                                           images)
        adv_pred = self.model(adv_images)
        assert simplex(adv_pred[0])

        batch_loss: List[torch.Tensor] = []  # type: ignore
        for subhead in range(pred.__len__()):
            _loss, _loss_no_lambda = self.criterion(pred[subhead],
                                                    adv_pred[subhead])
            batch_loss.append(_loss)
        batch_loss: torch.Tensor = sum(batch_loss) / len(
            batch_loss)  # type: ignore
        self.METERINTERFACE[f"train_mi"].add(-batch_loss.item())
        return batch_loss
    def __call__(
        self, img: Tensor, gt: Optional[Tensor], net: Callable[[Tensor], Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        generate adversarial images given a network.

        :param img:
        :param gt: can be fully supervised, semi supervised, and unsupervised when gt=None
        :param net:
        :return:
        """

        assert img.shape.__len__() == 4
        if gt is not None:
            assert img.shape[0] >= gt.shape[0]

        # set network:
        current_net = self._net
        if net:
            current_net = net
        assert current_net, current_net

        img.requires_grad = True
        if img.grad is not None:
            img.grad.zero_()
        current_net.zero_grad()
        pred = current_net(img)
        if not simplex(pred):
            pred = pred.softmax(1)

        if gt is not None:
            if img.shape[0] > gt.shape[0]:
                gt = torch.cat(
                    (gt, pred.max(1)[1][gt.shape[0] :].unsqueeze(1)), dim=0
                )  # semisupervised setting
        else:
            gt = pred.max(1)[1].unsqueeze(1)  # unsupervised setting
        assert gt.shape[0] == img.shape[0]
        loss = self._criterion(pred, class2one_hot(gt.squeeze(1), pred.shape[1]))
        loss.backward()
        assert img.grad is not None, "wrong with img.grad, {}".format(img.grad)

        adv_img, noise = self._adversarial_fgsm(img, img.grad, epsilon=self._eps)

        current_net.zero_grad()
        img.grad.zero_()
        img.requires_grad = False

        return adv_img.detach(), noise.detach()
Beispiel #14
0
    def forward(self, input):
        output1 = torch.tanh(self.hidden1(input))
        output2 = torch.tanh(self.hidden2(output1))
        output3 = torch.tanh(self.hidden3(output2))
        output4 = torch.tanh(self.hidden4(output3))
        output5 = torch.tanh(self.hidden5(output4))
        output6 = torch.tanh(self.hidden6(output5))
        output7 = torch.tanh(self.hidden7(output6))

        output8 = F.softmax(self.output_layer(output7), 1)
        assert simplex(output8)
        return output8, [
            output1, output2, output3, output4, output5, output6, output7,
            output8
        ]
Beispiel #15
0
    def _trainer_specific_loss(
        self,
        images: torch.Tensor,
        images_tf: torch.Tensor,
        pred: List[torch.Tensor],
        pred_tf: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        to override
        :param pred:
        :param pred_tf:
        :return:
        """
        assert simplex(pred[0]), pred
        mi_losses, entropy_losses, centropy_losses = [], [], []
        for subhead_num in range(self.model.arch_dict["num_sub_heads"]):
            _mi_loss, (_entropy_loss,
                       _centropy_loss) = self.criterion(pred[subhead_num])
            mi_losses.append(-_mi_loss)
            entropy_losses.append(_entropy_loss)
            centropy_losses.append(_centropy_loss)
        mi_loss = sum(mi_losses) / len(mi_losses)
        entrop_loss = sum(entropy_losses) / len(entropy_losses)
        centropy_loss = sum(centropy_losses) / len(centropy_losses)

        self.METERINTERFACE["train_mi"].add(-mi_loss.item())
        self.METERINTERFACE["train_entropy"].add(entrop_loss.item())
        self.METERINTERFACE["train_centropy"].add(centropy_loss.item())

        sat_loss = torch.Tensor([0]).to(self.device)
        if self.sat_weight > 0:
            if not self.use_vat:
                # use transformation
                _sat_loss = list(
                    map(lambda p1, p2: self.kl(p2, p1.detach()), pred,
                        pred_tf))
                sat_loss = sum(_sat_loss) / len(_sat_loss)
            else:
                sat_loss, *_ = VATLoss_Multihead(xi=1, eps=10, prop_eps=0.1)(
                    self.model.torchnet, images)

        self.METERINTERFACE["train_sat"].add(sat_loss.item())

        total_loss = mi_loss + self.sat_weight * sat_loss
        return total_loss
Beispiel #16
0
 def test_find_arch4(self):
     net2 = arch.ClusterNet6cTwoHead(**arch.ClusterNet6cTwoHead_Param)
     pred = net2(self.image)
     print(pred.__len__())
     print(pred[0].shape)
     assert simplex(pred[0])
Beispiel #17
0
 def test_find_arch(self):
     net1 = arch.ClusterNet5g(**arch.ClusterNet5g_Param)
     pred = net1(self.image)
     print(pred.__len__())
     print(pred[0].shape)
     assert simplex(pred[0])
 def setUp(self) -> None:
     self.x1 = F.softmax(torch.randn(1, 10), 1)
     self.x2 = F.softmax(torch.randn(1, 10), 1)
     assert simplex(self.x1)
     assert simplex(self.x2)