def metrics_calc(all_grp,all_inter_card,all_card_gt,all_card_pred, metric_axis,pprint=False): _, C = all_card_gt.shape unique_patients = torch.unique(all_grp) batch_dice = torch.zeros((len(unique_patients), C)) batch_avd = torch.zeros((len(unique_patients), C)) for i, p in enumerate(unique_patients): inter_card_p = torch.einsum("bc->c", [torch.masked_select(all_inter_card, all_grp == p).reshape((-1, C))]) card_gt_p= torch.einsum("bc->c", [torch.masked_select(all_card_gt, all_grp == p).reshape((-1, C))]) card_pred_p= torch.einsum("bc->c", [torch.masked_select(all_card_pred, all_grp == p).reshape((-1, C))]) dice_3d = (2 * inter_card_p + 1e-8) / ((card_pred_p + card_gt_p)+ 1e-8) avd = (card_pred_p + card_gt_p - 2 * inter_card_p + 1e-8) / (card_gt_p + 1e-8) if pprint: dice_3d = torch.round(dice_3d * 10**2) / (10**2) print(p,dice_3d) batch_dice[i,...] = dice_3d batch_avd[i,...] = avd indices = torch.tensor(metric_axis) dice_3d = torch.index_select(batch_dice, 1, indices) avd = torch.index_select(batch_avd, 1, indices) dice_3d_mean = dice_3d.mean(dim=0) avd_mean = avd.mean(dim=0) print('metric_axis dice',dice_3d_mean) dice_3d_sd = dice_3d.std(dim=0) avd_sd = avd.std(dim=0) [dice_3d, dice_3d_sd] = map_(lambda t: t.mean(), [dice_3d_mean, dice_3d_sd]) [avd, avd_sd] = map_(lambda t: t.mean(), [avd_mean, avd_sd]) return dice_3d.item(), dice_3d_sd.item(), avd.item(), avd_sd.item()
def dice3d(base_folder, folder, subfoldername, grp_regex, gt_folder, C): if base_folder == '': work_folder = Path(folder, subfoldername) else: work_folder = Path(base_folder,folder, subfoldername) filenames = map_(lambda p: str(p.name), work_folder.glob("*.png")) grouping_regex: Pattern = re.compile(grp_regex) stems: List[str] = [Path(filename).stem for filename in filenames] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) patients: List[str] = [match.group(0) for match in matches] unique_patients: List[str] = list(set(patients)) batch_dice = torch.zeros((len(unique_patients), C)) for i, patient in enumerate(unique_patients): patient_slices = [f for f in stems if f.startswith(patient)] w,h = [256,256] n = len(patient_slices) t_seg = np.ndarray(shape=(w, h, n)) t_gt = np.ndarray(shape=(w, h, n)) for slice in patient_slices: slice_nb = int(re.split(grp_regex, slice)[1]) seg = imageio.imread(str(work_folder)+'/'+slice+'.png') gt = imageio.imread(str(gt_folder )+'/'+ slice+'.png') if seg.shape != (w, h): seg = resize_im(seg, 36) if gt.shape != (w, h): gt = resize_im(gt, 36) seg[seg == 255] = 1 t_seg[:, :, slice_nb] = seg t_gt[:, :, slice_nb] = gt t_seg = torch.from_numpy(t_seg) t_gt = torch.from_numpy(t_gt) batch_dice[i,...] = dice_batch(class2one_hot(t_seg,3), class2one_hot(t_gt,3))[0] # do not save the interclasses etcetc return batch_dice.mean(dim=0), batch_dice.std(dim=0)
def __call__(self, epoch: int, optimizer: Any, loss_fns: List[List[Callable]], loss_weights: List[List[float]]) \ -> Tuple[float, List[List[Callable]], List[List[float]]]: assert len(self.to_add) == len(loss_weights) new_weights: List[List[float]] = map_(lambda w: map_(uc_(add), zip(w, self.to_add)), loss_weights) print(f"Loss weights went from {loss_weights} to {new_weights}") return optimizer, loss_fns, new_weights
def run(args: argparse.Namespace) -> None: assert len(args.folders) <= len(colors) if len(args.columns) > 1: raise NotImplementedError( "Only 1 columns at a time is handled for now") paths: List[Path] = [Path(f, args.filename) for f in args.folders] arrays: List[np.ndarray] = map_(np.load, paths) metric_name: str = paths[0].stem assert len(set( a.shape for a in arrays)) == 1 # All arrays should have the same shape if len(arrays[0].shape) == 2: arrays = map_(lambda a: a[..., np.newaxis], arrays) # Add an extra dimension for column selection fig = plt.figure(figsize=(14, 9)) ax = fig.gca() ax.set_ylim([0, 1]) # ax.set_xlim([0, len(args.folders) + 1]) ax.set_xlabel(metric_name) ax.set_ylabel("Percentage") ax.grid(True, axis='y') ax.set_title(f"{metric_name} moustaches") # bins = np.linspace(0, 1, args.nbins) for i, (a, c, p) in enumerate(zip(arrays, colors, paths)): for k in args.columns: mean_a = a[..., k].mean(axis=1) best_epoch: int = np.argmax(mean_a) # values = a[args.epc, :, k] values = a[best_epoch, :, k] ax.boxplot(values, positions=[i + 1], manage_xticks=False, showmeans=True, meanline=True, whis=[5, 95]) print( f"{p.parent.stem:10}: min {values.min():.03f} 25{np.percentile(values, 25):.03f} " + f"avg {values.mean():.03f} 75 {np.percentile(values, 75):.03f} max {values.max():.03f} at epc {best_epoch}" ) # ax.legend() ax.set_xticklabels([""] + map_(lambda p: p.parent.stem, paths)) ax.set_xticks(np.mgrid[0:len(args.folders) + 1]) ax.set_yticks(np.mgrid[0:1.1:.1]) fig.tight_layout() if args.savefig: fig.savefig(args.savefig) if not args.headless: plt.show()
def __call__(self, epoch: int, optimizer: Any, loss_fns: List[List[Callable]], loss_weights: List[List[float]]) \ -> Tuple[float, List[List[Callable]], List[List[float]]]: def update(loss: Any): if loss.__class__.__name__ == self.target_loss: loss.t *= self.mu return loss return optimizer, map_(lambda l: map_(update, l), loss_fns), loss_weights
def __call__(self, epoch: int, optimizer: Any, loss_fns: list[list[Callable]], loss_weights: list[list[float]]) \ -> Tuple[float, list[list[Callable]], list[list[float]]]: assert len(self.to_add) == len(loss_weights[0]) if len(loss_weights) > 1: raise NotImplementedError new_weights: list[list[float]] = map_( lambda w: map_(uc_(add), zip(w, self.to_add)), loss_weights) print(f"Loss weights went from {loss_weights} to {new_weights}") return optimizer, loss_fns, new_weights
def run(args: argparse.Namespace) -> None: assert len(args.folders) <= len(colors) if len(args.columns) > 1: raise NotImplementedError( "Only 1 columns at a time is handled for now") paths: List[Path] = [Path(f, args.filename) for f in args.folders] arrays: List[np.ndarray] = map_(np.load, paths) metric_name: str = paths[0].stem if len(arrays[0].shape) == 2: arrays = map_(lambda a: a[..., np.newaxis], arrays) epoch, _, class_ = arrays[0].shape for a in arrays[1:]: ea, _, ca = a.shape assert epoch == ea and class_ == ca fig = plt.figure(figsize=(14, 9)) ax = fig.gca() # ax.set_ylim([0, 1]) ax.set_xlim([0, 1]) ax.set_xlabel(metric_name) ax.set_ylabel("Percentage") ax.grid(True, axis='y') ax.set_title(f"{metric_name} histograms") bins = np.linspace(0, 1, args.nbins) for a, c, p in zip(arrays, colors, paths): for k in args.columns: mean_a = a[..., k].mean(axis=1) best_epoch: int = np.argmax(mean_a) # values = a[args.epc, :, k] values = a[best_epoch, :, k] ax.hist(values, bins, alpha=0.5, label=f"{p.parent.name}-{k}", color=c) ax.legend() fig.tight_layout() if args.savefig: fig.savefig(args.savefig) if not args.headless: plt.show()
def setup( args, n_class, dtype ) -> Tuple[Any, Any, Any, List[Callable], List[float], List[Callable], List[float], Callable]: print(">>> Setting up") cpu: bool = args.cpu or not torch.cuda.is_available() device = torch.device("cpu") if cpu else torch.device("cuda") if args.model_weights: if cpu: net = torch.load(args.model_weights, map_location='cpu') else: net = torch.load(args.model_weights) else: net_class = getattr(__import__('networks'), args.network) net = net_class(1, n_class).type(dtype).to(device) net.apply(weights_init) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=args.l_rate, betas=(0.9, 0.999)) print(args.target_losses) losses = eval(args.target_losses) loss_fns: List[Callable] = [] for loss_name, loss_params, _, _, fn, _ in losses: loss_class = getattr(__import__('losses'), loss_name) loss_fns.append(loss_class(**loss_params, dtype=dtype, fn=fn)) loss_weights = map_(itemgetter(5), losses) print(args.source_losses) losses_source = eval(args.source_losses) loss_fns_source: List[Callable] = [] for loss_name, loss_params, _, _, fn, _ in losses_source: loss_class = getattr(__import__('losses'), loss_name) loss_fns_source.append(loss_class(**loss_params, dtype=dtype, fn=fn)) loss_weights_source = map_(itemgetter(5), losses_source) if args.scheduler: scheduler = getattr(__import__('scheduler'), args.scheduler)(**eval(args.scheduler_params)) else: scheduler = '' return net, optimizer, device, loss_fns, loss_weights, loss_fns_source, loss_weights_source, scheduler
def __init__(self, root: str, train_mode: bool, K=2, in_memory: bool = False, crf_batch: int = 4, debug=False, batch_size=1, img_size=(256, 256)) -> None: filenames: List[str] = map_(lambda p: str(p.name), Path(root, 'img').glob("*")) folders: List[Path] = [Path(root, "img"), Path(root, "gt"), Path(root, "box")] are_hots: List[bool] = [False, True, True] self.crf_batch: int = crf_batch self.train_mode: bool = train_mode self.img_size = tuple(img_size) self.batch_size = batch_size super().__init__(filenames, folders, are_hots, K=K, transforms=[png_transform, gt_transform, gt_transform], bounds_generators=[(lambda *a: torch.zeros(K, 1, 2)) for _ in range(1)], debug=debug, debug_size=100) # self.uint_transform = unnormalized_color_transform() self.orig_boxes: List = [copy(e) for e in self.files[2]] assert len(self.orig_boxes) == len(self.filenames), (len(self.orig_boxes), len(self.filenames)) self.loader: DataLoader = DataLoader(self, shuffle=False, drop_last=False, pin_memory=True, collate_fn=custom_collate, batch_size=4) from torch.utils.data import SequentialSampler assert isinstance(self.loader.sampler, SequentialSampler), self.loader._index_sampler # type: ignore
def dice3dn(all_grp,all_inter_card,all_card_gt,all_card_pred, metric_axis,pprint=False): #print(all_card_gt.shape) _,C = all_card_gt.shape unique_patients = torch.unique(all_grp) #print(sum(unique_patients == 0)) unique_patients = unique_patients[unique_patients != torch.ones_like(unique_patients)*666] #unique_patients = unique_patients[unique_patients != 666] #print(unique_patients) batch_dice = torch.zeros((len(unique_patients), C)) for i, p in enumerate(unique_patients): inter_card_p = torch.einsum("bc->c", [torch.masked_select(all_inter_card, all_grp == p).reshape((-1, C))]) card_gt_p= torch.einsum("bc->c", [torch.masked_select(all_card_gt, all_grp == p).reshape((-1, C))]) card_pred_p= torch.einsum("bc->c", [torch.masked_select(all_card_pred, all_grp == p).reshape((-1, C))]) #if p == 0: # print("inter_card_p:",inter_card_p.detach()) # print("card_gt_p:", card_gt_p.detach()) # print("card_pred_p:",card_pred_p.detach()) #print(card_gt_p.shape) dice_3d = (2 * inter_card_p + 1e-8) / ((card_pred_p + card_gt_p)+ 1e-8) if pprint: dice_3d = torch.round(dice_3d * 10**2) / (10**2) print(p,dice_3d) batch_dice[i,...] = dice_3d indices = torch.tensor(metric_axis) dice_3d = torch.index_select(batch_dice, 1, indices) dice_3d_mean = dice_3d.mean(dim=0) print('metric_axis dice',dice_3d_mean) dice_3d_sd = dice_3d.std(dim=0) [dice_3d, dice_3d_sd] = map_(lambda t: t.mean(), [dice_3d_mean, dice_3d_sd]) return dice_3d.item(), dice_3d_sd.item()
def main(args: Namespace) -> None: inputs: List[Path] = list( Path(args.base_folder, args.GT_subfolder).glob(args.regex)) names: List[str] = [p.name for p in inputs] print(f"Found {len(names)} images to weaken") if args.verbose: pprint(names[:10]) strategy: Callable = eval(args.strategy) strat: Callable = partial(weaken_img, strategy=strategy) # sizes: np.ndarray = np.zeros(len(inputs), dtype=np.uint32) # for i, (pn) in tqdm(enumerate(zip(inputs, names)), ncols=100, total=len(names)): # sizes[i] = strat(pn) orig_sizes, new_sizes = map_(np.asarray, zip(*mmap_(strat, zip(inputs, names)))) assert len(orig_sizes) == len(new_sizes) == len(names) try: print("Orig sizes: (min, mean, max)", orig_sizes[orig_sizes > 0].min(), orig_sizes.mean(), orig_sizes.max()) print( f"Annotated {new_sizes.sum()} pixels for {len(new_sizes)} images") except ValueError: pass
def __init__(self, dataset: SliceDataset, grp_regex, shuffle=False) -> None: filenames: List[str] = dataset.filenames # Might be needed in case of escape sequence fuckups # self.grp_regex = bytes(grp_regex, "utf-8").decode('unicode_escape') self.grp_regex = grp_regex # Configure the shuffling function self.shuffle: bool = shuffle self.shuffle_fn: Callable = (lambda x: random.sample(x, len(x))) if self.shuffle else id_ print(f"Grouping using {self.grp_regex} regex") # assert grp_regex == "(patient\d+_\d+)_\d+" # grouping_regex: Pattern = re.compile("grp_regex") grouping_regex: Pattern = re.compile(self.grp_regex) stems: List[str] = [Path(filename).stem for filename in filenames] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) patients: List[str] = [match.group(1) for match in matches] unique_patients: List[str] = list(set(patients)) assert len(unique_patients) < len(filenames) print(f"Found {len(unique_patients)} unique patients out of {len(filenames)} images") self.idx_map: Dict[str, List[int]] = dict(zip(unique_patients, repeat(None))) for i, patient in enumerate(patients): if not self.idx_map[patient]: self.idx_map[patient] = [] self.idx_map[patient] += [i] # print(self.idx_map) assert sum(len(self.idx_map[k]) for k in unique_patients) == len(filenames) print("Patient to slices mapping done")
def main(args: argparse.Namespace) -> None: print( f'>>> Starting data augmentation (original + {args.n_aug} new images)') root_dir: str = args.root_dir dest_dir: str = args.dest_dir folders: List[Path] = list(Path(root_dir).glob("*")) dest_folders: List[Path] = [Path(dest_dir, p.name) for p in folders] print( f"Will augment data from {len(folders)} folders ({map_(str, folders)})" ) # Create all the destination folders for d_folder in dest_folders: d_folder.mkdir(parents=True, exist_ok=True) names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png")) partial_process = partial(process_name, folders=folders, dest_folders=dest_folders, n_aug=args.n_aug, args=args) mmap_(partial_process, names)
def get_mean_sd(x,indices): x_ind = torch.index_select(x, 1, indices) x_mean = x_ind.mean(dim=0) x_mean = torch.round(x_mean * 10**4) / (10**4) x_std = x_ind.std(dim=0) x_std = torch.round(x_std * 10**4) / (10**4) x_mean, x_std= map_(lambda t: t.mean(), [x_mean,x_std]) return x_mean,x_std
def hd3dn(all_grp,all_card_gt,all_pred,all_gt,all_pnames,metric_axis,pprint=False,do_hd=0): list(filter(lambda a: a != "0.0", all_pnames)) list(filter(lambda a: a != 0.0, all_pnames)) _,C = all_card_gt.shape unique_patients = torch.unique(all_grp) list(filter(lambda a: a != 0.0, unique_patients)) unique_patients = unique_patients[unique_patients != torch.ones_like(unique_patients)*666] unique_patients = [u.item() for u in unique_patients] batch_hd = torch.zeros((len(unique_patients), C)) for i, p in enumerate(unique_patients): try: bool_p = [int(re.split('_',re.split('slice',x.item())[1])[0])==p for x in all_pnames] data = "whs" except: bool_p = [int(re.split('_',re.split('Subj_',x.item())[1])[0])==p for x in all_pnames] data = "ivd" slices_p = all_pnames[bool_p] if do_hd >0: all_gt_p = all_gt[bool_p,:] all_pred_p = all_pred[bool_p,:] sn_p = [int(re.split('_',x)[1]) for x in slices_p] ord_p = np.argsort(sn_p) label_gt = all_gt_p[ord_p,...] label_pred = all_pred_p[ord_p,...] hd_3d_var_vec= [None] * C for j in range(0,C): label_pred_c = numpy.copy(label_pred) label_pred_c[label_pred_c!=j]=0 label_pred_c[label_pred_c==j]=1 label_gt_c = numpy.copy(label_gt) label_gt_c[label_gt!=j]=0 label_gt_c[label_gt==j]=1 if len(np.unique(label_pred_c))>1 and len(np.unique(label_gt_c))>1: if data=="whs": hd_3d_var_vec[j] = hd95(label_pred_c, label_gt_c,[0.6,0.44,0.44]).item() elif data=="ivd": hd_3d_var_vec[j] = hd95(label_pred_c, label_gt_c,[1.25,1.25,2]).item() else: hd_3d_var_vec[j]=np.NaN hd_3d_var=torch.from_numpy(np.asarray(hd_3d_var_vec))# np.nanmean(hd_3d_var_vec) dice_3d = (2 * inter_card_p + 1e-8) / ((card_pred_p + card_gt_p)+ 1e-8) if pprint and do_hd: print(p,hd_3d_var) if do_hd>0: batch_hd[i,...] = hd_3d_var indices = torch.tensor(metric_axis) hd_3d = torch.index_select(batch_hd, 1, indices) hd_3d_mean = hd_3d.mean(dim=0) hd_3d_mean = torch.round(hd_3d_mean * 10**4) / (10**4) hd_3d_sd = hd_3d.std(dim=0) hd_3d_sd = torch.round(hd_3d_sd * 10**4) / (10**4) print('metric_axis hd 3d',np.round(hd_3d_mean,2), 'mean ',np.round(hd_3d_mean.mean(),2),'std ',np.round(hd_3d_sd,2), 'std mean', np.round(hd_3d_sd.mean(),2)) [hd_3d, hd_3d_sd] = map_(lambda t: t.mean(), [hd_3d_mean, hd_3d_sd]) return hd_3d.item(), hd_3d_sd.item()
def setup( args, n_class: int ) -> Tuple[Any, Any, Any, List[List[Callable]], List[List[float]], Callable]: print("\n>>> Setting up") cpu: bool = args.cpu or not torch.cuda.is_available() device = torch.device("cpu") if cpu else torch.device("cuda") if args.weights: if cpu: net = torch.load(args.weights, map_location='cpu') else: net = torch.load(args.weights) print(f">> Restored weights from {args.weights} successfully.") else: net_class = getattr(__import__('networks'), args.network) net = net_class(args.modalities, n_class).to(device) net.init_weights() net.to(device) optimizer: Any # disable an error for the optmizer (ADAM and SGD not same type) if args.use_sgd: optimizer = torch.optim.SGD(net.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4) else: optimizer = torch.optim.Adam(net.parameters(), lr=args.l_rate, betas=(0.9, 0.99), amsgrad=False) # print(args.losses) list_losses = eval(args.losses) if depth( list_losses ) == 1: # For compatibility reasons, avoid changing all the previous configuration files list_losses = [list_losses] nd: str = "whd" if args.three_d else "wh" loss_fns: List[List[Callable]] = [] for i, losses in enumerate(list_losses): print(f">> {i}th list of losses: {losses}") tmp: List[Callable] = [] for loss_name, loss_params, _, _, fn, _ in losses: loss_class = getattr(__import__('losses'), loss_name) tmp.append(loss_class(**loss_params, fn=fn, nd=nd)) loss_fns.append(tmp) loss_weights: List[List[float]] = [ map_(itemgetter(5), losses) for losses in list_losses ] scheduler = getattr(__import__('scheduler'), args.scheduler)(**eval(args.scheduler_params)) return net, optimizer, device, loss_fns, loss_weights, scheduler
def main(args: argparse.Namespace): src_path: Path = Path(args.source_dir) dest_path: Path = Path(args.dest_dir) # Assume the cleaning up is done before calling the script assert src_path.exists() assert not dest_path.exists() # Get all the file names, avoid the temporal ones nii_paths: List[Path] = [p for p in src_path.rglob('*.mhd')] assert len( nii_paths) % 2 == 0, "Uneven number of .nii, one+ pair is broken" # We sort now, but also id matching is checked while iterating later on img_nii_paths: List[Path] = sorted(p for p in nii_paths if "_segmentation" not in str(p)) gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "_segmentation" in str(p)) assert len(img_nii_paths) == len(gt_nii_paths) paths: List[Tuple[Path, Path]] = list(zip(img_nii_paths, gt_nii_paths)) print(f"Found {len(img_nii_paths)} pairs in total") pprint(paths[:5]) validation_paths: List[Tuple[Path, Path]] = random.sample(paths, args.retain) training_paths: List[Tuple[Path, Path]] = [ p for p in paths if p not in validation_paths ] assert set(validation_paths).isdisjoint(set(training_paths)) assert len(paths) == (len(validation_paths) + len(training_paths)) for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): img_paths, gt_paths = zip(*_paths) # type: Tuple[Any, Any] dest_dir = Path(dest_path, mode) print(f"Slicing {len(img_paths)} pairs to {dest_dir}") assert len(img_paths) == len(gt_paths) pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment) sizes = mmap_(uc_(pfun), zip(img_paths, gt_paths)) # sizes = [] # for paths in tqdm(list(zip(img_paths, gt_paths)), ncols=50): # sizes.append(uc_(pfun)(paths)) sizes_3d, sizes_2d_min, sizes_2d_max = map_(np.asarray, zip(*sizes)) print("2d sizes: ", sizes_2d_min.min(), sizes_2d_max.max()) print("3d sizes: ", sizes_3d.min(), sizes_3d.mean(), sizes_3d.max())
def main(args: argparse.Namespace): src_path: Path = Path(args.source_dir) dest_path: Path = Path(args.dest_dir) # Assume the cleaning up is done before calling the script assert src_path.exists() assert not dest_path.exists() # Get all the file names, avoid the temporal ones all_paths: List[Path] = list(src_path.rglob('*.nii')) nii_paths: List[Path] = [p for p in all_paths if "_4D" not in str(p)] assert len(nii_paths) % 6 == 0, "Number of .nii not multiple of 6, some pairs are broken" # We sort now, but also id matching is checked while iterating later on CT_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT." in str(p)) CBF_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_CBF" in str(p)) CBV_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_CBV" in str(p)) MTT_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_MTT" in str(p)) Tmax_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_Tmax" in str(p)) gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "OT" in str(p)) assert len(CT_nii_paths) == len(CBF_nii_paths) == len(CBV_nii_paths) == len(MTT_nii_paths) \ == len(Tmax_nii_paths) == len(gt_nii_paths) paths: List[Tuple[Path, ...]] = list(zip(CT_nii_paths, CBF_nii_paths, CBV_nii_paths, MTT_nii_paths, Tmax_nii_paths, gt_nii_paths)) print(f"Found {len(CT_nii_paths)} pairs in total") pprint(paths[:2]) validation_paths: List[Tuple[Path, ...]] = random.sample(paths, args.retain) training_paths: List[Tuple[Path, ...]] = [p for p in paths if p not in validation_paths] assert set(validation_paths).isdisjoint(set(training_paths)) assert len(paths) == (len(validation_paths) + len(training_paths)) for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): # ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths = zip(*_paths) six_paths = list(zip(*_paths)) dest_dir = Path(dest_path, mode) print(f"Slicing {len(six_paths[0])} pairs to {dest_dir}") assert len(set(map_(len, six_paths))) == 1 pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment) space_dicts = mmap_(uc_(pfun), zip(*six_paths)) # for case_paths in tqdm(list(zip(*six_paths)), ncols=50): # uc_(pfun)(case_paths) final_dict = {k: v for space_dict in space_dicts for k, v in space_dict.items()} with open(Path(dest_dir, "spacing.pkl"), 'wb') as f: pickle.dump(final_dict, f, pickle.HIGHEST_PROTOCOL) print(f"Saved spacing dictionnary to {f}")
def slice_patient(id_: str, dest_path: Path, source_path: Path, shape: Tuple[int, int], n_augment: int): id_path: Path = Path(source_path, id_) for acq in id_path.glob("t0*"): acq_id: int = int(acq.name[1:]) # print(id, acq, acq_id) t1_path: Path = Path(acq, f"{id_}_t1w_deface_stx.nii.gz") nib_obj = nib.load(str(t1_path)) t1: np.ndarray = np.asarray(nib_obj.dataobj) # dx, dy, dz = nib_obj.header.get_zooms() x, y, z = t1.shape assert sanity_t1(t1, *t1.shape, *nib_obj.header.get_zooms()) # gt: np.ndarray = fuse_labels(t1, id_, acq, nib_obj) gt, gt1 = fuse_labels(t1, id_, acq, nib_obj) norm_img: np.ndarray = norm_arr(t1) for idz in range(z): padded_img: np.ndarray = center_pad(norm_img[:, :, idz], shape) padded_gt: np.ndarray = center_pad(gt[:, :, idz], shape) padded_gt1: np.ndarray = center_pad(gt1[:, :, idz], shape) assert padded_img.shape == padded_gt.shape == shape for k in range(n_augment + 1): arrays: List[np.ndarray] = [padded_img, padded_gt, padded_gt1] augmented_arrays: List[np.ndarray] if k == 0: augmented_arrays = arrays[:] else: augmented_arrays = map_(np.asarray, augment(*arrays)) subfolders: List[str] = ["img", "gt", "gt1"] assert len(augmented_arrays) == len(subfolders) for save_subfolder, data in zip(subfolders, augmented_arrays): filename = f"{id_}_{acq_id}_{idz}_{k}.png" save_path: Path = Path(dest_path, save_subfolder) save_path.mkdir(parents=True, exist_ok=True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) imsave(str(Path(save_path, filename)), data)
def main(args: argparse.Namespace): src_path: Path = Path(args.source_dir) dest_path: Path = Path(args.dest_dir) # Assume the cleaning up is done before calling the script assert src_path.exists() assert not dest_path.exists() # Get all the file names, avoid the temporal ones all_paths: List[Path] = list(src_path.rglob('*.nii.gz')) nii_paths: List[Path] = [p for p in all_paths if "_4D" not in str(p)] assert len(nii_paths) % 3 == 0, "Number of .nii not multiple of 6, some pairs are broken" # We sort now, but also id matching is checked while iterating later on flair_nii_paths: List[Path] = sorted(p for p in nii_paths if "FLAIR" in str(p)) t1_nii_paths: List[Path] = sorted(p for p in nii_paths if "T1" in str(p)) gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "wmh.nii" in str(p)) assert len(flair_nii_paths) == len(t1_nii_paths) == len(gt_nii_paths) paths: List[Tuple[Path, ...]] = list(zip(flair_nii_paths, t1_nii_paths, gt_nii_paths)) print(f"Found {len(flair_nii_paths)} pairs in total") pprint(paths[:2]) validation_paths: List[Tuple[Path, ...]] = random.sample(paths, args.retain) training_paths: List[Tuple[Path, ...]] = [p for p in paths if p not in validation_paths] assert set(validation_paths).isdisjoint(set(training_paths)) assert len(paths) == (len(validation_paths) + len(training_paths)) for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): three_paths = list(zip(*_paths)) dest_dir = Path(dest_path, mode) print(f"Slicing {len(three_paths[0])} pairs to {dest_dir}") assert len(set(map_(len, three_paths))) == 1 pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment, discard_negatives=args.discard_negatives) sizes = mmap_(uc_(pfun), zip(*three_paths)) all_neg, all_pos, space_dicts = zip(*sizes) neg, pos = sum(all_neg), sum(all_pos) ratio = pos / neg print(f"Ratio between pos/neg: {ratio} ({pos}/{neg})") final_dict = {k: v for space_dict in space_dicts for k, v in space_dict.items()} with open(Path(dest_dir, "spacing.pkl"), 'wb') as f: pickle.dump(final_dict, f, pickle.HIGHEST_PROTOCOL) print(f"Saved spacing dictionnary to {f}")
def runInference(args: argparse.Namespace): print('>>> Loading model') net = torch.load(args.model_weights) device = torch.device("cuda") net.to(device) print('>>> Loading the data') batch_size: int = args.batch_size num_classes: int = args.num_classes transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: nd / 255, # max <= 1 lambda nd: torch.tensor(nd, dtype=torch.float32) ]) folders: List[Path] = [Path(args.data_folder)] names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png")) dt_set = SliceDataset(names, folders, transforms=[transform], debug=False, C=num_classes) loader = DataLoader(dt_set, batch_size=batch_size, num_workers=batch_size + 2, shuffle=False, drop_last=False) print('>>> Starting the inference') savedir: str = args.save_folder total_iteration = len(loader) desc = f">> Inference" tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) with torch.no_grad(): for j, (filenames, image, _) in tq_iter: image = image.to(device) pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) with warnings.catch_warnings(): warnings.simplefilter("ignore") predicted_class: Tensor = probs2class(pred_probs) save_images(predicted_class, filenames, savedir, "", 0)
def setup(args, n_class, dtype) -> Tuple[Any, Any, Any, List[Callable], List[float],List[Callable], List[float], Callable]: print(">>> Setting up") cpu: bool = args.cpu or not torch.cuda.is_available() if cpu: print("WARNING CUDA NOT AVAILABLE") device = torch.device("cpu") if cpu else torch.device("cuda") n_epoch = args.n_epoch if args.model_weights: if cpu: net = torch.load(args.model_weights, map_location='cpu') else: net = torch.load(args.model_weights) else: net_class = getattr(__import__('networks'), args.network) net = net_class(1, n_class).type(dtype).to(device) net.apply(weights_init) net.to(device) if args.saveim: print("WARNING: Saving masks at each epc") optimizer = torch.optim.Adam(net.parameters(), lr=args.l_rate, betas=(0.9, 0.999),weight_decay=args.weight_decay) if args.adamw: optimizer = torch.optim.AdamW(net.parameters(), lr=args.l_rate, betas=(0.9, 0.999)) print(args.target_losses) losses = eval(args.target_losses) loss_fns: List[Callable] = [] for loss_name, loss_params, _, bounds_params, fn, _ in losses: loss_class = getattr(__import__('losses'), loss_name) loss_fns.append(loss_class(**loss_params, dtype=dtype, fn=fn)) print("bounds_params", bounds_params) if bounds_params!=None: bool_predexist = CheckBounds(**bounds_params) print(bool_predexist,"size predictor") if not bool_predexist: n_epoch = 0 loss_weights = map_(itemgetter(5), losses) if args.scheduler: scheduler = getattr(__import__('scheduler'), args.scheduler)(**eval(args.scheduler_params)) else: scheduler = '' return net, optimizer, device, loss_fns, loss_weights, scheduler, n_epoch
def rotate(r_path, grp_regex, s_path): filenames = [f for f in os.listdir(r_path) if f.endswith('.png')] # Might be needed in case of escape sequence problems # self.grp_regex = bytes(grp_regex, "utf-8").decode('unicode_escape') grouping_regex: Pattern = re.compile(grp_regex) stems: List[str] = [Path(filename).stem for filename in filenames] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) patients: List[str] = [match.group(0) for match in matches] unique_patients: List[str] = list(set(patients)) for patient in unique_patients: patient_slices = [f for f in stems if f.startswith(patient)] w,h = [256,256] n = len(patient_slices) t = np.ndarray(shape=(w, h, n)) for slice in patient_slices: slice_nb = int(re.split(grp_regex, slice)[1]) t[:, :, slice_nb] = imageio.imread(r_path+'/'+slice+'.png') for i in range(0, h): im = np.pad(t[i,:,:], [(0, 0), (110, 110)], 'constant') imageio.imwrite(s_path+'/'+patient+str(i)+'.png', im)
def setup( args, n_class: int ) -> Tuple[Any, Any, Any, List[Callable], List[float], Callable]: print(">>> Setting up") cpu: bool = args.cpu or not torch.cuda.is_available() device = torch.device("cpu") if cpu else torch.device("cuda") if args.weights: if cpu: net = torch.load(args.weights, map_location='cpu') else: net = torch.load(args.weights) print(f">>> Restored weights from {args.weights} successfully.") else: net_class = getattr(__import__('networks'), args.network) net = net_class(args.modalities, n_class).to(device) net.apply(weights_init) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=args.l_rate, betas=(0.9, 0.99), amsgrad=False) print(args.losses) losses = eval(args.losses) loss_fns: List[Callable] = [] for loss_name, loss_params, _, _, fn, _ in losses: loss_class = getattr(__import__('losses'), loss_name) loss_fns.append(loss_class(**loss_params, fn=fn)) loss_weights = map_(itemgetter(5), losses) scheduler = getattr(__import__('scheduler'), args.scheduler)(**eval(args.scheduler_params)) return net, optimizer, device, loss_fns, loss_weights, scheduler
def rotate_back(r_path, grp_regex, s_path): filenames = [f for f in os.listdir(r_path) if f.endswith('.png')] grouping_regex: Pattern = re.compile(grp_regex) stems: List[str] = [Path(filename).stem for filename in filenames] # avoid matching the extension matches: List[Match] = map_(grouping_regex.match, stems) patients: List[str] = [match.group(0) for match in matches] unique_patients: List[str] = list(set(patients)) print(unique_patients) for patient in unique_patients: patient_slices = [f for f in stems if f.startswith(patient)] w,h = [256,36] n = len(patient_slices) t = np.ndarray(shape=(w, h, n)) for slice in patient_slices: slice_nb = int(re.split(grp_regex, slice)[1]) im_or = imageio.imread(str(r_path)+'/'+slice+'.png') if im_or.shape !=(w,h): im_or = resize_im(im_or, 36) t[:, :, slice_nb] = im_or for i in range(0, h): im = t[:,i,:] imageio.imwrite(str(s_path)+'/'+patient+str(i)+'.png', im)
def save_slices(img_p: Path, gt_p: Path, dest_dir: Path, shape: Tuple[int], n_augment: int, img_dir: str = "img", gt_dir: str = "gt") -> Tuple[int, int, int]: p_id: str = get_p_id(img_p) assert "Case" in p_id assert p_id == get_p_id(gt_p) # Load the data img = imread(str(img_p), plugin='simpleitk') gt = imread(str(gt_p), plugin='simpleitk') # print(img.shape, img.dtype, gt.shape, gt.dtype) # print(img.min(), img.max(), len(np.unique(img))) # print(np.unique(gt)) assert img.shape == gt.shape assert img.dtype in [np.int16] assert gt.dtype in [np.int8] # Normalize and check data content norm_img = norm_arr(img) # We need to normalize the whole 3d img, not 2d slices assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(), norm_img.max()) assert norm_img.dtype == np.uint8 save_dir_img: Path = Path(dest_dir, img_dir) save_dir_gt: Path = Path(dest_dir, gt_dir) sizes_2d: np.ndarray = np.zeros(img.shape[-1]) for j in range(len(img)): img_s = norm_img[j, :, :] gt_s = gt[j, :, :] assert img_s.shape == gt_s.shape # Resize and check the data are still what we expect resize_: Callable = partial(resize, mode="constant", preserve_range=True, anti_aliasing=False) r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8) r_gt: np.ndarray = resize_(gt_s, shape).astype(np.uint8) assert r_img.dtype == r_gt.dtype == np.uint8 assert 0 <= r_img.min() and r_img.max() <= 255 # The range might be smaller assert set(uniq(r_gt)).issubset(set(uniq(gt))) sizes_2d[j] = r_gt[r_gt == 1].sum() # for save_dir, data in zip([save_dir_img, save_dir_gt], [r_img, r_gt]): # save_dir.mkdir(parents=True, exist_ok=True) # with warnings.catch_warnings(): # warnings.filterwarnings("ignore", category=UserWarning) # imsave(str(Path(save_dir, filename)), data) for k in range(n_augment + 1): if k == 0: a_img, a_gt = r_img, r_gt else: a_img, a_gt = map_(np.asarray, augment(r_img, r_gt)) for save_dir, data in zip([save_dir_img, save_dir_gt], [a_img, a_gt]): filename = f"{p_id}_{k}_{j:02d}.png" save_dir.mkdir(parents=True, exist_ok=True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) imsave(str(Path(save_dir, filename)), data) return sizes_2d.sum(), sizes_2d[sizes_2d > 0].min(), sizes_2d.max()
def runInference(args: argparse.Namespace): print('>>> Loading model') net = torch.load(args.model_weights) device = torch.device("cuda") net.to(device) print('>>> Loading the data') batch_size: int = args.batch_size num_classes: int = args.num_classes folders: list[Path] = [Path(args.data_folder)] names: list[str] = map_(lambda p: str(p.name), folders[0].glob("*.png")) dt_set = SliceDataset( names, folders * 2, # Duplicate for compatibility reasons are_hots=[False, False], transforms=[png_transform, dummy_gt_transform ], # So it is happy about the target size bounds_generators=[], debug=args.debug, K=num_classes) loader = DataLoader(dt_set, batch_size=batch_size, num_workers=batch_size + 2, shuffle=False, drop_last=False, collate_fn=custom_collate) print('>>> Starting the inference') savedir: Path = Path(args.save_folder) savedir.mkdir(parents=True, exist_ok=True) total_iteration = len(loader) desc = ">> Inference" tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc) with torch.no_grad(): for j, data in tq_iter: filenames: list[str] = data["filenames"] image: Tensor = data["images"].to(device) pred_logits: Tensor = net(image) pred_probs: Tensor = F.softmax(pred_logits, dim=1) with warnings.catch_warnings(): warnings.simplefilter("ignore") predicted_class: Tensor if args.mode == 'argmax': predicted_class = probs2class(pred_probs) elif args.mode == 'probs': predicted_class = (pred_probs[:, args.probs_class, ...] * 255).type(torch.uint8) elif args.mode == 'threshold': thresholded: Tensor = pred_probs[:, ...] > args.threshold predicted_class = thresholded.argmax(dim=1) elif args.mode == 'softmax': for i, filename in enumerate(filenames): np.save((savedir / filename).with_suffix(".npy"), pred_probs[i].cpu().numpy()) if args.mode != 'softmax': save_images(predicted_class, filenames, savedir)
def get_loaders(args, data_folder: str, subfolders: str, batch_size: int, n_class: int, debug: bool, in_memory: bool, dtype, shuffle: bool, mode: str, val_subfolders: "") -> Tuple[DataLoader, DataLoader]: nii_transform2 = transforms.Compose([ lambda nd: torch.tensor(nd, dtype=torch.float32), lambda nd: nd[:, 0:384, 0:384], #lambda nd: print(nd.shape), ]) nii_gt_transform2 = transforms.Compose([ lambda nd: torch.tensor(nd, dtype=torch.int64), partial(class2one_hot, C=n_class), lambda nd: nd[:, :, 0:384, 0:384], itemgetter(0), #lambda nd: print(nd.shape,"nii gt rtans") ]) nii_transform = transforms.Compose([ lambda nd: torch.tensor(nd, dtype=torch.float32), lambda nd: (nd + 4) / 8.5, # max <= 1 #lambda nd: print(nd.shape), ]) nii_gt_transform = transforms.Compose([ lambda nd: torch.tensor(nd, dtype=torch.int64), partial(class2one_hot, C=n_class), itemgetter(0), #lambda nd: print(nd.shape,"nii gt rtans") ]) png_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: nd / 255, # max <= 1 # lambda nd: np.pad(nd, [(0,0), (0,0), (110,110)], 'constant'), #lambda nd: pad_to(nd, 256,256), lambda nd: torch.tensor(nd, dtype=dtype) ]) imnpy_transform = transforms.Compose([ lambda nd: nd / 255, # max <= 1 # lambda nd: np.pad(nd, [(0,0), (0,0), (110,110)], 'constant'), #lambda nd: pad_to(nd, 256,256), lambda nd: torch.tensor(nd, dtype=dtype) ]) npy_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: nd / 255, # max <= 1 #lambda nd: np.pad(nd, [(0,0), (0,0), (110,110)], 'constant'), #lambda nd: pad_to(nd, 256, 256), lambda nd: torch.tensor(nd, dtype=dtype) ]) gtnpy_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], #lambda nd: np.pad(nd, [(0, 0), (0, 0), (110, 110)], 'constant'), #lambda nd: pad_to(nd, 256, 256), lambda nd: torch.tensor(nd, dtype=torch.int64), #lambda nd: remap({0:0, 36:4, 72:0, 109:1, 145:0, 182:2, 218:3, 255:0},nd), partial(class2one_hot, C=n_class), itemgetter(0) ]) gt_transform = transforms.Compose([ #lambda img: np.array(img)[np.newaxis, ...], #lambda nd: np.pad(nd, [(0, 0), (0, 0), (110, 110)], 'constant'), #lambda nd: pad_to(nd, 256, 256), lambda nd: torch.tensor(nd, dtype=torch.int64), #lambda nd: print(nd.shape,"nd in gt transform"), partial(class2one_hot, C=n_class), itemgetter(0), ]) gtpng_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], #lambda nd: np.pad(nd, [(0, 0), (0, 0), (110, 110)], 'constant'), #lambda nd: pad_to(nd, 256, 256), lambda nd: torch.tensor(nd, dtype=torch.int64), #lambda nd: print(nd.shape,"nd in gt transform"), partial(class2one_hot, C=n_class), itemgetter(0), ]) if mode == "target": losses = eval(args.target_losses) else: losses = eval(args.source_losses) bounds_generators: List[Callable] = [] for _, _, bounds_name, bounds_params, fn, _ in losses: if bounds_name is None: bounds_generators.append(lambda *a: torch.zeros(n_class, 1, 2)) continue bounds_class = getattr(__import__('bounds'), bounds_name) bounds_generators.append( bounds_class(C=args.n_class, fn=fn, **bounds_params)) folders_list = eval(subfolders) val_folders_list = eval(subfolders) if val_subfolders != "": val_folders_list = eval(val_subfolders) # print(folders_list) folders, trans, are_hots = zip(*folders_list) valfolders, val_trans, val_are_hots = zip(*val_folders_list) # Create partial functions: Easier for readability later (see the difference between train and validation) gen_dataset = partial(SliceDataset, transforms=trans, are_hots=are_hots, debug=debug, C=n_class, in_memory=in_memory, augment=args.augment, bounds_generators=bounds_generators) valgen_dataset = partial(SliceDataset, transforms=val_trans, are_hots=val_are_hots, debug=debug, C=n_class, in_memory=in_memory, augment=args.augment, bounds_generators=bounds_generators) data_loader = partial( DataLoader, num_workers=4, #num_workers=min(cpu_count(), batch_size + 4), #num_workers=1, pin_memory=True) # Prepare the datasets and dataloaders train_folders: List[Path] = [ Path(data_folder, "train", f) for f in folders ] if args.trainval: train_folders: List[Path] = [ Path(data_folder, "trainval", f) for f in folders ] elif args.valonly: train_folders: List[Path] = [ Path(data_folder, "val", f) for f in folders ] #if args.ontrain1: # train_folders: List[Path] = [Path(data_folder, "train1", f) for f in folders] # I assume all files have the same name inside their folder: makes things much easier train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*.png")) if len(train_names) == 0: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*.nii")) if args.ontrain1: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*1.nii")) if mode == "target" and args.ontrain19_1: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*19_1.nii")) ''' if mode=="target" and args.ontrain019_1: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*019_1.nii")) ''' if mode == "target" and args.ontrain9_1: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*9_1.nii")) if len(train_names) == 0: train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*.npy")) #train_names.sort() #print("train folders[0]",train_folders[0]) #print(train_names[:13], "train_names") train_set = gen_dataset(train_names, train_folders) #if fix_size!=[0,0] and len(train_set)<fix_size[0]: #nb_to_add= fix_size[0] - len(train_set) #print("nb_to_add", nb_to_add) #train_set_2 = RandomSampler(train_set, replacement=True, num_samples=nb_to_add) # train_set = Concat([train_set, train_set]) train_loader = data_loader(train_set, batch_size=batch_size, shuffle=shuffle, drop_last=False) #train_loader= torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None) if args.ontest: print('on test') val_folders: List[Path] = [ Path(data_folder, "test", f) for f in valfolders ] #print(val_folders) elif args.ontrain: print('on train') val_folders: List[Path] = [ Path(data_folder, "train", f) for f in valfolders ] else: #/ print('on val') val_folders: List[Path] = [ Path(data_folder, "val", f) for f in valfolders ] #print(val_folders,"(val_folders" ) val_names: List[str] = map_(lambda p: str(p.name), val_folders[0].glob("*.png")) if len(val_names) == 0: val_names: List[str] = map_(lambda p: str(p.name), val_folders[0].glob("*.nii")) if len(val_names) == 0: val_names: List[str] = map_(lambda p: str(p.name), val_folders[0].glob("*.npy")) #print(val_names, "val_names") #val_names.sort() val_set = valgen_dataset(val_names, val_folders) #if fix_size!=[0,0] and len(val_set)<fix_size[1]: #nb_to_add= fix_size[1] - len(val_set) #val_set_2 = RandomSampler(val_set, replacement=True, num_samples=nb_to_add) #val_set = Concat([val_set, val_set, val_set, val_set , val_set]) #val_sampler = PatientSampler(val_set, args.grp_regex, shuffle=shuffle) # val_sampler = None # val_loader = data_loader(val_set, # batch_sampler=val_sampler) val_loader = data_loader(val_set, batch_size=batch_size, shuffle=False, drop_last=False) return train_loader, val_loader
def run(args: argparse.Namespace) -> None: plt.rc('font', size=args.fontsize) colors: List[str] = args.colors if args.colors else util_colors styles = ['--', '-.', ':'] if len(args.folders) > len(colors): print("Warning: more folders than colors") assert len(args.columns) <= len(styles) paths: List[Path] = [Path(f, args.filename) for f in args.folders] arrays: List[np.ndarray] = map_(np.load, paths) if len(arrays[0].shape) == 2: arrays = map_(lambda a: a[..., np.newaxis], arrays) epoch, _, class_ = arrays[0].shape for a in arrays[1:]: ea, _, ca = a.shape assert epoch == ea, (epoch, class_, a.shape) if not args.dynamic_third_axis: # Useful for when trainings don't have same number of losses assert class_ == ca, (epoch, class_, a.shape) n_epoch = arrays[0].shape[0] fig = plt.figure(figsize=args.figsize) ax = fig.gca() ax.set_xlim([0, n_epoch - 2]) ymin, ymax = args.ylim # Tuple[int, int] ax.set_ylim(ymin, ymax) yrange: int = ymax - ymin ystep: float = yrange / 10 ax.set_yticks(np.mgrid[ymin:ymax + ystep:ystep]) ax.set_xlabel("Epoch") if args.ylabel: ax.set_ylabel(args.ylabel) else: ax.set_ylabel(Path(args.filename).stem) ax.grid(True, axis='y') if args.title: ax.set_title(args.title) else: ax.set_title(f"{paths[0].stem} over epochs") if args.labels: labels = args.labels else: labels = [p.parent.name for p in paths] xnew = np.linspace(0, n_epoch, n_epoch * 4) epcs = np.arange(n_epoch) for i, (a, c, p, l) in enumerate(zip(arrays, cycle(colors), paths, labels)): mean_a = a.mean(axis=1) _, n_col = mean_a.shape # For when more args.columns than columns (weird case with varying multiple losses) allowed_cols: List[int] = list( set(args.columns).intersection(set(range(n_col)))) if len(allowed_cols) > 1 and not args.no_mean: mean_column = mean_a[:, allowed_cols].mean(axis=1) ax.plot(epcs, mean_column, color=c, linestyle='-', label=f"{l}-mean", linewidth=2) if not args.only_mean: for k, s in zip(allowed_cols, styles): values = mean_a[..., k] if args.smooth: # smoothed = spline(epcs, values, xnew) smoothed = interp1d(epcs, values, xnew) x, y = xnew, smoothed else: x, y = epcs, values lab = l if len(args.columns) == 1 else f"{l}-{k}" sty: str if len(args.columns) == 1: if args.curves_styles: sty = args.curves_styles[i][ 1:] # Have to remove the extra space else: sty = '-' else: sty = s ax.plot(x, y, linestyle=sty, color=c, label=lab, linewidth=1.5) if args.min: print( f"{Path(p).parents[0]}, class {k}: {values.min():.04f}" ) else: print( f"{Path(p).parents[0]}, class {k}: {values.max():.04f}" ) if args.hline: for v, l, s in zip(args.hline, args.l_line, styles): ax.plot([0, n_epoch], [v, v], linestyle=s, linewidth=1, color='green', label=l) ax.legend(loc=args.loc) fig.tight_layout() if args.savefig: fig.savefig(args.savefig) if not args.headless: plt.show()
def get_loaders(args, data_folder: str, weak_subfolder: str, dict_params: Dict, batch_size: int, n_class: int, debug: bool, in_memory: bool, dtype) -> Tuple[DataLoader, DataLoader]: transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: nd / 255, # max <= 1 lambda nd: torch.tensor(nd, dtype=dtype) ]) gt_transform = transforms.Compose([ lambda img: np.array(img)[np.newaxis, ...], lambda nd: torch.tensor(nd, dtype=torch.int64), partial(class2one_hot, C=n_class), itemgetter(0) ]) losses = eval(args.losses) bounds_generators: List[Callable] = [] for _, _, bounds_name, bounds_params, fn, _ in losses: if bounds_name is None: bounds_generators.append(lambda *a: torch.zeros(n_class, 1, 2)) continue bounds_class = getattr(__import__('bounds'), bounds_name) bounds_generators.append(bounds_class(C=args.n_class, fn=fn, **bounds_params)) # Create partial functions: Easier for readability later (see the difference between train and validation) gen_dataset = partial(SliceDataset, transforms=[transform, gt_transform, gt_transform], debug=debug, C=n_class, dict_params=dict_params, in_memory=in_memory, bounds_generators=bounds_generators) data_loader = partial(DataLoader, num_workers=10, pin_memory=True) # Prepare the datasets and dataloaders train_folders: List[Path] = [Path(data_folder, "train", "img"), Path(data_folder, "train", "gt"), Path(data_folder, "train", weak_subfolder)] # I assume all files have the same name inside their folder: makes things much easier train_names: List[str] = map_(lambda p: str(p.name), train_folders[0].glob("*.png")) train_set = gen_dataset(train_names, train_folders) train_loader = data_loader(train_set, batch_size=batch_size, shuffle=True, drop_last=True) val_folders: List[Path] = [Path(data_folder, "val", "img"), Path(data_folder, "val", "gt"), Path(data_folder, "val", weak_subfolder)] val_names: List[str] = map_(lambda p: str(p.name), val_folders[0].glob("*.png")) val_set = gen_dataset(val_names, val_folders) val_sampler = PatientSampler(val_set, args.grp_regex, shuffle=False) # val_sampler = None val_loader = data_loader(val_set, batch_sampler=val_sampler) return train_loader, val_loader