def _convert2onehot(self, pred: Tensor, target: Tensor): # only two possibility: both onehot or both class-coded. assert pred.shape == target.shape # if they are onehot-coded: if simplex(pred, 1) and one_hot(target): return probs2one_hot(pred).long(), target.long() # here the pred and target are labeled long return ( class2one_hot(pred, self._C).long(), class2one_hot(target, self._C).long(), )
def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]: self._model.eval() assert self._model.training is False, self._model.training for i, val_data in zip(self._indicator, self._val_loader): val_img, val_target, file_path, _, group = self._unzip_data( val_data, self._device) val_logits = self._model(val_img) # write image write_img_target(val_img, val_target, self._save_dir, file_path) write_predict( val_logits, self._save_dir, file_path, ) onehot_target = class2one_hot(val_target.squeeze(1), self.num_classes) val_loss = self._sup_criterion(val_logits.softmax(1), onehot_target, disable_assert=True) self.meters["loss"].add(val_loss.item()) self.meters["dice"].add(val_logits.max(1)[1], val_target.squeeze(1), group_name=group) with ExceptionIgnorer(RuntimeError): self.meters["hd"].add( val_logits.max(1)[1], val_target.squeeze(1)) report_dict = self.meters.tracking_status() self._indicator.set_postfix_dict(report_dict) return report_dict, self.meters["dice"].summary()["DSC_mean"]
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training report_dict: EpochResultDict self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: for i, label_data in zip(indicator, self._labeled_loader): (labelimage, labeltarget), _, filename, partition_list, group_list \ = self._preprocess_data(label_data, self._device) predict_logits = self._model(labelimage) assert not simplex(predict_logits), predict_logits onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) sup_loss = self._sup_criterion(predict_logits.softmax(1), onehot_ltarget) self._optimizer.zero_grad() sup_loss.backward() self._optimizer.step() with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["ds"].add(predict_logits.max(1)[1], labeltarget.squeeze(1), group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() return report_dict
def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() self._teacher_model.train() assert self._model.training, self._model.training assert self._teacher_model.training, self._teacher_model.training self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) self.meters["reg_weight"].add(self._reg_weight) report_dict: EpochResultDict with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: for i, label_data, all_data in zip(indicator, self._labeled_loader, self._tra_loader): (labelimage, labeltarget), _, filename, partition_list, group_list \ = self._preprocess_data(label_data, self._device) (unlabelimage, _), *_ = self._preprocess_data(label_data, self._device) seed = random.randint(0, int(1e6)) with FixRandomSeed(seed): unlabelimage_tf = torch.stack([self._transformer(x) for x in unlabelimage], dim=0) assert unlabelimage_tf.shape == unlabelimage.shape student_logits = self._model(torch.cat([labelimage, unlabelimage_tf], dim=0)) if simplex(student_logits): raise RuntimeError("output of the model should be logits, instead of simplex") student_sup_logits, student_unlabel_logits_tf = student_logits[:len(labelimage)], \ student_logits[len(labelimage):] with torch.no_grad(): teacher_unlabel_logits = self._teacher_model(unlabelimage) with FixRandomSeed(seed): teacher_unlabel_logits_tf = torch.stack([self._transformer(x) for x in teacher_unlabel_logits]) assert teacher_unlabel_logits.shape == teacher_unlabel_logits_tf.shape # calcul the loss onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) sup_loss = self._sup_criterion(student_sup_logits.softmax(1), onehot_ltarget) reg_loss = self._reg_criterion(student_unlabel_logits_tf.softmax(1), teacher_unlabel_logits_tf.detach().softmax(1)) total_loss = sup_loss + self._reg_weight * reg_loss self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # update ema self._ema_updater(ema_model=self._teacher_model, student_model=self._model) with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["reg_loss"].add(reg_loss.item()) self.meters["ds"].add(student_sup_logits.max(1)[1], labeltarget.squeeze(1), group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() return report_dict
def toOneHot(pred_logit, mask): """ :param pred_logit: logit with b,c, h, w. it is fine to pass simplex prediction or onehot. :param mask: gt mask with b,h,w :return: onehot presentation of prediction and mask, pred.shape == mask.shape == b,c, h , w """ oh_predmask = probs2one_hot(F.softmax(pred_logit, 1)) oh_mask = class2one_hot(mask.squeeze(1), C=pred_logit.shape[1]) assert oh_predmask.shape == oh_mask.shape return oh_predmask, oh_mask
def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]: self._model.eval() assert not self._model.training, self._model.training with tqdm(self._val_loader).set_desc_from_epocher(self) as indicator: for i, data in enumerate(indicator): images, targets, filename, partiton_list, group_list = self._preprocess_data(data, self._device) predict_logits = self._model(images) assert not simplex(predict_logits), predict_logits.shape onehot_targets = class2one_hot(targets.squeeze(1), 4) loss = self._sup_criterion(predict_logits.softmax(1), onehot_targets, disable_assert=True) self.meters["sup_loss"].add(loss.item()) self.meters["ds"].add(predict_logits.max(1)[1], targets.squeeze(1), group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() return report_dict, report_dict["ds"]["DSC_mean"]
def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]: self._model.eval() report_dict = EpochResultDict() for i, val_data in zip(self._indicator, self._val_loader): val_img, val_target, file_path, _, group = self._unzip_data( val_data, self._device) val_logits = self._model(val_img) onehot_target = class2one_hot(val_target.squeeze(1), self.num_classes) val_loss = self._sup_criterion(val_logits.softmax(1), onehot_target, disable_assert=True) self.meters["loss"].add(val_loss.item()) self.meters["dice"].add(val_logits.max(1)[1], val_target.squeeze(1), group_name=group) report_dict = self.meters.tracking_status() self._indicator.set_postfix_dict(report_dict) return report_dict, self.meters["dice"].summary()["DSC_mean"]
def _run(self, *args, **kwargs) -> EpochResultDict: self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) self._model.train() assert self._model.training, self._model.training report_dict = {} with FeatureExtractor(self._model, self._feature_position) as self._fextractor: for i, labeled_data, unlabeled_data in zip(self._indicator, self._labeled_loader, self._unlabeled_loader): seed = random.randint(0, int(1e7)) labeled_image, labeled_target, labeled_filename, _, label_group = \ self._unzip_data(labeled_data, self._device) unlabeled_image, unlabeled_target, *_ = self._unzip_data( unlabeled_data, self._device) with FixRandomSeed(seed): unlabeled_image_tf = torch.stack( [self._affine_transformer(x) for x in unlabeled_image], dim=0) assert unlabeled_image_tf.shape == unlabeled_image.shape, \ (unlabeled_image_tf.shape, unlabeled_image.shape) predict_logits = self._model( torch.cat( [labeled_image, unlabeled_image, unlabeled_image_tf], dim=0)) label_logits, unlabel_logits, unlabel_tf_logits = \ torch.split( predict_logits, [len(labeled_image), len(unlabeled_image), len(unlabeled_image_tf)], dim=0 ) with FixRandomSeed(seed): unlabel_logits_tf = torch.stack( [self._affine_transformer(x) for x in unlabel_logits], dim=0) assert unlabel_logits_tf.shape == unlabel_tf_logits.shape, \ (unlabel_logits_tf.shape, unlabel_tf_logits.shape) # supervised part onehot_target = class2one_hot(labeled_target.squeeze(1), self.num_classes) sup_loss = self._sup_criterion(label_logits.softmax(1), onehot_target) # regularized part reg_loss = self.regularization( unlabeled_tf_logits=unlabel_tf_logits, unlabeled_logits_tf=unlabel_logits_tf, seed=seed, unlabeled_image=unlabeled_image, unlabeled_image_tf=unlabeled_image_tf, ) total_loss = sup_loss + self._reg_weight * reg_loss # gradient backpropagation self._optimizer.zero_grad() total_loss.backward() self._optimizer.step() # recording can be here or in the regularization method with torch.no_grad(): self.meters["sup_loss"].add(sup_loss.item()) self.meters["sup_dice"].add(label_logits.max(1)[1], labeled_target.squeeze(1), group_name=label_group) self.meters["reg_loss"].add(reg_loss.item()) report_dict = self.meters.tracking_status() self._indicator.set_postfix_dict(report_dict) return report_dict