def __call__(self, probs: Tensor, target: Tensor) -> Tensor: assert simplex(probs) assert simplex(target) assert probs.shape == target.shape B, K, *xyz = probs.shape # type: ignore pc = cast(Tensor, probs[:, self.idc, ...].type(torch.float32)) tc = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) assert pc.shape == tc.shape == (B, len(self.idc), *xyz) target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy()) for b in range(B)], axis=0) assert target_dm_npy.shape == tc.shape == pc.shape tdm: Tensor = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32) pred_segmentation: Tensor = probs2one_hot(probs).cpu().detach() pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...].numpy()) for b in range(B)], axis=0) assert pred_dm_npy.shape == tc.shape == pc.shape pdm: Tensor = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32) delta = (pc - tc)**2 dtm = tdm**2 + pdm**2 multipled = einsum("bkwh,bkwh->bkwh", delta, dtm) loss = multipled.mean() return loss
def compute_metrics(pred_probs, gt, labels): predicted_mask: Tensor = probs2one_hot(pred_probs) b, c, _,_ = predicted_mask.shape dices = dice_coef(predicted_mask.detach(), gt.detach()).cpu().numpy() baseline_dices = dice_coef(labels.detach(), gt.detach()).cpu().numpy() haussdorf_res = haussdorf(predicted_mask.detach(), gt.detach(), dtype= pred_probs.dtype).cpu().numpy() assert haussdorf_res.shape == (b, c) posim = torch.einsum("bcwh->b", [gt[:, 1:, :, :]]).detach() > 0 posim = posim.cpu().numpy() return dices, baseline_dices, posim, haussdorf_res
def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor: assert simplex( probs ) # and simplex(target) # Actually, does not care about second part b, _, w, h = probs.shape # type: Tuple[int, int, int, int] predicted_mask = probs2one_hot(probs).detach() est_prop_mask = self.__fn__(predicted_mask, self.power).squeeze(2) est_prop: Tensor = self.__fn__(probs, self.power) if self.curi: if self.ivd: bounds = bounds[:, :, 0] bounds = bounds.unsqueeze(2) gt_prop = torch.ones_like(est_prop) * bounds / (w * h) gt_prop = gt_prop[:, :, 0] else: gt_prop: Tensor = self.__fn__( target, self.power ) # the power here is actually useless if we have 0/1 gt labels if not self.curi: gt_prop = gt_prop.squeeze(2) est_prop = est_prop.squeeze(2) log_est_prop: Tensor = (est_prop + 1e-10).log() log_gt_prop: Tensor = (gt_prop + 1e-10).log() log_est_prop_mask: Tensor = (est_prop_mask + 1e-10).log() loss_cons_prior = -torch.einsum( "bc,bc->", [est_prop, log_gt_prop]) + torch.einsum( "bc,bc->", [est_prop, log_est_prop]) # Adding division by batch_size to normalise loss_cons_prior /= b log_p: Tensor = (probs + 1e-10).log() mask: Tensor = probs.type((torch.float32)) mask_weighted = torch.einsum( "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)]) loss_se = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p]) loss_se /= mask.sum() + 1e-10 assert loss_se.requires_grad == probs.requires_grad # Handle the case for validation return self.lamb_se * loss_se, self.lamb_consprior * loss_cons_prior, est_prop
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) assert probs.shape == target.shape assert len(self.mask_idc) == 1, "Cannot handle more at the time, I guess" b, c, w, h = probs.shape fake_probs: Tensor = torch.zeros_like(probs, dtype=torch.float32) for i in range(len(probs)): low: Tensor = bounds[i, self.mask_idc][0, 0, 0] high: Tensor = bounds[i, self.mask_idc][0, 0, 1] res = self.pathak_generator(probs[i].detach(), target[i].detach(), low, high) assert simplex(res, axis=0) assert res.shape == (c, w, h) fake_probs[i] = res fake_mask: Tensor = probs2one_hot(fake_probs) assert fake_mask.shape == probs.shape == target.shape return super().__call__(probs, fake_mask, bounds)
def pathak_generator(self, probs: Tensor, target: Tensor, bounds) -> Tensor: _, w, h = probs.shape # Replace the probabilities with certainty for the few weak labels that we have weak_labels = target[...] weak_labels[self.ignore, ...] = 0 assert not simplex(weak_labels) and simplex(target) lower, upper = bounds[-1] labeled_pixels = weak_labels.any(axis=0) assert w * h == (labeled_pixels.sum() + (~labeled_pixels).sum()) # make sure all pixels are covered scribbled_probs = weak_labels + einsum("cwh,wh->cwh", probs, ~labeled_pixels) assert simplex(scribbled_probs) u: Tensor max_iter: int = 100 lr: float = 0.00005 b: Tensor = Tensor([-lower, upper]) beta: Tensor = torch.zeros(2, torch.float32) f: Tensor = torch.zeros(2, *probs.shape) f[0, ...] = -1 f[1, ...] = 1 for i in range(max_iter): exped = - einsum("i,icwh->cwh", beta, f).exp() u_star = einsum('cwh,cwh->cwh', probs, exped) u_star /= u_star.sum(axis=0) assert simplex(u_star) d_beta = einsum("cwh,icwh->i", u_star, f) - b n_beta = torch.max(torch.zeros_like(beta), beta + lr * d_beta) u = u_star beta = n_beta return probs2one_hot(u)
def extra_eval_model(model, dataloaders, log_dir="./log/", logger=None, opt=None): since = time.time() if True: # eval infection segmentation and cls logger.info("-"*8+"extra eval infection cls"+"-"*8) model.eval() val_gt = [] val_cls_pred = [] val_cls_probs = [] # for VOC val_seg_pred = [] val_seg_probs = [] # for VOC val_seg_probs_au = [] val_seg_pred_au = [] # for VOC for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_cls_extra_val"], 0): inputs = inputs.to(device) # adjust label val_gt.append(labels.cpu().data.numpy()) with torch.set_grad_enabled(False): annotation = dataloaders["tgt_cls_extra_val"].dataset.annotations[batch_idx] img_dir = annotation.strip().split(',')[0] img_name = Path(img_dir).name if opt.use_aux: cls_logits, _, seg_logits, _, seg_logits_au = model(inputs) else: cls_logits, _, seg_logits, _, _ = model(inputs) if opt.do_seg: seg_probs = torch.softmax(seg_logits, dim=1) val_seg_probs.append(seg_probs[:,-1:].detach().cpu().view(seg_probs.shape[0], 1, -1).max(-1)[0]) predicted_mask_onehot = probs2one_hot(seg_probs.detach()) # for save predicted_mask = predicted_mask_onehot.squeeze().cpu().numpy() # 3xwxh mask_inone = (np.zeros_like(predicted_mask[0])+predicted_mask[1]*128+predicted_mask[2]*255).astype(np.uint8) # save dir: save_dir = os.path.join(opt.logs, "tgt_cls_extra_val", "eval") # if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(os.path.join(save_dir, img_name), mask_inone) # seg2cls preds_cls_seg = (predicted_mask_onehot[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8) val_seg_pred.append(preds_cls_seg) if opt.do_seg and opt.use_aux: seg_probs_au = torch.softmax(seg_logits_au, dim=1) val_seg_probs_au.append(seg_probs_au[:,-1:].detach().cpu().view(seg_probs_au.shape[0], 1, -1).max(-1)[0]) predicted_mask_onehot_au = probs2one_hot(seg_probs_au.detach()) # for save predicted_mask_au = predicted_mask_onehot_au.squeeze().cpu().numpy() # 3xwxh mask_inone_au = (np.zeros_like(predicted_mask_au[0])+predicted_mask_au[1]*128+predicted_mask_au[2]*255).astype(np.uint8) # save dir: save_dir_au = os.path.join(opt.logs, "tgt_cls_extra_val_au", "eval") # if not os.path.exists(save_dir_au): os.makedirs(save_dir_au) cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au) # seg2cls preds_cls_seg_au = (predicted_mask_onehot_au[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8) val_seg_pred_au.append(preds_cls_seg_au) # cls #print(cls_logits) if opt.do_cls: probs_cls = torch.softmax(cls_logits, dim=1) val_cls_probs.append(probs_cls[...,1:].detach().cpu().numpy()) preds_cls = (probs_cls[...,1:] > 0.5).type(torch.long) val_cls_pred.append(preds_cls.cpu().data.numpy()) if not os.path.exists(os.path.join(opt.logs, "cf")): os.makedirs(os.path.join(opt.logs, "cf")) val_gt = np.concatenate(val_gt, axis=0) if opt.do_cls: val_cls_pred = np.concatenate(val_cls_pred, axis=0) val_cls_probs = np.concatenate(val_cls_probs, axis=0) save_cf_png_dir = os.path.join(opt.logs, "cf", "extra_eval_cls_cf.png") save_metric_dir = os.path.join(opt.logs, "extra_eval_metric_cls.txt") result_str = get_results(val_gt, val_cls_pred, val_cls_probs, save_cf_png_dir, save_metric_dir) logger.info("tgt_cls_extra_val:[cls]: %s" % (result_str)) if opt.do_seg: val_seg_pred = np.concatenate(val_seg_pred, axis=0) val_seg_probs = np.concatenate(val_seg_probs, axis=0) # seg2cls save_cf_png_dir = os.path.join(opt.logs, "cf", "extra_eval_seg_cf.png") save_metric_dir = os.path.join(opt.logs, "extra_eval_metric_seg.txt") result_str = get_results(val_gt, val_seg_pred, val_seg_probs, save_cf_png_dir, save_metric_dir) logger.info("tgt_seg_extra_val:[seg2cls]: %s" % (result_str)) if opt.do_seg and opt.use_aux: val_seg_pred_au = np.concatenate(val_seg_pred_au, axis=0) val_seg_probs_au = np.concatenate(val_seg_probs_au, axis=0) # seg2cls save_cf_png_dir_au = os.path.join(opt.logs, "cf", "extra_eval_seg_au_cf.png") save_metric_dir_au = os.path.join(opt.logs, "extra_eval_metric_seg_au.txt") result_str_au = get_results(val_gt, val_seg_pred_au, val_seg_probs_au, save_cf_png_dir_au, save_metric_dir_au) logger.info("tgt_seg_au_extra_val:[seg2cls]: %s" % (result_str_au)) time_elapsed = time.time() - since logger.info("Extra_Eval complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
def eval_model(model, dataloaders, log_dir="./log/", logger=None, opt=None): since = time.time() if False:#opt.do_seg: # eval lung segmentation logger.info("-"*8+"eval lung segmentation"+"-"*8) model.eval() all_dices = [] all_dices_au = [] for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_lung_seg_val"], 0): annotation = dataloaders["tgt_lung_seg_val"].dataset.annotations[batch_idx] img_dir = annotation.strip().split(',')[0] img_name = Path(img_dir).name inputs = inputs.to(device) # adjust labels labels[labels==opt.xray_mask_value_dict["lung"]] = 1 labels = labels[:,-1].to(device) labels = torch.stack([labels == c for c in range(2)], dim=1) with torch.set_grad_enabled(False): if opt.use_aux: _, _, seg_logits, _, seg_logits_au = model(inputs) else: _, _, seg_logits, _, _ = model(inputs) seg_probs = torch.softmax(seg_logits, dim=1) predicted_mask = probs2one_hot(seg_probs.detach()) # change the infection to Lung predicted_mask_lung = predicted_mask[:,:-1] predicted_mask_lung[:,-1] += predicted_mask[:,-1] dices = dice_coef(predicted_mask_lung, labels.detach().type_as(predicted_mask)).cpu().numpy() all_dices.append(dices) # [(B,C)] predicted_mask_lung = predicted_mask_lung.squeeze().cpu().numpy() # 3xwxh mask_inone = (np.zeros_like(predicted_mask_lung[0])+predicted_mask_lung[1]*255).astype(np.uint8) # save dir: save_dir = os.path.join(opt.logs, "tgt_lung_seg_val", "eval") # if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(os.path.join(save_dir, img_name), mask_inone) ###################################################au if opt.use_aux: seg_probs_au = torch.softmax(seg_logits_au, dim=1) predicted_mask_au = probs2one_hot(seg_probs_au.detach()) # change the infection to Lung predicted_mask_lung_au = predicted_mask_au[:,:-1] predicted_mask_lung_au[:,-1] += predicted_mask_au[:,-1] dices_au = dice_coef(predicted_mask_lung_au, labels.detach().type_as(predicted_mask_au)).cpu().numpy() all_dices_au.append(dices_au) # [(B,C)] predicted_mask_lung_au = predicted_mask_lung_au.squeeze().cpu().numpy() # 3xwxh mask_inone_au = (np.zeros_like(predicted_mask_lung_au[0])+predicted_mask_lung_au[1]*255).astype(np.uint8) # save dir: save_dir_au = os.path.join(opt.logs, "tgt_lung_seg_val_au", "eval") # if not os.path.exists(save_dir_au): os.makedirs(save_dir_au) cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au) avg_dice = np.mean(np.concatenate(all_dices, 0), 0) # logger.info("tgt_lung_seg_val:[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f" % (batch_idx, len(dataloaders['tgt_lung_seg_val'].dataset)//inputs.shape[0], avg_dice[0], avg_dice[1], np.mean(np.concatenate(all_dices, 0)))) if opt.use_aux: avg_dice_au = np.mean(np.concatenate(all_dices_au, 0), 0) # logger.info("tgt_lung_seg_val_au:[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f" % (batch_idx, len(dataloaders['tgt_lung_seg_val'].dataset)//inputs.shape[0], avg_dice_au[0], avg_dice_au[1], np.mean(np.concatenate(all_dices_au, 0)))) if True: # eval infection segmentation and cls logger.info("-"*8+"eval infection cls"+"-"*8) model.eval() val_gt = [] val_cls_pred = [] val_cls_probs = [] # for VOC val_seg_pred = [] val_seg_probs = [] # for VOC val_seg_probs_au = [] val_seg_pred_au = [] # for VOC for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_cls_val"], 0): inputs = inputs.to(device) # adjust label val_gt.append(labels.cpu().data.numpy()) with torch.set_grad_enabled(False): annotation = dataloaders["tgt_cls_val"].dataset.annotations[batch_idx] img_dir = annotation.strip().split(',')[0] img_name = Path(img_dir).name if opt.use_aux: cls_logits, _, seg_logits, _, seg_logits_au = model(inputs) else: cls_logits, _, seg_logits, _, _ = model(inputs) if opt.do_seg: seg_probs = torch.softmax(seg_logits, dim=1) val_seg_probs.append(seg_probs[:,-1:].detach().cpu().view(seg_probs.shape[0], 1, -1).max(-1)[0]) predicted_mask_onehot = probs2one_hot(seg_probs.detach()) # for save predicted_mask = predicted_mask_onehot.squeeze().cpu().numpy() # 3xwxh mask_inone = (np.zeros_like(predicted_mask[0])+predicted_mask[1]*128+predicted_mask[2]*255).astype(np.uint8) # save dir: save_dir = os.path.join(opt.logs, "tgt_cls_val", "eval") # if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(os.path.join(save_dir, img_name), mask_inone) # seg2cls preds_cls_seg = (predicted_mask_onehot[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8) val_seg_pred.append(preds_cls_seg) if opt.do_seg and opt.use_aux: seg_probs_au = torch.softmax(seg_logits_au, dim=1) val_seg_probs_au.append(seg_probs_au[:,-1:].detach().cpu().view(seg_probs_au.shape[0], 1, -1).max(-1)[0]) predicted_mask_onehot_au = probs2one_hot(seg_probs_au.detach()) # for save predicted_mask_au = predicted_mask_onehot_au.squeeze().cpu().numpy() # 3xwxh mask_inone_au = (np.zeros_like(predicted_mask_au[0])+predicted_mask_au[1]*128+predicted_mask_au[2]*255).astype(np.uint8) # save dir: save_dir_au = os.path.join(opt.logs, "tgt_cls_val_au", "eval") # if not os.path.exists(save_dir_au): os.makedirs(save_dir_au) cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au) # seg2cls preds_cls_seg_au = (predicted_mask_onehot_au[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8) val_seg_pred_au.append(preds_cls_seg_au) # cls #print(cls_logits) if opt.do_cls: probs_cls = torch.softmax(cls_logits, dim=1) val_cls_probs.append(probs_cls[...,1:].detach().cpu().numpy()) preds_cls = (probs_cls[...,1:] > 0.5).type(torch.long) val_cls_pred.append(preds_cls.cpu().data.numpy()) if not os.path.exists(os.path.join(opt.logs, "cf")): os.makedirs(os.path.join(opt.logs, "cf")) val_gt = np.concatenate(val_gt, axis=0) if opt.do_cls: val_cls_pred = np.concatenate(val_cls_pred, axis=0) val_cls_probs = np.concatenate(val_cls_probs, axis=0) save_cf_png_dir = os.path.join(opt.logs, "cf", "eval_cls_cf.png") save_metric_dir = os.path.join(opt.logs, "eval_metric_cls.txt") result_str = get_results(val_gt, val_cls_pred, val_cls_probs, save_cf_png_dir, save_metric_dir) logger.info("tgt_cls_val:[cls]: %s" % (result_str)) if opt.do_seg: val_seg_pred = np.concatenate(val_seg_pred, axis=0) val_seg_probs = np.concatenate(val_seg_probs, axis=0) # seg2cls save_cf_png_dir = os.path.join(opt.logs, "cf", "eval_seg_cf.png") save_metric_dir = os.path.join(opt.logs, "eval_metric_seg.txt") result_str = get_results(val_gt, val_seg_pred, val_seg_probs, save_cf_png_dir, save_metric_dir) logger.info("tgt_seg_val:[seg2cls]: %s" % (result_str)) if opt.do_seg and opt.use_aux: val_seg_pred_au = np.concatenate(val_seg_pred_au, axis=0) val_seg_probs_au = np.concatenate(val_seg_probs_au, axis=0) # seg2cls save_cf_png_dir_au = os.path.join(opt.logs, "cf", "eval_seg_au_cf.png") save_metric_dir_au = os.path.join(opt.logs, "eval_metric_seg_au.txt") result_str_au = get_results(val_gt, val_seg_pred_au, val_seg_probs_au, save_cf_png_dir_au, save_metric_dir_au) logger.info("tgt_seg_au_val:[seg2cls]: %s" % (result_str_au)) time_elapsed = time.time() - since logger.info("Eval complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, log_dir="./log/", scheduler=None, writer=None, logger=None, opt=None): print(opt) since = time.time() val_acc_history = [] best_acc = 0.0 batch_size = dataloaders['train'].batch_size print_step = 5 # print info per 10 batches val_losses = [] tgt_cls_train_iter = iter(dataloaders['tgt_cls_train']) for epoch in range(num_epochs): logger.info('Epoch {}/{}'.format(epoch, num_epochs - 1)) logger.info('-' * 10) learning_rate = get_learning_rate(optimizer) writer.add_scalar("lr", learning_rate, epoch) epoch_val_preds = [] epoch_val_y = [] epoch_train_preds = [] epoch_train_y = [] # Each epoch has a training and validation phase model.train() # Set model to training mode running_loss = 0.0 # Iterate over data. for batch_idx, (inputs, labels) in enumerate(dataloaders["train"], 0): inputs = inputs.to(device) # adjust labels labels[labels == opt.drr_mask_value_dict["lung"]] = 1 labels[labels == opt.drr_mask_value_dict["infection"]] = 2 labels = labels[:, -1].to(device) tag_labels = ((labels == 2).sum(-1).sum(-1) > 0).type( torch.long).to(device) # batch_size, 1 c_labels = tag_labels if opt.do_cls_mmd else None s_labels = labels if opt.do_seg_mmd else None if opt.do_cls_mmd or opt.do_seg_mmd: # tgt_cls try: tgt_inputs, _ = tgt_cls_train_iter.next() except StopIteration: tgt_cls_train_iter = iter(dataloaders['tgt_cls_train']) tgt_inputs, _ = tgt_cls_train_iter.next() tgt_inputs = tgt_inputs.to(device) else: tgt_inputs = None # zero the parameter gradients optimizer.zero_grad() model.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(True): src_cls_logits, loss_cls_lmmd, src_seg_logits, loss_seg_lmmd, _ = model( inputs, tgt_img=tgt_inputs, c_label=c_labels, s_label=s_labels) lambd = 2 / (1 + math.exp(-10 * (epoch) / num_epochs)) - 1 if opt.do_cls and opt.do_cls_mmd: loss_cls_lmmd = lambd * loss_cls_lmmd * opt.lambda_cls_mmd loss_cls_lmmd_item = loss_cls_lmmd.item() else: loss_cls_lmmd = 0 loss_cls_lmmd_item = 0 if opt.do_seg and opt.do_seg_mmd: loss_seg_lmmd = lambd * loss_seg_lmmd * opt.lmabda_seg_mmd loss_seg_lmmd_item = loss_seg_lmmd.item() else: loss_seg_lmmd = 0 loss_seg_lmmd_item = 0 if opt.do_seg: loss_seg = criterion( labels, src_seg_logits, class_weights=opt.seg_class_weights) * opt.lambda_seg loss_seg_item = loss_seg.item() else: loss_seg = 0 loss_seg_item = 0 if opt.do_cls: loss_cls = F.cross_entropy(src_cls_logits, tag_labels) * opt.lambda_cls loss_cls_item = loss_cls.item() else: loss_cls = 0 loss_cls_item = 0 loss = loss_seg + loss_cls + loss_seg_lmmd + loss_cls_lmmd loss_item = loss.item() loss.backward() optimizer.step() # statistics if batch_idx % print_step == 0: # print info print_loss = running_loss / ((batch_idx + 1) * batch_size) logger.info( "Train E{:>03} B{:>05} LR:{:.8f} Loss: {:.4f} LSeg: {:.4f} SegMmd: {:.4f} LCls: {:.4f} ClsMmd: {:.4f}" .format(epoch, batch_idx, learning_rate, loss_item, loss_seg_item, loss_seg_lmmd_item, loss_cls_item, loss_cls_lmmd_item)) scheduler.step() weight_path = os.path.join(log_dir, "latest.pth") torch.save(model.state_dict(), weight_path) if ((epoch + 1) % opt.eval_times == 0 or epoch + 1 == num_epochs) and opt.do_seg: # eval lung segmentation logger.info("-" * 8 + "eval lung segmentation" + "-" * 8) model.eval() all_dices = [] for batch_idx, (inputs, labels) in enumerate( dataloaders["tgt_lung_seg_val"], 0): annotation = dataloaders[ "tgt_lung_seg_val"].dataset.annotations[batch_idx] img_dir = annotation.strip().split(',')[0] img_name = Path(img_dir).name inputs = inputs.to(device) # adjust labels labels[labels == opt.xray_mask_value_dict["lung"]] = 1 labels = labels[:, -1].to(device) labels = torch.stack([labels == c for c in range(2)], dim=1) with torch.set_grad_enabled(False): _, _, seg_logits, _, _ = model(inputs) seg_probs = torch.softmax(seg_logits, dim=1) predicted_mask = probs2one_hot(seg_probs.detach()) # change the infection to Lung predicted_mask_lung = predicted_mask[:, :-1] predicted_mask_lung[:, -1] += predicted_mask[:, -1] dices = dice_coef( predicted_mask_lung, labels.detach().type_as(predicted_mask)).cpu().numpy() all_dices.append(dices) # [(B,C)] predicted_mask_lung = predicted_mask_lung.squeeze().cpu( ).numpy() # 3xwxh mask_inone = (np.zeros_like(predicted_mask_lung[0]) + predicted_mask_lung[1] * 255).astype( np.uint8) # save dir: save_dir = os.path.join(opt.logs, "tgt_lung_seg_val", "ep%03d" % epoch) # if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(os.path.join(save_dir, img_name), mask_inone) avg_dice = np.mean(np.concatenate(all_dices, 0), 0) # logger.info( "tgt_lung_seg_val:EP%03d,[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f" % (epoch, batch_idx, len(dataloaders['tgt_lung_seg_val'].dataset) // inputs.shape[0], avg_dice[0], avg_dice[1], np.mean(np.concatenate(all_dices, 0)))) if ((epoch + 1) % opt.eval_cls_times == 0 or epoch + 1 == num_epochs): # eval infection segmentation and cls logger.info("-" * 8 + "eval infection cls" + "-" * 8) model.eval() val_gt = [] val_cls_pred = [] val_seg_pred = [] for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_cls_val"], 0): inputs = inputs.to(device) # adjust label val_gt.append(labels.cpu().data.numpy()) with torch.set_grad_enabled(False): annotation = dataloaders[ "tgt_cls_val"].dataset.annotations[batch_idx] img_dir = annotation.strip().split(',')[0] img_name = Path(img_dir).name cls_logits, _, seg_logits, _, _ = model(inputs) if opt.do_seg: seg_probs = torch.softmax(seg_logits, dim=1) predicted_mask_onehot = probs2one_hot( seg_probs.detach()) # for save predicted_mask = predicted_mask_onehot.squeeze().cpu( ).numpy() # 3xwxh mask_inone = (np.zeros_like(predicted_mask[0]) + predicted_mask[1] * 128 + predicted_mask[2] * 255).astype(np.uint8) # save dir: save_dir = os.path.join(opt.logs, "tgt_cls_val", "ep%03d" % epoch) # if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(os.path.join(save_dir, img_name), mask_inone) # seg2cls preds_cls_seg = ( predicted_mask_onehot[:, -1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8) val_seg_pred.append(preds_cls_seg) # cls #print(cls_logits) if opt.do_cls: probs_cls = torch.softmax(cls_logits, dim=1) preds_cls = (probs_cls[..., 1:] > 0.5).type(torch.long) val_cls_pred.append(preds_cls.cpu().data.numpy()) if not os.path.exists(os.path.join(opt.logs, "cf")): os.makedirs(os.path.join(opt.logs, "cf")) val_gt = np.concatenate(val_gt, axis=0) if opt.do_cls: val_cls_pred = np.concatenate(val_cls_pred, axis=0) save_cf_png_dir = os.path.join(opt.logs, "cf", "ep%03d_cls_cf.png" % epoch) save_metric_dir = os.path.join(opt.logs, "metric_cls.txt") result_str = get_results(val_gt, val_cls_pred, save_cf_png_dir, save_metric_dir) logger.info("tgt_cls_val:EP%03d,[cls]: %s" % (epoch, result_str)) if opt.do_seg: val_seg_pred = np.concatenate(val_seg_pred, axis=0) # seg2cls save_cf_png_dir = os.path.join(opt.logs, "cf", "ep%03d_seg_cf.png" % epoch) save_metric_dir = os.path.join(opt.logs, "metric_seg.txt") result_str = get_results(val_gt, val_seg_pred, save_cf_png_dir, save_metric_dir) logger.info("tgt_seg_val:EP%03d,[seg2cls]: %s" % (epoch, result_str)) time_elapsed = time.time() - since logger.info("Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60))
def do_epoch(args, mode: str, net: Any, device: Any, loader: DataLoader, epc: int, loss_fns: List[Callable], loss_weights: List[float],loss_fns_source: List[Callable], loss_weights_source: List[float], new_w:int, num_steps:int, C: int, metric_axis:List[int], savedir: str = "", optimizer: Any = None, target_loader: Any = None): assert mode in ["train", "val"] L: int = len(loss_fns) indices = torch.tensor(metric_axis,device=device) if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() # net.train() desc = f">> Validation ({epc})" total_it_s, total_images = len(loader), len(loader.dataset) total_it_t, total_images_t = len(target_loader), len(target_loader.dataset) total_iteration = max(total_it_s, total_it_t) # Lazy add lines below because we will be cycling until the biggest length is reached total_images = max(total_images, total_images_t) total_images_t = total_images pho=1 dtype = eval(args.dtype) all_dices: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_inter_card: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_gt: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_pred: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) loss_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_inf: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_cons: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_fs: Tensor = torch.zeros((total_images), dtype=dtype, device=device) posim_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_grp: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) dice_3d_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) dice_3d_sd_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) if args.source_metrics == True: all_dices_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_inter_card_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_gt_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_pred_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_grp_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) dice_3d_s_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) dice_3d_s_sd_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) # if len(loader)>len(target_loader): # tq_iter = tqdm_(enumerate(zip(loader, cycle(target_loader))), total=total_iteration, desc=desc) # elif len(loader)<len(target_loader): # tq_iter = tqdm_(enumerate(zip(cycle(loader), target_loader)), total=total_iteration, desc=desc) # else: # tq_iter = tqdm_(enumerate(zip(loader, target_loader)), total=total_iteration, desc=desc) tq_iter = tqdm_(enumerate(zip(loader, target_loader)), total=total_iteration, desc=desc) #tq_iter = tqdm_(enumerate(target_loader), total=total_iteration, desc=desc) done: int = 0 #ratio_losses = 0 n_warmup = 0 mult_lw = [pho ** (epc - n_warmup + 1)] * len(loss_weights) #if epc > 100: # mult_lw = [pho ** 100] * len(loss_weights) mult_lw[0] = 1 loss_weights = [a * b for a, b in zip(loss_weights, mult_lw)] losses_vec, source_vec, target_vec, baseline_target_vec = [], [], [], [] pen_count = 0 with warnings.catch_warnings(): warnings.simplefilter("ignore") for j, (source_data, target_data) in tq_iter: #for j, target_data in tq_iter: source_data[1:] = [e.to(device) for e in source_data[1:]] # Move all tensors to device filenames_source, source_image, source_gt = source_data[:3] target_data[1:] = [e.to(device) for e in target_data[1:]] # Move all tensors to device filenames_target, target_image, target_gt = target_data[:3] labels = target_data[3:3+L] labels_source = source_data[3:3 + L] bounds = target_data[3+L:] bounds_source = source_data[3+L:] assert len(labels) == len(bounds), len(bounds) if args.mix == False: assert filenames_source == filenames_target #print(filenames_source,filenames_target) B = len(target_image) # Reset gradients if optimizer: #adjust_learning_rate(optimizer, 1, args.l_rate, num_steps, args.power) optimizer.zero_grad() # Forward with torch.set_grad_enabled(mode == "train"): pred_logits: Tensor = net(target_image) pred_logits_source: Tensor = net(source_image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) pred_probs_source: Tensor = F.softmax(pred_logits_source, dim=1) if new_w > 0: pred_probs = resize(pred_probs, new_w) labels = [resize(label, new_w) for label in labels] target = resize(target, new_w) predicted_mask: Tensor = probs2one_hot(pred_probs) # Used only for dice computation predicted_mask_source: Tensor = probs2one_hot(pred_probs_source) # Used only for dice computation #print(torch.sum(predicted_mask, dim=[2,3]).cpu().numpy()) #print(list(map(lambda n: [int(f) for f in n], np.around(torch.sum(pred_probs, dim=[2,3]).detach().cpu().numpy())))) assert len(bounds) == len(loss_fns) == len(loss_weights) if epc < n_warmup: loss_weights = [0]*len(loss_weights) loss: Tensor = torch.zeros(1, requires_grad=True).to(device) loss_vec = [] loss_k = [] for loss_fn,label, w, bound in zip(loss_fns,labels, loss_weights, bounds): if w > 0: if args.lin_aug_w: if epc<70: w=w*(epc+1)/70 loss_b = loss_fn(pred_probs, label, bound) loss = loss + w * loss_b #pen_count += count_b.detach() #print(count_b.detach()) loss_k.append(w*loss_b.detach()) #for loss_fn, label, w, bound in zip(loss_fns_source, [source_gt], loss_weights_source, torch.randn(1)): #for loss_fn, label, w, bound in zip(loss_fns_source, labels_source, loss_weights_source, torch.randn(1)): for loss_fn, label, w, bound in zip(loss_fns_source, labels_source, loss_weights_source, bounds_source): if w > 0: loss_b = loss_fn(pred_probs_source, label, bound) loss = loss+ w * loss_b loss_k.append(w*loss_b.detach()) #print(loss_k) # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics #dices: Tensor = dice_coef(predicted_mask.detach(), target.detach()) # baseline_dices: Tensor = dice_coef(labels[0].detach(), target.detach()) #batch_dice: Tensor = dice_batch(predicted_mask.detach(), target.detach()) # assert batch_dice.shape == (C,) and dices.shape == (B, C), (batch_dice.shape, dices.shape, B, C) dices, inter_card, card_gt, card_pred = dice_coef(predicted_mask.detach(), target_gt.detach()) assert dices.shape == (B, C), (dices.shape, B, C) sm_slice = slice(done, done + B) # Values only for current batch all_dices[sm_slice, ...] = dices # # for 3D dice all_grp[sm_slice, ...] = int(re.split('_', filenames_target[0])[1]) * torch.ones([1, C]) all_inter_card[sm_slice, ...] = inter_card all_card_gt[sm_slice, ...] = card_gt all_card_pred[sm_slice, ...] = card_pred # 3D dice on source if args.source_metrics ==True: dices_s, inter_card_s, card_gt_s, card_pred_s = dice_coef(predicted_mask_source.detach(), source_gt.detach()) all_grp_s[sm_slice, ...] = int(re.split('_', filenames_source[0])[1]) * torch.ones([1, C]) all_inter_card_s[sm_slice, ...] = inter_card_s all_card_gt_s[sm_slice, ...] = card_gt_s all_card_pred_s[sm_slice, ...] = card_pred_s #loss_log[sm_slice] = loss.detach() loss_inf[sm_slice] = loss_k[0] if len(loss_k)>1: loss_cons[sm_slice] = loss_k[1] else: loss_cons[sm_slice] = 0 if len(loss_k)>2: loss_fs[sm_slice] = loss_k[2] else: loss_fs[sm_slice] = 0 #posim_log[sm_slice] = torch.einsum("bcwh->b", [target_gt[:, 1:, :, :]]).detach() > 0 #haussdorf_res: Tensor = haussdorf(predicted_mask.detach(), target_gt.detach(), dtype) #assert haussdorf_res.shape == (B, C) #haussdorf_log[sm_slice] = haussdorf_res # # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) warnings.simplefilter("ignore") predicted_class: Tensor = probs2class(pred_probs) #save_images_p(predicted_class, filenames_target, args.dataset, mode, epc, False) save_images(predicted_class, filenames_target, savedir, mode, epc, True) # Logging big_slice = slice(0, done + B) # Value for current and previous batches stat_dict = {"dice": torch.index_select(all_dices, 1, indices).mean(), "loss": loss_log[big_slice].mean()} nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()} done += B tq_iter.set_postfix(nice_dict) print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) #dice_posim = torch.masked_select(all_dices[:, -1], posim_log.type(dtype=torch.uint8)).mean() # dice3D gives back the 3d dice mai on images # if not args.debug: # dice_3d_log_o, dice_3d_sd_log_o = dice3d(args.workdir, f"iter{epc:03d}", mode, "Subj_\\d+_",args.dataset + mode + '/CT_GT', C) dice_3d_log, dice_3d_sd_log = dice3dn(all_grp, all_inter_card, all_card_gt, all_card_pred,metric_axis,True) if args.source_metrics ==True: dice_3d_s_log, dice_3d_s_sd_log = dice3dn(all_grp_s, all_inter_card_s, all_card_gt_s, all_card_pred_s,metric_axis,True) print("mean 3d_dice over all patients:",dice_3d_log) #source_vec = [ dice_3d_s, dice_3d_sd_s, haussdorf_log_s] dice_2d = torch.index_select(all_dices, 1, indices).mean().cpu().numpy() target_vec = [ dice_3d_log, dice_3d_sd_log, dice_2d] if args.source_metrics ==True: source_vec = [ dice_3d_s_log, dice_3d_s_sd_log] else: source_vec = [0,0] #losses_vec = [loss_log.mean().item()] losses_vec = [loss_inf.mean().item(),loss_cons.mean().item(),loss_fs.mean().item()] return losses_vec, target_vec,source_vec
def do_epoch(mode: str, net: Any, device: Any, loader: DataLoader, epc: int, loss_fns: List[Callable], loss_weights: List[float], C: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_haussdorf: bool = False) \ -> Tuple[Tensor, Tensor, Tensor, Tensor]: assert mode in ["train", "val"] L: int = len(loss_fns) if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration, total_images = len(loader), len(loader.dataset) all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration), dtype=torch.float32, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) done: int = 0 for j, data in tq_iter: data[1:] = [e.to(device) for e in data[1:]] # Move all tensors to device filenames, image, target = data[:3] labels = data[3:3 + L] bounds = data[3 + L:] assert len(labels) == len(bounds) B = len(image) # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert len(bounds) == len(loss_fns) == len(loss_weights) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [ w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics loss_log[j] = loss.detach() sm_slice = slice(done, done + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target.detach()) assert dices.shape == (B, C), (dices.shape, B, C) all_dices[sm_slice, ...] = dices if B > 1 and mode == "val": batch_dice: Tensor = dice_batch(predicted_mask, target.detach()) assert batch_dice.shape == (C, ), (batch_dice.shape, B, C) batch_dices[j] = batch_dice if compute_haussdorf: haussdorf_res: Tensor = haussdorf(predicted_mask.detach(), target.detach()) assert haussdorf_res.shape == (B, C) haussdorf_log[sm_slice] = haussdorf_res # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done + B) # Value for current and previous batches dsc_dict = { f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis } hauss_dict = { f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis } if compute_haussdorf else {} batch_dict = { f"bDSC{n}": batch_dices[:j, n].mean() for n in metric_axis } if B > 1 and mode == "val" else {} mean_dict = { "DSC": all_dices[big_slice, metric_axis].mean(), "HD": haussdorf_log[big_slice, metric_axis].mean() } if len(metric_axis) > 1 else {} stat_dict = { **dsc_dict, **hauss_dict, **mean_dict, **batch_dict, "loss": loss_log[:j].mean() } nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} tq_iter.set_postfix(nice_dict) done += B print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return loss_log, all_dices, batch_dices, haussdorf_log
def do_epoch(mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int, list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], C: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_haussdorf: bool = False, compute_miou: bool = False, temperature: float = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tuple[None, Tensor]]: assert mode in ["train", "val"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) iiou_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) intersections: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) unions: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) few_axis: bool = len(metric_axis) <= 3 done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)): L: int = len(loss_fns) for data in loader: data[1:] = [e.to(device) for e in data[1:]] # Move all tensors to device filenames, image, target = data[:3] assert not target.requires_grad labels = data[3:3 + L] bounds = data[3 + L:] assert len(labels) == len(bounds) B = len(image) # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot(pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped] loss = reduce(add, losses) assert loss.shape == (), loss.shape # if epc >= 1 and False: # import matplotlib.pyplot as plt # _, axes = plt.subplots(nrows=1, ncols=3) # axes[0].imshow(image[0, 0].cpu().numpy(), cmap='gray') # axes[0].contour(target[0, 1].cpu().numpy(), cmap='rainbow') # pred_np = pred_probs[0, 1].detach().cpu().numpy() # axes[1].imshow(pred_np) # bins = np.linspace(0, 1, 50) # axes[2].hist(pred_np.flatten(), bins) # print(bounds) # print(bounds[2].cpu().numpy()) # print(bounds[2][0, 1].cpu().numpy()) # print(pred_np.sum()) # plt.show() # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics # loss_log[done_batch] = loss.detach() for j in range(len(loss_fns)): loss_log[done_batch, j] = losses[j].detach() sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target) assert dices.shape == (B, C), (dices.shape, B, C) all_dices[sm_slice, ...] = dices if B > 1 and mode == "val": batch_dice: Tensor = dice_batch(predicted_mask, target) assert batch_dice.shape == (C,), (batch_dice.shape, B, C) batch_dices[done_batch] = batch_dice if compute_haussdorf: haussdorf_res: Tensor = haussdorf(predicted_mask, target) assert haussdorf_res.shape == (B, C) haussdorf_log[sm_slice] = haussdorf_res if compute_miou: IoUs: Tensor = iIoU(predicted_mask, target) assert IoUs.shape == (B, C), IoUs.shape iiou_log[sm_slice] = IoUs intersections[sm_slice] = inter_sum(predicted_mask, target) unions[sm_slice] = union_sum(predicted_mask, target) # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {} hauss_dict = {f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis} \ if compute_haussdorf and few_axis else {} batch_dict = {f"bDSC{n}": batch_dices[:done_batch, n].mean() for n in metric_axis} \ if B > 1 and mode == "val" and few_axis else {} miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(), f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \ if compute_miou else {} if len(metric_axis) > 1: mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()} if compute_haussdorf: mean_dict["HD"] = haussdorf_log[big_slice, metric_axis].mean() else: mean_dict = {} stat_dict = {**miou_dict, **dsc_dict, **hauss_dict, **mean_dict, **batch_dict, "loss": loss_log[:done_batch].mean()} nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) if compute_miou: mIoUs: Tensor = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)) assert mIoUs.shape == (C,), mIoUs.shape else: mIoUs = None if not few_axis and False: print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}") print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}") if mIoUs: print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}") return loss_log, all_dices, batch_dices, haussdorf_log, mIoUs
def do_epoch( mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int, list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], K: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_hausdorff: bool = False, compute_miou: bool = False, compute_3d_dice: bool = False, temperature: float = 1 ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: assert mode in ["train", "val"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, K), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) iiou_log: Optional[Tensor] intersections: Optional[Tensor] unions: Optional[Tensor] if compute_miou: iiou_log = torch.zeros((total_images, K), dtype=torch.float32, device=device) intersections = torch.zeros((total_images, K), dtype=torch.float32, device=device) unions = torch.zeros((total_images, K), dtype=torch.float32, device=device) else: iiou_log = None intersections = None unions = None three_d_dices: Optional[Tensor] if compute_3d_dice: three_d_dices = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) else: three_d_dices = None hausdorff_log: Optional[Tensor] if compute_hausdorff: hausdorff_log = torch.zeros((total_images, K), dtype=torch.float32, device=device) else: hausdorff_log = None few_axis: bool = len(metric_axis) <= 3 done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate( zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) spacings: Tensor = data["spacings"] # Keep that one on CPU assert not target.requires_grad labels: List[Tensor] = [e.to(device) for e in data["labels"]] bounds: List[Tensor] = [e.to(device) for e in data["bounds"]] box_priors: List[List[Tuple[ Tensor, Tensor]]] # one more level for the batch box_priors = [[(m.to(device), b.to(device)) for (m, b) in B] for B in data["box_priors"]] assert len(labels) == len(bounds) B, C, *_ = image.shape samplings: List[List[Tuple[slice]]] = data["samplings"] assert len(samplings) == B assert len(samplings[0][0]) == len( image[0, 0].shape), (samplings[0][0], image[0, 0].shape) probs_receptacle: Tensor = -torch.ones_like( target, dtype=torch.float32) # -1 for unfilled mask_receptacle: Tensor = -torch.ones_like( target, dtype=torch.int32) # -1 for unfilled # Use the sampling coordinates of the first batch item assert not (len(samplings[0]) > 1 and B > 1), samplings # No subsampling if batch size > 1 loss_sub_log: Tensor = torch.zeros( (len(samplings[0]), len(loss_fns)), dtype=torch.float32, device=device) for k, sampling in enumerate(samplings[0]): img_sampling = [slice(0, B), slice(0, C)] + list(sampling) label_sampling = [slice(0, B), slice(0, K)] + list(sampling) assert len(img_sampling) == len(image.shape), (img_sampling, image.shape) sub_img = image[img_sampling] # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(sub_img) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad probs_receptacle[label_sampling] = pred_probs[...] mask_receptacle[label_sampling] = predicted_mask[...] assert len(bounds) == len(loss_fns) == len( loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [ w * loss_fn(pred_probs, label[label_sampling], bound, box_priors) for loss_fn, label, w, bound in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_sub_log[k, j] = losses[j].detach() reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0) assert reduced_loss_sublog.shape == (len(loss_fns), ), ( reduced_loss_sublog.shape, len(loss_fns)) loss_log[done_batch, ...] = reduced_loss_sublog[...] del loss_sub_log sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(mask_receptacle, target) assert dices.shape == (B, K), (dices.shape, B, K) all_dices[sm_slice, ...] = dices if compute_3d_dice: three_d_DSC: Tensor = dice_batch(mask_receptacle, target) assert three_d_DSC.shape == (K, ) three_d_dices[done_batch] = three_d_DSC # type: ignore if compute_hausdorff: hausdorff_res: Tensor try: hausdorff_res = hausdorff(mask_receptacle, target, spacings) except RuntimeError: hausdorff_res = torch.zeros((B, K), device=device) assert hausdorff_res.shape == (B, K) hausdorff_log[sm_slice] = hausdorff_res # type: ignore if compute_miou: IoUs: Tensor = iIoU(mask_receptacle, target) assert IoUs.shape == (B, K), IoUs.shape iiou_log[sm_slice] = IoUs # type: ignore intersections[sm_slice] = inter_sum(mask_receptacle, target) # type: ignore unions[sm_slice] = union_sum(mask_receptacle, target) # type: ignore # if False and target[0, 1].sum() > 0: # Useful template for quick and dirty inspection # import matplotlib.pyplot as plt # from pprint import pprint # from mpl_toolkits.axes_grid1 import ImageGrid # from utils import soft_length # print(data["filenames"]) # pprint(data["bounds"]) # pprint(soft_length(mask_receptacle)) # fig = plt.figure() # fig.clear() # grid = ImageGrid(fig, 211, nrows_ncols=(1, 2)) # grid[0].imshow(data["images"][0, 0], cmap="gray") # grid[0].contour(data["gt"][0, 1], cmap='jet', alpha=.75, linewidths=2) # grid[1].imshow(data["images"][0, 0], cmap="gray") # grid[1].contour(mask_receptacle[0, 1], cmap='jet', alpha=.75, linewidths=2) # plt.show() # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, data["filenames"], savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict: Dict if few_axis: dsc_dict = { **{ f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis }, **({ f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis } if three_d_dices is not None else {}) } else: dsc_dict = {} # dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {} hauss_dict = {f"HD{n}": hausdorff_log[big_slice, n].mean() for n in metric_axis} \ if hausdorff_log is not None and few_axis else {} miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(), f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \ if iiou_log is not None and intersections is not None and unions is not None else {} if len(metric_axis) > 1: mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()} if hausdorff_log: mean_dict["HD"] = hausdorff_log[big_slice, metric_axis].mean() else: mean_dict = {} stat_dict = { **miou_dict, **dsc_dict, **hauss_dict, **mean_dict, "loss": loss_log[:done_batch].mean() } nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) mIoUs: Optional[Tensor] if intersections and unions: mIoUs = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)) assert mIoUs.shape == (K, ), mIoUs.shape else: mIoUs = None if not few_axis and False: print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}") print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}") if mIoUs: print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}") return ( loss_log.detach().cpu(), all_dices.detach().cpu(), hausdorff_log.detach().cpu() if hausdorff_log is not None else None, mIoUs.detach().cpu() if mIoUs is not None else None, three_d_dices.detach().cpu() if three_d_dices is not None else None)
def do_epoch(mode: str, args, net, device, use_cuda, loader, optimizer, num_classes, epoch): totalImages = len(loader) if mode == "train": net.train() desc = f">> Training ({epoch})" elif mode == "val": net.eval() desc = f">> Validation ({epoch})" total_iteration, total_images = len(loader), len(loader.dataset) all_dices: Tensor = torch.zeros((total_images, num_classes), dtype=eval(args.dtype), device=device) batch_dices: Tensor = torch.zeros((total_iteration, num_classes), dtype=eval(args.dtype), device=device) loss_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) entropy_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) KL_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) done: int = 0 for j, data in tq_iter: image_f, image_i, image_d, image_o, image_w, image_c, labels, img_names = data #image_f=image_f.type(torch.FloatTensor)/65535. #image_f = image_f.type(torch.FloatTensor)/65535. #image_i = image_i.type(torch.FloatTensor)/65535. #image_d = image_d.type(torch.FloatTensor)/65535. #image_o = image_o.type(torch.FloatTensor)/65535. #image_w = image_w.type(torch.FloatTensor)/65535. #image_c = image_c.type(torch.FloatTensor)/65535. MRI: Tensor = torch.zeros((1, 6, image_f.size()[2], image_f.size()[3]), dtype=eval(args.dtype)) MRI = torch.cat((image_f, image_i, image_d, image_o, image_w, image_c), dim=1) MRI = MRI.type( torch.FloatTensor ) / 65535.0 #.type(eval(args.dtype)) #.type(torch.FloatTensor) targets = torch.cat((1 - labels, labels), dim=1) #.type(torch.LongTensor) B = len(image_f) #print(type(labels)) #MRI = torch.cat((image_f,image_i,image_d,image_w),dim=1) if use_cuda: MRI, targets = MRI.to(device), targets.to(device) # forward outputs = net(MRI) pred_probs = F.softmax(outputs, dim=1) predicted_mask = probs2one_hot(pred_probs) entropy = crossEntropy_f(pred_probs, targets) pred_probs_aver: Tensor = torch.sum(pred_probs, dim=(2, 3)) pred_probs_aver = pred_probs_aver / torch.sum(targets).float() target_aver: Tensor = torch.sum(targets, dim=(2, 3)).float() target_aver = target_aver / torch.sum(targets).float() KL_loss = args.lam * kl(target_aver, pred_probs_aver) loss = entropy + KL_loss if mode == "train": # zero the parameter gradients8544 optimizer.zero_grad() # backward + optimize loss.backward() optimizer.step() # Compute and log metrics dices: Tensor = dice_coef(predicted_mask.detach(), targets.type(torch.cuda.IntTensor).detach()) batch_dice: Tensor = dice_batch( predicted_mask.detach(), targets.type(torch.cuda.IntTensor).detach()) assert batch_dice.shape == (num_classes, ) and dices.shape == ( B, num_classes), (batch_dice.shape, dices.shape, B, num_classes) sm_slice = slice(done, done + B) # Values only for current batch all_dices[sm_slice, ...] = dices entropy_log[sm_slice] = entropy.detach() loss_log[sm_slice] = loss.detach() KL_log[sm_slice] = KL_loss.detach() batch_dices[j] = batch_dice # Logging big_slice = slice(0, done + B) # Value for current and previous batches stat_dict = { "dice": all_dices[big_slice, -1].mean(), "total loss": loss_log[big_slice].mean(), "entropy loss": entropy_log[big_slice].mean(), "KL loss": KL_log[big_slice].mean(), "b dice": batch_dices[:j + 1, -1].mean() } nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()} done += B tq_iter.set_postfix(nice_dict) return loss_log, entropy_log, KL_log, all_dices, batch_dices
def do_epoch(args, mode: str, net: Any, device: Any, epc: int, loss_fns: List[Callable], loss_weights: List[float], new_w:int, C: int, metric_axis:List[int], savedir: str = "", optimizer: Any = None, target_loader: Any = None, best_dice3d_val:Any=None): assert mode in ["train", "val"] L: int = len(loss_fns) indices = torch.tensor(metric_axis,device=device) if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() # net.train() desc = f">> Validation ({epc})" total_it_t, total_images_t = len(target_loader), len(target_loader.dataset) total_iteration = total_it_t total_images = total_images_t if args.debug: total_iteration = 10 pho=1 dtype = eval(args.dtype) all_dices: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_gt_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_sizes2: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_inter_card: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_gt: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_card_pred: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_gt = [] all_pred = [] if args.do_hd: all_gt: Tensor = torch.zeros((total_images, args.wh, args.wh), dtype=dtype) all_pred: Tensor = torch.zeros((total_images, args.wh, args.wh), dtype=dtype) loss_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_cons: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_se: Tensor = torch.zeros((total_images), dtype=dtype, device=device) loss_tot: Tensor = torch.zeros((total_images), dtype=dtype, device=device) posim_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_grp: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device) all_pnames = np.zeros([total_images]).astype('U256') #dice_3d_log: Tensor = torch.zeros((1, C), dtype=dtype, device=device) dice_3d_log, dice_3d_sd_log = 0, 0 #dice_3d_sd_log: Tensor = torch.zeros((1, C), dtype=dtype, device=device) hd_3d_log, asd_3d_log, hd_3d_sd_log, asd_3d_sd_log= 0, 0, 0, 0 tq_iter = tqdm_(enumerate(target_loader), total=total_iteration, desc=desc) done: int = 0 n_warmup = args.n_warmup mult_lw = [pho ** (epc - n_warmup + 1)] * len(loss_weights) mult_lw[0] = 1 loss_weights = [a * b for a, b in zip(loss_weights, mult_lw)] losses_vec, source_vec, target_vec, baseline_target_vec = [], [], [], [] pen_count = 0 with warnings.catch_warnings(): warnings.simplefilter("ignore") count_losses = 0 for j, target_data in tq_iter: target_data[1:] = [e.to(device) for e in target_data[1:]] # Move all tensors to device filenames_target, target_image, target_gt = target_data[:3] #print("target", filenames_target) labels = target_data[3:3+L] bounds = target_data[3+L:] filenames_target = [f.split('.nii')[0] for f in filenames_target] assert len(labels) == len(bounds), len(bounds) B = len(target_image) # Reset gradients if optimizer: optimizer.zero_grad() # Forward with torch.set_grad_enabled(mode == "train"): pred_logits: Tensor = net(target_image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) if new_w > 0: pred_probs = resize(pred_probs, new_w) labels = [resize(label, new_w) for label in labels] target = resize(target, new_w) predicted_mask: Tensor = probs2one_hot(pred_probs) # Used only for dice computation assert len(bounds) == len(loss_fns) == len(loss_weights) if epc < n_warmup: loss_weights = [0]*len(loss_weights) loss: Tensor = torch.zeros(1, requires_grad=True).to(device) loss_vec = [] loss_kw = [] for loss_fn,label, w, bound in zip(loss_fns,labels, loss_weights, bounds): if w > 0: if eval(args.target_losses)[0][0]=="EntKLProp": loss_1, loss_cons_prior,est_prop = loss_fn(pred_probs, label, bound) loss = loss_1 + loss_cons_prior else: loss = loss_fn(pred_probs, label, bound) loss = w*loss loss_1 = loss loss_kw.append(loss_1.detach()) # Backward if optimizer: loss.backward() optimizer.step() dices, inter_card, card_gt, card_pred = dice_coef(predicted_mask.detach(), target_gt.detach()) assert dices.shape == (B, C), (dices.shape, B, C) sm_slice = slice(done, done + B) # Values only for current batch all_dices[sm_slice, ...] = dices if eval(args.target_losses)[0][0] in ["EntKLProp"]: all_sizes[sm_slice, ...] = torch.round(est_prop.detach()*target_image.shape[2]*target_image.shape[3]) all_sizes2[sm_slice, ...] = torch.sum(predicted_mask,dim=(2,3)) all_gt_sizes[sm_slice, ...] = torch.sum(target_gt,dim=(2,3)) all_grp[sm_slice, ...] = torch.FloatTensor(get_subj_nb(filenames_target)).unsqueeze(1).repeat(1,C) all_pnames[sm_slice] = filenames_target all_inter_card[sm_slice, ...] = inter_card all_card_gt[sm_slice, ...] = card_gt all_card_pred[sm_slice, ...] = card_pred if args.do_hd: all_pred[sm_slice, ...] = probs2class(predicted_mask[:,:,:,:]).cpu().detach() all_gt[sm_slice, ...] = probs2class(target_gt).detach() loss_se[sm_slice] = loss_kw[0] if len(loss_kw)>1: loss_cons[sm_slice] = loss_kw[1] loss_tot[sm_slice] = loss_kw[1]+loss_kw[0] else: loss_cons[sm_slice] = 0 loss_tot[sm_slice] = loss_kw[0] # # Save images if savedir and args.saveim and mode =="val": with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) warnings.simplefilter("ignore") predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames_target, savedir, mode, epc, False) if args.entmap: ent_map = torch.einsum("bcwh,bcwh->bwh", [-pred_probs, (pred_probs+1e-10).log()]) save_images_ent(ent_map, filenames_target, savedir,'ent_map', epc) # Logging big_slice = slice(0, done + B) # Value for current and previous batches stat_dict = {**{f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis}, **{f"SZ{n}": all_sizes[big_slice, n].mean() for n in metric_axis}, **({f"DSC_source{n}": all_dices_s[big_slice, n].mean() for n in metric_axis} if args.source_metrics else {})} size_dict = {**{f"SZ{n}": all_sizes[big_slice, n].mean() for n in metric_axis}} nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()} done += B tq_iter.set_postfix(nice_dict) if args.dice_3d and (mode == 'val'): dice_3d_log, dice_3d_sd_log,asd_3d_log, asd_3d_sd_log,hd_3d_log, hd_3d_sd_log = dice3d(all_grp, all_inter_card, all_card_gt, all_card_pred,all_pred,all_gt,all_pnames,metric_axis,args.pprint,args.do_hd,args.do_asd,best_dice3d_val) dice_2d = torch.index_select(all_dices, 1, indices).mean().cpu().numpy().item() target_vec = [dice_3d_log, dice_3d_sd_log,asd_3d_log, asd_3d_sd_log,hd_3d_log,hd_3d_sd_log,dice_2d] size_mean = torch.index_select(all_sizes2, 1, indices).mean(dim=0).cpu().numpy() size_gt_mean = torch.index_select(all_gt_sizes, 1, indices).mean(dim=0).cpu().numpy() mask_pos = torch.index_select(all_sizes2, 1, indices)!=0 gt_pos = torch.index_select(all_gt_sizes, 1, indices)!=0 size_mean_pos = torch.index_select(all_sizes2, 1, indices).sum(dim=0).cpu().numpy()/mask_pos.sum(dim=0).cpu().numpy() gt_size_mean_pos = torch.index_select(all_gt_sizes, 1, indices).sum(dim=0).cpu().numpy()/gt_pos.sum(dim=0).cpu().numpy() size_mean2 = torch.index_select(all_sizes2, 1, indices).mean(dim=0).cpu().numpy() losses_vec = [loss_se.mean().item(),loss_cons.mean().item(),loss_tot.mean().item(),size_mean.mean(),size_mean_pos.mean(),size_gt_mean.mean(),gt_size_mean_pos.mean()] if not epc%50: df_t = pd.DataFrame({ "val_ids":all_pnames, "proposal_size":all_sizes2.cpu()}) df_t.to_csv(Path(savedir,mode+str(epc)+"sizes.csv"), float_format="%.4f", index_label="epoch") return losses_vec, target_vec,source_vec
def for_back_step_comb(optimizer, mode, source_image, target_image, gt_source, labels, net, loss_fns,loss_weights, loss_fns_source, loss_weights_source, new_w, device, bounds, model_D, optimizer_D, lambda_adv_target): source_label = 0 target_label = 1 # Reset gradients if optimizer: optimizer.zero_grad() optimizer_D.zero_grad() # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # Forward with torch.set_grad_enabled(mode == "train"): #Forward pred_logits_source: Tensor = net(source_image) probs_source: Tensor = F.softmax(pred_logits_source, dim=1) pred_logits_target: Tensor = net(target_image) probs_target: Tensor = F.softmax(pred_logits_target, dim=1) predicted_mask_target: Tensor = probs2one_hot(probs_target) if new_w > 0: probs_source = resize(probs_source, new_w) probs_target = resize(probs_target, new_w) if labels[0].shape[3]!= new_w: labels = [resize(label, new_w) for label in labels ] gt_source = resize(gt_source, new_w) assert len(bounds) == len(loss_fns) == len(loss_weights) loss_adv_target = torch.zeros(1, requires_grad=True).to(device) loss_vec = [] # Losses on source ziped = zip(loss_fns_source, [gt_source], loss_weights_source) losses = [w * loss_fn(probs_source, label, torch.randn(1)) for loss_fn, label, w in ziped] loss_vec.extend([loss.item() for loss in losses]) loss_source = reduce(add, losses) # add adversarial loss of target im if not a pair of negative images if lambda_adv_target > 0 and max(predicted_mask_target[0,1,...].sum(),predicted_mask_target[0,1,...].sum()).item()>0: D_out = model_D(probs_target) loss_adv_target = d_loss_calc(D_out, source_label).to(device=device)*lambda_adv_target loss_vec.append(loss_adv_target.item()) else: loss_vec.append(0) # Constraint loss on target images (and eventually also cross entropy with FGT) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [w * loss_fn(probs_target, label, bound) for loss_fn, label, w, bound in ziped] loss_vec.extend([loss.item() for loss in losses]) loss_target = reduce(add, losses) loss = loss_source + loss_adv_target + loss_target # Backward if optimizer: loss.backward() optimizer.step() # train D if lambda_adv_target > 0 and max(predicted_mask_target[0,1,...].sum(),predicted_mask_target[0,1,...].sum()).item()>0: # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source probs_source = probs_source.detach() D_out = model_D(probs_source) loss_D_s = d_loss_calc(D_out, source_label).to(device=device)/2 if optimizer: loss_D_s.backward() # train with target probs_target = probs_target.detach() D_out_t = model_D(probs_target) loss_D_t = d_loss_calc(D_out_t, target_label).to(device=device)/2 if optimizer: loss_D_t.backward() optimizer_D.step() loss_vec.append(loss_D_s.item()+loss_D_t.item()) else: loss_vec.append(0) return probs_source, probs_target, loss, loss_vec[0], loss_vec[1], loss_vec[2], loss_vec[3], loss_vec[4]
def do_epoch( mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int, list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int, savedir: str = "", optimizer: Any = None, metric_axis: list[int] = [1], compute_3d_dice: bool = False, temperature: float = 1) -> Tuple[Tensor, Tensor, Optional[Tensor]]: assert mode in ["train", "val", "dual"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, K), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) three_d_dices: Optional[Tensor] if compute_3d_dice: three_d_dices = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) else: three_d_dices = None done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate( zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: # t0 = time() image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) filenames: list[str] = data["filenames"] assert not target.requires_grad labels: list[Tensor] = [e.to(device) for e in data["labels"]] B, C, *_ = image.shape # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad assert len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights) losses = [ w * loss_fn(pred_probs, label) for loss_fn, label, w in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_log[done_batch, j] = losses[j].detach() sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target) assert dices.shape == (B, K), (dices.shape, B, K) all_dices[sm_slice, ...] = dices if compute_3d_dice: three_d_DSC: Tensor = dice_batch(predicted_mask, target) assert three_d_DSC.shape == (K, ) three_d_dices[done_batch] = three_d_DSC # type: ignore # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict: dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} | \ ({f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis} if three_d_dices is not None else {}) loss_dict = { f"loss_{i}": loss_log[:done_batch].mean(dim=0)[i] for i in range(n_loss) } stat_dict = dsc_dict | loss_dict nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return (loss_log.detach().cpu(), all_dices.detach().cpu(), three_d_dices.detach().cpu() if three_d_dices is not None else None)
fold_all_H2.write(f"file, dice, haussdorf,connecterror \n") fold_clean_H2.write(f"file, dice, haussdorf,connecterror \n") for _,_,files in os.walk(os.path.join(root, 'in_npy')): for file in files: image = np.load(os.path.join(root,'in_npy', file)) gt = np.load(os.path.join(root,'gt_npy', file)) #print('infering {} of shape {} and classes {}, max {} and min {} '.format( file, image.shape, np.unique(gt), image.max(), image.min())) image = image.reshape(-1, 1, 256, 256) image = torch.tensor(image, dtype=torch.float) image = Variable(image, requires_grad=True) pred = net(image) pred = F.softmax(pred, dim=1) predicted_output = probs2one_hot(pred.detach()) #np.save(os.path.join(path, 'predictions', '{}'.format(file)), pred.detach().numpy()) dice = dice_coef(predicted_output, class2one_hot(torch.tensor(gt), n_classes)) hauss = haussdorf(predicted_output, class2one_hot(torch.tensor(gt), n_classes)) #pred_label = len(np.unique(label(np.array(pred.argmax(axis = 1).detach().numpy())))) #gt_label = len(np.unique(label(gt))) pred_label = len(np.unique(label(predicted_output[0][1]))) gt_label = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][1]))) error = np.abs(pred_label - gt_label) pred_label2 = len(np.unique(label(predicted_output[0][2]))) gt_label2 = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][2]))) error2 = np.abs(pred_label2 - gt_label2) print(f"{file}, {np.float(dice[0][1])}, {np.float(hauss[0][1])},{np.float(error)} \n") print(f"{file}, {np.float(dice[0][2])}, {np.float(hauss[0][2])},{np.float(error2)} \n")
def do_epoch(mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int, list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int, savedir: Path = None, optimizer: Any = None, metric_axis: list[int] = [1], requested_metrics: list[str] = None, temperature: float = 1) -> dict[str, Tensor]: assert mode in ["train", "val", "dual"] if requested_metrics is None: requested_metrics = [] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" elif mode == "dual": net.eval() desc = f">> Dual ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) epoch_metrics: dict[str, Tensor] epoch_metrics = {"dice": torch.zeros((total_images, K), dtype=torch.float32, device=device), "loss": torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)} if "3d_dsc" in requested_metrics: epoch_metrics["3d_dsc"] = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) few_axis: bool = len(metric_axis) <= 4 # time_log: np.ndarray = np.ndarray(total_iteration, dtype=np.float32) done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: # t0 = time() image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) filenames: list[str] = data["filenames"] assert not target.requires_grad labels: list[Tensor] = [e.to(device) for e in data["labels"]] bounds: list[Tensor] = [e.to(device) for e in data["bounds"]] assert len(labels) == len(bounds) B, C, *_ = image.shape samplings: list[list[Tuple[slice]]] = data["samplings"] assert len(samplings) == B assert len(samplings[0][0]) == len(image[0, 0].shape), (samplings[0][0], image[0, 0].shape) probs_receptacle: Tensor = - torch.ones_like(target, dtype=torch.float32) # -1 for unfilled mask_receptacle: Tensor = - torch.ones_like(target, dtype=torch.int32) # -1 for unfilled # Use the sampling coordinates of the first batch item assert not (len(samplings[0]) > 1 and B > 1), samplings # No subsampling if batch size > 1 loss_sub_log: Tensor = torch.zeros((len(samplings[0]), len(loss_fns)), dtype=torch.float32, device=device) for k, sampling in enumerate(samplings[0]): img_sampling = [slice(0, B), slice(0, C)] + list(sampling) label_sampling = [slice(0, B), slice(0, K)] + list(sampling) assert len(img_sampling) == len(image.shape), (img_sampling, image.shape) sub_img = image[img_sampling] # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(sub_img) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) # Used only for dice computation: predicted_mask: Tensor = probs2one_hot(pred_probs.detach()) assert not predicted_mask.requires_grad probs_receptacle[label_sampling] = pred_probs[...] mask_receptacle[label_sampling] = predicted_mask[...] assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [w * loss_fn(pred_probs, label[label_sampling], bound, filenames) for loss_fn, label, w, bound in ziped] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_sub_log[k, j] = losses[j].detach() reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0) assert reduced_loss_sublog.shape == (len(loss_fns),), (reduced_loss_sublog.shape, len(loss_fns)) epoch_metrics["loss"][done_batch, ...] = reduced_loss_sublog[...] del loss_sub_log sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(mask_receptacle, target) assert dices.shape == (B, K), (dices.shape, B, K) epoch_metrics["dice"][sm_slice, ...] = dices if "3d_dsc" in requested_metrics: three_d_DSC: Tensor = dice_batch(mask_receptacle, target) assert three_d_DSC.shape == (K,) epoch_metrics["3d_dsc"][done_batch] = three_d_DSC # type: ignore # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, data["filenames"], savedir / f"iter{epc:03d}" / mode) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches stat_dict: dict[str, Any] = {} # The order matters for the final display -- it is easy to change if few_axis: stat_dict |= {f"DSC{n}": epoch_metrics["dice"][big_slice, n].mean() for n in metric_axis} if "3d_dsc" in requested_metrics: stat_dict |= {f"3d_DSC{n}": epoch_metrics["3d_dsc"][:done_batch, n].mean() for n in metric_axis} if len(metric_axis) > 1: stat_dict |= {"DSC": epoch_metrics["dice"][big_slice, metric_axis].mean()} stat_dict |= {f"loss_{i}": epoch_metrics["loss"][:done_batch].mean(dim=0)[i] for i in range(n_loss)} nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} # t1 = time() # time_log[done_batch] = (t1 - t0) done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return {k: v.detach().cpu() for (k, v) in epoch_metrics.items()}