def _run(self, *args, **kwargs) -> EpochResultDict:
        self._model.train()
        self._teacher_model.train()
        assert self._model.training, self._model.training
        assert self._teacher_model.training, self._teacher_model.training
        self.meters["lr"].add(self._optimizer.param_groups[0]["lr"])
        self.meters["reg_weight"].add(self._reg_weight)
        report_dict: EpochResultDict

        with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator:
            for i, label_data, all_data in zip(indicator, self._labeled_loader, self._tra_loader):
                (labelimage, labeltarget), _, filename, partition_list, group_list \
                    = self._preprocess_data(label_data, self._device)
                (unlabelimage, _), *_ = self._preprocess_data(label_data, self._device)

                seed = random.randint(0, int(1e6))
                with FixRandomSeed(seed):
                    unlabelimage_tf = torch.stack([self._transformer(x) for x in unlabelimage], dim=0)
                assert unlabelimage_tf.shape == unlabelimage.shape

                student_logits = self._model(torch.cat([labelimage, unlabelimage_tf], dim=0))
                if simplex(student_logits):
                    raise RuntimeError("output of the model should be logits, instead of simplex")
                student_sup_logits, student_unlabel_logits_tf = student_logits[:len(labelimage)], \
                                                                student_logits[len(labelimage):]

                with torch.no_grad():
                    teacher_unlabel_logits = self._teacher_model(unlabelimage)
                with FixRandomSeed(seed):
                    teacher_unlabel_logits_tf = torch.stack([self._transformer(x) for x in teacher_unlabel_logits])
                assert teacher_unlabel_logits.shape == teacher_unlabel_logits_tf.shape

                # calcul the loss
                onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4)
                sup_loss = self._sup_criterion(student_sup_logits.softmax(1), onehot_ltarget)

                reg_loss = self._reg_criterion(student_unlabel_logits_tf.softmax(1),
                                               teacher_unlabel_logits_tf.detach().softmax(1))
                total_loss = sup_loss + self._reg_weight * reg_loss

                self._optimizer.zero_grad()
                total_loss.backward()
                self._optimizer.step()

                # update ema
                self._ema_updater(ema_model=self._teacher_model, student_model=self._model)

                with torch.no_grad():
                    self.meters["sup_loss"].add(sup_loss.item())
                    self.meters["reg_loss"].add(reg_loss.item())
                    self.meters["ds"].add(student_sup_logits.max(1)[1], labeltarget.squeeze(1),
                                          group_name=list(group_list))
                    report_dict = self.meters.tracking_status()
                    indicator.set_postfix_dict(report_dict)
            report_dict = self.meters.tracking_status()
        return report_dict
示例#2
0
    def _run(self, *args, **kwargs) -> EpochResultDict:
        self._model.train()
        assert self._model.training, self._model.training
        self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])

        with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator:  # noqa
            for i, data in zip(indicator, self._pretrain_decoder_loader):
                (img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device)
                seed = random.randint(0, int(1e5))
                with FixRandomSeed(seed):
                    img_gtf = torch.stack([self._transformer(x) for x in img], dim=0)
                assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape)
                _, *features = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True)
                dn = self._feature_extractor(features)[0]
                dn_gtf, dn_ctf = torch.chunk(dn, chunks=2, dim=0)
                with FixRandomSeed(seed):
                    dn_ctf_gtf = torch.stack([self._transformer(x) for x in dn_ctf], dim=0)
                assert dn_ctf_gtf.shape == dn_ctf.shape, (dn_ctf_gtf.shape, dn_ctf.shape)
                dn_tf = torch.cat([dn_gtf, dn_ctf_gtf])
                local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(dn_tf), chunks=2, dim=0)
                # todo: convert representation to distance
                local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2))
                local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2))
                b, *_ = local_enc_unfold.shape
                local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1)
                local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1)

                labels = self._label_generation(partition_list, group_list, _fold_partition)
                contrastive_loss = self._contrastive_criterion(
                    torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1),
                    labels=labels
                )
                if torch.isnan(contrastive_loss):
                    raise RuntimeError(contrastive_loss)
                self._optimizer.zero_grad()
                contrastive_loss.backward()
                self._optimizer.step()
                # todo: meter recording.
                with torch.no_grad():
                    self.meters["contrastive_loss"].add(contrastive_loss.item())
                    report_dict = self.meters.tracking_status()
                    indicator.set_postfix_dict(report_dict)
        return report_dict
 def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, global_seed=None, **kwargs):
     global_seed = int(random.randint(0, int(1e5))) if global_seed is None else int(global_seed)  # type ignore
     with FixRandomSeed(global_seed):
         comm_seed1, comm_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
         img_seed1, img_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
         target_seed1, target_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
         if self._total_freedom:
             return [
                 super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1),
                 super().__call__(imgs, targets, comm_seed2, img_seed2, target_seed2),
             ]
         return [
             super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1),
             super().__call__(imgs, targets, comm_seed1, img_seed2, target_seed1),
         ]
    def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, comm_seed=None, img_seed=None,
                 target_seed=None):
        _comm_seed: int = int(random.randint(0, int(1e5))) if comm_seed is None else int(comm_seed)  # type ignore
        imgs_after_comm, targets_after_comm = imgs, targets
        if self._comm_transform:
            imgs_after_comm, targets_after_comm = [], []
            for img in imgs:
                with FixRandomSeed(_comm_seed):
                    img_ = self._comm_transform(img)
                    imgs_after_comm.append(img_)
            if targets:
                for target in targets:
                    with FixRandomSeed(_comm_seed):
                        target_ = self._comm_transform(target)
                        targets_after_comm.append(target_)
        imgs_after_img_transform = []
        targets_after_target_transform = []
        _img_seed: int = int(random.randint(0, int(1e5))) if img_seed is None else int(img_seed)  # type ignore
        for img in imgs_after_comm:
            with FixRandomSeed(_img_seed):
                img_ = self._img_transform(img)
                imgs_after_img_transform.append(img_)

        _target_seed: int = int(random.randint(0, int(1e5))) if target_seed is None else int(target_seed)  # type ignore
        if targets_after_comm:
            for target in targets_after_comm:
                with FixRandomSeed(_target_seed):
                    target_ = self._target_transform(target)
                    targets_after_target_transform.append(target_)

        if targets is None:
            targets_after_target_transform = None

        if targets_after_target_transform is None:
            return imgs_after_img_transform
        return [*imgs_after_img_transform, *targets_after_target_transform]
示例#5
0
def _draw_indices(
    targets: np.ndarray,
    labeled_sample_num: int,
    class_nums: int = 10,
    validation_num: int = 5000,
    verbose: bool = True,
    seed: int = 1,
) -> Tuple[List[int], List[int], List[int]]:
    """
    draw indices for labeled and unlabeled dataset separations.
    :param targets: `torch.utils.data.Dataset.targets`-like numpy ndarray with all labels, used to split into labeled, unlabeled and validation dataset.
    :param labeled_sample_num: labeled sample number
    :param class_nums: num of classes in the target.
    :param validation_num: num of validation set, usually we split the big training set into `labeled`, `unlabeled`, `validation` sets, the `test` set is taken directly from the big test set.
    :param verbose: whether to print information while running.
    :param seed: random seed to draw indices
    :return: labeled indices and unlabeled indices
    """
    labeled_sample_per_class = int(labeled_sample_num / class_nums)
    validation_sample_per_class = int(validation_num / class_nums) if class_nums else 0
    targets = np.array(targets)
    train_labeled_idxs: List[int] = []
    train_unlabeled_idxs: List[int] = []
    val_idxs: List[int] = []
    with FixRandomSeed(seed):
        for i in range(class_nums):
            idxs = np.where(targets == i)[0]
            np.random.shuffle(idxs)
            train_labeled_idxs.extend(idxs[:labeled_sample_per_class])
            train_unlabeled_idxs.extend(
                idxs[labeled_sample_per_class:-validation_sample_per_class]
            )
            val_idxs.extend(idxs[-validation_sample_per_class:])
        np.random.shuffle(train_labeled_idxs)
        np.random.shuffle(val_idxs)

    #  highlight: this is to meet the UDA paper: unlabeled data is the true unlabeled_data + labeled_data, and there is no val_data
    # train_unlabeled_idxs = train_labeled_idxs + train_unlabeled_idxs + val_idxs
    # highlight: this leads to bad performance, using unlabeled = unlabeled + val
    train_unlabeled_idxs = train_unlabeled_idxs + val_idxs
    np.random.shuffle(train_unlabeled_idxs)
    # assert train_unlabeled_idxs.__len__() == len(targets)
    assert len(train_labeled_idxs) == labeled_sample_num
    if verbose:
        print(
            f">>>Generating {len(train_labeled_idxs)} labeled data, {len(train_unlabeled_idxs)} unlabeled data, and {len(val_idxs)} validation data."
        )
    return train_labeled_idxs, train_unlabeled_idxs, val_idxs
示例#6
0
    def regularization(self, unlabeled_tf_logits: Tensor,
                       unlabeled_logits_tf: Tensor, seed: int, *args,
                       **kwargs):
        # todo: adding projectors here.
        feature_names = self._fextractor._feature_names  # noqa
        unlabeled_length = len(unlabeled_tf_logits) * 2
        iic_losses_for_features = []

        for i, (inter_feature, projector, criterion) \
            in enumerate(zip(self._fextractor, self._projectors_wrapper, self._IIDSegCriterionWrapper)):

            unlabeled_features = inter_feature[len(inter_feature) -
                                               unlabeled_length:]
            unlabeled_features, unlabeled_tf_features = torch.chunk(
                unlabeled_features, 2, dim=0)

            if isinstance(projector, ClusterHead):  # features from encoder
                unlabeled_features_tf = unlabeled_features
            else:
                with FixRandomSeed(seed):
                    unlabeled_features_tf = torch.stack([
                        self._affine_transformer(x) for x in unlabeled_features
                    ],
                                                        dim=0)
            assert unlabeled_tf_features.shape == unlabeled_tf_features.shape, \
                (unlabeled_tf_features.shape, unlabeled_tf_features.shape)
            prob1, prob2 = list(
                zip(*[
                    torch.chunk(x, 2, 0) for x in projector(
                        torch.cat(
                            [unlabeled_features_tf, unlabeled_tf_features],
                            dim=0))
                ]))
            _iic_loss_list = [criterion(x, y) for x, y in zip(prob1, prob2)]
            _iic_loss = average_iter(_iic_loss_list)
            iic_losses_for_features.append(_iic_loss)
        reg_loss = weighted_average_iter(iic_losses_for_features,
                                         self._feature_importance)
        self.meters["mi"].add(-reg_loss.item())
        self.meters["individual_mis"].add(**dict(
            zip(self._feature_position,
                [-x.item() for x in iic_losses_for_features])))

        return reg_loss
示例#7
0
    def _run(self, *args, **kwargs) -> EpochResultDict:
        self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])
        self._model.train()
        assert self._model.training, self._model.training
        report_dict = {}
        with FeatureExtractor(self._model,
                              self._feature_position) as self._fextractor:
            for i, labeled_data, unlabeled_data in zip(self._indicator,
                                                       self._labeled_loader,
                                                       self._unlabeled_loader):
                seed = random.randint(0, int(1e7))
                labeled_image, labeled_target, labeled_filename, _, label_group = \
                    self._unzip_data(labeled_data, self._device)
                unlabeled_image, unlabeled_target, *_ = self._unzip_data(
                    unlabeled_data, self._device)
                with FixRandomSeed(seed):
                    unlabeled_image_tf = torch.stack(
                        [self._affine_transformer(x) for x in unlabeled_image],
                        dim=0)
                assert unlabeled_image_tf.shape == unlabeled_image.shape, \
                    (unlabeled_image_tf.shape, unlabeled_image.shape)

                predict_logits = self._model(
                    torch.cat(
                        [labeled_image, unlabeled_image, unlabeled_image_tf],
                        dim=0))
                label_logits, unlabel_logits, unlabel_tf_logits = \
                    torch.split(
                        predict_logits,
                        [len(labeled_image), len(unlabeled_image), len(unlabeled_image_tf)],
                        dim=0
                    )
                with FixRandomSeed(seed):
                    unlabel_logits_tf = torch.stack(
                        [self._affine_transformer(x) for x in unlabel_logits],
                        dim=0)
                assert unlabel_logits_tf.shape == unlabel_tf_logits.shape, \
                    (unlabel_logits_tf.shape, unlabel_tf_logits.shape)
                # supervised part
                onehot_target = class2one_hot(labeled_target.squeeze(1),
                                              self.num_classes)
                sup_loss = self._sup_criterion(label_logits.softmax(1),
                                               onehot_target)
                # regularized part
                reg_loss = self.regularization(
                    unlabeled_tf_logits=unlabel_tf_logits,
                    unlabeled_logits_tf=unlabel_logits_tf,
                    seed=seed,
                    unlabeled_image=unlabeled_image,
                    unlabeled_image_tf=unlabeled_image_tf,
                )
                total_loss = sup_loss + self._reg_weight * reg_loss
                # gradient backpropagation
                self._optimizer.zero_grad()
                total_loss.backward()
                self._optimizer.step()
                # recording can be here or in the regularization method
                with torch.no_grad():
                    self.meters["sup_loss"].add(sup_loss.item())
                    self.meters["sup_dice"].add(label_logits.max(1)[1],
                                                labeled_target.squeeze(1),
                                                group_name=label_group)
                    self.meters["reg_loss"].add(reg_loss.item())
                    report_dict = self.meters.tracking_status()
                    self._indicator.set_postfix_dict(report_dict)
        return report_dict
示例#8
0
    def _create_semi_supervised_datasets(
        self,
        labeled_transform: SequentialWrapper = None,
        unlabeled_transform: SequentialWrapper = None,
        val_transform: SequentialWrapper = None,
    ) -> Tuple[MedicalImageSegmentationDataset,
               MedicalImageSegmentationDataset,
               MedicalImageSegmentationDataset, ]:
        train_set = self.DataClass(
            root_dir=self.root_dir,
            modality=self.modality,
            mode="train",
            subfolders=["img", "gt"],
            transforms=None,
            verbose=self.verbose,
        )
        val_set = self.DataClass(
            root_dir=self.root_dir,
            modality=self.modality,
            mode="val",
            subfolders=["img", "gt"],
            transforms=None,
            verbose=self.verbose,
        )
        if self.labeled_ratio == 1 or self.unlabeled_ratio == 1:
            import warnings

            warnings.warn(
                f"given self.labeled_ratio == 1 or self.unlabeled_ratio == 1, {self.__class__.__name__} returns "
                f"train_set as the labeled and unlabeled datasets",
                UserWarning,
            )
            labeled_set = train_set
            unlabeled_set = deepcopy(train_set)
            if labeled_transform:
                labeled_set.set_transform(labeled_transform)
            if unlabeled_transform:
                unlabeled_set.set_transform(unlabeled_transform)
            if val_transform:
                val_set.set_transform(val_transform)
            return labeled_set, unlabeled_set, val_set
        with FixRandomSeed(random_seed=self.seed):
            shuffled_patients = train_set.get_group_list()[:]
            random.shuffle(shuffled_patients)
            labeled_patients, unlabeled_patients = (
                shuffled_patients[:int(
                    len(shuffled_patients) * self.labeled_ratio)],
                shuffled_patients[-int(
                    math.ceil(len(shuffled_patients) *
                              self.unlabeled_ratio)):],
            )

        labeled_set = SubMedicalDatasetBasedOnIndex(train_set,
                                                    labeled_patients)
        unlabeled_set = SubMedicalDatasetBasedOnIndex(train_set,
                                                      unlabeled_patients)
        assert len(labeled_set) + len(unlabeled_set) == len(
            train_set), "wrong on labeled/unlabeled split."
        del train_set
        if labeled_transform:
            labeled_set.set_transform(labeled_transform)
        if unlabeled_transform:
            unlabeled_set.set_transform(unlabeled_transform)
        if val_transform:
            val_set.set_transform(val_transform)
        return labeled_set, unlabeled_set, val_set