def dice3d(base_folder, folder, subfoldername, grp_regex, gt_folder, C): if base_folder == '': work_folder = Path(folder, subfoldername) else: work_folder = Path(base_folder,folder, subfoldername) filenames = map_(lambda p: str(p.name), work_folder.glob("*.png")) grouping_regex: Pattern = re.compile(grp_regex) stems: List[str] = [Path(filename).stem for filename in filenames] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) patients: List[str] = [match.group(0) for match in matches] unique_patients: List[str] = list(set(patients)) batch_dice = torch.zeros((len(unique_patients), C)) for i, patient in enumerate(unique_patients): patient_slices = [f for f in stems if f.startswith(patient)] w,h = [256,256] n = len(patient_slices) t_seg = np.ndarray(shape=(w, h, n)) t_gt = np.ndarray(shape=(w, h, n)) for slice in patient_slices: slice_nb = int(re.split(grp_regex, slice)[1]) seg = imageio.imread(str(work_folder)+'/'+slice+'.png') gt = imageio.imread(str(gt_folder )+'/'+ slice+'.png') if seg.shape != (w, h): seg = resize_im(seg, 36) if gt.shape != (w, h): gt = resize_im(gt, 36) seg[seg == 255] = 1 t_seg[:, :, slice_nb] = seg t_gt[:, :, slice_nb] = gt t_seg = torch.from_numpy(t_seg) t_gt = torch.from_numpy(t_gt) batch_dice[i,...] = dice_batch(class2one_hot(t_seg,3), class2one_hot(t_gt,3))[0] # do not save the interclasses etcetc return batch_dice.mean(dim=0), batch_dice.std(dim=0)
def runInference(args: argparse.Namespace, pred_folder: str): # print('>>> Loading the data') device = torch.device("cuda") if torch.cuda.is_available( ) and not args.cpu else torch.device("cpu") C: int = args.num_classes # Let's just reuse some code png_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: nd / 255, # max <= 1 lambda nd: torch.tensor(nd, dtype=torch.float32) ]) gt_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: torch.tensor(nd, dtype=torch.int64), partial(class2one_hot, C=C), itemgetter(0) ]) bounds_gen = [(lambda *a: torch.zeros(C, 1, 2)) for _ in range(2)] folders: List[Path] = [ Path(pred_folder), Path(pred_folder), Path(args.gt_folder) ] # First one is dummy names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png")) are_hots = [False, True, True] dt_set = SliceDataset( names, folders, transforms=[png_transform, gt_transform, gt_transform], debug=False, C=C, are_hots=are_hots, in_memory=False, bounds_generators=bounds_gen) sampler = PatientSampler(dt_set, args.grp_regex) loader = DataLoader(dt_set, batch_sampler=sampler, num_workers=11) # print('>>> Computing the metrics') total_iteration, total_images = len(loader), len(loader.dataset) metrics = { "all_dices": torch.zeros((total_images, C), dtype=torch.float64, device=device), "batch_dices": torch.zeros((total_iteration, C), dtype=torch.float64, device=device), "sizes": torch.zeros((total_images, 1), dtype=torch.float64, device=device) } desc = f">> Computing" tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) done: int = 0 for j, (filenames, _, pred, gt, _) in tq_iter: B = len(pred) pred = pred.to(device) gt = gt.to(device) assert simplex(pred) and sset(pred, [0, 1]) assert simplex(gt) and sset(gt, [0, 1]) dices: Tensor = dice_coef(pred, gt) b_dices: Tensor = dice_batch(pred, gt) assert dices.shape == (B, C) assert b_dices.shape == (C, ), b_dices.shape sm_slice = slice(done, done + B) # Values only for current batch metrics["all_dices"][sm_slice, ...] = dices metrics["sizes"][sm_slice, :] = torch.einsum("bwh->b", gt[:, 1, ...])[..., None] metrics["batch_dices"][j] = b_dices done += B print(f">>> {pred_folder}") for key, v in metrics.items(): print(key, map_("{:.4f}".format, v.mean(dim=0)))
def do_epoch(mode: str, net: Any, device: Any, loader: DataLoader, epc: int, loss_fns: List[Callable], loss_weights: List[float], C: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_haussdorf: bool = False) \ -> Tuple[Tensor, Tensor, Tensor, Tensor]: assert mode in ["train", "val"] L: int = len(loss_fns) if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration, total_images = len(loader), len(loader.dataset) all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration), dtype=torch.float32, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) done: int = 0 for j, data in tq_iter: data[1:] = [e.to(device) for e in data[1:]] # Move all tensors to device filenames, image, target = data[:3] labels = data[3:3 + L] bounds = data[3 + L:] assert len(labels) == len(bounds) B = len(image) # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert len(bounds) == len(loss_fns) == len(loss_weights) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [ w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics loss_log[j] = loss.detach() sm_slice = slice(done, done + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target.detach()) assert dices.shape == (B, C), (dices.shape, B, C) all_dices[sm_slice, ...] = dices if B > 1 and mode == "val": batch_dice: Tensor = dice_batch(predicted_mask, target.detach()) assert batch_dice.shape == (C, ), (batch_dice.shape, B, C) batch_dices[j] = batch_dice if compute_haussdorf: haussdorf_res: Tensor = haussdorf(predicted_mask.detach(), target.detach()) assert haussdorf_res.shape == (B, C) haussdorf_log[sm_slice] = haussdorf_res # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done + B) # Value for current and previous batches dsc_dict = { f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis } hauss_dict = { f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis } if compute_haussdorf else {} batch_dict = { f"bDSC{n}": batch_dices[:j, n].mean() for n in metric_axis } if B > 1 and mode == "val" else {} mean_dict = { "DSC": all_dices[big_slice, metric_axis].mean(), "HD": haussdorf_log[big_slice, metric_axis].mean() } if len(metric_axis) > 1 else {} stat_dict = { **dsc_dict, **hauss_dict, **mean_dict, **batch_dict, "loss": loss_log[:j].mean() } nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} tq_iter.set_postfix(nice_dict) done += B print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return loss_log, all_dices, batch_dices, haussdorf_log
def do_epoch(mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int, list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], C: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_haussdorf: bool = False, compute_miou: bool = False, temperature: float = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tuple[None, Tensor]]: assert mode in ["train", "val"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) iiou_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) intersections: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) unions: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device) few_axis: bool = len(metric_axis) <= 3 done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)): L: int = len(loss_fns) for data in loader: data[1:] = [e.to(device) for e in data[1:]] # Move all tensors to device filenames, image, target = data[:3] assert not target.requires_grad labels = data[3:3 + L] bounds = data[3 + L:] assert len(labels) == len(bounds) B = len(image) # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot(pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped] loss = reduce(add, losses) assert loss.shape == (), loss.shape # if epc >= 1 and False: # import matplotlib.pyplot as plt # _, axes = plt.subplots(nrows=1, ncols=3) # axes[0].imshow(image[0, 0].cpu().numpy(), cmap='gray') # axes[0].contour(target[0, 1].cpu().numpy(), cmap='rainbow') # pred_np = pred_probs[0, 1].detach().cpu().numpy() # axes[1].imshow(pred_np) # bins = np.linspace(0, 1, 50) # axes[2].hist(pred_np.flatten(), bins) # print(bounds) # print(bounds[2].cpu().numpy()) # print(bounds[2][0, 1].cpu().numpy()) # print(pred_np.sum()) # plt.show() # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics # loss_log[done_batch] = loss.detach() for j in range(len(loss_fns)): loss_log[done_batch, j] = losses[j].detach() sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target) assert dices.shape == (B, C), (dices.shape, B, C) all_dices[sm_slice, ...] = dices if B > 1 and mode == "val": batch_dice: Tensor = dice_batch(predicted_mask, target) assert batch_dice.shape == (C,), (batch_dice.shape, B, C) batch_dices[done_batch] = batch_dice if compute_haussdorf: haussdorf_res: Tensor = haussdorf(predicted_mask, target) assert haussdorf_res.shape == (B, C) haussdorf_log[sm_slice] = haussdorf_res if compute_miou: IoUs: Tensor = iIoU(predicted_mask, target) assert IoUs.shape == (B, C), IoUs.shape iiou_log[sm_slice] = IoUs intersections[sm_slice] = inter_sum(predicted_mask, target) unions[sm_slice] = union_sum(predicted_mask, target) # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {} hauss_dict = {f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis} \ if compute_haussdorf and few_axis else {} batch_dict = {f"bDSC{n}": batch_dices[:done_batch, n].mean() for n in metric_axis} \ if B > 1 and mode == "val" and few_axis else {} miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(), f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \ if compute_miou else {} if len(metric_axis) > 1: mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()} if compute_haussdorf: mean_dict["HD"] = haussdorf_log[big_slice, metric_axis].mean() else: mean_dict = {} stat_dict = {**miou_dict, **dsc_dict, **hauss_dict, **mean_dict, **batch_dict, "loss": loss_log[:done_batch].mean()} nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) if compute_miou: mIoUs: Tensor = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)) assert mIoUs.shape == (C,), mIoUs.shape else: mIoUs = None if not few_axis and False: print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}") print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}") if mIoUs: print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}") return loss_log, all_dices, batch_dices, haussdorf_log, mIoUs
def do_epoch( mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int, list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int, savedir: str = "", optimizer: Any = None, metric_axis: list[int] = [1], compute_3d_dice: bool = False, temperature: float = 1) -> Tuple[Tensor, Tensor, Optional[Tensor]]: assert mode in ["train", "val", "dual"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, K), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) three_d_dices: Optional[Tensor] if compute_3d_dice: three_d_dices = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) else: three_d_dices = None done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate( zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: # t0 = time() image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) filenames: list[str] = data["filenames"] assert not target.requires_grad labels: list[Tensor] = [e.to(device) for e in data["labels"]] B, C, *_ = image.shape # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad assert len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights) losses = [ w * loss_fn(pred_probs, label) for loss_fn, label, w in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_log[done_batch, j] = losses[j].detach() sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(predicted_mask, target) assert dices.shape == (B, K), (dices.shape, B, K) all_dices[sm_slice, ...] = dices if compute_3d_dice: three_d_DSC: Tensor = dice_batch(predicted_mask, target) assert three_d_DSC.shape == (K, ) three_d_dices[done_batch] = three_d_DSC # type: ignore # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict: dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} | \ ({f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis} if three_d_dices is not None else {}) loss_dict = { f"loss_{i}": loss_log[:done_batch].mean(dim=0)[i] for i in range(n_loss) } stat_dict = dsc_dict | loss_dict nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return (loss_log.detach().cpu(), all_dices.detach().cpu(), three_d_dices.detach().cpu() if three_d_dices is not None else None)
def do_epoch(mode: str, args, net, device, use_cuda, loader, optimizer, num_classes, epoch): totalImages = len(loader) if mode == "train": net.train() desc = f">> Training ({epoch})" elif mode == "val": net.eval() desc = f">> Validation ({epoch})" total_iteration, total_images = len(loader), len(loader.dataset) all_dices: Tensor = torch.zeros((total_images, num_classes), dtype=eval(args.dtype), device=device) batch_dices: Tensor = torch.zeros((total_iteration, num_classes), dtype=eval(args.dtype), device=device) loss_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) entropy_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) KL_log: Tensor = torch.zeros((total_images), dtype=eval(args.dtype), device=device) tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) done: int = 0 for j, data in tq_iter: image_f, image_i, image_d, image_o, image_w, image_c, labels, img_names = data #image_f=image_f.type(torch.FloatTensor)/65535. #image_f = image_f.type(torch.FloatTensor)/65535. #image_i = image_i.type(torch.FloatTensor)/65535. #image_d = image_d.type(torch.FloatTensor)/65535. #image_o = image_o.type(torch.FloatTensor)/65535. #image_w = image_w.type(torch.FloatTensor)/65535. #image_c = image_c.type(torch.FloatTensor)/65535. MRI: Tensor = torch.zeros((1, 6, image_f.size()[2], image_f.size()[3]), dtype=eval(args.dtype)) MRI = torch.cat((image_f, image_i, image_d, image_o, image_w, image_c), dim=1) MRI = MRI.type( torch.FloatTensor ) / 65535.0 #.type(eval(args.dtype)) #.type(torch.FloatTensor) targets = torch.cat((1 - labels, labels), dim=1) #.type(torch.LongTensor) B = len(image_f) #print(type(labels)) #MRI = torch.cat((image_f,image_i,image_d,image_w),dim=1) if use_cuda: MRI, targets = MRI.to(device), targets.to(device) # forward outputs = net(MRI) pred_probs = F.softmax(outputs, dim=1) predicted_mask = probs2one_hot(pred_probs) entropy = crossEntropy_f(pred_probs, targets) pred_probs_aver: Tensor = torch.sum(pred_probs, dim=(2, 3)) pred_probs_aver = pred_probs_aver / torch.sum(targets).float() target_aver: Tensor = torch.sum(targets, dim=(2, 3)).float() target_aver = target_aver / torch.sum(targets).float() KL_loss = args.lam * kl(target_aver, pred_probs_aver) loss = entropy + KL_loss if mode == "train": # zero the parameter gradients8544 optimizer.zero_grad() # backward + optimize loss.backward() optimizer.step() # Compute and log metrics dices: Tensor = dice_coef(predicted_mask.detach(), targets.type(torch.cuda.IntTensor).detach()) batch_dice: Tensor = dice_batch( predicted_mask.detach(), targets.type(torch.cuda.IntTensor).detach()) assert batch_dice.shape == (num_classes, ) and dices.shape == ( B, num_classes), (batch_dice.shape, dices.shape, B, num_classes) sm_slice = slice(done, done + B) # Values only for current batch all_dices[sm_slice, ...] = dices entropy_log[sm_slice] = entropy.detach() loss_log[sm_slice] = loss.detach() KL_log[sm_slice] = KL_loss.detach() batch_dices[j] = batch_dice # Logging big_slice = slice(0, done + B) # Value for current and previous batches stat_dict = { "dice": all_dices[big_slice, -1].mean(), "total loss": loss_log[big_slice].mean(), "entropy loss": entropy_log[big_slice].mean(), "KL loss": KL_log[big_slice].mean(), "b dice": batch_dices[:j + 1, -1].mean() } nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()} done += B tq_iter.set_postfix(nice_dict) return loss_log, entropy_log, KL_log, all_dices, batch_dices
def main() -> None: args = get_args() iterations_paths: List[Path] = sorted(Path(args.basefolder).glob("iter*")) # print(iterations_paths) print(f">>> Found {len(iterations_paths)} epoch folders") # Handle gracefully if not all folders are there (early stop) EPC: int = args.n_epoch if args.n_epoch >= 0 else len(iterations_paths) K: int = args.num_classes # Get the patient number, and image names, from the GT folder gt_path: Path = Path(args.gt_folder) names: List[str] = map_(lambda p: str(p.name), gt_path.glob("*")) n_img: int = len(names) grouping_regex: Pattern = re.compile(args.grp_regex) stems: List[str] = [Path(filename).stem for filename in names] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) # type: ignore patients: List[str] = [match.group(1) for match in matches] unique_patients: List[str] = list(set(patients)) n_patients: int = len(unique_patients) print( f">>> Found {len(unique_patients)} unique patients out of {n_img} images ; regex: {args.grp_regex}" ) # from pprint import pprint # pprint(unique_patients) # First, quickly assert all folders have the same numbers of predited images n_img_epoc: List[int] = [ len(list(Path(p, "val").glob("*.png"))) for p in iterations_paths ] assert len(set(n_img_epoc)) == 1 assert all( len(list(Path(p, "val").glob("*.png"))) == n_img for p in iterations_paths) metrics: Dict['str', Tensor] = {} if '3d_dsc' in args.metrics: metrics['3d_dsc'] = torch.zeros((EPC, n_patients, K), dtype=torch.float32) print(f">> Will compute {'3d_dsc'} metric") if '3d_hausdorff' in args.metrics: metrics['3d_hausdorff'] = torch.zeros((EPC, n_patients, K), dtype=torch.float32) print(f">> Will compute {'3d_hausdorff'} metric") gen_dataset = partial( SliceDataset, transforms=[png_transform, gt_transform, gt_transform], are_hots=[False, True, True], K=K, in_memory=False, bounds_generators=[(lambda *a: torch.zeros(K, 1, 2)) for _ in range(1)], box_prior=False, box_priors_arg='{}', dimensions=2) data_loader = partial(DataLoader, num_workers=cpu_count(), pin_memory=False, collate_fn=custom_collate) # Will replace live dataset.folders and call again load_images to update dataset.files print(gt_path, gt_path, Path(iterations_paths[0], 'val')) dataset: SliceDataset = gen_dataset( names, [gt_path, gt_path, Path(iterations_paths[0], 'val')]) sampler: PatientSampler = PatientSampler(dataset, args.grp_regex, shuffle=False) dataloader: DataLoader = data_loader(dataset, batch_sampler=sampler) current_path: Path for e, current_path in enumerate(iterations_paths): dataset.folders = [gt_path, gt_path, Path(current_path, 'val')] dataset.files = SliceDataset.load_images(dataset.folders, dataset.filenames, False) print(f">>> Doing epoch {str(current_path)}") for i, data in enumerate(tqdm(dataloader, leave=None)): target: Tensor = data["gt"] prediction: Tensor = data["labels"][0] assert target.shape == prediction.shape if '3d_dsc' in args.metrics: dsc: Tensor = dice_batch(target, prediction) assert dsc.shape == (K, ) metrics['3d_dsc'][e, i, :] = dsc if '3d_hausdorff' in args.metrics: np_pred: np.ndarray = prediction[:, 1, :, :].cpu().numpy() np_target: np.ndarray = target[:, 1, :, :].cpu().numpy() if np_pred.sum() > 0: hd_: float = hd(np_pred, np_target) metrics["3d_hausdorff"][e, i, 1] = hd_ else: x, y, z = np_pred.shape metrics["3d_hausdorff"][e, i, 1] = (x**2 + y**2 + z**2)**0.5 for metric in args.metrics: # For now, hardcode the fact we care about class 1 only print(f">> {metric}: {metrics[metric][e].mean(dim=0)[1]:.04f}") k: str el: Tensor for k, el in metrics.items(): np.save(Path(args.basefolder, f"val_{k}.npy"), el.cpu().numpy())
def do_epoch( mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int, list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], K: int, savedir: str = "", optimizer: Any = None, metric_axis: List[int] = [1], compute_hausdorff: bool = False, compute_miou: bool = False, compute_3d_dice: bool = False, temperature: float = 1 ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: assert mode in ["train", "val"] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) all_dices: Tensor = torch.zeros((total_images, K), dtype=torch.float32, device=device) loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) iiou_log: Optional[Tensor] intersections: Optional[Tensor] unions: Optional[Tensor] if compute_miou: iiou_log = torch.zeros((total_images, K), dtype=torch.float32, device=device) intersections = torch.zeros((total_images, K), dtype=torch.float32, device=device) unions = torch.zeros((total_images, K), dtype=torch.float32, device=device) else: iiou_log = None intersections = None unions = None three_d_dices: Optional[Tensor] if compute_3d_dice: three_d_dices = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) else: three_d_dices = None hausdorff_log: Optional[Tensor] if compute_hausdorff: hausdorff_log = torch.zeros((total_images, K), dtype=torch.float32, device=device) else: hausdorff_log = None few_axis: bool = len(metric_axis) <= 3 done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate( zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) spacings: Tensor = data["spacings"] # Keep that one on CPU assert not target.requires_grad labels: List[Tensor] = [e.to(device) for e in data["labels"]] bounds: List[Tensor] = [e.to(device) for e in data["bounds"]] box_priors: List[List[Tuple[ Tensor, Tensor]]] # one more level for the batch box_priors = [[(m.to(device), b.to(device)) for (m, b) in B] for B in data["box_priors"]] assert len(labels) == len(bounds) B, C, *_ = image.shape samplings: List[List[Tuple[slice]]] = data["samplings"] assert len(samplings) == B assert len(samplings[0][0]) == len( image[0, 0].shape), (samplings[0][0], image[0, 0].shape) probs_receptacle: Tensor = -torch.ones_like( target, dtype=torch.float32) # -1 for unfilled mask_receptacle: Tensor = -torch.ones_like( target, dtype=torch.int32) # -1 for unfilled # Use the sampling coordinates of the first batch item assert not (len(samplings[0]) > 1 and B > 1), samplings # No subsampling if batch size > 1 loss_sub_log: Tensor = torch.zeros( (len(samplings[0]), len(loss_fns)), dtype=torch.float32, device=device) for k, sampling in enumerate(samplings[0]): img_sampling = [slice(0, B), slice(0, C)] + list(sampling) label_sampling = [slice(0, B), slice(0, K)] + list(sampling) assert len(img_sampling) == len(image.shape), (img_sampling, image.shape) sub_img = image[img_sampling] # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(sub_img) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) predicted_mask: Tensor = probs2one_hot( pred_probs.detach()) # Used only for dice computation assert not predicted_mask.requires_grad probs_receptacle[label_sampling] = pred_probs[...] mask_receptacle[label_sampling] = predicted_mask[...] assert len(bounds) == len(loss_fns) == len( loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [ w * loss_fn(pred_probs, label[label_sampling], bound, box_priors) for loss_fn, label, w, bound in ziped ] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_sub_log[k, j] = losses[j].detach() reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0) assert reduced_loss_sublog.shape == (len(loss_fns), ), ( reduced_loss_sublog.shape, len(loss_fns)) loss_log[done_batch, ...] = reduced_loss_sublog[...] del loss_sub_log sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(mask_receptacle, target) assert dices.shape == (B, K), (dices.shape, B, K) all_dices[sm_slice, ...] = dices if compute_3d_dice: three_d_DSC: Tensor = dice_batch(mask_receptacle, target) assert three_d_DSC.shape == (K, ) three_d_dices[done_batch] = three_d_DSC # type: ignore if compute_hausdorff: hausdorff_res: Tensor try: hausdorff_res = hausdorff(mask_receptacle, target, spacings) except RuntimeError: hausdorff_res = torch.zeros((B, K), device=device) assert hausdorff_res.shape == (B, K) hausdorff_log[sm_slice] = hausdorff_res # type: ignore if compute_miou: IoUs: Tensor = iIoU(mask_receptacle, target) assert IoUs.shape == (B, K), IoUs.shape iiou_log[sm_slice] = IoUs # type: ignore intersections[sm_slice] = inter_sum(mask_receptacle, target) # type: ignore unions[sm_slice] = union_sum(mask_receptacle, target) # type: ignore # if False and target[0, 1].sum() > 0: # Useful template for quick and dirty inspection # import matplotlib.pyplot as plt # from pprint import pprint # from mpl_toolkits.axes_grid1 import ImageGrid # from utils import soft_length # print(data["filenames"]) # pprint(data["bounds"]) # pprint(soft_length(mask_receptacle)) # fig = plt.figure() # fig.clear() # grid = ImageGrid(fig, 211, nrows_ncols=(1, 2)) # grid[0].imshow(data["images"][0, 0], cmap="gray") # grid[0].contour(data["gt"][0, 1], cmap='jet', alpha=.75, linewidths=2) # grid[1].imshow(data["images"][0, 0], cmap="gray") # grid[1].contour(mask_receptacle[0, 1], cmap='jet', alpha=.75, linewidths=2) # plt.show() # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, data["filenames"], savedir, mode, epc) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches dsc_dict: Dict if few_axis: dsc_dict = { **{ f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis }, **({ f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis } if three_d_dices is not None else {}) } else: dsc_dict = {} # dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {} hauss_dict = {f"HD{n}": hausdorff_log[big_slice, n].mean() for n in metric_axis} \ if hausdorff_log is not None and few_axis else {} miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(), f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \ if iiou_log is not None and intersections is not None and unions is not None else {} if len(metric_axis) > 1: mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()} if hausdorff_log: mean_dict["HD"] = hausdorff_log[big_slice, metric_axis].mean() else: mean_dict = {} stat_dict = { **miou_dict, **dsc_dict, **hauss_dict, **mean_dict, "loss": loss_log[:done_batch].mean() } nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) mIoUs: Optional[Tensor] if intersections and unions: mIoUs = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)) assert mIoUs.shape == (K, ), mIoUs.shape else: mIoUs = None if not few_axis and False: print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}") print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}") if mIoUs: print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}") return ( loss_log.detach().cpu(), all_dices.detach().cpu(), hausdorff_log.detach().cpu() if hausdorff_log is not None else None, mIoUs.detach().cpu() if mIoUs is not None else None, three_d_dices.detach().cpu() if three_d_dices is not None else None)
def do_epoch(mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int, list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int, savedir: Path = None, optimizer: Any = None, metric_axis: list[int] = [1], requested_metrics: list[str] = None, temperature: float = 1) -> dict[str, Tensor]: assert mode in ["train", "val", "dual"] if requested_metrics is None: requested_metrics = [] if mode == "train": net.train() desc = f">> Training ({epc})" elif mode == "val": net.eval() desc = f">> Validation ({epc})" elif mode == "dual": net.eval() desc = f">> Dual ({epc})" total_iteration: int = sum(len(loader) for loader in loaders) # U total_images: int = sum(len(loader.dataset) for loader in loaders) # D n_loss: int = max(map(len, list_loss_fns)) epoch_metrics: dict[str, Tensor] epoch_metrics = {"dice": torch.zeros((total_images, K), dtype=torch.float32, device=device), "loss": torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)} if "3d_dsc" in requested_metrics: epoch_metrics["3d_dsc"] = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) few_axis: bool = len(metric_axis) <= 4 # time_log: np.ndarray = np.ndarray(total_iteration, dtype=np.float32) done_img: int = 0 done_batch: int = 0 tq_iter = tqdm_(total=total_iteration, desc=desc) for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)): for data in loader: # t0 = time() image: Tensor = data["images"].to(device) target: Tensor = data["gt"].to(device) filenames: list[str] = data["filenames"] assert not target.requires_grad labels: list[Tensor] = [e.to(device) for e in data["labels"]] bounds: list[Tensor] = [e.to(device) for e in data["bounds"]] assert len(labels) == len(bounds) B, C, *_ = image.shape samplings: list[list[Tuple[slice]]] = data["samplings"] assert len(samplings) == B assert len(samplings[0][0]) == len(image[0, 0].shape), (samplings[0][0], image[0, 0].shape) probs_receptacle: Tensor = - torch.ones_like(target, dtype=torch.float32) # -1 for unfilled mask_receptacle: Tensor = - torch.ones_like(target, dtype=torch.int32) # -1 for unfilled # Use the sampling coordinates of the first batch item assert not (len(samplings[0]) > 1 and B > 1), samplings # No subsampling if batch size > 1 loss_sub_log: Tensor = torch.zeros((len(samplings[0]), len(loss_fns)), dtype=torch.float32, device=device) for k, sampling in enumerate(samplings[0]): img_sampling = [slice(0, B), slice(0, C)] + list(sampling) label_sampling = [slice(0, B), slice(0, K)] + list(sampling) assert len(img_sampling) == len(image.shape), (img_sampling, image.shape) sub_img = image[img_sampling] # Reset gradients if optimizer: optimizer.zero_grad() # Forward pred_logits: Tensor = net(sub_img) pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) # Used only for dice computation: predicted_mask: Tensor = probs2one_hot(pred_probs.detach()) assert not predicted_mask.requires_grad probs_receptacle[label_sampling] = pred_probs[...] mask_receptacle[label_sampling] = predicted_mask[...] assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels) ziped = zip(loss_fns, labels, loss_weights, bounds) losses = [w * loss_fn(pred_probs, label[label_sampling], bound, filenames) for loss_fn, label, w, bound in ziped] loss = reduce(add, losses) assert loss.shape == (), loss.shape # Backward if optimizer: loss.backward() optimizer.step() # Compute and log metrics for j in range(len(loss_fns)): loss_sub_log[k, j] = losses[j].detach() reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0) assert reduced_loss_sublog.shape == (len(loss_fns),), (reduced_loss_sublog.shape, len(loss_fns)) epoch_metrics["loss"][done_batch, ...] = reduced_loss_sublog[...] del loss_sub_log sm_slice = slice(done_img, done_img + B) # Values only for current batch dices: Tensor = dice_coef(mask_receptacle, target) assert dices.shape == (B, K), (dices.shape, B, K) epoch_metrics["dice"][sm_slice, ...] = dices if "3d_dsc" in requested_metrics: three_d_DSC: Tensor = dice_batch(mask_receptacle, target) assert three_d_DSC.shape == (K,) epoch_metrics["3d_dsc"][done_batch] = three_d_DSC # type: ignore # Save images if savedir: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, data["filenames"], savedir / f"iter{epc:03d}" / mode) # Logging big_slice = slice(0, done_img + B) # Value for current and previous batches stat_dict: dict[str, Any] = {} # The order matters for the final display -- it is easy to change if few_axis: stat_dict |= {f"DSC{n}": epoch_metrics["dice"][big_slice, n].mean() for n in metric_axis} if "3d_dsc" in requested_metrics: stat_dict |= {f"3d_DSC{n}": epoch_metrics["3d_dsc"][:done_batch, n].mean() for n in metric_axis} if len(metric_axis) > 1: stat_dict |= {"DSC": epoch_metrics["dice"][big_slice, metric_axis].mean()} stat_dict |= {f"loss_{i}": epoch_metrics["loss"][:done_batch].mean(dim=0)[i] for i in range(n_loss)} nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} # t1 = time() # time_log[done_batch] = (t1 - t0) done_img += B done_batch += 1 tq_iter.set_postfix({**nice_dict, "loader": str(i)}) tq_iter.update(1) tq_iter.close() print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) return {k: v.detach().cpu() for (k, v) in epoch_metrics.items()}