Example #1
0
def color_delta_unet_model(img_shape,
                           n_output_chans,
                           model_name='color_delta_unet',
                           enc_params=None,
                           include_aux_input=False,
                           aux_input_shape=None,
                           do_warp_to_target_space=False
                           ):
    x_src = Input(img_shape, name='input_src')
    x_tgt = Input(img_shape, name='input_tgt')
    inputs = [x_src, x_tgt]

    if aux_input_shape is None:
        aux_input_shape = img_shape

    x_seg = Input(aux_input_shape, name='input_src_aux')
    inputs += [x_seg]

    if do_warp_to_target_space: # warp transformed vol to target space in the end
        n_dims = len(img_shape) - 1
        flow_srctotgt = Input(img_shape[:-1] + (n_dims,), name='input_flow')
        inputs += [flow_srctotgt]

    if include_aux_input:
        unet_inputs = [x_src, x_tgt, x_seg]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2 + aux_input_shape[-1],)
    else:
        unet_inputs = [x_src, x_tgt]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2,)
    x_stacked = Concatenate(axis=-1)(unet_inputs)

    n_dims = len(img_shape) - 1

    if n_dims == 2:
        color_delta = unet2D(x_stacked, unet_input_shape, n_output_chans,
                             nf_enc=enc_params['nf_enc'],
                             nf_dec=enc_params['nf_dec'],
                             n_convs_per_stage=enc_params['n_convs_per_stage'],
                             )
        conv_fn = Conv2D
    else:
        color_delta = unet3D(x_stacked, unet_input_shape, n_output_chans,
                             nf_enc=enc_params['nf_enc'],
                             nf_dec=enc_params['nf_dec'],
                             n_convs_per_stage=enc_params['n_convs_per_stage'],
                             )
        conv_fn = Conv3D

    # last conv to get the output shape that we want
    color_delta = conv_fn(n_output_chans, kernel_size=3, padding='same', name='color_delta')(color_delta)

    transformed_out = Add(name='add_color_delta')([x_src, color_delta])
    if do_warp_to_target_space:
        transformed_out = SpatialTransformer(indexing='xy')([transformed_out, flow_srctotgt])

    # hacky, but do a reshape so keras doesnt complain about returning an input
    x_seg = Reshape(aux_input_shape, name='aux')(x_seg)

    return Model(inputs=inputs, outputs=[transformed_out, color_delta, x_seg], name=model_name)
Example #2
0
def warp_model(img_shape, interp_mode='linear', indexing='ij'):
    n_dims = len(img_shape) - 1
    img_in = Input(img_shape, name='input_img')
    flow_in = Input(img_shape[:-1] + (n_dims,), name='input_flow')

    img_warped = SpatialTransformer(
        interp_mode, indexing=indexing, name='densespatialtransformer_img')([img_in, flow_in])

    return Model(inputs=[img_in, flow_in], outputs=img_warped, name='warp_model')
Example #3
0
def randflow_model(
    img_shape,
    model,
    model_name='randflow_model',
    flow_sigma=None,
    flow_amp=None,
    blur_sigma=5,
    interp_mode='linear',
    indexing='xy',
):
    n_dims = len(img_shape) - 1

    x_in = Input(img_shape, name='img_input_randwarp')

    if n_dims == 3:
        flow = MaxPooling3D(2)(x_in)
        flow = MaxPooling3D(2)(flow)
        blur_sigma = int(np.ceil(blur_sigma / 4.))
        flow_shape = tuple([int(s / 4) for s in img_shape[:-1]] + [n_dims])
    else:
        flow = x_in
        flow_shape = img_shape[:-1] + (n_dims, )

    # random flow field
    if flow_amp is None:
        flow = RandFlow(name='randflow',
                        img_shape=flow_shape,
                        blur_sigma=blur_sigma,
                        flow_sigma=flow_sigma)(flow)
    elif flow_sigma is None:
        flow = RandFlow_Uniform(name='randflow',
                                img_shape=flow_shape,
                                blur_sigma=blur_sigma,
                                flow_amp=flow_amp)(flow)

    if n_dims == 3:
        flow = Reshape(flow_shape)(flow)
        # upsample with linear interpolation
        flow = Lambda(interp_upsampling)(flow)
        flow = Lambda(interp_upsampling,
                      output_shape=img_shape[:-1] + (n_dims, ))(flow)
        flow = Reshape(img_shape[:-1] + (n_dims, ), name='randflow_out')(flow)
    else:
        flow = Reshape(img_shape[:-1] + (n_dims, ), name='randflow_out')(flow)

    x_warped = SpatialTransformer(interp_method=interp_mode,
                                  name='densespatialtransformer_img',
                                  indexing=indexing)([x_in, flow])

    if model is not None:
        model_outputs = model(x_warped)
        if not isinstance(model_outputs, list):
            model_outputs = [model_outputs]
    else:
        model_outputs = [x_warped, flow]
    return Model(inputs=[x_in], outputs=model_outputs, name=model_name)
Example #4
0
def interp_upsampling(V):
    """
    upsample a field by a factor of 2
    TODO: should switch this to use neuron.utils.interpn()
    """
    V = tf.reshape(V, [-1] + V.get_shape().as_list()[1:])
    grid = volshape_to_ndgrid([f*2 for f in V.get_shape().as_list()[1:-1]])
    grid = [tf.cast(f, 'float32') for f in grid]
    grid = [tf.expand_dims(f/2 - f, 0) for f in grid]
    offset = tf.stack(grid, len(grid) + 1)

    V = SpatialTransformer(interp_method='linear')([V, offset])
    return V
Example #5
0
def cvpr2018_net(vol_size,
                 enc_nf,
                 dec_nf,
                 full_size=True,
                 indexing='ij',
                 model_name=''):
    """
    From https://github.com/voxelmorph/voxelmorph.

    unet architecture for voxelmorph models presented in the CVPR 2018 paper.
    You may need to modify this code (e.g., number of layers) to suit your project needs.

    :param vol_size: volume size. e.g. (256, 256, 256)
    :param enc_nf: list of encoder filters. right now it needs to be 1x4.
           e.g. [16,32,32,32]
    :param dec_nf: list of decoder filters. right now it must be 1x6 (like voxelmorph-1) or 1x7 (voxelmorph-2)
    :return: the keras model
    """
    import tensorflow.keras.layers as KL
    from tensorflow.keras.initializers import RandomNormal

    ndims = len(vol_size)
    assert ndims == 3, "ndims should be 3. found: %d" % ndims

    src = Input(vol_size + (1, ), name='input_src')
    tgt = Input(vol_size + (1, ), name='input_tgt')

    input_stack = Concatenate(name='concat_inputs')([src, tgt])

    # get the core model
    x = unet3D(input_stack,
               img_shape=vol_size,
               out_im_chans=ndims,
               nf_enc=enc_nf,
               nf_dec=dec_nf)

    # transform the results into a flow field.
    Conv = getattr(KL, 'Conv%dD' % ndims)
    flow = Conv(ndims,
                kernel_size=3,
                padding='same',
                name='flow',
                kernel_initializer=RandomNormal(mean=0.0, stddev=1e-5))(x)

    # warp the source with the flow
    y = SpatialTransformer(interp_method='linear',
                           indexing=indexing)([src, flow])
    # prepare model
    model = Model(inputs=[src, tgt], outputs=[y, flow], name=model_name)
    return model