def train_model_CV(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } batch_size = 32 transforms = TransformList([RandomHFlip()]) target_ims, input_ims = utils.load_phantom_ds( folder_path='PhantomsRadialNoisy/') cv_split_generator = utils.split_data_CV(input_ims, target_ims, frac=(1 / 7, 2 / 7)) for i, (_, val_set, train_set) in enumerate(cv_split_generator): if i % 2 or i // 2 not in [0, 1, 2]: continue model = MSDRegressionModel(**model_params) train_ds = MultiOrbitDataset(*train_set, device='cuda', transforms=transforms) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_ds = MultiOrbitDataset(*val_set, device='cuda') val_dl = DataLoader(val_ds, batch_size=batch_size, sampler=ValSampler(len(val_ds))) train_params = { 'epochs': 100, 'lr': 2e-3, 'regularization': None, 'cutoff_epoch': 1 } if MODEL_NAME is not None: train_params[ 'save_folder'] = f"model_weights/{MODEL_NAME}_CV{i+1:0>2d}_{datetime.now().strftime('%m%d%H%M%S')}" model.set_normalization( DataLoader(train_ds, batch_size=100, sampler=ValSampler(len(train_ds), min(len(train_ds), 5000)))) train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
def eval_metrics_CV(): """Computes metrics over different inits with cross validation, creates box plots usefull to check generalization/robustness""" target_ims, input_ims = utils.load_phantom_ds(folder_path='PhantomsRadialNoisy/') # target_ims, input_ims = utils.load_walnut_ds() ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) models_dir = Path('/media/beta/florian/model_weights') metrics = ['SSIM', 'DSC', 'PSNR'] models, model_names = get_models() [model.set_normalization(norm_dl) for model in models] print(eval_init_metrics(metrics, ds)) metrics_te = [] for cv in ['01', '03', '05']: # if not cv == '01': continue task = 'meanVarInit' [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_te.append(eval_metrics(models, metrics, ds, 50)) plot_metrics_CV(metrics_te, model_names, metrics, ref_metrics=eval_init_metrics(metrics, ds), filename=f'metrics_P_CV{cv}_best.png')
def train_model(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } batch_size = 32 model = MSDRegressionModel(**model_params) # model= UNetRegressionModel(**model_params) target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy/') test_set, val_set, train_set = split_data(input_ims, target_ims, 1 / 7) train_ds = MultiOrbitDataset(*train_set, device='cuda', transforms=RandomHFlip()) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_ds = MultiOrbitDataset(*val_set, device='cuda') val_dl = DataLoader(val_ds, batch_size=batch_size, sampler=ValSampler(len(val_ds))) train_params = { 'epochs': 30, 'lr': 2e-3, 'regularization': None, 'cutoff_epoch': 5 } if MODEL_NAME is not None: train_params[ 'save_folder'] = f"model_weights/{MODEL_NAME}_{datetime.now().strftime('%m%d%H%M%S')}" model.set_normalization( DataLoader(train_ds, batch_size=100, sampler=ValSampler(len(train_ds), min(len(train_ds), 5000)))) train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
def svcca_over_training(): """Evolution of SVCCA similarity matrix over training""" target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy') ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False, vert_sym=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) patches = get_patches(ds, 50) for width in [80]: model_params = {'c_in': 1, 'c_out': 1, 'width': 1, 'depth': width, 'dilations': [1,2,4,8,16]} fig, axes = plt.subplots(1,4, figsize=(22,5)) ref_model, comp_model = MSDRegressionModel(**model_params), MSDRegressionModel(**model_params) ref_model.set_normalization(norm_dl) comp_model.set_normalization(norm_dl) init_state_dict = torch.load(sorted(models_dir.glob(f'MSD_baseline/MSD_d{width}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu') first_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu') best_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu') last_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu') ref_model.msd.load_state_dict(best_state_dict) for i, state_dict in enumerate([init_state_dict, first_state_dict, best_state_dict, last_state_dict]): comp_model.msd.load_state_dict(state_dict) if i == 0: # shuffle_weights(comp_model.msd) pass # reps = get_model_representation([ref_model, comp_model], patches, # [get_unet_layers(ref_model), get_unet_layers(comp_model)], sample_rate=10) reps = get_model_representation([ref_model, comp_model], patches, [range(1,81), range(1,81)], sample_rate=10) mat = axes[i].matshow(get_svcca_matrix(*reps), vmin=.2, vmax=1) if i == 0: axes[i].set_ylabel('BEST') axes[i].set_xlabel(['INIT', 'FIRST', 'BEST', 'LAST'][i]) axes[i].set_title(['(a)', '(b)', '(c)', '(d)'][i], loc='left') # axes[0].set_yticks(rang) for ax in axes: ax.set_xticks([]) ax.yaxis.tick_right() # ax.set_yticks(range(9)) # ax.set_yticklabels(unet_layers) cax = fig.add_axes([.931, 0.1, .02, 0.8]) plt.colorbar(mat, cax=cax) plt.savefig(f'outputs/svcca_training_MSD_d{width}_transfer_CV01.png', bbox_inches='tight')
def eval_metrics_samples(): """Evaluates models for different training setups, here with changing amounts of training samples""" global colors target_ims, input_ims = utils.load_phantom_ds(folder_path='PhantomsRadialNoisy/') # target_ims, input_ims = utils.load_walnut_ds() ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) models_dir = Path('/media/beta/florian/model_weights') metrics = ['SSIM', 'DSC', 'PSNR'] models, model_names = get_models() [model.set_normalization(norm_dl) for model in models] for cv in ['01', '03', '05']: metrics_samples = [] for n_samples in [16, 128, 1024]: task = '_shuffle' [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'models_Pn_CV{cv}/MSD_d{d}_Pn{n_samples}{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'models_Pn_CV{cv}/UNet_f{f}_Pn{n_samples}{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_samples.append(eval_metrics(models, metrics, ds)) task = 'shuffle' [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_samples.append(eval_metrics(models, metrics, ds)) fig, axes = plt.subplots(1, len(metrics), figsize=(len(metrics) *10,10)) cmap = cm.get_cmap('coolwarm', 4)(range(4)) box_plot_handles = [None for _ in range(len(metrics_samples))] for j in range(len(metrics_samples)): for i, name in enumerate(metrics): if i == 1: box_plot_handles[j] = axes[i].boxplot(metrics_samples[j][:, i].T, positions=np.array([0,1, 3,4,5,6])-.21+ j*.14, widths=.1, boxprops={'color': cmap[j], 'linewidth': 2}, whiskerprops={'color': cmap[j], 'linewidth': 2}, capprops={'color': cmap[j], 'linewidth': 2}, medianprops={'color': 'k'}, showfliers=False)["boxes"][0] else: axes[i].boxplot(metrics_samples[j][:, i].T, positions=np.array([0,1, 3,4,5,6])+(j-1)*.15, widths=.1, boxprops={'color': cmap[j], 'linewidth': 2}, whiskerprops={'color': cmap[j], 'linewidth': 2}, capprops={'color': cmap[j], 'linewidth': 2}, medianprops={'color': 'k'}, showfliers=False) metrics_samples = np.array(metrics_samples) for i, ax in enumerate(axes): for j in range(metrics_samples.shape[1]): # print(np.median(metrics_samples[:,j,i], -1)) ax.plot(np.array([-.21,-.07,.07,.21]) +np.array([0,1, 3,4,5,6])[j] +.05, np.median(metrics_samples[:,j,i], -1), '-k', linewidth=.7, alpha=.7) ref_metrics = eval_init_metrics(metrics, ds) for i, name in enumerate(metrics): axes[i].plot([-.5, len(model_names)+.5], (ref_metrics[i],) *2, '--k', linewidth=4, alpha=.5) axes[i].set_title(name) axes[i].set_xticks(np.array([0,1, 3,4,5,6])) axes[i].set_xticklabels(model_names, rotation=30) axes[1].legend(box_plot_handles, ['16', '128', '1024', f'{4*709} '], loc='lower right') plt.savefig(Path('outputs/') / f'metrics_P_CV{cv}_samples_shuffle_best.png')
def eval_metrics_mat(): """Plots result of all metrics wrt to different init setups as a matrix""" target_ims, input_ims = utils.load_phantom_ds(folder_path='PhantomsRadialNoisy/') ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) models_dir = Path('/media/beta/florian/model_weights') metrics = ['SSIM', 'DSC', 'PSNR'] models, model_names = get_models() [model.set_normalization(norm_dl) for model in models] metrics_mat, metrics_std_te = np.zeros((15,6)), np.zeros((15,6)) fig, ax = plt.subplots(figsize=(5*6,5)) tasks = ['scratch', 'transfer', 'shuffle', 'mean-var'] ref_metrics = [0.7825, 0.965, 33.1472258 ] print(ref_metrics) for i, task in enumerate(['scratch', 'transfer_CV01', 'shuffle', 'meanVarInit']): metrics_te = [] for cv in ['01', '03', '05']: [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_te.append(eval_metrics(models, metrics, ds, 50)) metrics_te = np.array(metrics_te) for j in range(len(metrics)): if j == 2: metrics_mat[j*5+i+1] = metrics_te.mean((0,-1))[:,j] /38 metrics_std_te[j*5+i+1] = metrics_te.std((0,-1))[:,j] /38 else: metrics_mat[j*5+i+1] = metrics_te.mean((0,-1))[:,j] metrics_std_te[j*5+i+1] = metrics_te.std((0,-1))[:,j] metrics_mat[metrics_mat == 0] = np.nan metrics_std_te[metrics_mat == 0] = np.nan ax.matshow(metrics_mat, cmap='twilight', aspect=1/6, vmin=.64, vmax=1) current_cmap = cm.get_cmap() current_cmap.set_bad(color='white') for (i, j), z in np.ndenumerate(metrics_mat): if i % 5 == 0: continue if i < 10: if z < ref_metrics[i//5]: ax.text(j, i, f'{z:0.3f}', ha='center', va='center', c='red') else: if z >= metrics_mat[i//5*5+1:i//5*5+5,j].max(): ax.text(j, i, f'{z:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z > .92)], fontweight='bold') else: ax.text(j, i, f'{z:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z > .92)]) else: if z >= metrics_mat[i//5*5+1:i//5*5+5,j].max(): ax.text(j, i, f'{z*38:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z*38 > 35)], fontweight='bold') else: ax.text(j, i, f'{z*38:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z*38 > 35)]) ax.xaxis.tick_top() ax.set_xticks(range(6)) fontproperties = {'weight' : 'bold', 'size' : 10} ax.set_xticklabels(model_names, fontproperties) ax.tick_params(axis='both', which='minor', labelsize=8) # ax.yaxis.set_major_locator(MultipleLocator(5)) ax.yaxis.set_major_locator(FixedLocator([0,5,10])) ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{metrics[pos]:<7s}')) # ax.set_yticklabels(ax.yaxis.get_majorticklabels(), fontdict={'weight' : 'bold', 'size' : 15}) ax.yaxis.set_minor_locator(MultipleLocator(1)) ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f'{tasks[pos%4-1]}')) # ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: print(pos))) plt.savefig('outputs/metrics_table.png', dpi=300, bbox_inches='tight') fig, ax = plt.subplots(figsize=(5*6,5)) ax.matshow(metrics_mat, cmap='twilight', aspect=1/6, vmin=.64, vmax=1) current_cmap = cm.get_cmap() current_cmap.set_bad(color='white') for (i, j), z in np.ndenumerate(metrics_mat): if i % 5 == 0: continue if i < 10: if z < ref_metrics[i//5]: ax.text(j, i, f'{z:0.3f} -+ {metrics_std_te[i,j]:0.3f}', ha='center', va='center', c='red') else: if z >= metrics_mat[i//5*5+1:i//5*5+5,j].max(): ax.text(j, i, f'{z:0.3f} -+ {metrics_std_te[i,j]:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z > .92)], fontweight='bold') else: ax.text(j, i, f'{z:0.3f} -+ {metrics_std_te[i,j]:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z > .92)]) else: if z >= metrics_mat[i//5*5+1:i//5*5+5,j].max(): ax.text(j, i, f'{z*38:0.3f} -+ {metrics_std_te[i,j]*38:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z*38 > 35)], fontweight='bold') else: ax.text(j, i, f'{z*38:0.3f} -+ {metrics_std_te[i,j]*38:0.3f}', ha='center', va='center', c=[(0,0,0,1), (1,1,1,1)][int(not z*38 > 35)]) ax.xaxis.tick_top() ax.set_xticks(range(6)) fontproperties = {'weight' : 'bold', 'size' : 10} ax.set_xticklabels(model_names, fontproperties) ax.tick_params(axis='both', which='minor', labelsize=8) # ax.yaxis.set_major_locator(MultipleLocator(5)) ax.yaxis.set_major_locator(FixedLocator([0,5,10])) ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{metrics[pos]:<7s}')) # ax.set_yticklabels(ax.yaxis.get_majorticklabels(), fontdict={'weight' : 'bold', 'size' : 15}) ax.yaxis.set_minor_locator(MultipleLocator(1)) ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f'{tasks[pos%4-1]}')) plt.savefig('outputs/metrics_table_std.png', dpi=300, bbox_inches='tight')
def memory_generalization(): """Generalization of models to the walnut dataset while training on the phatom one""" target_ims, input_ims = utils.load_phantom_ds(folder_path='PhantomsRadialNoisy/') # target_ims, input_ims = utils.load_walnut_ds() ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) models_dir = Path('/media/beta/florian/model_weights') metrics = ['SSIM', 'DSC', 'PSNR'] models, model_names = get_models() [model.set_normalization(norm_dl) for model in models] cv = '01' task = 'transfer_CV01' msd_cmap = cm.get_cmap('Purples', 4)(range(2,4,1)) unet_cmap = cm.get_cmap('Oranges', 8)(range(4,8,1)) cmap = np.concatenate([msd_cmap, unet_cmap], 0) metrics_phantoms = np.zeros((len(metrics), len(models), 4)) [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_baseline/MSD_d{d}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_baseline/UNet_f{f}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_phantoms[...,0] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_phantoms[...,1] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_phantoms[...,2] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_phantoms[...,3] = eval_metrics(models, metrics, ds, 4).mean(-1).T fig, axes = plt.subplots(1,3, figsize=(25,7.5)) for i, ax in enumerate(axes): ax.fill_between(range(4), metrics_phantoms[i].min(0), metrics_phantoms[i].max(0), color='k', alpha=.2) # for j in range(metrics_phantoms.shape[1]): # ax.plot(metrics_phantoms[i,j].mean(1), c=cmap[j], linestyle='dashed') models, model_names = get_models() target_ims, input_ims = utils.load_walnut_ds() ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) [model.set_normalization(norm_dl) for model in models] ref_metrics = eval_init_metrics(metrics, ds) for i, (ax, ref) in enumerate(zip(axes, ref_metrics)): ax.plot([0, 3], (ref,) *2, c='k', alpha=.7, linestyle='dashed', linewidth=3) metrics_walnuts = np.zeros((len(metrics), len(models), 4)) [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_baseline/MSD_d{d}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_baseline/UNet_f{f}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_walnuts[...,0] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_walnuts[...,1] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_walnuts[...,2] = eval_metrics(models, metrics, ds, 4).mean(-1).T [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{d}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{f}_P_{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] metrics_walnuts[...,3] = eval_metrics(models, metrics, ds, 4).mean(-1).T for i, ax in enumerate(axes): for j in range(metrics_phantoms.shape[1]): if i == 1: ax.plot(metrics_walnuts[i,j], c=cmap[j], linewidth=2, label=['MSD_d30', 'MSD_d80', 'UNet_f8', 'UNet_f16', 'UNet_f32', 'UNet_f64'][j]) else: ax.plot(metrics_walnuts[i,j], c=cmap[j], linewidth=2) ax.set_title(metrics[i]) ax.set_xticks([0,1,2,3]) ax.set_xticklabels(['random init', 'first epoch', 'best epoch', 'last epoch'], rotation=20) axes[1].legend(loc='lower right') plt.savefig(f'outputs/generalization_{task}_CV{cv}.png', bbox_inches='tight')
def fig_convergence(): """Study convergence by observing how models converge from the first to the last epoch""" global colors target_ims, input_ims = utils.load_phantom_ds(folder_path='PhantomsRadialNoisy/') ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) models_dir = Path('/media/beta/florian/model_weights') metrics = ['SSIM', 'DSC', 'PSNR'] models, model_names = get_models() [model.set_normalization(norm_dl) for model in models] msd_cmap = cm.get_cmap('Purples', 6)(range(3,6,1)) unet_cmap = cm.get_cmap('Oranges', 6)(range(3,6,1)) for cv in ['01', '03', '05']: n_samples = 16 mean_metrics = [] ref_metrics = eval_init_metrics(metrics, ds) for e in range(5): fig, axes = plt.subplots(1, len(metrics), figsize=(len(metrics)*7.5,7.5)) for i, name in enumerate(metrics): axes[i].plot([-.5, len(model_names)+.5], (ref_metrics[i],) *2, '--k', linewidth=3, alpha=.5) axes[i].set_title(name) axes[i].set_xticks(np.array([0,1, 3,4,5,6])) axes[i].set_xticklabels(model_names, rotation=30) for k, task in enumerate(['', '_transfer', '_shuffle']): if k == 0: legend_handles = [] # Load models' state dicts after the first epoch [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'models_Pn_CV{cv}/MSD_d{d}_Pn{n_samples}{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[e], map_location='cpu')) for model, d in zip(models[:2], [30, 80])] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'models_Pn_CV{cv}/UNet_f{f}_Pn{n_samples}{task}_CV{cv}_*/model_*.h5'), key=_nat_sort)[e], map_location='cpu')) for model, f in zip(models[-4:], [8, 16, 32, 64])] mean_metrics.append(np.array(eval_metrics(models, metrics, ds)).mean(-1)) for j in range(len(models)): for i in range(len(metrics)): cmap = msd_cmap if j < 2 else unet_cmap # axes[i].scatter(np.array([0,1, 3,4,5,6])[j] -.21 +.21*k, mean_metrics_1e[j,i], color=cmap[k], linewidth=5) if e == 0: axes[i].arrow(np.array([0,1, 3,4,5,6])[j] -.21 +.21*k, mean_metrics[k][j,i], 0, 0, length_includes_head=False, width=.19, head_width=.3, head_length=[.05,.05,1.5][i], color=cmap[k], alpha=.8) else: axes[i].arrow(np.array([0,1, 3,4,5,6])[j] -.21 +.21*k, mean_metrics[k][j,i], 0, mean_metrics[e*3+k][j,i]-mean_metrics[k][j,i], length_includes_head=False, width=.19, head_width=.3, head_length=[.05,.05,1.5][i], color=cmap[k], alpha=.8) cmap = cm.get_cmap('gray', 6)(range(2,-1,-1)) legend_handles = [axes[2].arrow(0,0,0,0, width=0, head_width=0, color=cmap[i], alpha=.8) for i in range(3)] axes[1].legend(legend_handles, ['scratch', 'transfer', 'shuffle'], loc='lower left') for i in range(len(axes)): axes[i].set_ylim([[0,1],[0,1],[15,40]][i]) plt.savefig(f'outputs/metrics_delta_Pn{n_samples}_CV{cv}_{e+1}e.png', bbox_inches='tight') frames = sorted(Path('outputs/').glob(f'metrics_delta_Pn{n_samples}_CV{cv}_[0-9]e.png'), key=_nat_sort) print(frames) frames = [imread(frame) for frame in frames] mimsave('outputs/convergence.gif', np.array(frames), fps=3) sys.exit()