def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() self._teacher_model.train() assert self._model.training, self._model.training assert self._teacher_model.training, self._teacher_model.training self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) self.meters["reg_weight"].add(self._reg_weight) report_dict: EpochResultDict with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: for i, label_data, all_data in zip(indicator, self._labeled_loader, self._tra_loader): (labelimage, labeltarget), _, filename, partition_list, group_list \ = self._preprocess_data(label_data, self._device) (unlabelimage, _), *_ = self._preprocess_data(label_data, self._device) seed = random.randint(0, int(1e6)) with FixRandomSeed(seed): unlabelimage_tf = torch.stack([self._transformer(x) for x in unlabelimage], dim=0) assert unlabelimage_tf.shape == unlabelimage.shape student_logits = self._model(torch.cat([labelimage, unlabelimage_tf], dim=0)) if simplex(student_logits): raise RuntimeError("output of the model should be logits, instead of simplex") student_sup_logits, student_unlabel_logits_tf = student_logits[:len(labelimage)], \ student_logits[len(labelimage):] with torch.no_grad(): teacher_unlabel_logits = self._teacher_model(unlabelimage) with FixRandomSeed(seed): teacher_unlabel_logits_tf = torch.stack([self._transformer(x) for x in teacher_unlabel_logits]) assert teacher_unlabel_logits.shape == teacher_unlabel_logits_tf.shape # calcul the loss onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) sup_loss = self._sup_criterion(student_sup_logits.softmax(1), onehot_ltarget) reg_loss = self._reg_criterion(student_unlabel_logits_tf.softmax(1), teacher_unlabel_logits_tf.detach().softmax(1)) total_loss = sup_loss + self._reg_weight * reg_loss self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # update ema self._ema_updater(ema_model=self._teacher_model, student_model=self._model) with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["reg_loss"].add(reg_loss.item()) self.meters["ds"].add(student_sup_logits.max(1)[1], labeltarget.squeeze(1), group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_decoder_loader): (img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) seed = random.randint(0, int(1e5)) with FixRandomSeed(seed): img_gtf = torch.stack([self._transformer(x) for x in img], dim=0) assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape) _, *features = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True) dn = self._feature_extractor(features)[0] dn_gtf, dn_ctf = torch.chunk(dn, chunks=2, dim=0) with FixRandomSeed(seed): dn_ctf_gtf = torch.stack([self._transformer(x) for x in dn_ctf], dim=0) assert dn_ctf_gtf.shape == dn_ctf.shape, (dn_ctf_gtf.shape, dn_ctf.shape) dn_tf = torch.cat([dn_gtf, dn_ctf_gtf]) local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(dn_tf), chunks=2, dim=0) # todo: convert representation to distance local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2)) local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2)) b, *_ = local_enc_unfold.shape local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1) local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1) labels = self._label_generation(partition_list, group_list, _fold_partition) contrastive_loss = self._contrastive_criterion( torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1), labels=labels ) if torch.isnan(contrastive_loss): raise RuntimeError(contrastive_loss) self._optimizer.zero_grad() contrastive_loss.backward() self._optimizer.step() # todo: meter recording. with torch.no_grad(): self.meters["contrastive_loss"].add(contrastive_loss.item()) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) return report_dict
def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, global_seed=None, **kwargs): global_seed = int(random.randint(0, int(1e5))) if global_seed is None else int(global_seed) # type ignore with FixRandomSeed(global_seed): comm_seed1, comm_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) img_seed1, img_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) target_seed1, target_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) if self._total_freedom: return [ super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1), super().__call__(imgs, targets, comm_seed2, img_seed2, target_seed2), ] return [ super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1), super().__call__(imgs, targets, comm_seed1, img_seed2, target_seed1), ]
def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, comm_seed=None, img_seed=None, target_seed=None): _comm_seed: int = int(random.randint(0, int(1e5))) if comm_seed is None else int(comm_seed) # type ignore imgs_after_comm, targets_after_comm = imgs, targets if self._comm_transform: imgs_after_comm, targets_after_comm = [], [] for img in imgs: with FixRandomSeed(_comm_seed): img_ = self._comm_transform(img) imgs_after_comm.append(img_) if targets: for target in targets: with FixRandomSeed(_comm_seed): target_ = self._comm_transform(target) targets_after_comm.append(target_) imgs_after_img_transform = [] targets_after_target_transform = [] _img_seed: int = int(random.randint(0, int(1e5))) if img_seed is None else int(img_seed) # type ignore for img in imgs_after_comm: with FixRandomSeed(_img_seed): img_ = self._img_transform(img) imgs_after_img_transform.append(img_) _target_seed: int = int(random.randint(0, int(1e5))) if target_seed is None else int(target_seed) # type ignore if targets_after_comm: for target in targets_after_comm: with FixRandomSeed(_target_seed): target_ = self._target_transform(target) targets_after_target_transform.append(target_) if targets is None: targets_after_target_transform = None if targets_after_target_transform is None: return imgs_after_img_transform return [*imgs_after_img_transform, *targets_after_target_transform]
def _draw_indices( targets: np.ndarray, labeled_sample_num: int, class_nums: int = 10, validation_num: int = 5000, verbose: bool = True, seed: int = 1, ) -> Tuple[List[int], List[int], List[int]]: """ draw indices for labeled and unlabeled dataset separations. :param targets: `torch.utils.data.Dataset.targets`-like numpy ndarray with all labels, used to split into labeled, unlabeled and validation dataset. :param labeled_sample_num: labeled sample number :param class_nums: num of classes in the target. :param validation_num: num of validation set, usually we split the big training set into `labeled`, `unlabeled`, `validation` sets, the `test` set is taken directly from the big test set. :param verbose: whether to print information while running. :param seed: random seed to draw indices :return: labeled indices and unlabeled indices """ labeled_sample_per_class = int(labeled_sample_num / class_nums) validation_sample_per_class = int(validation_num / class_nums) if class_nums else 0 targets = np.array(targets) train_labeled_idxs: List[int] = [] train_unlabeled_idxs: List[int] = [] val_idxs: List[int] = [] with FixRandomSeed(seed): for i in range(class_nums): idxs = np.where(targets == i)[0] np.random.shuffle(idxs) train_labeled_idxs.extend(idxs[:labeled_sample_per_class]) train_unlabeled_idxs.extend( idxs[labeled_sample_per_class:-validation_sample_per_class] ) val_idxs.extend(idxs[-validation_sample_per_class:]) np.random.shuffle(train_labeled_idxs) np.random.shuffle(val_idxs) # highlight: this is to meet the UDA paper: unlabeled data is the true unlabeled_data + labeled_data, and there is no val_data # train_unlabeled_idxs = train_labeled_idxs + train_unlabeled_idxs + val_idxs # highlight: this leads to bad performance, using unlabeled = unlabeled + val train_unlabeled_idxs = train_unlabeled_idxs + val_idxs np.random.shuffle(train_unlabeled_idxs) # assert train_unlabeled_idxs.__len__() == len(targets) assert len(train_labeled_idxs) == labeled_sample_num if verbose: print( f">>>Generating {len(train_labeled_idxs)} labeled data, {len(train_unlabeled_idxs)} unlabeled data, and {len(val_idxs)} validation data." ) return train_labeled_idxs, train_unlabeled_idxs, val_idxs
def regularization(self, unlabeled_tf_logits: Tensor, unlabeled_logits_tf: Tensor, seed: int, *args, **kwargs): # todo: adding projectors here. feature_names = self._fextractor._feature_names # noqa unlabeled_length = len(unlabeled_tf_logits) * 2 iic_losses_for_features = [] for i, (inter_feature, projector, criterion) \ in enumerate(zip(self._fextractor, self._projectors_wrapper, self._IIDSegCriterionWrapper)): unlabeled_features = inter_feature[len(inter_feature) - unlabeled_length:] unlabeled_features, unlabeled_tf_features = torch.chunk( unlabeled_features, 2, dim=0) if isinstance(projector, ClusterHead): # features from encoder unlabeled_features_tf = unlabeled_features else: with FixRandomSeed(seed): unlabeled_features_tf = torch.stack([ self._affine_transformer(x) for x in unlabeled_features ], dim=0) assert unlabeled_tf_features.shape == unlabeled_tf_features.shape, \ (unlabeled_tf_features.shape, unlabeled_tf_features.shape) prob1, prob2 = list( zip(*[ torch.chunk(x, 2, 0) for x in projector( torch.cat( [unlabeled_features_tf, unlabeled_tf_features], dim=0)) ])) _iic_loss_list = [criterion(x, y) for x, y in zip(prob1, prob2)] _iic_loss = average_iter(_iic_loss_list) iic_losses_for_features.append(_iic_loss) reg_loss = weighted_average_iter(iic_losses_for_features, self._feature_importance) self.meters["mi"].add(-reg_loss.item()) self.meters["individual_mis"].add(**dict( zip(self._feature_position, [-x.item() for x in iic_losses_for_features]))) return reg_loss
def _run(self, *args, **kwargs) -> EpochResultDict: self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) self._model.train() assert self._model.training, self._model.training report_dict = {} with FeatureExtractor(self._model, self._feature_position) as self._fextractor: for i, labeled_data, unlabeled_data in zip(self._indicator, self._labeled_loader, self._unlabeled_loader): seed = random.randint(0, int(1e7)) labeled_image, labeled_target, labeled_filename, _, label_group = \ self._unzip_data(labeled_data, self._device) unlabeled_image, unlabeled_target, *_ = self._unzip_data( unlabeled_data, self._device) with FixRandomSeed(seed): unlabeled_image_tf = torch.stack( [self._affine_transformer(x) for x in unlabeled_image], dim=0) assert unlabeled_image_tf.shape == unlabeled_image.shape, \ (unlabeled_image_tf.shape, unlabeled_image.shape) predict_logits = self._model( torch.cat( [labeled_image, unlabeled_image, unlabeled_image_tf], dim=0)) label_logits, unlabel_logits, unlabel_tf_logits = \ torch.split( predict_logits, [len(labeled_image), len(unlabeled_image), len(unlabeled_image_tf)], dim=0 ) with FixRandomSeed(seed): unlabel_logits_tf = torch.stack( [self._affine_transformer(x) for x in unlabel_logits], dim=0) assert unlabel_logits_tf.shape == unlabel_tf_logits.shape, \ (unlabel_logits_tf.shape, unlabel_tf_logits.shape) # supervised part onehot_target = class2one_hot(labeled_target.squeeze(1), self.num_classes) sup_loss = self._sup_criterion(label_logits.softmax(1), onehot_target) # regularized part reg_loss = self.regularization( unlabeled_tf_logits=unlabel_tf_logits, unlabeled_logits_tf=unlabel_logits_tf, seed=seed, unlabeled_image=unlabeled_image, unlabeled_image_tf=unlabeled_image_tf, ) total_loss = sup_loss + self._reg_weight * reg_loss # gradient backpropagation self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # recording can be here or in the regularization method with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["sup_dice"].add(label_logits.max(1)[1], labeled_target.squeeze(1), group_name=label_group) self.meters["reg_loss"].add(reg_loss.item()) report_dict = self.meters.tracking_status() self._indicator.set_postfix_dict(report_dict) return report_dict
def _create_semi_supervised_datasets( self, labeled_transform: SequentialWrapper = None, unlabeled_transform: SequentialWrapper = None, val_transform: SequentialWrapper = None, ) -> Tuple[MedicalImageSegmentationDataset, MedicalImageSegmentationDataset, MedicalImageSegmentationDataset, ]: train_set = self.DataClass( root_dir=self.root_dir, modality=self.modality, mode="train", subfolders=["img", "gt"], transforms=None, verbose=self.verbose, ) val_set = self.DataClass( root_dir=self.root_dir, modality=self.modality, mode="val", subfolders=["img", "gt"], transforms=None, verbose=self.verbose, ) if self.labeled_ratio == 1 or self.unlabeled_ratio == 1: import warnings warnings.warn( f"given self.labeled_ratio == 1 or self.unlabeled_ratio == 1, {self.__class__.__name__} returns " f"train_set as the labeled and unlabeled datasets", UserWarning, ) labeled_set = train_set unlabeled_set = deepcopy(train_set) if labeled_transform: labeled_set.set_transform(labeled_transform) if unlabeled_transform: unlabeled_set.set_transform(unlabeled_transform) if val_transform: val_set.set_transform(val_transform) return labeled_set, unlabeled_set, val_set with FixRandomSeed(random_seed=self.seed): shuffled_patients = train_set.get_group_list()[:] random.shuffle(shuffled_patients) labeled_patients, unlabeled_patients = ( shuffled_patients[:int( len(shuffled_patients) * self.labeled_ratio)], shuffled_patients[-int( math.ceil(len(shuffled_patients) * self.unlabeled_ratio)):], ) labeled_set = SubMedicalDatasetBasedOnIndex(train_set, labeled_patients) unlabeled_set = SubMedicalDatasetBasedOnIndex(train_set, unlabeled_patients) assert len(labeled_set) + len(unlabeled_set) == len( train_set), "wrong on labeled/unlabeled split." del train_set if labeled_transform: labeled_set.set_transform(labeled_transform) if unlabeled_transform: unlabeled_set.set_transform(unlabeled_transform) if val_transform: val_set.set_transform(val_transform) return labeled_set, unlabeled_set, val_set