Beispiel #1
0
CACHE_FILES = {
    'train': ('./cache_lodopab_train_fbp.npy', None),
    'validation': ('./cache_lodopab_validation_fbp.npy', None)
}

dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = FBPUNetReconstructor(
    ray_trafo,
    log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('fbpunet', 'lodopab', include_learned=False):
    download_params('fbpunet', 'lodopab', include_learned=False)
hyper_params_path = get_hyper_params_path('fbpunet', 'lodopab')
reconstructor.load_hyper_params(hyper_params_path)

#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
# uncomment the next line to generate the cache files (~20 GB)
# generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES)
cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES)
dataset.fbp_dataset = cached_fbp_dataset

#%% train
# reduce the batch size here if the model does not fit into GPU memory
# reconstructor.batch_size = 16
reconstructor.train(dataset)
IMPL = 'astra_cpu'

LOG_DIR = './logs/ellipses_n2self'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/ellipses_n2self'

dataset = get_standard_dataset('ellipses', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = N2SelfReconstructor(
    ray_trafo,
    log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('fbpunet', 'ellipses', include_learned=False):
    download_params('fbpunet', 'ellipses', include_learned=False)
hyper_params_path = get_hyper_params_path('fbpunet', 'ellipses')
reconstructor.load_hyper_params(hyper_params_path)
print(reconstructor.HYPER_PARAMS)
# reconstructor.lr = 0.0001
# reconstructor.HYPER_PARAMS['lr'] = 0.0001

#%% train
# reduce the batch size here if the model does not fit into GPU memory
reconstructor.batch_size = 16
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
Beispiel #3
0
def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt],
                        fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('DIP iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()

reconstructor = DeepImagePriorCTReconstructor(
    dataset.get_ray_trafo(impl=IMPL),
    callback_func=callback_func, callback_func_interval=100)

#%% obtain reference hyper parameters
if not check_for_params('diptv', 'lodopab'):
    download_params('diptv', 'lodopab')
params_path = get_params_path('diptv', 'lodopab')
reconstructor.load_params(params_path)

#%% evaluate
reco = reconstructor.reconstruct(obs)
psnr = PSNR(reco, gt)

print('psnr: {:f}'.format(psnr))
_, ax = plot_images([reco, gt],
                    fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr))
ax[0].set_title('DeepImagePriorCTReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
Beispiel #4
0
IMPL = 'astra_cuda'

LOG_DIR = './logs/lodopab_iradonmap'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_iradonmap'

dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = IRadonMapReconstructor(
    ray_trafo, log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('iradonmap', 'lodopab', include_learned=False):
    download_params('iradonmap', 'lodopab', include_learned=False)
hyper_params_path = get_hyper_params_path('iradonmap', 'lodopab')
reconstructor.load_hyper_params(hyper_params_path)

#%% train
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
    reco = reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))
Beispiel #5
0
IMPL = 'astra_cuda'

LOG_DIR = './logs/lodopab_learnedpd'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_learnedpd'

dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = LearnedPDReconstructor(
    ray_trafo,
    log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('learnedpd', 'lodopab', include_learned=False):
    download_params('learnedpd', 'lodopab', include_learned=False)
hyper_params_path = get_hyper_params_path('learnedpd', 'lodopab')
reconstructor.load_hyper_params(hyper_params_path)

#%% train
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
    reco = reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))
Beispiel #6
0
def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt], fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('TV iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()


reconstructor = TVAdamCTReconstructor(dataset.get_ray_trafo(impl=IMPL),
                                      callback_func=callback_func,
                                      callback_func_interval=100)

#%% obtain reference hyper parameters
if not check_for_params('tvadam', 'lodopab'):
    download_params('tvadam', 'lodopab')
params_path = get_params_path('tvadam', 'lodopab')
reconstructor.load_params(params_path)

#%% evaluate
reco = reconstructor.reconstruct(obs)
psnr = PSNR(reco, gt)

print('psnr: {:f}'.format(psnr))
_, ax = plot_images([reco, gt], fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr))
ax[0].set_title('TVAdamReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))