Пример #1
0
 def before_eval_epoch(self, *args, **kwargs):
     self.tqdm_indicator = tqdm_(range(self._val_batches),
                                 total=self._val_batches)
     self._epoch = kwargs.get("epoch")
     if self._epoch is not None:
         self.tqdm_indicator.set_description(
             f"Evaluating Epoch {self._epoch}")
Пример #2
0
 def _eval_loop(
     self,
     val_loader: DataLoader = None,
     epoch: int = 0,
     mode=ModelMode.EVAL,
     *args,
     **kwargs,
 ) -> float:
     # set model mode
     self._model.set_mode(mode)
     assert self._model.torchnet.training == False, self._model.training
     # set tqdm-based trainer
     _val_loader = tqdm_(val_loader)
     _val_loader.set_description(f"Validating epoch {epoch}: ")
     for batch_id, (imgs, targets) in enumerate(_val_loader):
         imgs, targets = imgs.to(self._device), targets.to(self._device)
         preds = self._model(imgs)
         loss = self.ce_criterion(preds, targets)
         self.METERINTERFACE["val_loss"].add(loss.item())
         self.METERINTERFACE["val_acc"].add(preds.max(1)[1], targets)
         report_dict = self._eval_report_dict
         _val_loader.set_postfix(report_dict)
     print(
         colored(f"Validating epoch {epoch}: {nice_dict(report_dict)}",
                 "green"))
     return self.METERINTERFACE["val_acc"].summary()["acc"]
Пример #3
0
    def _train_loop(
        self,
        train_loader: DataLoader,
        epoch: int,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        super()._train_loop(*args, **kwargs)
        self.model.set_mode(mode)
        assert self.model.training
        _train_loader: tqdm = tqdm_(train_loader)
        for _batch_num, images_labels_indices in enumerate(_train_loader):
            images, labels, *_ = zip(*images_labels_indices)
            tf1_images = torch.cat(tuple(
                [images[0] for _ in range(images.__len__() - 1)]),
                                   dim=0).to(self.device)
            tf2_images = torch.cat(tuple(images[1:]), dim=0).to(self.device)
            pred_tf1_simplex = self.model(tf1_images)
            pred_tf2_simplex = self.model(tf2_images)
            assert simplex(pred_tf1_simplex[0]), pred_tf1_simplex
            assert simplex(pred_tf2_simplex[0]), pred_tf2_simplex
            total_loss = self._trainer_specific_loss(tf1_images, tf2_images,
                                                     pred_tf1_simplex,
                                                     pred_tf2_simplex)
            self.model.zero_grad()
            total_loss.backward()
            self.model.step()
            report_dict = self._training_report_dict
            _train_loader.set_postfix(report_dict)

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"  Training epoch: {epoch} : {report_dict_str}")
Пример #4
0
    def _eval_loop(self,
                   val_loader: DataLoader,
                   epoch: int,
                   mode=ModelMode.EVAL,
                   *args,
                   **kwargs) -> float:
        super(IMSAT_Trainer, self)._eval_loop(*args, **kwargs)
        self.model.set_mode(mode)
        assert not self.model.training
        _val_loader = tqdm_(val_loader)
        preds = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            dtype=torch.long,
            device=self.device,
        )
        probas = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            self.model.arch_dict["output_k"],
            dtype=torch.float,
            device=self.device,
        )
        gts = torch.zeros(val_loader.dataset.__len__(),
                          dtype=torch.long,
                          device=self.device)
        _batch_done = 0
        for _batch_num, images_labels_indices in enumerate(_val_loader):
            images, labels, *_ = zip(*images_labels_indices)
            images, labels = images[0].to(self.device), labels[0].to(
                self.device)
            pred = self.model(images)
            _bSlice = slice(_batch_done, _batch_done + images.shape[0])
            gts[_bSlice] = labels
            for subhead in range(pred.__len__()):
                preds[subhead][_bSlice] = pred[subhead].max(1)[1]
                probas[subhead][_bSlice] = pred[subhead]
            _batch_done += images.shape[0]
        assert _batch_done == val_loader.dataset.__len__(), _batch_done

        # record
        subhead_accs = []
        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=gts,
                preds_k=self.model.arch_dict["output_k"],
                targets_k=self.model.arch_dict["output_k"],
            )
            _acc = flat_acc(reorder_pred, gts)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_average_acc.add(_acc)
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        report_dict = self._eval_report_dict

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"Validating epoch: {epoch} : {report_dict_str}")
        return self.METERINTERFACE.val_best_acc.summary()["mean"]
        def _sup_train_loop(train_loader, epoch):
            self.model.train()
            train_loader_ = tqdm_(train_loader)
            for batch_num, (image_gt) in enumerate(train_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                loss = self.kl(pred, class2one_hot(gt, 10).float())
                self.model.zero_grad()
                loss.backward()
                self.model.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")
    def plot_cluster_average_images(val_loader, soft_pred):
        # assert val_loader.dataset_name == "mnist", \
        #     f"save tsne plot is only implemented for MNIST dataset, given {val_loader.dataset_name}."
        from deepclustering.augment.tensor_augment import Resize
        import warnings
        resize_call = Resize((24, 24), interpolation='bilinear')

        average_images = [torch.zeros(24, 24) for _ in range(10)]

        counter = 0
        for image_labels in tqdm_(val_loader):
            images, gt, *_ = list(zip(*image_labels))
            # only take the tf3 image and gts, put them to self.device
            images, gt = images[0].cuda(), gt[0].cuda()
            for i, img in enumerate(images):
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    img = resize_call(img.unsqueeze(0))
                average_images[soft_pred[counter + i].argmax(
                )] += img.squeeze().cpu() * soft_pred[counter + i].max()

            counter += len(images)
        assert counter == val_loader.dataset.__len__()
        average_images = [
            average_image / (counter / 10) for average_image in average_images
        ]
        return average_images
Пример #7
0
    def _train_loop(
        self,
        train_loader: DataLoader = None,
        epoch: int = 0,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        # set model mode
        self.model.set_mode(mode)
        assert self.model.torchnet.training == True
        # set tqdm-based trainer
        self.METERINTERFACE["lr"].add(self.model.get_lr()[0])
        _train_loader = tqdm_(train_loader)
        _train_loader.set_description(
            f" Training epoch {epoch}: lr={self.METERINTERFACE['lr'].summary()['value']:.5f}"
        )
        for batch_id, (imgs, targets) in enumerate(_train_loader):
            imgs, targets = imgs.to(self.device), targets.to(self.device)
            preds = self.model(imgs)
            loss = self.ce_criterion(preds, targets)
            with ZeroGradientBackwardStep(loss, self.model) as scaled_loss:
                scaled_loss.backward()

            self.METERINTERFACE["train_loss"].add(loss.item())
            self.METERINTERFACE["train_acc"].add(preds.max(1)[1], targets)
            report_dict = self._training_report_dict
            _train_loader.set_postfix(report_dict)
        print(colored(f"  Training epoch {epoch}: {nice_dict(report_dict)}", "red"))
Пример #8
0
    def _train_loop(
        self,
        labeled_loader: DataLoader = None,
        unlabeled_loader: DataLoader = None,
        epoch: int = 0,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        self._model.set_mode(mode)
        _max_iter = tqdm_(range(self.max_iter))
        _max_iter.set_description(f"Training Epoch {epoch}")
        self.METERINTERFACE["lr"].add(self._model.get_lr()[0])
        for batch_num, (lab_img, lab_gt), (unlab_img, unlab_gt) in zip(
                _max_iter, labeled_loader, unlabeled_loader):
            lab_img, lab_gt = lab_img.to(self._device), lab_gt.to(self._device)
            lab_preds = self._model(lab_img)
            sup_loss = self.kl_criterion(
                lab_preds,
                class2one_hot(lab_gt,
                              C=self._model.torchnet.num_classes).float(),
            )
            reg_loss = self._trainer_specific_loss(unlab_img, unlab_gt)
            self.METERINTERFACE["traloss"].add(sup_loss.item())
            self.METERINTERFACE["traconf"].add(lab_preds.max(1)[1], lab_gt)

            with ZeroGradientBackwardStep(sup_loss + reg_loss,
                                          self._model) as total_loss:
                total_loss.backward()
            report_dict = self._training_report_dict
            _max_iter.set_postfix(report_dict)
        print(f"Training Epoch {epoch}: {nice_dict(report_dict)}")
        self.writer.add_scalar_with_tag("train",
                                        report_dict,
                                        global_step=epoch)
Пример #9
0
 def _eval_loop(
     self,
     val_loader: DataLoader = None,
     epoch: int = 0,
     mode=ModelMode.EVAL,
     *args,
     **kwargs,
 ) -> float:
     self._model.set_mode(mode)
     _val_loader = tqdm_(val_loader)
     _val_loader.set_description(f"Validating Epoch {epoch}")
     for batch_num, (val_img, val_gt) in enumerate(_val_loader):
         val_img, val_gt = val_img.to(self._device), val_gt.to(self._device)
         val_preds = self._model(val_img)
         val_loss = self.kl_criterion(
             val_preds,
             class2one_hot(val_gt,
                           C=self._model.torchnet.num_classes).float(),
             disable_assert=True,
         )
         self.METERINTERFACE["valloss"].add(val_loss.item())
         self.METERINTERFACE["valconf"].add(val_preds.max(1)[1], val_gt)
         report_dict = self._eval_report_dict
         _val_loader.set_postfix(report_dict)
     print(f"Validating Epoch {epoch}: {nice_dict(report_dict)}")
     self.writer.add_scalar_with_tag(tag="eval",
                                     tag_scalar_dict=report_dict,
                                     global_step=epoch)
     return self.METERINTERFACE["valconf"].summary()["acc"]
Пример #10
0
 def _eval_loop(
     self,
     val_loader: DataLoader = None,
     epoch: int = 0,
     mode=ModelMode.EVAL,
     *args,
     **kwargs,
 ) -> float:
     self.model.set_mode(mode)
     _val_loader = tqdm_(val_loader)
     _val_loader.set_description(f"Validating Epoch {epoch}")
     for batch_num, ((val_img, val_gt), val_path) in enumerate(_val_loader):
         val_img, val_gt = val_img.to(self.device), val_gt.to(self.device)
         val_preds = self.model(val_img, force_simplex=True)
         val_loss = self.kl_criterion(
             val_preds,
             class2one_hot(val_gt.squeeze(1),
                           C=self.model.arch_dict["num_classes"]).float(),
             disable_assert=True,
         )
         self.METERINTERFACE["valloss"].add(val_loss.item())
         self.METERINTERFACE["valdice"].add(val_preds, val_gt)
         self.METERINTERFACE["valbdice"].add(val_preds, val_gt)
         report_dict = self._eval_report_dict
         _val_loader.set_postfix(report_dict)
     print(f"Validating Epoch {epoch}: {nice_dict(report_dict)}")
     self.writer.add_scalar_with_tag(tag="eval",
                                     tag_scalar_dict=report_dict,
                                     global_step=epoch)
     return self.METERINTERFACE["valbdice"].value()[0][0].item()
    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode: ModelMode = ModelMode.EVAL,
        **kwargs,
    ) -> float:
        self.model.set_mode(mode)
        assert (
            not self.model.training
        ), f"Model should be in eval model in _eval_loop, given {self.model.training}."
        val_loader_: tqdm = tqdm_(val_loader)
        preds = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            dtype=torch.long,
            device=self.device,
        )
        target = torch.zeros(val_loader.dataset.__len__(),
                             dtype=torch.long,
                             device=self.device)
        slice_done = 0
        subhead_accs = []
        val_loader_.set_description(f"Validating epoch: {epoch}")
        for batch, image_labels in enumerate(val_loader_):
            images, gt, *_ = list(zip(*image_labels))
            images, gt = images[0].to(self.device), gt[0].to(self.device)
            _pred = self.model.torchnet(images)
            assert (assert_list(simplex, _pred) and _pred.__len__()
                    == self.model.arch_dict["num_sub_heads"])
            bSlicer = slice(slice_done, slice_done + images.shape[0])
            for subhead in range(self.model.arch_dict["num_sub_heads"]):
                preds[subhead][bSlicer] = _pred[subhead].max(1)[1]
            target[bSlicer] = gt
            slice_done += gt.shape[0]
        assert slice_done == val_loader.dataset.__len__(
        ), "Slice not completed."

        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=target,
                preds_k=self.model.arch_dict["output_k_B"],
                targets_k=self.model.arch_dict["output_k_B"],
            )
            _acc = flat_acc(reorder_pred, target)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_avg_acc.add(_acc)
        # record best acc
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        self.METERINTERFACE.val_worst_acc.add(min(subhead_accs))
        report_dict = self._eval_report_dict

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"Validating epoch: {epoch} : {report_dict_str}")
        return self.METERINTERFACE.val_best_acc.summary()["mean"]
    def _train_loop(self,
                    train_loader=None,
                    epoch=0,
                    mode: ModelMode = ModelMode.TRAIN,
                    **kwargs):
        self.model.set_mode(mode)
        assert (
            self.model.training
        ), f"Model should be in train() model, given {self.model.training}."
        train_loader_: tqdm = tqdm_(train_loader)
        train_loader_.set_description(f"Training epoch: {epoch}")
        for batch, image_labels in enumerate(train_loader_):
            images, _, (index, *_) = list(zip(*image_labels))
            tf1_images = torch.cat(
                [images[0] for _ in range(images.__len__() - 1)],
                dim=0).to(self.device)
            tf2_images = torch.cat(images[1:],
                                   dim=0).to(self.device).to(self.device)
            index = torch.cat([index for _ in range(images.__len__() - 1)],
                              dim=0)

            assert tf1_images.shape == tf2_images.shape
            tf1_pred_logit = self.model.torchnet(tf1_images)
            tf2_pred_logit = self.model.torchnet(tf2_images)
            assert (assert_list(simplex, tf1_pred_logit)
                    and tf1_pred_logit[0].shape == tf2_pred_logit[0].shape)

            sat_losses = []
            ml_losses = []
            for subhead_num, (tf1_pred, tf2_pred) in enumerate(
                    zip(tf1_pred_logit, tf2_pred_logit)):
                sat_loss = self.SAT_criterion(tf2_pred, tf1_pred.detach())
                ml_loss, *_ = self.MI_criterion(tf1_pred)
                # sat_losses.append(sat_loss)
                ml_losses.append(ml_loss)
            ml_losses = sum(ml_losses) / len(ml_losses)
            # sat_losses = sum(sat_losses) / len(sat_losses)

            # VAT_generator = VATLoss_Multihead(eps=self.nearest_dict[index])
            VAT_generator = VATLoss_Multihead(eps=10)
            vat_loss, adv_tf1_images, _ = VAT_generator(
                self.model.torchnet, tf1_images)

            batch_loss: torch.Tensor = vat_loss - 0.1 * ml_losses

            # self.METERINTERFACE["train_sat_loss"].add(sat_losses.item())
            self.METERINTERFACE["train_mi_loss"].add(ml_losses.item())
            self.METERINTERFACE["train_adv_loss"].add(vat_loss.item())
            self.model.zero_grad()
            batch_loss.backward()
            self.model.step()
            report_dict = self._training_report_dict
            train_loader_.set_postfix(report_dict)
 def _linear_eval_loop(val_loader, epoch) -> Tensor:
     val_loader_ = tqdm_(val_loader)
     for batch_num, (feature, gt) in enumerate(val_loader_):
         feature, gt = feature.to(self.device), gt.to(self.device)
         pred = linearnet(feature)
         linear_meters["val_acc"].add(pred.max(1)[1], gt)
         report_dict = {
             "val_acc": linear_meters["val_acc"].summary()["acc"]
         }
         val_loader_.set_postfix(report_dict)
     print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
     return linear_meters["val_acc"].summary()["acc"]
Пример #14
0
 def _train_loop(
     self, train_loader=None, epoch=0, mode=ModelMode.TRAIN, *args, **kwargs
 ):
     self.model.train()
     train_loader_: tqdm = tqdm_(train_loader)
     for batch_num, data in enumerate(train_loader_):
         img, _ = data
         img = img.to(self.device)
         # ===================forward=====================
         output = self.model(img)
         loss = self.criterion(output, img)
         # ===================backward====================
         self.model.zero_grad()
         loss.backward()
         self.model.step()
         self.METERINTERFACE.rec_loss.add(loss.item())
         train_loader_.set_postfix(self._training_report_dict())
        def _linear_train_loop(train_loader, epoch):
            train_loader_ = tqdm_(train_loader)
            for batch_num, (feature, gt) in enumerate(train_loader_):
                feature, gt = feature.to(self.device), gt.to(self.device)
                pred = linearnet(feature)
                loss = self.criterion(pred, gt)
                linearOptim.zero_grad()
                loss.backward()
                linearOptim.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")
Пример #16
0
    def _eval_loop(self,
                   val_loader: DataLoader = None,
                   epoch: int = 0,
                   mode=ModelMode.EVAL,
                   *args,
                   **kwargs) -> float:
        self.model.set_mode(mode)
        assert not self.model.training
        val_loader_: tqdm = tqdm_(val_loader)

        for _batch_num, (img, label) in enumerate(val_loader_):
            img, label = img.to(self.device), label.to(self.device)
            pred, _ = self.model(img)
            self.METERINTERFACE.val_conf.add(pred.max(1)[1], label)
            report_dict = self._eval_report_dict
            val_loader_.set_postfix(report_dict)
        print(f'Validating epoch {epoch}: {nice_dict(report_dict)}')

        return self.METERINTERFACE.val_conf.summary()['acc']
        def _sup_eval_loop(val_loader, epoch) -> Tensor:
            self.model.eval()
            val_loader_ = tqdm_(val_loader)
            for batch_num, (image_gt) in enumerate(val_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]
Пример #18
0
    def training(self):
        x1, y = make_classification(1000, n_features=10, n_informative=5, n_classes=10)
        x1 = torch.from_numpy(x1).cuda().float()
        y = torch.from_numpy(y).cuda().long()
        itera: tqdm = tqdm_(range(100000))
        for i in itera:
            noise = torch.randn_like(x1).cuda()
            x2 = x1 + 0.1 * noise
            p1 = self.model(x1)
            p2 = self.model(x2)
            loss = self._loss_function(x1, p1, x2, p2)
            reordered, _ = hungarian_match(p1.max(1)[1], y, 10, 10)
            print(reordered.unique())
            acc = flat_acc(y, reordered)
            acc2 = flat_acc(y, p1.max(1)[1])

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if i % 10 == 0:
                self.show(p1, p2)
            itera.set_postfix({"loss": loss.item(), "acc": acc, "acc2": acc2})
Пример #19
0
    def _train_loop(self,
                    labeled_loader: DataLoader = None,
                    unlabeled_loader: DataLoader = None,
                    epoch: int = 0,
                    mode=ModelMode.TRAIN,
                    *args,
                    **kwargs):
        super(AdaNetTrainer, self)._train_loop(*args, **kwargs)  # warnings
        self.model.set_mode(mode)
        assert self.model.training

        labeled_loader_ = DataIter(labeled_loader)
        unlabeled_loader_ = DataIter(unlabeled_loader)
        batch_num: tqdm = tqdm_(range(unlabeled_loader.__len__()))

        for _batch_num, ((label_img, label_gt), (unlabel_img, _),
                         _) in enumerate(
                             zip(labeled_loader_, unlabeled_loader_,
                                 batch_num)):
            label_img, label_gt, unlabel_img = label_img.to(self.device), \
                                               label_gt.to(self.device), unlabel_img.to(self.device)

            label_pred, _ = self.model(label_img)
            self.METERINTERFACE.tra_conf.add(label_pred.max(1)[1], label_gt)
            sup_loss = self.ce_loss(label_pred, label_gt.squeeze())
            self.METERINTERFACE.tra_sup_label.add(sup_loss.item())

            reg_loss = self._trainer_specific_loss(label_img, label_gt,
                                                   unlabel_img)
            self.METERINTERFACE.tra_reg_total.add(reg_loss.item())
            with ZeroGradientBackwardStep(sup_loss + reg_loss,
                                          self.model) as loss:
                loss.backward()
            report_dict = self._training_report_dict
            batch_num.set_postfix(report_dict)
        print(f'  Training epoch {epoch}: {nice_dict(report_dict)}')
Пример #20
0
    def _train_loop(
        self,
        train_loader_A: DataLoader = None,
        train_loader_B: DataLoader = None,
        epoch: int = None,
        mode: ModelMode = ModelMode.TRAIN,
        head_control_param: OrderedDict = None,
        *args,
        **kwargs,
    ) -> None:
        """
        :param train_loader_A:
        :param train_loader_B:
        :param epoch:
        :param mode:
        :param head_control_param:
        :param args:
        :param kwargs:
        :return: None
        """
        # robustness asserts
        assert isinstance(train_loader_B, DataLoader) and isinstance(
            train_loader_A, DataLoader)
        assert (head_control_param and head_control_param.__len__() > 0), \
            f"`head_control_param` must be provided, given {head_control_param}."
        assert set(head_control_param.keys()) <= {"A", "B", }, \
            f"`head_control_param` key must be in `A` or `B`, given {set(head_control_param.keys())}"
        for k, v in head_control_param.items():
            assert k in ("A", "B"), (
                f"`head_control_param` key must be in `A` or `B`,"
                f" given{set(head_control_param.keys())}")
            assert isinstance(
                v, int) and v >= 0, f"Iteration for {k} must be >= 0."
        # set training mode
        self.model.set_mode(mode)
        assert (
            self.model.training
        ), f"Model should be in train() model, given {self.model.training}."
        assert len(train_loader_B) == len(train_loader_A), (
            f'The length of the train_loaders should be the same,"'
            f"given `len(train_loader_A)`:{len(train_loader_A)} and `len(train_loader_B)`:{len(train_loader_B)}."
        )

        for head_name, head_iterations in head_control_param.items():
            assert head_name in ("A", "B"), head_name
            train_loader = eval(f"train_loader_{head_name}"
                                )  # change the dataset for different head
            for head_epoch in range(head_iterations):
                # given one head, one iteration in this head, and one train_loader.
                train_loader_: tqdm = tqdm_(
                    train_loader)  # reinitialize the train_loader
                train_loader_.set_description(
                    f"Training epoch: {epoch} head:{head_name}, head_epoch:{head_epoch + 1}/{head_iterations}"
                )
                for batch, image_labels in enumerate(train_loader_):
                    images, *_ = list(zip(*image_labels))
                    # extract tf1_images, tf2_images and put then to self.device
                    tf1_images = torch.cat(tuple(
                        [images[0] for _ in range(len(images) - 1)]),
                                           dim=0).to(self.device)
                    tf2_images = torch.cat(tuple(images[1:]),
                                           dim=0).to(self.device)
                    assert tf1_images.shape == tf2_images.shape, f"`tf1_images` should have the same size as `tf2_images`," \
                        f"given {tf1_images.shape} and {tf2_images.shape}."
                    # if images are processed with sobel filters
                    if self.use_sobel:
                        tf1_images = self.sobel(tf1_images)
                        tf2_images = self.sobel(tf2_images)
                        assert tf1_images.shape == tf2_images.shape
                    # Here you have two kinds of geometric transformations
                    # todo: functions to be overwritten
                    batch_loss = self._trainer_specific_loss(
                        tf1_images, tf2_images, head_name)
                    # update model with self-defined context manager support Apex module
                    with ZeroGradientBackwardStep(batch_loss,
                                                  self.model) as loss:
                        loss.backward()
                    # write value to tqdm module for system monitoring
                    report_dict = self._training_report_dict
                    train_loader_.set_postfix(report_dict)
        # for tensorboard recording
        self.writer.add_scalar_with_tag("train", report_dict, epoch)
        # for std recording
        print(f"Training epoch: {epoch} : {nice_dict(report_dict)}")
Пример #21
0
    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode: ModelMode = ModelMode.EVAL,
        return_soft_predict=False,
        *args,
        **kwargs,
    ) -> float:
        assert isinstance(
            val_loader, DataLoader)  # make sure a validation loader is passed.
        self.model.set_mode(mode)  # set model to be eval mode, by default.
        # make sure the model is in eval mode.
        assert (
            not self.model.training
        ), f"Model should be in eval model in _eval_loop, given {self.model.training}."
        val_loader_: tqdm = tqdm_(val_loader)
        # prediction initialization with shape: (num_sub_heads, num_samples)
        preds = torch.zeros(self.model.arch_dict["num_sub_heads"],
                            val_loader.dataset.__len__(),
                            dtype=torch.long,
                            device=self.device)
        # soft_prediction initialization with shape (num_sub_heads, num_sample, num_classes)
        if return_soft_predict:
            soft_preds = torch.zeros(
                self.model.arch_dict["num_sub_heads"],
                val_loader.dataset.__len__(),
                self.model.arch_dict["output_k_B"],
                dtype=torch.float,
                device=torch.device("cpu"))  # I put it into cpu
        # target initialization with shape: (num_samples)
        target = torch.zeros(val_loader.dataset.__len__(),
                             dtype=torch.long,
                             device=self.device)
        # begin index
        slice_done = 0
        subhead_accs = []
        val_loader_.set_description(f"Validating epoch: {epoch}")
        for batch, image_labels in enumerate(val_loader_):
            images, gt, *_ = list(zip(*image_labels))
            # only take the tf3 image and gts, put them to self.device
            images, gt = images[0].to(self.device), gt[0].to(self.device)
            # if use sobel filter
            if self.use_sobel:
                images = self.sobel(images)
            # using default head_B for inference, _pred should be a list of simplex by default.
            _pred = self.model.torchnet(images, head="B")
            assert assert_list(simplex,
                               _pred), "pred should be a list of simplexes."
            assert _pred.__len__() == self.model.arch_dict["num_sub_heads"]
            # slice window definition
            bSlicer = slice(slice_done, slice_done + images.shape[0])
            for subhead in range(self.model.arch_dict["num_sub_heads"]):
                # save predictions for each subhead for each batch
                preds[subhead][bSlicer] = _pred[subhead].max(1)[1]
                if return_soft_predict:
                    soft_preds[subhead][bSlicer] = _pred[subhead]
            # save target for each batch
            target[bSlicer] = gt
            # update slice index
            slice_done += gt.shape[0]
        # make sure that all the dataset has been done. Errors will raise if dataloader.drop_last=True
        assert slice_done == val_loader.dataset.__len__(
        ), "Slice not completed."
        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            # remap pred for each head and compare with target to get subhead_acc
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=target,
                preds_k=self.model.arch_dict["output_k_B"],
                targets_k=self.model.arch_dict["output_k_B"],
            )
            _acc = flat_acc(reorder_pred, target)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_average_acc.add(_acc)

            if return_soft_predict:
                soft_preds[subhead][:, list(remap.values(
                ))] = soft_preds[subhead][:, list(remap.keys())]
                assert torch.allclose(soft_preds[subhead].max(1)[1],
                                      reorder_pred.cpu())

        # record best acc
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        # record worst acc
        self.METERINTERFACE.val_worst_acc.add(min(subhead_accs))
        report_dict = self._eval_report_dict
        # record results for std
        print(f"Validating epoch: {epoch} : {nice_dict(report_dict)}")
        # record results for tensorboard
        self.writer.add_scalar_with_tag("val", report_dict, epoch)
        # using multithreads to call histogram interface of tensorboard.
        pred_histgram(self.writer, preds, epoch=epoch)
        # return the current score to save the best checkpoint.
        if return_soft_predict:
            return self.METERINTERFACE.val_best_acc.summary()["mean"], (
                target.cpu(), soft_preds[np.argmax(subhead_accs)]
            )  # type ignore

        return self.METERINTERFACE.val_best_acc.summary()["mean"]