示例#1
0
 def _convert2onehot(self, pred: Tensor, target: Tensor):
     # only two possibility: both onehot or both class-coded.
     assert pred.shape == target.shape
     # if they are onehot-coded:
     if simplex(pred, 1) and one_hot(target):
         return probs2one_hot(pred).long(), target.long()
     # here the pred and target are labeled long
     return (
         class2one_hot(pred, self._C).long(),
         class2one_hot(target, self._C).long(),
     )
示例#2
0
    def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]:
        self._model.eval()
        assert self._model.training is False, self._model.training
        for i, val_data in zip(self._indicator, self._val_loader):
            val_img, val_target, file_path, _, group = self._unzip_data(
                val_data, self._device)
            val_logits = self._model(val_img)
            # write image
            write_img_target(val_img, val_target, self._save_dir, file_path)
            write_predict(
                val_logits,
                self._save_dir,
                file_path,
            )
            onehot_target = class2one_hot(val_target.squeeze(1),
                                          self.num_classes)

            val_loss = self._sup_criterion(val_logits.softmax(1),
                                           onehot_target,
                                           disable_assert=True)

            self.meters["loss"].add(val_loss.item())
            self.meters["dice"].add(val_logits.max(1)[1],
                                    val_target.squeeze(1),
                                    group_name=group)
            with ExceptionIgnorer(RuntimeError):
                self.meters["hd"].add(
                    val_logits.max(1)[1], val_target.squeeze(1))
            report_dict = self.meters.tracking_status()
            self._indicator.set_postfix_dict(report_dict)
        return report_dict, self.meters["dice"].summary()["DSC_mean"]
    def _run(self, *args, **kwargs) -> EpochResultDict:
        self._model.train()
        assert self._model.training, self._model.training
        report_dict: EpochResultDict
        self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])
        with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator:
            for i, label_data in zip(indicator, self._labeled_loader):
                (labelimage, labeltarget), _, filename, partition_list, group_list \
                    = self._preprocess_data(label_data, self._device)
                predict_logits = self._model(labelimage)
                assert not simplex(predict_logits), predict_logits

                onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4)
                sup_loss = self._sup_criterion(predict_logits.softmax(1), onehot_ltarget)

                self._optimizer.zero_grad()
                sup_loss.backward()
                self._optimizer.step()

                with torch.no_grad():
                    self.meters["sup_loss"].add(sup_loss.item())
                    self.meters["ds"].add(predict_logits.max(1)[1], labeltarget.squeeze(1),
                                          group_name=list(group_list))
                    report_dict = self.meters.tracking_status()
                    indicator.set_postfix_dict(report_dict)
            report_dict = self.meters.tracking_status()
        return report_dict
    def _run(self, *args, **kwargs) -> EpochResultDict:
        self._model.train()
        self._teacher_model.train()
        assert self._model.training, self._model.training
        assert self._teacher_model.training, self._teacher_model.training
        self.meters["lr"].add(self._optimizer.param_groups[0]["lr"])
        self.meters["reg_weight"].add(self._reg_weight)
        report_dict: EpochResultDict

        with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator:
            for i, label_data, all_data in zip(indicator, self._labeled_loader, self._tra_loader):
                (labelimage, labeltarget), _, filename, partition_list, group_list \
                    = self._preprocess_data(label_data, self._device)
                (unlabelimage, _), *_ = self._preprocess_data(label_data, self._device)

                seed = random.randint(0, int(1e6))
                with FixRandomSeed(seed):
                    unlabelimage_tf = torch.stack([self._transformer(x) for x in unlabelimage], dim=0)
                assert unlabelimage_tf.shape == unlabelimage.shape

                student_logits = self._model(torch.cat([labelimage, unlabelimage_tf], dim=0))
                if simplex(student_logits):
                    raise RuntimeError("output of the model should be logits, instead of simplex")
                student_sup_logits, student_unlabel_logits_tf = student_logits[:len(labelimage)], \
                                                                student_logits[len(labelimage):]

                with torch.no_grad():
                    teacher_unlabel_logits = self._teacher_model(unlabelimage)
                with FixRandomSeed(seed):
                    teacher_unlabel_logits_tf = torch.stack([self._transformer(x) for x in teacher_unlabel_logits])
                assert teacher_unlabel_logits.shape == teacher_unlabel_logits_tf.shape

                # calcul the loss
                onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4)
                sup_loss = self._sup_criterion(student_sup_logits.softmax(1), onehot_ltarget)

                reg_loss = self._reg_criterion(student_unlabel_logits_tf.softmax(1),
                                               teacher_unlabel_logits_tf.detach().softmax(1))
                total_loss = sup_loss + self._reg_weight * reg_loss

                self._optimizer.zero_grad()
                total_loss.backward()
                self._optimizer.step()

                # update ema
                self._ema_updater(ema_model=self._teacher_model, student_model=self._model)

                with torch.no_grad():
                    self.meters["sup_loss"].add(sup_loss.item())
                    self.meters["reg_loss"].add(reg_loss.item())
                    self.meters["ds"].add(student_sup_logits.max(1)[1], labeltarget.squeeze(1),
                                          group_name=list(group_list))
                    report_dict = self.meters.tracking_status()
                    indicator.set_postfix_dict(report_dict)
            report_dict = self.meters.tracking_status()
        return report_dict
def toOneHot(pred_logit, mask):
    """
    :param pred_logit: logit with b,c, h, w. it is fine to pass simplex prediction or onehot.
    :param mask: gt mask with b,h,w
    :return: onehot presentation of prediction and mask, pred.shape == mask.shape == b,c, h , w
    """
    oh_predmask = probs2one_hot(F.softmax(pred_logit, 1))
    oh_mask = class2one_hot(mask.squeeze(1), C=pred_logit.shape[1])
    assert oh_predmask.shape == oh_mask.shape
    return oh_predmask, oh_mask
 def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]:
     self._model.eval()
     assert not self._model.training, self._model.training
     with tqdm(self._val_loader).set_desc_from_epocher(self) as indicator:
         for i, data in enumerate(indicator):
             images, targets, filename, partiton_list, group_list = self._preprocess_data(data, self._device)
             predict_logits = self._model(images)
             assert not simplex(predict_logits), predict_logits.shape
             onehot_targets = class2one_hot(targets.squeeze(1), 4)
             loss = self._sup_criterion(predict_logits.softmax(1), onehot_targets, disable_assert=True)
             self.meters["sup_loss"].add(loss.item())
             self.meters["ds"].add(predict_logits.max(1)[1], targets.squeeze(1), group_name=list(group_list))
             report_dict = self.meters.tracking_status()
             indicator.set_postfix_dict(report_dict)
     report_dict = self.meters.tracking_status()
     return report_dict, report_dict["ds"]["DSC_mean"]
示例#7
0
    def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]:
        self._model.eval()
        report_dict = EpochResultDict()
        for i, val_data in zip(self._indicator, self._val_loader):
            val_img, val_target, file_path, _, group = self._unzip_data(
                val_data, self._device)
            val_logits = self._model(val_img)
            onehot_target = class2one_hot(val_target.squeeze(1),
                                          self.num_classes)

            val_loss = self._sup_criterion(val_logits.softmax(1),
                                           onehot_target,
                                           disable_assert=True)

            self.meters["loss"].add(val_loss.item())
            self.meters["dice"].add(val_logits.max(1)[1],
                                    val_target.squeeze(1),
                                    group_name=group)
            report_dict = self.meters.tracking_status()
            self._indicator.set_postfix_dict(report_dict)
        return report_dict, self.meters["dice"].summary()["DSC_mean"]
示例#8
0
    def _run(self, *args, **kwargs) -> EpochResultDict:
        self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])
        self._model.train()
        assert self._model.training, self._model.training
        report_dict = {}
        with FeatureExtractor(self._model,
                              self._feature_position) as self._fextractor:
            for i, labeled_data, unlabeled_data in zip(self._indicator,
                                                       self._labeled_loader,
                                                       self._unlabeled_loader):
                seed = random.randint(0, int(1e7))
                labeled_image, labeled_target, labeled_filename, _, label_group = \
                    self._unzip_data(labeled_data, self._device)
                unlabeled_image, unlabeled_target, *_ = self._unzip_data(
                    unlabeled_data, self._device)
                with FixRandomSeed(seed):
                    unlabeled_image_tf = torch.stack(
                        [self._affine_transformer(x) for x in unlabeled_image],
                        dim=0)
                assert unlabeled_image_tf.shape == unlabeled_image.shape, \
                    (unlabeled_image_tf.shape, unlabeled_image.shape)

                predict_logits = self._model(
                    torch.cat(
                        [labeled_image, unlabeled_image, unlabeled_image_tf],
                        dim=0))
                label_logits, unlabel_logits, unlabel_tf_logits = \
                    torch.split(
                        predict_logits,
                        [len(labeled_image), len(unlabeled_image), len(unlabeled_image_tf)],
                        dim=0
                    )
                with FixRandomSeed(seed):
                    unlabel_logits_tf = torch.stack(
                        [self._affine_transformer(x) for x in unlabel_logits],
                        dim=0)
                assert unlabel_logits_tf.shape == unlabel_tf_logits.shape, \
                    (unlabel_logits_tf.shape, unlabel_tf_logits.shape)
                # supervised part
                onehot_target = class2one_hot(labeled_target.squeeze(1),
                                              self.num_classes)
                sup_loss = self._sup_criterion(label_logits.softmax(1),
                                               onehot_target)
                # regularized part
                reg_loss = self.regularization(
                    unlabeled_tf_logits=unlabel_tf_logits,
                    unlabeled_logits_tf=unlabel_logits_tf,
                    seed=seed,
                    unlabeled_image=unlabeled_image,
                    unlabeled_image_tf=unlabeled_image_tf,
                )
                total_loss = sup_loss + self._reg_weight * reg_loss
                # gradient backpropagation
                self._optimizer.zero_grad()
                total_loss.backward()
                self._optimizer.step()
                # recording can be here or in the regularization method
                with torch.no_grad():
                    self.meters["sup_loss"].add(sup_loss.item())
                    self.meters["sup_dice"].add(label_logits.max(1)[1],
                                                labeled_target.squeeze(1),
                                                group_name=label_group)
                    self.meters["reg_loss"].add(reg_loss.item())
                    report_dict = self.meters.tracking_status()
                    self._indicator.set_postfix_dict(report_dict)
        return report_dict