Example #1
0
def cvpr2018_net(vol_size, enc_nf, dec_nf, full_size=True, indexing='ij'):
    """
    From https://github.com/voxelmorph/voxelmorph.

    unet architecture for voxelmorph models presented in the CVPR 2018 paper.
    You may need to modify this code (e.g., number of layers) to suit your project needs.

    :param vol_size: volume size. e.g. (256, 256, 256)
    :param enc_nf: list of encoder filters. right now it needs to be 1x4.
           e.g. [16,32,32,32]
    :param dec_nf: list of decoder filters. right now it must be 1x6 (like voxelmorph-1) or 1x7 (voxelmorph-2)
    :return: the keras model
    """
    import keras.layers as KL
    from keras.initializers import RandomNormal

    ndims = len(vol_size)
    assert ndims==3, "ndims should be 3. found: %d" % ndims

    src = Input(vol_size + (1,), name='input_src')
    tgt = Input(vol_size + (1,), name='input_tgt')

    input_stack = Concatenate(name='concat_inputs')([src, tgt])

    # get the core model
    x = unet3D(input_stack, img_shape=vol_size, out_im_chans=ndims, nf_enc=enc_nf, nf_dec=dec_nf)

    # transform the results into a flow field.
    Conv = getattr(KL, 'Conv%dD' % ndims)
    flow = Conv(ndims, kernel_size=3, padding='same', name='flow',
                  kernel_initializer=RandomNormal(mean=0.0, stddev=1e-5))(x)

    # warp the source with the flow
    y = SpatialTransformer(interp_method='linear', indexing=indexing)([src, flow])
    # prepare model
    model = Model(inputs=[src, tgt], outputs=[y, flow])
    return model
Example #2
0
def color_delta_unet_model(img_shape,
                           n_output_chans,
                           model_name='color_delta_unet',
                           enc_params=None,
                           include_aux_input=False,
                           aux_input_shape=None,
                           do_warp_to_target_space=False):
    x_src = Input(img_shape, name='input_src')
    x_tgt = Input(img_shape, name='input_tgt')
    if aux_input_shape is None:
        aux_input_shape = img_shape

    x_seg = Input(aux_input_shape, name='input_src_aux')
    inputs = [x_src, x_tgt, x_seg]

    if do_warp_to_target_space:  # warp transformed vol to target space in the end
        n_dims = len(img_shape) - 1
        flow_srctotgt = Input(img_shape[:-1] + (n_dims, ), name='input_flow')
        inputs += [flow_srctotgt]

    if include_aux_input:
        unet_inputs = [x_src, x_tgt, x_seg]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2 +
                                             aux_input_shape[-1], )
    else:
        unet_inputs = [x_src, x_tgt]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2, )
    x_stacked = Concatenate(axis=-1)(unet_inputs)

    n_dims = len(img_shape) - 1

    if n_dims == 2:
        color_delta = unet2D(
            x_stacked,
            unet_input_shape,
            n_output_chans,
            nf_enc=enc_params['nf_enc'],
            nf_dec=enc_params['nf_dec'],
            n_convs_per_stage=enc_params['n_convs_per_stage'],
        )
        conv_fn = Conv2D
    else:
        color_delta = unet3D(
            x_stacked,
            unet_input_shape,
            n_output_chans,
            nf_enc=enc_params['nf_enc'],
            nf_dec=enc_params['nf_dec'],
            n_convs_per_stage=enc_params['n_convs_per_stage'],
        )
        conv_fn = Conv3D

    # last conv to get the output shape that we want
    color_delta = conv_fn(n_output_chans,
                          kernel_size=3,
                          padding='same',
                          name='color_delta')(color_delta)

    transformed_out = Add(name='add_color_delta')([x_src, color_delta])
    if do_warp_to_target_space:
        transformed_out = SpatialTransformer(indexing='xy')(
            [transformed_out, flow_srctotgt])

    # kind of silly, but do a reshape so keras doesnt complain about returning an input
    x_seg = Reshape(aux_input_shape, name='aux')(x_seg)

    return Model(inputs=inputs,
                 outputs=[color_delta, transformed_out, x_seg],
                 name=model_name)