def evaluate_ncpdnet(multicoil=False, dcomp=False, normalize_image=False, n_iter=10, n_filters=32, n_primal=5, non_linearity='relu', **eval_kwargs): run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'activation': non_linearity, 'n_iter': n_iter, 'n_filters': n_filters, 'im_size': IM_SIZE, 'dcomp': dcomp, 'normalize_image': normalize_image, } model = NCPDNet(**run_params) return evaluate_nc( model, multicoil=multicoil, dcomp=dcomp, **eval_kwargs, )
def evaluate_ncpdnet(multicoil=False, three_d=False, dcomp=False, normalize_image=False, n_iter=10, n_filters=32, n_primal=5, non_linearity='relu', refine_smaps=True, **eval_kwargs): if three_d: image_size = VOLUME_SIZE else: image_size = IM_SIZE run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'three_d': three_d, 'activation': non_linearity, 'n_iter': n_iter, 'n_filters': n_filters, 'im_size': image_size, 'dcomp': dcomp, 'normalize_image': normalize_image, 'refine_smaps': refine_smaps, } model = NCPDNet(**run_params) return evaluate_nc( model, multicoil=multicoil, dcomp=dcomp, three_d=three_d, **eval_kwargs, )
def train_ncpdnet( multicoil=False, three_d=False, dcomp=False, normalize_image=False, n_iter=10, n_filters=32, n_primal=5, non_linearity='relu', refine_smaps=True, **train_kwargs, ): if three_d: image_size = VOLUME_SIZE else: image_size = IM_SIZE run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'three_d': three_d, 'activation': non_linearity, 'n_iter': n_iter, 'n_filters': n_filters, 'im_size': image_size, 'dcomp': dcomp, 'normalize_image': normalize_image, 'refine_smaps': refine_smaps, 'fastmri': not three_d, } if multicoil: ncpdnet_type = 'ncpdnet_sense_' elif three_d: ncpdnet_type = 'ncpdnet_3d_' else: ncpdnet_type = 'ncpdnet_singlecoil_' additional_info = '' if n_iter != 10: additional_info += f'_i{n_iter}' if non_linearity != 'relu': additional_info += f'_{non_linearity}' if multicoil and refine_smaps: additional_info += '_rfs' run_id = f'{ncpdnet_type}_{additional_info}' model = NCPDNet(**run_params) return train_ncnet_block( model, n_iter=n_iter, run_id=run_id, multicoil=multicoil, dcomp=dcomp, three_d=three_d, **train_kwargs, )
def test_ncpdnet_init_and_call(ktraj): model = NCPDNet(n_iter=3, n_primal=3, n_filters=8) image_shape = (640, 372) nspokes = 15 traj = ktraj(image_shape, nspokes) spokelength = image_shape[-1] * 2 kspace_shape = spokelength * nspokes model([ tf.zeros([1, 1, kspace_shape, 1], dtype=tf.complex64), traj, (tf.constant([image_shape[-1]]), ), ])
def ncpdnet_qualitative_validation(multicoil=False, three_d=False, dcomp=False, normalize_image=False, n_iter=10, n_filters=32, n_primal=5, non_linearity='relu', refine_smaps=True, brain=False, **eval_kwargs): if three_d: image_size = VOLUME_SIZE else: image_size = IM_SIZE run_params = { 'n_primal': n_primal, 'multicoil': multicoil, 'three_d': three_d, 'activation': non_linearity, 'n_iter': n_iter, 'n_filters': n_filters, 'im_size': image_size, 'dcomp': dcomp, 'normalize_image': normalize_image, 'refine_smaps': refine_smaps, 'output_shape_spec': brain, 'fastmri': not three_d, } model = NCPDNet(**run_params) name = 'pdnet' if dcomp: name += '-dcomp' if normalize_image: name += '-norm' return ncnet_qualitative_validation( model, name, multicoil=multicoil, dcomp=dcomp, three_d=three_d, brain=brain, **eval_kwargs, )
def test_ncpdnet_init_and_call_3d(dcomp, volume_shape): model = NCPDNet( n_iter=1, n_primal=2, n_filters=2, multicoil=False, im_size=volume_shape, three_d=True, dcomp=dcomp, fastmri=False, ) af = 16 traj = get_stacks_of_radial_trajectory(volume_shape, af=af) spokelength = volume_shape[-2] nspokes = volume_shape[-1] // af nstacks = volume_shape[0] kspace_shape = nspokes * spokelength * nstacks extra_args = (tf.constant([volume_shape]), ) if dcomp: nufft_ob = KbNufftModule( im_size=volume_shape, grid_size=None, norm='ortho', ) interpob = nufft_ob._extract_nufft_interpob() nufftob_forw = kbnufft_forward(interpob) nufftob_back = kbnufft_adjoint(interpob) dcomp = calculate_radial_dcomp_tf( interpob, nufftob_forw, nufftob_back, traj[0], stacks=True, ) dcomp = tf.ones([1, tf.shape(dcomp)[0]], dtype=dcomp.dtype) * dcomp[None, :] extra_args += (dcomp, ) res = model([ tf.zeros([1, 1, kspace_shape, 1], dtype=tf.complex64), traj, extra_args, ]) assert res.shape[1:4] == volume_shape