Пример #1
0
def randflow_ronneberger_model(img_shape,
                   model,
                   model_name='randflow_ronneberger_model',
                   flow_sigma=None,
                   flow_amp=None,
                   blur_sigma=5,
                   interp_mode='linear',
                    indexing='xy',
                   ):
    n_dims = len(img_shape) - 1

    x_in = Input(img_shape, name='img_input_randwarp')

    if n_dims == 3:
        n_pools = 5

        flow = MaxPooling3D(2)(x_in)
        for i in range(n_pools-1):
            flow = MaxPooling3D(2)(flow)
        # reduce flow by a factor of 64 until we have roughly 3x3x3
        flow_shape = tuple([int(s/(2**n_pools)) for s in img_shape[:-1]] + [n_dims])
        print('Smallest flow shape: {}'.format(flow_shape))
    else:
        #flow = flow_placeholder
        flow = x_in
        flow_shape = img_shape[:-1] + (n_dims,)
    # random flow field
    if flow_amp is None:
        # sigmas and blurring are hand-tuned to be similar to gaussian with stddev = 10, with smooth upsampling
        flow = RandFlow(name='randflow', img_shape=flow_shape, blur_sigma=0., flow_sigma=flow_sigma * 8)(flow)

    if n_dims == 3:
        print(flow_shape)
        print(flow.get_shape())
        flow = Reshape(flow_shape)(flow)
        flow_shape = flow_shape[:-1]
        for i in range(n_pools):
            flow_shape = [fs * 2 for fs in flow_shape]
            flow = Lambda(interp_upsampling, output_shape=tuple(flow_shape) + (n_dims,))(flow)
            if i > 0 and i < 4:
                print(flow_shape)
                flow = BlurFlow(img_shape=tuple(flow_shape) + (n_dims,), blur_sigma=5,
                    )(flow)#min(7, flow_shape[0]/4.))(flow)

        flow = basic_networks._pad_or_crop_to_shape(flow, flow_shape, img_shape)
        flow = BlurFlow(img_shape[:-1] + (n_dims,), blur_sigma=3)(flow)
        flow = Reshape(img_shape[:-1] + (n_dims,), name='randflow_out')(flow)
    else:
        flow = Reshape(img_shape[:-1] + (n_dims,), name='randflow_out')(flow)

    x_warped = SpatialTransformer(
        indexing=indexing, interp_method=interp_mode, name='densespatialtransformer_img')([x_in, flow])

    if model is not None:
        model_outputs = model(x_warped)
        if not isinstance(model_outputs, list):
            model_outputs = [model_outputs]
    else:
        model_outputs = [x_warped, flow]
    return Model(inputs=[x_in], outputs=model_outputs, name=model_name)
Пример #2
0
def generator_net(vol_size, enc_nf, dec_nf, vel_resize, ti_flow, int_steps):

    ndims = len(vol_size)
    assert ndims in [
        1, 2, 3
    ], "ndims should be one of 1, 2, or 3. found: {}".format(ndims)

    unet = unet_core(vol_size, enc_nf, dec_nf, vel_resize, ti_flow)

    # target delta in binary representation
    b_in = Input(shape=(16, ))

    x_in = unet.inputs[0]
    x_out, x_ti = unet.outputs

    inputs = [x_in, b_in]

    Conv = getattr(KL, 'Conv{}D'.format(ndims))

    # time dependent flow component (i.e. aging)
    flow_mean = Conv(ndims,
                     kernel_size=3,
                     padding='same',
                     name='flow_mean',
                     kernel_initializer=RandomNormal(mean=0.0,
                                                     stddev=1e-5))(x_out)

    flow_log_sigma = Conv(
        ndims,
        kernel_size=3,
        padding='same',
        name='flow_log_sigma',
        kernel_initializer=RandomNormal(mean=0.0, stddev=1e-10),
        bias_initializer=keras.initializers.Constant(value=-10))(x_out)

    flow_params = concatenate([flow_mean, flow_log_sigma])

    # sample velocity field (using reparameterization)
    z_sample = Sample(name='z_sample')([flow_mean, flow_log_sigma])

    # integrate flow
    flow = VecInt(method='ss', name='flow_int',
                  int_steps=int_steps)([z_sample, b_in])

    # resize to full size
    if vel_resize != 1.0:
        flow = trf_resize(flow, vel_resize, name='flow_resize')

    # time independent flow component
    flow_ti = Conv(ndims, kernel_size=3, padding='same', name='flow_ti')(x_ti)
    flow_ti = trf_resize(flow_ti, 0.25, name='flow_ti_resize')

    if ti_flow:
        flow = add([flow, flow_ti])

    y = SpatialTransformer(interp_method='linear', indexing='ij')([x_in, flow])

    outputs = [y, flow_params, flow_ti]

    return Model(inputs=inputs, outputs=outputs)
Пример #3
0
def warp_model(img_shape, interp_mode='linear', indexing='ij'):
    n_dims = len(img_shape) - 1
    img_in = Input(img_shape, name='input_img')
    flow_in = Input(img_shape[:-1] + (n_dims,), name='input_flow')

    img_warped = SpatialTransformer(
        interp_mode, indexing=indexing, name='densespatialtransformer_img')([img_in, flow_in])

    return Model(inputs=[img_in, flow_in], outputs=img_warped, name='warp_model')
Пример #4
0
def color_delta_unet_model(img_shape,
                           n_output_chans,
                           model_name='color_delta_unet',
                           enc_params=None,
                           include_aux_input=False,
                           aux_input_shape=None,
                           do_warp_to_target_space=False
                           ):
    x_src = Input(img_shape, name='input_src')
    x_tgt = Input(img_shape, name='input_tgt')
    if aux_input_shape is None:
        aux_input_shape = img_shape

    x_seg = Input(aux_input_shape, name='input_src_aux')
    inputs = [x_src, x_tgt, x_seg]

    if do_warp_to_target_space: # warp transformed vol to target space in the end
        n_dims = len(img_shape) - 1
        flow_srctotgt = Input(img_shape[:-1] + (n_dims,), name='input_flow')
        inputs += [flow_srctotgt]

    if include_aux_input:
        unet_inputs = [x_src, x_tgt, x_seg]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2 + aux_input_shape[-1],)
    else:
        unet_inputs = [x_src, x_tgt]
        unet_input_shape = img_shape[:-1] + (img_shape[-1] * 2,)
    x_stacked = Concatenate(axis=-1)(unet_inputs)

    n_dims = len(img_shape) - 1

    if n_dims == 2:
        color_delta = basic_networks.unet2D(x_stacked, unet_input_shape, n_output_chans,
                                       nf_enc=enc_params['nf_enc'],
                                        nf_dec=enc_params['nf_dec'],
                                       n_convs_per_stage=enc_params['n_convs_per_stage'],
                                       include_residual=False)
        conv_fn = Conv2D
    else:
        color_delta = basic_networks.unet3D(x_stacked, unet_input_shape, n_output_chans,
                                       nf_enc=enc_params['nf_enc'],
                                        nf_dec=enc_params['nf_dec'],
                                       n_convs_per_stage=enc_params['n_convs_per_stage'],
                                       include_residual=False)
        conv_fn = Conv3D

    # last conv to get the output shape that we want
    color_delta = conv_fn(n_output_chans, kernel_size=3, padding='same', name='color_delta')(color_delta)

    transformed_out = Add(name='add_color_delta')([x_src, color_delta])
    if do_warp_to_target_space:
        transformed_out = SpatialTransformer(indexing='xy')([transformed_out, flow_srctotgt])

    # kind of silly, but do a reshape so keras doesnt complain about returning an input
    x_seg = Reshape(aux_input_shape, name='aux')(x_seg)

    return Model(inputs=inputs, outputs=[color_delta, transformed_out, x_seg], name=model_name)
Пример #5
0
def perform_test_affine_transform(arr):
    """
    perform an identity affine transformation for testing
    :param arr: image array
    :return:
    """
    tx = create_identity_transform_stn()
    ST = SpatialTransformer(interp_method='linear', indexing='ij')
    x = ST([tf.convert_to_tensor(arr), tf.convert_to_tensor(tx)])
    plot_side2side(arr, x, 'before after')
    return 1
Пример #6
0
 def build_morph(self, enc_nf, dec_nf):
     # 实例化morph main body
     morph = self.morph_main_body(enc_nf, dec_nf)
     # 指定输入
     moving_images, fixed_images = morph.input
     # 实例化空间变化 并且指定输出
     moved_images = SpatialTransformer(interp_method='linear',
                                       indexing='ij',
                                       name='moved_images')(
                                           [moving_images, morph.output])
     # 返回模型
     return Model([moving_images, fixed_images],
                  [moved_images, morph.output])
Пример #7
0
def interp_upsampling(V):
    """
    upsample a field by a factor of 2
    TODO: should switch this to use neuron.utils.interpn()
    """
    V = tf.reshape(V, [-1] + V.get_shape().as_list()[1:])
    grid = volshape_to_ndgrid([f*2 for f in V.get_shape().as_list()[1:-1]])
    grid = [tf.cast(f, 'float32') for f in grid]
    grid = [tf.expand_dims(f/2 - f, 0) for f in grid]
    offset = tf.stack(grid, len(grid) + 1)

    V = SpatialTransformer(interp_method='linear')([V, offset])
    return V
Пример #8
0
    def create_morph(self, enc_nf, dec_nf):
        # 输入层
        moving_images = Input(shape=[*self.vol_shape, self.channels],
                              name='moving_images')
        fixed_images = Input(shape=[*self.vol_shape, self.channels],
                             name='fixed_images')
        # 链接
        x_in = concatenate([moving_images, fixed_images])
        # 卷积实现下采样
        x_enc = [x_in]

        for i in range(len(enc_nf)):
            if i == 0:
                x_enc.append(self.conv_block(x_enc[-1], enc_nf[i]))
            else:
                x_enc.append(self.conv_block(x_enc[-1], enc_nf[i], 2))

        # 上采样
        x = UpSampling2D()(x_enc[-1])
        x = concatenate([x, x_enc[-2]])
        x = self.conv_block(x, dec_nf[0])

        x = UpSampling2D()(x)
        x = concatenate([x, x_enc[-3]])
        x = self.conv_block(x, dec_nf[1])

        x = UpSampling2D()(x)
        x = concatenate([x, x_enc[-4]])
        x = self.conv_block(x, dec_nf[2])

        x = self.conv_block(x, dec_nf[3])

        x = UpSampling2D()(x)
        x = concatenate([x, x_enc[-5]])
        x = self.conv_block(x, dec_nf[4])

        # 特征形成形变场
        Fai = Conv2D(filters=2,
                     kernel_size=3,
                     padding='same',
                     kernel_initializer='he_normal',
                     strides=1,
                     name='Fai')(x)

        # 恢复
        moved_images = SpatialTransformer(interp_method='linear',
                                          indexing='ij',
                                          name='moved_images')(
                                              [moving_images, Fai])
        return Model([moving_images, fixed_images], [moved_images, Fai])
Пример #9
0
def randflow_model(img_shape,
                   model,
                   model_name='randflow_model',
                   flow_sigma=None,
                   flow_amp=None,
                   blur_sigma=5,
                   interp_mode='linear',
                    indexing='xy',
                   ):
    n_dims = len(img_shape) - 1

    x_in = Input(img_shape, name='img_input_randwarp')

    if n_dims == 3:
        flow = MaxPooling3D(2)(x_in)
        flow = MaxPooling3D(2)(flow)
        blur_sigma = int(np.ceil(blur_sigma / 4.))
        flow_shape = tuple([int(s/4) for s in img_shape[:-1]] + [n_dims])
    else:
        flow = x_in
        flow_shape = img_shape[:-1] + (n_dims,)

    # random flow field
    if flow_amp is None:
        flow = RandFlow(name='randflow', img_shape=flow_shape, blur_sigma=blur_sigma, flow_sigma=flow_sigma)(flow)
    elif flow_sigma is None:
        flow = RandFlow_Uniform(name='randflow', img_shape=flow_shape, blur_sigma=blur_sigma, flow_amp=flow_amp)(flow)

    if n_dims == 3:
        flow = Reshape(flow_shape)(flow)
        # upsample with linear interpolation
        flow = Lambda(interp_upsampling)(flow)
        flow = Lambda(interp_upsampling, output_shape=img_shape[:-1] + (n_dims,))(flow)
        flow = Reshape(img_shape[:-1] + (n_dims,), name='randflow_out')(flow)
    else:
        flow = Reshape(img_shape[:-1] + (n_dims,), name='randflow_out')(flow)

    x_warped = SpatialTransformer(interp_method=interp_mode, name='densespatialtransformer_img', indexing=indexing)(
        [x_in, flow])


    if model is not None:
        model_outputs = model(x_warped)
        if not isinstance(model_outputs, list):
            model_outputs = [model_outputs]
    else:
        model_outputs = [x_warped, flow]
    return Model(inputs=[x_in], outputs=model_outputs, name=model_name)
Пример #10
0
 def __init__(self, fixed, moving, x, y, z, masked_loss=True):
     """
     parameters of the class are
     :param fixed: fixed image array
     :param moving: moving image array to be registered
     :param x: translation in X as a trainable tf variable
     :param y: translation in Y as a trainable tf variable
     :param z: translation in Z as a trainable tf variable
     :param masked_loss: whether to calculate the loss wrt to the masked head only (remove bg)
     """
     self.stn = SpatialTransformer(interp_method='linear', indexing='ij')
     self.iteartion = 0
     self.fixed = fixed
     self.moving = moving
     self.masked_loss = masked_loss
     self.x = x
     self.y = y
     self.z = z
     self.losses = []
Пример #11
0
def apply_transformation(
    x_source,
    x_transformation,
    output_shape,
    conditioning_input_shape,
    transform_name,
    apply_flow_transform=True,
    apply_color_transform=False,
    flow_indexing='xy',
    color_transform_type='WB',
):
    n_dims = len(conditioning_input_shape) - 1

    transformation_shape = x_transformation.get_shape().as_list()[1:]
    x_transformation = Reshape(
        transformation_shape,
        name='{}_dec_out'.format(transform_name))(x_transformation)
    if apply_flow_transform:
        # apply flow transform
        im_out = SpatialTransformer(name='spatial_transformer',
                                    indexing=flow_indexing)(
                                        [x_source, x_transformation])

    elif apply_color_transform:
        # apply color transform
        print('Applying color transform {}'.format(color_transform_type))
        if color_transform_type == 'delta':
            x_color_out = Add()([x_source, x_transformation])
        elif color_transform_type == 'mult':
            x_color_out = Multiply()([x_source, x_transformation])
        else:
            raise NotImplementedError(
                'Only color transform types delta and mult are supported!')
        im_out = Reshape(output_shape, name='color_transformer')(x_color_out)
    else:
        im_out = x_transformation

    return im_out, x_transformation
Пример #12
0
def cvae_learned_prior_trainer_wrapper(
        ae_input_shapes,
        ae_input_names,
        conditioning_input_shapes,
        conditioning_input_names,
        latent_distribution='normal',
        output_shape=None,
        model_name='transformer_trainer',
        seg_model=None,
        transform_encoder_model=None,
        transformer_model=None,
        prior_encoder_model=None,
        transform_type='flow',
        transform_latent_shape=(50, ),
        include_aug_matrix=False,
        n_outputs=1,
):
    '''''' '''''' '''''' '''
    VTE transformer train model
        - takes I, I+J as input
        - encodes I+J to z
        - condition_on_image = True means that the transform is decoded from the transform+image embedding,
                otherwise it is decoded from only the transform embedding
        - decodes latent embedding into transform and applies it
    ''' '''''' '''''' ''''''
    ae_inputs, ae_stack, conditioning_inputs, cond_stack = _collect_inputs(
        ae_input_shapes, ae_input_names, conditioning_input_shapes,
        conditioning_input_names)
    conditioning_input_shape = cond_stack.get_shape().as_list()[1:]

    inputs = ae_inputs + conditioning_inputs

    if include_aug_matrix:
        T_in = Input((3, 3), name='transform_input')
        inputs += [T_in]
        # ae_stack = SpatialTransformer(name='st_affine_stack')([ae_stack, T_in])
        # cond_stack = SpatialTransformer(name='st_affine_img')([cond_stack, T_in])
        ae_inputs = [
            SpatialTransformer(name='st_affine_{}'.format(ae_input_names[ii]))(
                [ae_input, T_in]) for ii, ae_input in enumerate(ae_inputs)
        ]

        conditioning_inputs = [
            SpatialTransformer(
                name='st_affine_{}'.format(conditioning_input_names[ii]))(
                    [cond_input, T_in])
            for ii, cond_input in enumerate(conditioning_inputs)
        ]

    segs = seg_model(conditioning_inputs[0])

    if latent_distribution == 'normal':
        z_mean_prior, z_logvar_prior = prior_encoder_model(
            conditioning_inputs)  # + [segs])
        z_mean_prior = Reshape(transform_latent_shape,
                               name='latent_mean_prior')(z_mean_prior)
        z_logvar_prior = Reshape(transform_latent_shape,
                                 name='latent_logvar_prior')(z_logvar_prior)

        # encode x_stacked into z
        z_mean, z_logvar = transform_encoder_model(ae_inputs)  # + [segs])

        z_mean = Reshape(transform_latent_shape, name='latent_mean')(z_mean)
        z_logvar = Reshape(transform_latent_shape,
                           name='latent_logvar')(z_logvar)

        z_sampled = Lambda(transform_network_utils.sampling,
                           output_shape=transform_latent_shape,
                           name='lambda_sampling')([z_mean, z_logvar])

        z_outputs = [z_mean, z_logvar, z_mean_prior, z_logvar_prior]
    elif latent_distribution == 'categorical':
        z_prior = prior_encoder_model(conditioning_inputs)  # + [segs])
        z_post = transform_encoder_model(ae_inputs)

        temperature = 5.
        z_sampled = Lambda(
            lambda x: gumbel_softmax(x, temperature=temperature, hard=False),
            name='lambda_sampling_gumbel')(z_post)

        # actually normalize outputs to a probability, since this is what the KL loss expects
        z_prior = Activation('softmax',
                             name='latent_categorical_prior')(z_prior)
        z_post = Activation('softmax', name='latent_categorical')(z_post)

        z_outputs = [z_post, z_prior]
    decoder_out = transformer_model(conditioning_inputs + [segs, z_sampled])

    if transform_type == 'flow':
        im_out, transform_out = decoder_out
        transform_shape = transform_out.get_shape().as_list()[1:]

        transform_out = Reshape(transform_shape,
                                name='decoder_flow_out')(transform_out)
        im_out = Reshape(output_shape, name='spatial_transformer')(im_out)
    elif transform_type == 'color':
        im_out, transform_out = decoder_out

        transform_out = Reshape(output_shape,
                                name='decoder_color_out')(transform_out)
        im_out = Reshape(output_shape, name='color_transformer')(im_out)
    else:
        im_out = decoder_out

    if transform_type is not None:
        return Model(inputs=inputs,
                     outputs=[im_out] * n_outputs + [transform_out] +
                     z_outputs,
                     name=model_name)
    else:
        return Model(inputs=inputs,
                     outputs=[im_out] * n_outputs + z_outputs,
                     name=model_name)
Пример #13
0
def cvae_trainer_wrapper(
        ae_input_shapes,
        ae_input_names,
        conditioning_input_shapes,
        conditioning_input_names,
        output_shape=None,
        model_name='transformer_trainer',
        transform_encoder_model=None,
        transformer_model=None,
        transform_type='flow',
        transform_latent_shape=(50, ),
        include_aug_matrix=False,
        n_outputs=1,
):
    '''''' '''''' '''''' '''
    VTE transformer train model
        - takes I, I+J as input
        - encodes I+J to z
        - condition_on_image = True means that the transform is decoded from the transform+image embedding,
                otherwise it is decoded from only the transform embedding
        - decodes latent embedding into transform and applies it
    ''' '''''' '''''' ''''''
    ae_inputs, ae_stack, conditioning_inputs, cond_stack = _collect_inputs(
        ae_input_shapes, ae_input_names, conditioning_input_shapes,
        conditioning_input_names)
    conditioning_input_shape = cond_stack.get_shape().as_list()[1:]

    inputs = ae_inputs + conditioning_inputs

    if include_aug_matrix:
        T_in = Input((3, 3), name='transform_input')
        inputs += [T_in]
        # ae_stack = SpatialTransformer(name='st_affine_stack')([ae_stack, T_in])
        # cond_stack = SpatialTransformer(name='st_affine_img')([cond_stack, T_in])
        ae_inputs = [
            SpatialTransformer(name='st_affine_{}'.format(ae_input_names[ii]))(
                [ae_input, T_in]) for ii, ae_input in enumerate(ae_inputs)
        ]

        conditioning_inputs = [
            SpatialTransformer(
                name='st_affine_{}'.format(conditioning_input_names[ii]))(
                    [cond_input, T_in])
            for ii, cond_input in enumerate(conditioning_inputs)
        ]
    # encode x_stacked into z
    z_mean, z_logvar = transform_encoder_model(ae_inputs)

    z_mean = Reshape(transform_latent_shape, name='latent_mean')(z_mean)
    z_logvar = Reshape(transform_latent_shape, name='latent_logvar')(z_logvar)

    z_sampled = Lambda(transform_network_utils.sampling,
                       output_shape=transform_latent_shape,
                       name='lambda_sampling')([z_mean, z_logvar])

    decoder_out = transformer_model(conditioning_inputs + [z_sampled])

    if transform_type == 'flow':
        im_out, transform_out = decoder_out
        transform_shape = transform_out.get_shape().as_list()[1:]

        transform_out = Reshape(transform_shape,
                                name='decoder_flow_out')(transform_out)
        im_out = Reshape(output_shape, name='spatial_transformer')(im_out)
    elif transform_type == 'color':
        im_out, transform_out = decoder_out

        transform_out = Reshape(output_shape,
                                name='decoder_color_out')(transform_out)
        im_out = Reshape(output_shape, name='color_transformer')(im_out)
    else:
        im_out = decoder_out

    if transform_type is not None:
        return Model(inputs=inputs,
                     outputs=[im_out] * n_outputs +
                     [transform_out, z_mean, z_logvar],
                     name=model_name)
    else:
        return Model(inputs=inputs,
                     outputs=[im_out] * n_outputs + [z_mean, z_logvar],
                     name=model_name)