Exemplo n.º 1
0
def evaluate_ncpdnet(multicoil=False,
                     dcomp=False,
                     normalize_image=False,
                     n_iter=10,
                     n_filters=32,
                     n_primal=5,
                     non_linearity='relu',
                     **eval_kwargs):
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'activation': non_linearity,
        'n_iter': n_iter,
        'n_filters': n_filters,
        'im_size': IM_SIZE,
        'dcomp': dcomp,
        'normalize_image': normalize_image,
    }

    model = NCPDNet(**run_params)
    return evaluate_nc(
        model,
        multicoil=multicoil,
        dcomp=dcomp,
        **eval_kwargs,
    )
Exemplo n.º 2
0
def evaluate_ncpdnet(multicoil=False,
                     three_d=False,
                     dcomp=False,
                     normalize_image=False,
                     n_iter=10,
                     n_filters=32,
                     n_primal=5,
                     non_linearity='relu',
                     refine_smaps=True,
                     **eval_kwargs):
    if three_d:
        image_size = VOLUME_SIZE
    else:
        image_size = IM_SIZE
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'three_d': three_d,
        'activation': non_linearity,
        'n_iter': n_iter,
        'n_filters': n_filters,
        'im_size': image_size,
        'dcomp': dcomp,
        'normalize_image': normalize_image,
        'refine_smaps': refine_smaps,
    }

    model = NCPDNet(**run_params)
    return evaluate_nc(
        model,
        multicoil=multicoil,
        dcomp=dcomp,
        three_d=three_d,
        **eval_kwargs,
    )
def train_ncpdnet(
    multicoil=False,
    three_d=False,
    dcomp=False,
    normalize_image=False,
    n_iter=10,
    n_filters=32,
    n_primal=5,
    non_linearity='relu',
    refine_smaps=True,
    **train_kwargs,
):
    if three_d:
        image_size = VOLUME_SIZE
    else:
        image_size = IM_SIZE
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'three_d': three_d,
        'activation': non_linearity,
        'n_iter': n_iter,
        'n_filters': n_filters,
        'im_size': image_size,
        'dcomp': dcomp,
        'normalize_image': normalize_image,
        'refine_smaps': refine_smaps,
        'fastmri': not three_d,
    }

    if multicoil:
        ncpdnet_type = 'ncpdnet_sense_'
    elif three_d:
        ncpdnet_type = 'ncpdnet_3d_'
    else:
        ncpdnet_type = 'ncpdnet_singlecoil_'
    additional_info = ''
    if n_iter != 10:
        additional_info += f'_i{n_iter}'
    if non_linearity != 'relu':
        additional_info += f'_{non_linearity}'
    if multicoil and refine_smaps:
        additional_info += '_rfs'

    run_id = f'{ncpdnet_type}_{additional_info}'
    model = NCPDNet(**run_params)

    return train_ncnet_block(
        model,
        n_iter=n_iter,
        run_id=run_id,
        multicoil=multicoil,
        dcomp=dcomp,
        three_d=three_d,
        **train_kwargs,
    )
def test_ncpdnet_init_and_call(ktraj):
    model = NCPDNet(n_iter=3, n_primal=3, n_filters=8)
    image_shape = (640, 372)
    nspokes = 15
    traj = ktraj(image_shape, nspokes)
    spokelength = image_shape[-1] * 2
    kspace_shape = spokelength * nspokes
    model([
        tf.zeros([1, 1, kspace_shape, 1], dtype=tf.complex64),
        traj,
        (tf.constant([image_shape[-1]]), ),
    ])
Exemplo n.º 5
0
def ncpdnet_qualitative_validation(multicoil=False,
                                   three_d=False,
                                   dcomp=False,
                                   normalize_image=False,
                                   n_iter=10,
                                   n_filters=32,
                                   n_primal=5,
                                   non_linearity='relu',
                                   refine_smaps=True,
                                   brain=False,
                                   **eval_kwargs):
    if three_d:
        image_size = VOLUME_SIZE
    else:
        image_size = IM_SIZE
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'three_d': three_d,
        'activation': non_linearity,
        'n_iter': n_iter,
        'n_filters': n_filters,
        'im_size': image_size,
        'dcomp': dcomp,
        'normalize_image': normalize_image,
        'refine_smaps': refine_smaps,
        'output_shape_spec': brain,
        'fastmri': not three_d,
    }

    model = NCPDNet(**run_params)
    name = 'pdnet'
    if dcomp:
        name += '-dcomp'
    if normalize_image:
        name += '-norm'
    return ncnet_qualitative_validation(
        model,
        name,
        multicoil=multicoil,
        dcomp=dcomp,
        three_d=three_d,
        brain=brain,
        **eval_kwargs,
    )
def test_ncpdnet_init_and_call_3d(dcomp, volume_shape):
    model = NCPDNet(
        n_iter=1,
        n_primal=2,
        n_filters=2,
        multicoil=False,
        im_size=volume_shape,
        three_d=True,
        dcomp=dcomp,
        fastmri=False,
    )
    af = 16
    traj = get_stacks_of_radial_trajectory(volume_shape, af=af)
    spokelength = volume_shape[-2]
    nspokes = volume_shape[-1] // af
    nstacks = volume_shape[0]
    kspace_shape = nspokes * spokelength * nstacks
    extra_args = (tf.constant([volume_shape]), )
    if dcomp:
        nufft_ob = KbNufftModule(
            im_size=volume_shape,
            grid_size=None,
            norm='ortho',
        )
        interpob = nufft_ob._extract_nufft_interpob()
        nufftob_forw = kbnufft_forward(interpob)
        nufftob_back = kbnufft_adjoint(interpob)
        dcomp = calculate_radial_dcomp_tf(
            interpob,
            nufftob_forw,
            nufftob_back,
            traj[0],
            stacks=True,
        )
        dcomp = tf.ones([1, tf.shape(dcomp)[0]],
                        dtype=dcomp.dtype) * dcomp[None, :]
        extra_args += (dcomp, )
    res = model([
        tf.zeros([1, 1, kspace_shape, 1], dtype=tf.complex64),
        traj,
        extra_args,
    ])
    assert res.shape[1:4] == volume_shape