def _regulaze(self, images: Tensor, tf_images: Tensor, img_pred_simplex: List[Tensor],
               head_name: str = "B") -> Tensor:
     gaussian_reg = super()._regulaze(images, tf_images, img_pred_simplex, head_name)
     tf_pred_simplex = self.model(tf_images)
     # dimension check
     assert assert_list(simplex, tf_pred_simplex)
     assert assert_list(simplex, img_pred_simplex)
     assert len(tf_pred_simplex) == len(img_pred_simplex)
     geo_reg = self._geo_regularization(img_pred_simplex, tf_pred_simplex)
     self.METERINTERFACE["train_geo"].add(geo_reg.item())
     return gaussian_reg + geo_reg
 def _trainer_specific_loss(
         self, tf1_images: Tensor, tf2_images: Tensor, head_name: str
 ) -> Tensor:
     """
     MI+Reg implementation of MI loss on tf1_images. Reg is going to be overrided by children modules
     :param tf1_images: basic transformed images with device = self.device
     :param tf2_images: advanced transformed image with device = self.device
     :param head_name: head name for model inference
     :return: loss tensor to call .backward()
     """
     assert (head_name == "B"), "Only head B is supported in IMSAT, try to set head_control_parameter as {`B`:1}"
     # only tf1_images are needed
     tf1_pred_simplex = self.model.torchnet(tf1_images, head=head_name)
     assert assert_list(simplex, tf1_pred_simplex), "Prediction must be a list of simplexes."
     batch_loss: List[torch.Tensor] = []  # type: ignore
     entropies: List[torch.Tensor] = []
     centropies: List[torch.Tensor] = []
     for pred in tf1_pred_simplex:
         mi, (entropy, centropy) = self.criterion(pred)
         batch_loss.append(mi)
         entropies.append(entropy)
         centropies.append(centropy)
     # MI object function to be maximized.
     batch_loss: Tensor = sum(batch_loss) / len(batch_loss)  # type: ignore
     entropies: Tensor = sum(entropies) / len(entropies)  # type: ignore
     centropies: Tensor = sum(centropies) / len(centropies)  # type: ignore
     self.METERINTERFACE["train_mi"].add(batch_loss.item())  # type: ignore
     self.METERINTERFACE["train_entropy"].add(entropies.item())  # type: ignore
     self.METERINTERFACE["train_centropy"].add(centropies.item())  # type: ignore
     # add regularizations such as VAT, Mixup, GEO or more.
     reg_loss = self._regulaze(tf1_images, tf2_images, tf1_pred_simplex, head_name)
     # decrease the importance of MI, based on the IMSAT chainer implementation.
     return -batch_loss * 0.1 + reg_loss
    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode: ModelMode = ModelMode.EVAL,
        **kwargs,
    ) -> float:
        self.model.set_mode(mode)
        assert (
            not self.model.training
        ), f"Model should be in eval model in _eval_loop, given {self.model.training}."
        val_loader_: tqdm = tqdm_(val_loader)
        preds = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            dtype=torch.long,
            device=self.device,
        )
        target = torch.zeros(val_loader.dataset.__len__(),
                             dtype=torch.long,
                             device=self.device)
        slice_done = 0
        subhead_accs = []
        val_loader_.set_description(f"Validating epoch: {epoch}")
        for batch, image_labels in enumerate(val_loader_):
            images, gt, *_ = list(zip(*image_labels))
            images, gt = images[0].to(self.device), gt[0].to(self.device)
            _pred = self.model.torchnet(images)
            assert (assert_list(simplex, _pred) and _pred.__len__()
                    == self.model.arch_dict["num_sub_heads"])
            bSlicer = slice(slice_done, slice_done + images.shape[0])
            for subhead in range(self.model.arch_dict["num_sub_heads"]):
                preds[subhead][bSlicer] = _pred[subhead].max(1)[1]
            target[bSlicer] = gt
            slice_done += gt.shape[0]
        assert slice_done == val_loader.dataset.__len__(
        ), "Slice not completed."

        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=target,
                preds_k=self.model.arch_dict["output_k_B"],
                targets_k=self.model.arch_dict["output_k_B"],
            )
            _acc = flat_acc(reorder_pred, target)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_avg_acc.add(_acc)
        # record best acc
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        self.METERINTERFACE.val_worst_acc.add(min(subhead_accs))
        report_dict = self._eval_report_dict

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"Validating epoch: {epoch} : {report_dict_str}")
        return self.METERINTERFACE.val_best_acc.summary()["mean"]
示例#4
0
 def _geo_regularization(self, tf1_pred_simplex,
                         tf2_pred_simplex) -> Tensor:
     """
     :param tf1_pred_simplex: basic
     :param tf2_pred_simplex: advanced
     :return:
     """
     assert (assert_list(simplex, tf1_pred_simplex)
             and assert_list(simplex, tf2_pred_simplex)
             and tf1_pred_simplex.__len__() == tf2_pred_simplex.__len__()
             ), f"Error on tf1 and tf2 predictions."
     _batch_loss: List[torch.Tensor] = []  # type: ignore
     for subhead in range(tf1_pred_simplex.__len__()):
         _loss = self.kl_div(tf2_pred_simplex[subhead],
                             tf1_pred_simplex[subhead].detach())
         _batch_loss.append(_loss)
     batch_loss: torch.Tensor = sum(_batch_loss) / len(
         _batch_loss)  # type:ignore
     return batch_loss
    def _train_loop(self,
                    train_loader=None,
                    epoch=0,
                    mode: ModelMode = ModelMode.TRAIN,
                    **kwargs):
        self.model.set_mode(mode)
        assert (
            self.model.training
        ), f"Model should be in train() model, given {self.model.training}."
        train_loader_: tqdm = tqdm_(train_loader)
        train_loader_.set_description(f"Training epoch: {epoch}")
        for batch, image_labels in enumerate(train_loader_):
            images, _, (index, *_) = list(zip(*image_labels))
            tf1_images = torch.cat(
                [images[0] for _ in range(images.__len__() - 1)],
                dim=0).to(self.device)
            tf2_images = torch.cat(images[1:],
                                   dim=0).to(self.device).to(self.device)
            index = torch.cat([index for _ in range(images.__len__() - 1)],
                              dim=0)

            assert tf1_images.shape == tf2_images.shape
            tf1_pred_logit = self.model.torchnet(tf1_images)
            tf2_pred_logit = self.model.torchnet(tf2_images)
            assert (assert_list(simplex, tf1_pred_logit)
                    and tf1_pred_logit[0].shape == tf2_pred_logit[0].shape)

            sat_losses = []
            ml_losses = []
            for subhead_num, (tf1_pred, tf2_pred) in enumerate(
                    zip(tf1_pred_logit, tf2_pred_logit)):
                sat_loss = self.SAT_criterion(tf2_pred, tf1_pred.detach())
                ml_loss, *_ = self.MI_criterion(tf1_pred)
                # sat_losses.append(sat_loss)
                ml_losses.append(ml_loss)
            ml_losses = sum(ml_losses) / len(ml_losses)
            # sat_losses = sum(sat_losses) / len(sat_losses)

            # VAT_generator = VATLoss_Multihead(eps=self.nearest_dict[index])
            VAT_generator = VATLoss_Multihead(eps=10)
            vat_loss, adv_tf1_images, _ = VAT_generator(
                self.model.torchnet, tf1_images)

            batch_loss: torch.Tensor = vat_loss - 0.1 * ml_losses

            # self.METERINTERFACE["train_sat_loss"].add(sat_losses.item())
            self.METERINTERFACE["train_mi_loss"].add(ml_losses.item())
            self.METERINTERFACE["train_adv_loss"].add(vat_loss.item())
            self.model.zero_grad()
            batch_loss.backward()
            self.model.step()
            report_dict = self._training_report_dict
            train_loader_.set_postfix(report_dict)
    def _regulaze(
            self,
            images: Tensor,
            tf_images: Tensor,
            img_pred_simplex: List[Tensor],
            head_name="B",
    ) -> Tensor:
        vat_loss = super()._regulaze(images, tf_images, img_pred_simplex, head_name)
        tf_img_pred_simplex = self.model.torchnet(tf_images, head=head_name)
        assert_list(simplex, tf_img_pred_simplex)

        # IICloss
        batch_loss: List[torch.Tensor] = []  # type: ignore
        for subhead in range(img_pred_simplex.__len__()):
            _loss, _loss_no_lambda = self.IIC_loss(
                img_pred_simplex[subhead], tf_img_pred_simplex[subhead]
            )
            batch_loss.append(_loss)
        batch_loss: torch.Tensor = sum(batch_loss) / len(batch_loss)  # type:ignore
        self.METERINTERFACE[f"train_head_{head_name}"].add(-batch_loss.item())  # type: ignore
        return batch_loss + vat_loss
示例#7
0
 def _gaussian_regularization(self,
                              model: Model,
                              tf1_images,
                              tf1_pred_simplex: List[Tensor],
                              head_name="B") -> Tensor:
     """
     calculate predicton simplexes on gaussian noise tf1 images and the kl div of the original prediction simplex.
     :param tf1_images: tf1-transformed images
     :param tf1_pred_simplex: simplex list of tf1-transformed image prediction
     :return:  loss
     """
     _tf1_images_gaussian = self.gaussian_adder(tf1_images)
     _tf1_gaussian_simplex = model.torchnet(_tf1_images_gaussian,
                                            head=head_name)
     assert assert_list(simplex, tf1_pred_simplex)
     assert assert_list(simplex, _tf1_gaussian_simplex)
     assert tf1_pred_simplex.__len__() == _tf1_gaussian_simplex.__len__()
     reg_loss = []
     for __tf1_simplex, __tf1_gaussian_simplex in zip(
             tf1_pred_simplex, _tf1_gaussian_simplex):
         reg_loss.append(
             self.kl_div(__tf1_gaussian_simplex, __tf1_simplex.detach()))
     return sum(reg_loss) / len(reg_loss)  # type: ignore
 def _regulaze(
         self,
         images: Tensor,
         tf_images: Tensor,
         img_pred_simplex: List[Tensor],
         head_name: str = "B",
 ) -> Tensor:
     # advanced transformed images
     tf_pred_simplex = self.model.torchnet(tf_images, head=head_name)
     assert assert_list(simplex, tf_pred_simplex) and len(tf_pred_simplex) == len(img_pred_simplex)
     geo_loss = self._geo_regularization(img_pred_simplex, tf_pred_simplex)
     self.METERINTERFACE["train_geo"].add(geo_loss.item())
     # the regularization for the two are 1:1 by default for the sake for simplification.
     return geo_loss
示例#9
0
 def __init__(
     self,
     root_dir: str,
     mode: str,
     subfolders: List[str],
     transforms: SequentialWrapper = None,
     patient_pattern: str = None,
     verbose=True,
 ) -> None:
     """
     :param root_dir: main folder path of the dataset
     :param mode: the subfolder name of this root, usually train, val, test or etc.
     :param subfolders: subsubfolder name of this root, usually img, gt, etc
     :param transforms: synchronized transformation for all the subfolders
     :param verbose: verbose
     """
     assert (len(subfolders) == set(subfolders).__len__()
             ), f"subfolders must be unique, given {subfolders}."
     assert assert_list(
         lambda x: isinstance(x, str), subfolders
     ), f"`subfolder` elements should be str, given {subfolders}"
     self._name: str = f"{mode}_dataset"
     self._mode: str = mode
     self._root_dir = root_dir
     self._subfolders: List[str] = subfolders
     self._transform = default_transform(self._subfolders)
     if transforms:
         self._transform = transforms
     self._verbose = verbose
     if self._verbose:
         print(f"->> Building {self._name}:\t")
     self._filenames = self._make_dataset(self._root_dir,
                                          self._mode,
                                          self._subfolders,
                                          verbose=verbose)
     self._debug = os.environ.get("PYDEBUG", "0") == "1"
     self._set_patient_pattern(patient_pattern)
示例#10
0
    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode: ModelMode = ModelMode.EVAL,
        return_soft_predict=False,
        *args,
        **kwargs,
    ) -> float:
        assert isinstance(
            val_loader, DataLoader)  # make sure a validation loader is passed.
        self.model.set_mode(mode)  # set model to be eval mode, by default.
        # make sure the model is in eval mode.
        assert (
            not self.model.training
        ), f"Model should be in eval model in _eval_loop, given {self.model.training}."
        val_loader_: tqdm = tqdm_(val_loader)
        # prediction initialization with shape: (num_sub_heads, num_samples)
        preds = torch.zeros(self.model.arch_dict["num_sub_heads"],
                            val_loader.dataset.__len__(),
                            dtype=torch.long,
                            device=self.device)
        # soft_prediction initialization with shape (num_sub_heads, num_sample, num_classes)
        if return_soft_predict:
            soft_preds = torch.zeros(
                self.model.arch_dict["num_sub_heads"],
                val_loader.dataset.__len__(),
                self.model.arch_dict["output_k_B"],
                dtype=torch.float,
                device=torch.device("cpu"))  # I put it into cpu
        # target initialization with shape: (num_samples)
        target = torch.zeros(val_loader.dataset.__len__(),
                             dtype=torch.long,
                             device=self.device)
        # begin index
        slice_done = 0
        subhead_accs = []
        val_loader_.set_description(f"Validating epoch: {epoch}")
        for batch, image_labels in enumerate(val_loader_):
            images, gt, *_ = list(zip(*image_labels))
            # only take the tf3 image and gts, put them to self.device
            images, gt = images[0].to(self.device), gt[0].to(self.device)
            # if use sobel filter
            if self.use_sobel:
                images = self.sobel(images)
            # using default head_B for inference, _pred should be a list of simplex by default.
            _pred = self.model.torchnet(images, head="B")
            assert assert_list(simplex,
                               _pred), "pred should be a list of simplexes."
            assert _pred.__len__() == self.model.arch_dict["num_sub_heads"]
            # slice window definition
            bSlicer = slice(slice_done, slice_done + images.shape[0])
            for subhead in range(self.model.arch_dict["num_sub_heads"]):
                # save predictions for each subhead for each batch
                preds[subhead][bSlicer] = _pred[subhead].max(1)[1]
                if return_soft_predict:
                    soft_preds[subhead][bSlicer] = _pred[subhead]
            # save target for each batch
            target[bSlicer] = gt
            # update slice index
            slice_done += gt.shape[0]
        # make sure that all the dataset has been done. Errors will raise if dataloader.drop_last=True
        assert slice_done == val_loader.dataset.__len__(
        ), "Slice not completed."
        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            # remap pred for each head and compare with target to get subhead_acc
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=target,
                preds_k=self.model.arch_dict["output_k_B"],
                targets_k=self.model.arch_dict["output_k_B"],
            )
            _acc = flat_acc(reorder_pred, target)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_average_acc.add(_acc)

            if return_soft_predict:
                soft_preds[subhead][:, list(remap.values(
                ))] = soft_preds[subhead][:, list(remap.keys())]
                assert torch.allclose(soft_preds[subhead].max(1)[1],
                                      reorder_pred.cpu())

        # record best acc
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        # record worst acc
        self.METERINTERFACE.val_worst_acc.add(min(subhead_accs))
        report_dict = self._eval_report_dict
        # record results for std
        print(f"Validating epoch: {epoch} : {nice_dict(report_dict)}")
        # record results for tensorboard
        self.writer.add_scalar_with_tag("val", report_dict, epoch)
        # using multithreads to call histogram interface of tensorboard.
        pred_histgram(self.writer, preds, epoch=epoch)
        # return the current score to save the best checkpoint.
        if return_soft_predict:
            return self.METERINTERFACE.val_best_acc.summary()["mean"], (
                target.cpu(), soft_preds[np.argmax(subhead_accs)]
            )  # type ignore

        return self.METERINTERFACE.val_best_acc.summary()["mean"]