def cvpr2018_net(vol_size, enc_nf, dec_nf, full_size=True, indexing='ij'): """ From https://github.com/voxelmorph/voxelmorph. unet architecture for voxelmorph models presented in the CVPR 2018 paper. You may need to modify this code (e.g., number of layers) to suit your project needs. :param vol_size: volume size. e.g. (256, 256, 256) :param enc_nf: list of encoder filters. right now it needs to be 1x4. e.g. [16,32,32,32] :param dec_nf: list of decoder filters. right now it must be 1x6 (like voxelmorph-1) or 1x7 (voxelmorph-2) :return: the keras model """ import keras.layers as KL from keras.initializers import RandomNormal ndims = len(vol_size) assert ndims==3, "ndims should be 3. found: %d" % ndims src = Input(vol_size + (1,), name='input_src') tgt = Input(vol_size + (1,), name='input_tgt') input_stack = Concatenate(name='concat_inputs')([src, tgt]) # get the core model x = unet3D(input_stack, img_shape=vol_size, out_im_chans=ndims, nf_enc=enc_nf, nf_dec=dec_nf) # transform the results into a flow field. Conv = getattr(KL, 'Conv%dD' % ndims) flow = Conv(ndims, kernel_size=3, padding='same', name='flow', kernel_initializer=RandomNormal(mean=0.0, stddev=1e-5))(x) # warp the source with the flow y = SpatialTransformer(interp_method='linear', indexing=indexing)([src, flow]) # prepare model model = Model(inputs=[src, tgt], outputs=[y, flow]) return model
def color_delta_unet_model(img_shape, n_output_chans, model_name='color_delta_unet', enc_params=None, include_aux_input=False, aux_input_shape=None, do_warp_to_target_space=False): x_src = Input(img_shape, name='input_src') x_tgt = Input(img_shape, name='input_tgt') if aux_input_shape is None: aux_input_shape = img_shape x_seg = Input(aux_input_shape, name='input_src_aux') inputs = [x_src, x_tgt, x_seg] if do_warp_to_target_space: # warp transformed vol to target space in the end n_dims = len(img_shape) - 1 flow_srctotgt = Input(img_shape[:-1] + (n_dims, ), name='input_flow') inputs += [flow_srctotgt] if include_aux_input: unet_inputs = [x_src, x_tgt, x_seg] unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2 + aux_input_shape[-1], ) else: unet_inputs = [x_src, x_tgt] unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2, ) x_stacked = Concatenate(axis=-1)(unet_inputs) n_dims = len(img_shape) - 1 if n_dims == 2: color_delta = unet2D( x_stacked, unet_input_shape, n_output_chans, nf_enc=enc_params['nf_enc'], nf_dec=enc_params['nf_dec'], n_convs_per_stage=enc_params['n_convs_per_stage'], ) conv_fn = Conv2D else: color_delta = unet3D( x_stacked, unet_input_shape, n_output_chans, nf_enc=enc_params['nf_enc'], nf_dec=enc_params['nf_dec'], n_convs_per_stage=enc_params['n_convs_per_stage'], ) conv_fn = Conv3D # last conv to get the output shape that we want color_delta = conv_fn(n_output_chans, kernel_size=3, padding='same', name='color_delta')(color_delta) transformed_out = Add(name='add_color_delta')([x_src, color_delta]) if do_warp_to_target_space: transformed_out = SpatialTransformer(indexing='xy')( [transformed_out, flow_srctotgt]) # kind of silly, but do a reshape so keras doesnt complain about returning an input x_seg = Reshape(aux_input_shape, name='aux')(x_seg) return Model(inputs=inputs, outputs=[color_delta, transformed_out, x_seg], name=model_name)