def transfer_model():

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 30,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    model = MSDRegressionModel(**model_params)
    # state_dict = torch.load('model_weights/radial_msd_depth80_it5_epoch59_copy.pytorch')['state_dict']
    # model.net.load_state_dict(state_dict)

    agd_ims, fdk_ims = utils.load_walnut_ds()
    random.seed(0)
    val_id = random.randrange(len(agd_ims))
    input_val, target_val = [fdk_ims.pop(val_id)], [agd_ims.pop(val_id)]

    # train_id = random.randrange(len(agd_ims))
    # input_tr, target_tr = [fdk_ims.pop(train_id)], [agd_ims.pop(train_id)]
    input_tr, target_tr = fdk_ims, agd_ims

    val_ds = MultiOrbitDataset(input_val, target_val, device='cuda')
    train_dl = DataLoader(MultiOrbitDataset(input_tr, target_tr,
                                            device='cuda'),
                          batch_size=8,
                          shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=8, sampler=ValSampler(len(val_ds)))

    model.set_normalization(train_dl)

    transfer(model, (train_dl, val_dl))
Beispiel #2
0
def svcca_over_training():
    """Evolution of SVCCA similarity matrix over training"""

    target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy')

    ds = MultiOrbitDataset(input_ims, target_ims, data_augmentation=False, vert_sym=False)
    norm_dl = DataLoader(ds, batch_size=50, sampler=ValSampler(len(ds), 500))

    patches = get_patches(ds, 50) 

    for width in [80]:

        model_params = {'c_in': 1, 'c_out': 1, 'width': 1, 'depth': width, 'dilations': [1,2,4,8,16]}            

        fig, axes = plt.subplots(1,4, figsize=(22,5))

        ref_model, comp_model = MSDRegressionModel(**model_params), MSDRegressionModel(**model_params)
        ref_model.set_normalization(norm_dl)
        comp_model.set_normalization(norm_dl)

        init_state_dict = torch.load(sorted(models_dir.glob(f'MSD_baseline/MSD_d{width}_W_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')
        first_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[0], map_location='cpu')
        best_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')
        last_state_dict = torch.load(sorted(models_dir.glob(f'MSD_phantoms/MSD_d{width}_P_scratch_CV01_*/model_*.h5'), key=_nat_sort)[-1], map_location='cpu')

        ref_model.msd.load_state_dict(best_state_dict)

        for i, state_dict in enumerate([init_state_dict, first_state_dict, best_state_dict, last_state_dict]):

            comp_model.msd.load_state_dict(state_dict)
            if i == 0:
                # shuffle_weights(comp_model.msd)
                pass

            # reps = get_model_representation([ref_model, comp_model], patches, 
            #                                 [get_unet_layers(ref_model), get_unet_layers(comp_model)], sample_rate=10)

            reps = get_model_representation([ref_model, comp_model], patches, 
                                            [range(1,81), range(1,81)], sample_rate=10)
        
            mat = axes[i].matshow(get_svcca_matrix(*reps), vmin=.2, vmax=1)
            if i == 0: axes[i].set_ylabel('BEST')
            axes[i].set_xlabel(['INIT', 'FIRST', 'BEST', 'LAST'][i])
            axes[i].set_title(['(a)', '(b)', '(c)', '(d)'][i], loc='left')

        # axes[0].set_yticks(rang)
        for ax in axes:
            ax.set_xticks([])
            ax.yaxis.tick_right()
            # ax.set_yticks(range(9))
            # ax.set_yticklabels(unet_layers)

        cax = fig.add_axes([.931, 0.1, .02, 0.8])
        plt.colorbar(mat, cax=cax)
        plt.savefig(f'outputs/svcca_training_MSD_d{width}_transfer_CV01.png', bbox_inches='tight')
def train_model_CV():

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 30,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    batch_size = 32

    transforms = TransformList([RandomHFlip()])

    target_ims, input_ims = utils.load_phantom_ds(
        folder_path='PhantomsRadialNoisy/')
    cv_split_generator = utils.split_data_CV(input_ims,
                                             target_ims,
                                             frac=(1 / 7, 2 / 7))

    for i, (_, val_set, train_set) in enumerate(cv_split_generator):
        if i % 2 or i // 2 not in [0, 1, 2]: continue

        model = MSDRegressionModel(**model_params)

        train_ds = MultiOrbitDataset(*train_set,
                                     device='cuda',
                                     transforms=transforms)
        train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_ds = MultiOrbitDataset(*val_set, device='cuda')
        val_dl = DataLoader(val_ds,
                            batch_size=batch_size,
                            sampler=ValSampler(len(val_ds)))

        train_params = {
            'epochs': 100,
            'lr': 2e-3,
            'regularization': None,
            'cutoff_epoch': 1
        }
        if MODEL_NAME is not None:
            train_params[
                'save_folder'] = f"model_weights/{MODEL_NAME}_CV{i+1:0>2d}_{datetime.now().strftime('%m%d%H%M%S')}"

        model.set_normalization(
            DataLoader(train_ds,
                       batch_size=100,
                       sampler=ValSampler(len(train_ds),
                                          min(len(train_ds), 5000))))
        train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
Beispiel #4
0
def sub_model_hypothesis():
    
    width, depth = 1, 1
    fig, ax = plt.subplots()

    for i, width in enumerate([8, 16, 32, 64]):
        model_params = {'c_in': 1, 'c_out': 1, 'width': width, 'depth': depth, 'dilations': [1,2,4,8,16]}
        models = [MSDRegressionModel(**model_params) for _ in range(3)]
        [model.set_normalization(ds) for model in models]

        [model.msd.load_state_dict(
            torch.load(sorted(models_dir.glob(f'UNet_phantoms/UNet_f{width}_P_transfer_CV01_CV0{cv}_*/best_*.h5'), key=_nat_sort)[0], map_location='cpu')
        ) for model, cv in zip(models, [1,3,5])]

        layers = [range(1, depth+1) for model in models]
        reps = get_model_representation(models, patches, layers, sample_rate=10)
        

        dists = [get_pwcca_dist(reps[0], reps[1]), get_pwcca_dist(reps[0], reps[2])]

        ax.plot(np.array(dists).mean(0), c=colors[i], label=f'UNet_f{width}')
        ax.fill_between(range(len(dists[0])), np.array(dists).min(0), np.array(dists).max(0), color=colors[i], alpha=.5)

    ax.set_ylabel('PWCCA distance')
    # ax.set_xticks(range(9))
    # ax.set_xticklabels(unet_layers)
    plt.legend()
    plt.savefig('outputs/sub_model_UNet.png')    
def train_model():

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 30,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    batch_size = 32

    model = MSDRegressionModel(**model_params)
    # model= UNetRegressionModel(**model_params)

    target_ims, input_ims = utils.load_phantom_ds('PhantomsRadialNoisy/')
    test_set, val_set, train_set = split_data(input_ims, target_ims, 1 / 7)

    train_ds = MultiOrbitDataset(*train_set,
                                 device='cuda',
                                 transforms=RandomHFlip())
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_ds = MultiOrbitDataset(*val_set, device='cuda')
    val_dl = DataLoader(val_ds,
                        batch_size=batch_size,
                        sampler=ValSampler(len(val_ds)))

    train_params = {
        'epochs': 30,
        'lr': 2e-3,
        'regularization': None,
        'cutoff_epoch': 5
    }
    if MODEL_NAME is not None:
        train_params[
            'save_folder'] = f"model_weights/{MODEL_NAME}_{datetime.now().strftime('%m%d%H%M%S')}"

    model.set_normalization(
        DataLoader(train_ds,
                   batch_size=100,
                   sampler=ValSampler(len(train_ds), min(len(train_ds),
                                                         5000))))
    train(model, (train_dl, val_dl), nn.MSELoss(), **train_params)
Beispiel #6
0
def get_models():
    """Easy way to get all the models and their names for tests"""

    model_params = {'c_in': 1, 'c_out': 1,
                    'width': 1, 'dilations': [1,2,4,8,16]}

    model_d30 = MSDRegressionModel(depth=30, **model_params)
    model_d80 = MSDRegressionModel(depth=80, **model_params) 
    
    del model_params['width']
    model_f8 = UNetRegressionModel(width=8, **model_params)
    model_f16 = UNetRegressionModel(width=16, **model_params)
    model_f32 = UNetRegressionModel(width=32, **model_params)
    model_f64 = UNetRegressionModel(width=64, **model_params)

    models = (model_d30, model_d80) + (model_f8, model_f16, model_f32, model_f64)
    model_names = [f'MSD_d{d}' for d in [30, 80]] + [f'UNet_f{f}' for f in [8, 16, 32, 64]]

    return models, model_names
def train_model(seed=0):

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 30,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    model = MSDRegressionModel(**model_params)
    # model.net.load_state_dict(torch.load('model_weights/radial_msd_depth80_it5_epoch59_copy.pytorch')['state_dict'])
    # regularization = TVRegularization(scaling=1e-3)
    regularization = None

    # agd_ims, fdk_ims = utils.load_phantom_ds()
    agd_ims, fdk_ims = utils.load_walnut_ds()
    random.seed(seed)
    test_id = random.randrange(len(agd_ims))
    print(f"Using sample {test_id} as validation")
    input_val, target_val = [fdk_ims.pop(test_id)], [agd_ims.pop(test_id)]
    # train_id = random.randrange(len(agd_ims))
    # input_tr, target_tr = [fdk_ims.pop(train_id)], [agd_ims.pop(train_id)]
    input_tr, target_tr = fdk_ims, agd_ims

    batch_size = 32
    train_dl = DataLoader(MultiOrbitDataset(input_tr, target_tr,
                                            device='cuda'),
                          batch_size=batch_size,
                          shuffle=True)
    val_ds = MultiOrbitDataset(input_val, target_val, device='cuda')
    val_dl = DataLoader(val_ds,
                        batch_size=batch_size,
                        sampler=ValSampler(len(val_ds)))

    kwargs = {}
    # kwargs = {'save_folder':
    #           f"model_weights/MSD_d80_walnuts_finetuned_{datetime.now().strftime('%m%d%H%M%S')}"}
    train(model, (train_dl, val_dl),
          nn.MSELoss(),
          20,
          regularization,
          lr=2e-3,
          **kwargs)
def test_model():

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 80,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    model_80 = MSDRegressionModel(**model_params)
    state_dicts = sorted(glob.glob(
        'model_weights/MSD_d80_walnuts_finetuned_1114125135/best*.h5'),
                         key=_nat_sort)
    model_80.msd.load_state_dict(torch.load(state_dicts[-1]))

    model_params = {
        'c_in': 1,
        'c_out': 1,
        'depth': 30,
        'width': 1,
        'dilations': [1, 2, 4, 8, 16],
        'loss': 'L2'
    }
    model_30 = MSDRegressionModel(**model_params)
    state_dicts = sorted(
        glob.glob('model_weights/MSD_d30_walnuts1113135028/best*.h5'),
        key=_nat_sort)
    model_30.msd.load_state_dict(torch.load(state_dicts[-1]))

    agd_ims, fdk_ims = utils.load_walnut_ds()
    # agd_ims, fdk_ims = utils.load_phantom_ds()
    random.seed(0)
    test_id = random.randrange(len(agd_ims))
    input_te, target_te = [fdk_ims.pop(test_id)], [agd_ims.pop(test_id)]

    te_ds = MultiOrbitDataset(input_te, target_te, data_augmentation=False)
    te_dl = DataLoader(te_ds, batch_size=8, sampler=ValSampler(len(te_ds)))

    model_80.set_normalization(te_dl)
    model_30.set_normalization(te_dl)

    mean, std = test(model_80, te_ds)
    print(
        f"Model d80 \n\tMSE: {mean[0]:.4e} +-{std[0]:.4e}, \n\tSSIM: {mean[1]:.4f} +-{std[1]:.4e}, \n\tDSC: {mean[2]:.4f} +-{std[2]:.4e}"
    )

    mean, std = test(model_30, te_ds)
    print(
        f"Model d30 \n\tMSE: {mean[0]:.4e} +-{std[0]:.4e}, \n\tSSIM: {mean[1]:.4f} +-{std[1]:.4e}, \n\tDSC: {mean[2]:.4f} +-{std[2]:.4e}"
    )

    sys.exit()

    model.msd.load_state_dict(torch.load(state_dicts[-1]))
    with evaluate(model):
        for i, (input_, target) in enumerate(te_dl):
            pred = model(input_)
            print(
                f"MSE: {mse(pred, target):.4e}, SSIM: {ssim(pred, target):.4f}, DSC: {dsc(pred, target):.4f}"
            )

            imsave(
                f'outputs/test_pred_{i+1}.tif',
                np.clip(
                    np.concatenate([
                        input_[0, 0].cpu().numpy(), pred[0, 0].cpu().numpy(),
                        target[0, 0].cpu().numpy()
                    ],
                                   axis=-1), 0, None))

    plt.figure(figsize=(10, 10))
    plt.plot(epochs, losses)
    plt.savefig('outputs/training.png')
        
        # Create an additional list in order to process 2D slices for evaluation. 
        # This list is necessary to remember which slices correspond to each walnut. 
        test_ds.append(ImageDataset(inp_imgs, tgt_imgs))
        test_size += len(ImageDataset(inp_imgs, tgt_imgs))
                     
    print('Test set size', str(len(ImageDataset(inp_imgs, tgt_imgs))))


    #########################################################################################
    #                                      Create Model                                     #
    #########################################################################################

    # Create MS-D Net 
    if architecture== 'msd':
        model = MSDRegressionModel(in_channels, out_channels, depth, width,
                                   dilations = dilations, loss = loss_f, parallel=True) #find another way to print model
        print(model)
    # Create U-Net
    elif architecture== 'unet':
        model = UNetRegressionModel(run_network_path, in_channels, out_channels, depth, width, 
            loss_function=loss_f, lr=lr, opt=opt, dilation=dilations, reflect=True, conv3d=False)
    elif architecture== 'unet_jordi':
        model = UNetRegressionModelJordi(run_network_path, in_channels, out_channels, depth, width,
        loss_function=loss_f, dilation=dilations, reflect=True, conv3d=False)
    #########################################################################################
    #                                      Train Model                                      #
    #########################################################################################

    if train==True:
        print('Training model..')
        # Define dataloaders