def test_load_torch(self):
     multi_slice_viewer_debug(self._t1, block=True)  # only one image
     multi_slice_viewer_debug(
         [self._t1, self._t2],
         block=True)  # only two images on pair without given gts
     multi_slice_viewer_debug(self._t1,
                              self._l1,
                              self._l2,
                              self._l3,
                              block=True)  # only one image with 3 gts
     multi_slice_viewer_debug([self._t1, self._t2],
                              self._l1,
                              self._l2,
                              self._l3,
                              block=True)  # two images with 3 shared gt
    def test_semi_dataloader(self):
        iseg_handler = ISeg2017SemiInterface(
            root_dir=self.dataset_root,
            labeled_data_ratio=0.8,
            unlabeled_data_ratio=0.2,
            seed=0,
            verbose=True,
        )
        iseg_handler.compile_dataloader_params(
            batch_size=20, shuffle=True, drop_last=True
        )

        (
            labeled_loader,
            unlabeled_loader,
            val_loader,
        ) = iseg_handler.SemiSupervisedDataLoaders(
            labeled_transform=None,
            unlabeled_transform=None,
            val_transform=None,
            group_labeled=True,
            group_unlabeled=True,
            group_val=True,
        )

        from deepclustering.viewer import multi_slice_viewer_debug

        (T1, T2, Labels), filename = iter(labeled_loader).__next__()
        multi_slice_viewer_debug(T1.squeeze(), Labels.squeeze(), block=False)
        multi_slice_viewer_debug(T2.squeeze(), Labels.squeeze(), block=False)

        (T1, T2, Labels), filename = iter(unlabeled_loader).__next__()
        multi_slice_viewer_debug(
            T1.squeeze(), T2.squeeze(), Labels.squeeze(), block=False
        )

        (T1, T2, Labels), filename = iter(val_loader).__next__()
        multi_slice_viewer_debug(
            T1.squeeze(), T2.squeeze(), Labels.squeeze(), block=False
        )
    def test_gt(self):
        preds = self._network(self._img)
        multi_slice_viewer_debug(
            self._img.transpose(1, 3).transpose(1, 2),
            preds.max(1)[1],
            block=False,
            no_contour=False,
        )
        from copy import deepcopy as dcp

        preds2 = dcp(self._network)(self._img)
        multi_slice_viewer_debug(
            self._img.transpose(1, 3).transpose(1, 2),
            preds2.max(1)[1],
            block=False,
            no_contour=False,
        )
        assert torch.allclose(preds, preds2)

        fsgmGenerator = FSGMGenerator(net=self._network, eplision=0.01)
        adv_img, _ = fsgmGenerator(self._img, gt=None)
        assert adv_img.shape == self._img.shape
        preds_adv = self._network(adv_img)

        multi_slice_viewer_debug(
            self._img.transpose(1, 3).transpose(1, 2),
            preds_adv.max(1)[1],
            block=False,
            no_contour=False,
        )

        rand_pred = self._network(self._img + torch.randn_like(self._img).sign() * 0.01)
        multi_slice_viewer_debug(
            self._img.transpose(1, 3).transpose(1, 2),
            rand_pred.max(1)[1],
            block=True,
            no_contour=False,
        )
    def test_semi_dataloader(self):
        iseg_handler = WMHSemiInterface(
            root_dir=self.dataset_root,
            labeled_data_ratio=0.8,
            unlabeled_data_ratio=0.2,
            seed=0,
            verbose=True,
        )
        iseg_handler.compile_dataloader_params(batch_size=20,
                                               shuffle=True,
                                               drop_last=True)

        (
            labeled_loader,
            unlabeled_loader,
            val_loader,
        ) = iseg_handler.SemiSupervisedDataLoaders(
            group_labeled=False,
            group_unlabeled=True,
            group_val=True,
        )
        from deepclustering.viewer import multi_slice_viewer_debug

        (t1, flair, gt), filename = iter(labeled_loader).__next__()
        multi_slice_viewer_debug(t1.squeeze(),
                                 flair.squeeze(),
                                 gt.squeeze(),
                                 block=False)

        (t1, flair, gt), filename = iter(unlabeled_loader).__next__()
        multi_slice_viewer_debug(t1.squeeze(),
                                 flair.squeeze(),
                                 gt.squeeze(),
                                 block=False)

        (t1, flair, gt), filename = iter(val_loader).__next__()
        multi_slice_viewer_debug(t1.squeeze(),
                                 flair.squeeze(),
                                 gt.squeeze(),
                                 block=True)
    def test_load_numpy(self):
        t1 = torch2numpy(self._t1)
        t2 = torch2numpy(self._t2)
        l1 = torch2numpy(self._l1)
        l2 = torch2numpy(self._l2)
        l3 = torch2numpy(self._l3)

        multi_slice_viewer_debug(t1, block=True)
        multi_slice_viewer_debug([t1, t2], block=True)
        multi_slice_viewer_debug((t1, t2), block=True)
        multi_slice_viewer_debug(t1, l1, l2, l3, block=True)
        multi_slice_viewer_debug([t1, t2], l1, l2, l3, block=True)
        multi_slice_viewer_debug((t1, t2), l1, l2, l3, block=True)
        multi_slice_viewer_debug((t1, t2),
                                 l1,
                                 l2,
                                 l3,
                                 block=True,
                                 no_contour=True)