def transfer_model(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } model = MSDRegressionModel(**model_params) # state_dict = torch.load('model_weights/radial_msd_depth80_it5_epoch59_copy.pytorch')['state_dict'] # model.net.load_state_dict(state_dict) agd_ims, fdk_ims = utils.load_walnut_ds() random.seed(0) val_id = random.randrange(len(agd_ims)) input_val, target_val = [fdk_ims.pop(val_id)], [agd_ims.pop(val_id)] # train_id = random.randrange(len(agd_ims)) # input_tr, target_tr = [fdk_ims.pop(train_id)], [agd_ims.pop(train_id)] input_tr, target_tr = fdk_ims, agd_ims val_ds = MultiOrbitDataset(input_val, target_val, device='cuda') train_dl = DataLoader(MultiOrbitDataset(input_tr, target_tr, device='cuda'), batch_size=8, shuffle=True) val_dl = DataLoader(val_ds, batch_size=8, sampler=ValSampler(len(val_ds))) model.set_normalization(train_dl) transfer(model, (train_dl, val_dl))
def svcca_over_training(): """Evolution of SVCCA similarity matrix over training""" target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy') ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False, vert_sym=False) norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500)) patches = get_patches(ds, 50) for width in [80]: model_params = {'c_in': 1, 'c_out': 1, 'width': 1, 'depth': width, 'dilations': [1,2,4,8,16]} fig, axes = plt.subplots(1,4, figsize=(22,5)) ref_model, comp_model = MSDRegressionModel(**model_params), MSDRegressionModel(**model_params) ref_model.set_normalization(norm_dl) comp_model.set_normalization(norm_dl) init_state_dict = torch.load(sorted(models_dir.glob(f'MSD_baseline/MSD_d{width}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu') first_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu') best_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu') last_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu') ref_model.msd.load_state_dict(best_state_dict) for i, state_dict in enumerate([init_state_dict, first_state_dict, best_state_dict, last_state_dict]): comp_model.msd.load_state_dict(state_dict) if i == 0: # shuffle_weights(comp_model.msd) pass # reps = get_model_representation([ref_model, comp_model], patches, # [get_unet_layers(ref_model), get_unet_layers(comp_model)], sample_rate=10) reps = get_model_representation([ref_model, comp_model], patches, [range(1,81), range(1,81)], sample_rate=10) mat = axes[i].matshow(get_svcca_matrix(*reps), vmin=.2, vmax=1) if i == 0: axes[i].set_ylabel('BEST') axes[i].set_xlabel(['INIT', 'FIRST', 'BEST', 'LAST'][i]) axes[i].set_title(['(a)', '(b)', '(c)', '(d)'][i], loc='left') # axes[0].set_yticks(rang) for ax in axes: ax.set_xticks([]) ax.yaxis.tick_right() # ax.set_yticks(range(9)) # ax.set_yticklabels(unet_layers) cax = fig.add_axes([.931, 0.1, .02, 0.8]) plt.colorbar(mat, cax=cax) plt.savefig(f'outputs/svcca_training_MSD_d{width}_transfer_CV01.png', bbox_inches='tight')
def train_model_CV(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } batch_size = 32 transforms = TransformList([RandomHFlip()]) target_ims, input_ims = utils.load_phantom_ds( folder_path='PhantomsRadialNoisy/') cv_split_generator = utils.split_data_CV(input_ims, target_ims, frac=(1 / 7, 2 / 7)) for i, (_, val_set, train_set) in enumerate(cv_split_generator): if i % 2 or i // 2 not in [0, 1, 2]: continue model = MSDRegressionModel(**model_params) train_ds = MultiOrbitDataset(*train_set, device='cuda', transforms=transforms) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_ds = MultiOrbitDataset(*val_set, device='cuda') val_dl = DataLoader(val_ds, batch_size=batch_size, sampler=ValSampler(len(val_ds))) train_params = { 'epochs': 100, 'lr': 2e-3, 'regularization': None, 'cutoff_epoch': 1 } if MODEL_NAME is not None: train_params[ 'save_folder'] = f"model_weights/{MODEL_NAME}_CV{i+1:0>2d}_{datetime.now().strftime('%m%d%H%M%S')}" model.set_normalization( DataLoader(train_ds, batch_size=100, sampler=ValSampler(len(train_ds), min(len(train_ds), 5000)))) train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
def sub_model_hypothesis(): width, depth = 1, 1 fig, ax = plt.subplots() for i, width in enumerate([8, 16, 32, 64]): model_params = {'c_in': 1, 'c_out': 1, 'width': width, 'depth': depth, 'dilations': [1,2,4,8,16]} models = [MSDRegressionModel(**model_params) for _ in range(3)] [model.set_normalization(ds) for model in models] [model.msd.load_state_dict( torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{width}_P_transfer_CV01_CV0{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu') ) for model, cv in zip(models, [1,3,5])] layers = [range(1, depth+1) for model in models] reps = get_model_representation(models, patches, layers, sample_rate=10) dists = [get_pwcca_dist(reps[0], reps[1]), get_pwcca_dist(reps[0], reps[2])] ax.plot(np.array(dists).mean(0), c=colors[i], label=f'UNet_f{width}') ax.fill_between(range(len(dists[0])), np.array(dists).min(0), np.array(dists).max(0), color=colors[i], alpha=.5) ax.set_ylabel('PWCCA distance') # ax.set_xticks(range(9)) # ax.set_xticklabels(unet_layers) plt.legend() plt.savefig('outputs/sub_model_UNet.png')
def train_model(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } batch_size = 32 model = MSDRegressionModel(**model_params) # model= UNetRegressionModel(**model_params) target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy/') test_set, val_set, train_set = split_data(input_ims, target_ims, 1 / 7) train_ds = MultiOrbitDataset(*train_set, device='cuda', transforms=RandomHFlip()) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_ds = MultiOrbitDataset(*val_set, device='cuda') val_dl = DataLoader(val_ds, batch_size=batch_size, sampler=ValSampler(len(val_ds))) train_params = { 'epochs': 30, 'lr': 2e-3, 'regularization': None, 'cutoff_epoch': 5 } if MODEL_NAME is not None: train_params[ 'save_folder'] = f"model_weights/{MODEL_NAME}_{datetime.now().strftime('%m%d%H%M%S')}" model.set_normalization( DataLoader(train_ds, batch_size=100, sampler=ValSampler(len(train_ds), min(len(train_ds), 5000)))) train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
def get_models(): """Easy way to get all the models and their names for tests""" model_params = {'c_in': 1, 'c_out': 1, 'width': 1, 'dilations': [1,2,4,8,16]} model_d30 = MSDRegressionModel(depth=30, **model_params) model_d80 = MSDRegressionModel(depth=80, **model_params) del model_params['width'] model_f8 = UNetRegressionModel(width=8, **model_params) model_f16 = UNetRegressionModel(width=16, **model_params) model_f32 = UNetRegressionModel(width=32, **model_params) model_f64 = UNetRegressionModel(width=64, **model_params) models = (model_d30, model_d80) + (model_f8, model_f16, model_f32, model_f64) model_names = [f'MSD_d{d}' for d in [30, 80]] + [f'UNet_f{f}' for f in [8, 16, 32, 64]] return models, model_names
def train_model(seed=0): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } model = MSDRegressionModel(**model_params) # model.net.load_state_dict(torch.load('model_weights/radial_msd_depth80_it5_epoch59_copy.pytorch')['state_dict']) # regularization = TVRegularization(scaling=1e-3) regularization = None # agd_ims, fdk_ims = utils.load_phantom_ds() agd_ims, fdk_ims = utils.load_walnut_ds() random.seed(seed) test_id = random.randrange(len(agd_ims)) print(f"Using sample {test_id} as validation") input_val, target_val = [fdk_ims.pop(test_id)], [agd_ims.pop(test_id)] # train_id = random.randrange(len(agd_ims)) # input_tr, target_tr = [fdk_ims.pop(train_id)], [agd_ims.pop(train_id)] input_tr, target_tr = fdk_ims, agd_ims batch_size = 32 train_dl = DataLoader(MultiOrbitDataset(input_tr, target_tr, device='cuda'), batch_size=batch_size, shuffle=True) val_ds = MultiOrbitDataset(input_val, target_val, device='cuda') val_dl = DataLoader(val_ds, batch_size=batch_size, sampler=ValSampler(len(val_ds))) kwargs = {} # kwargs = {'save_folder': # f"model_weights/MSD_d80_walnuts_finetuned_{datetime.now().strftime('%m%d%H%M%S')}"} train(model, (train_dl, val_dl), nn.MSELoss(), 20, regularization, lr=2e-3, **kwargs)
def test_model(): model_params = { 'c_in': 1, 'c_out': 1, 'depth': 80, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } model_80 = MSDRegressionModel(**model_params) state_dicts = sorted(glob.glob( 'model_weights/MSD_d80_walnuts_finetuned_1114125135/best*.h5'), key=_nat_sort) model_80.msd.load_state_dict(torch.load(state_dicts[-1])) model_params = { 'c_in': 1, 'c_out': 1, 'depth': 30, 'width': 1, 'dilations': [1, 2, 4, 8, 16], 'loss': 'L2' } model_30 = MSDRegressionModel(**model_params) state_dicts = sorted( glob.glob('model_weights/MSD_d30_walnuts1113135028/best*.h5'), key=_nat_sort) model_30.msd.load_state_dict(torch.load(state_dicts[-1])) agd_ims, fdk_ims = utils.load_walnut_ds() # agd_ims, fdk_ims = utils.load_phantom_ds() random.seed(0) test_id = random.randrange(len(agd_ims)) input_te, target_te = [fdk_ims.pop(test_id)], [agd_ims.pop(test_id)] te_ds = MultiOrbitDataset(input_te, target_te, data_augmentation=False) te_dl = DataLoader(te_ds, batch_size=8, sampler=ValSampler(len(te_ds))) model_80.set_normalization(te_dl) model_30.set_normalization(te_dl) mean, std = test(model_80, te_ds) print( f"Model d80 \n\tMSE: {mean[0]:.4e} +-{std[0]:.4e}, \n\tSSIM: {mean[1]:.4f} +-{std[1]:.4e}, \n\tDSC: {mean[2]:.4f} +-{std[2]:.4e}" ) mean, std = test(model_30, te_ds) print( f"Model d30 \n\tMSE: {mean[0]:.4e} +-{std[0]:.4e}, \n\tSSIM: {mean[1]:.4f} +-{std[1]:.4e}, \n\tDSC: {mean[2]:.4f} +-{std[2]:.4e}" ) sys.exit() model.msd.load_state_dict(torch.load(state_dicts[-1])) with evaluate(model): for i, (input_, target) in enumerate(te_dl): pred = model(input_) print( f"MSE: {mse(pred, target):.4e}, SSIM: {ssim(pred, target):.4f}, DSC: {dsc(pred, target):.4f}" ) imsave( f'outputs/test_pred_{i+1}.tif', np.clip( np.concatenate([ input_[0, 0].cpu().numpy(), pred[0, 0].cpu().numpy(), target[0, 0].cpu().numpy() ], axis=-1), 0, None)) plt.figure(figsize=(10, 10)) plt.plot(epochs, losses) plt.savefig('outputs/training.png')
# Create an additional list in order to process 2D slices for evaluation. # This list is necessary to remember which slices correspond to each walnut. test_ds.append(ImageDataset(inp_imgs, tgt_imgs)) test_size += len(ImageDataset(inp_imgs, tgt_imgs)) print('Test set size', str(len(ImageDataset(inp_imgs, tgt_imgs)))) ######################################################################################### # Create Model # ######################################################################################### # Create MS-D Net if architecture== 'msd': model = MSDRegressionModel(in_channels, out_channels, depth, width, dilations = dilations, loss = loss_f, parallel=True) #find another way to print model print(model) # Create U-Net elif architecture== 'unet': model = UNetRegressionModel(run_network_path, in_channels, out_channels, depth, width, loss_function=loss_f, lr=lr, opt=opt, dilation=dilations, reflect=True, conv3d=False) elif architecture== 'unet_jordi': model = UNetRegressionModelJordi(run_network_path, in_channels, out_channels, depth, width, loss_function=loss_f, dilation=dilations, reflect=True, conv3d=False) ######################################################################################### # Train Model # ######################################################################################### if train==True: print('Training model..') # Define dataloaders