Beispiel #1
0
def img_atlas_diff_model(vol_shape,
                         nf_enc,
                         nf_dec,
                         atl_mult=1.0,
                         bidir=True,
                         smooth_pen_layer='diffflow',
                         atl_int_steps=3,
                         vel_resize=1 / 2,
                         int_steps=3,
                         mean_cap=100,
                         atl_layer_name='atlas',
                         **kwargs):

    # vm model
    mn = diff_net(vol_shape,
                  nf_enc,
                  nf_dec,
                  int_steps=int_steps,
                  bidir=bidir,
                  vel_resize=vel_resize,
                  **kwargs)

    # pre-warp model (atlas model)
    pw = atl_img_model(vol_shape,
                       mult=atl_mult,
                       src=mn.inputs[0],
                       atl_layer_name=atl_layer_name)  # Wait I'm confused....

    # stack models
    sm = nrn_utils.stack_models([pw, mn], [[0]])
    # note: sm.outputs might be out of order now

    # TODO: I'm not sure the mean layer is the right direction
    mean_layer = nrn_layers.MeanStream(name='mean_stream', cap=mean_cap)(
        sm.get_layer('neg_diffflow').get_output_at(-1))

    outputs = [
        sm.get_layer('warped_src').get_output_at(-1),
        sm.get_layer('warped_tgt').get_output_at(-1), mean_layer,
        mn.get_layer(smooth_pen_layer).get_output_at(-1)
    ]

    model = keras.models.Model(mn.inputs, outputs)
    return model
Beispiel #2
0
def cond_img_atlas_diff_model(vol_shape,
                              nf_enc,
                              nf_dec,
                              atl_mult=1.0,
                              bidir=True,
                              smooth_pen_layer='diffflow',
                              vel_resize=1 / 2,
                              int_steps=5,
                              nb_conv_features=32,
                              cond_im_input_shape=[10, 12, 14, 1],
                              cond_nb_levels=5,
                              cond_conv_size=[3, 3, 3],
                              use_stack=True,
                              do_mean_layer=True,
                              pheno_input_shape=[1],
                              atlas_feats=1,
                              name='cond_model',
                              mean_cap=100,
                              templcondsi=False,
                              templcondsi_init=None,
                              full_size=False,
                              ret_vm=False,
                              extra_conv_layers=0,
                              **kwargs):

    # conv layer class
    Conv = getattr(KL, 'Conv%dD' % len(vol_shape))

    # vm model. inputs: "atlas" (we will replace this) and
    mn = diff_net(vol_shape,
                  nf_enc,
                  nf_dec,
                  int_steps=int_steps,
                  bidir=bidir,
                  src_feats=atlas_feats,
                  full_size=full_size,
                  vel_resize=vel_resize,
                  ret_flows=(not use_stack),
                  **kwargs)

    # pre-warp model (atlas model)
    pheno_input = KL.Input(pheno_input_shape, name='pheno_input')
    dense_tensor = KL.Dense(np.prod(cond_im_input_shape),
                            activation='elu')(pheno_input)
    reshape_tensor = KL.Reshape(cond_im_input_shape)(dense_tensor)
    pheno_init_model = keras.models.Model(pheno_input, reshape_tensor)
    pheno_tmp_model = nrn_models.conv_dec(nb_conv_features,
                                          cond_im_input_shape,
                                          cond_nb_levels,
                                          cond_conv_size,
                                          nb_labels=nb_conv_features,
                                          final_pred_activation='linear',
                                          input_model=pheno_init_model,
                                          name='atlasmodel')
    last_tensor = pheno_tmp_model.output
    for i in range(extra_conv_layers):
        last_tensor = Conv(nb_conv_features,
                           kernel_size=cond_conv_size,
                           padding='same',
                           name='atlas_ec_%d' % i)(last_tensor)
    pout = Conv(atlas_feats,
                kernel_size=3,
                padding='same',
                name='atlasmodel_c',
                kernel_initializer=RandomNormal(mean=0.0, stddev=1e-7),
                bias_initializer=RandomNormal(mean=0.0,
                                              stddev=1e-7))(last_tensor)
    atlas_input = KL.Input([*vol_shape, atlas_feats], name='atlas_input')
    if not templcondsi:
        atlas_tensor = KL.Add(name='atlas')([atlas_input, pout])
    else:
        atlas_tensor = KL.Add(name='atlas_tmp')([atlas_input, pout])

        # change first channel to be result from seg with another add layer
        tmp_layer = KL.Lambda(lambda x: K.softmax(x[..., 1:]))(
            atlas_tensor)  # this is just tmp. Do not use me.
        cl = Conv(1,
                  kernel_size=1,
                  padding='same',
                  use_bias=False,
                  name='atlas_gen',
                  kernel_initializer=RandomNormal(mean=0, stddev=1e-5))
        ximg = cl(tmp_layer)
        if templcondsi_init is not None:
            w = cl.get_weights()
            w[0] = templcondsi_init.reshape(w[0].shape)
            cl.set_weights(w)
        atlas_tensor = KL.Lambda(
            lambda x: K.concatenate([x[0], x[1][..., 1:]]),
            name='atlas')([ximg, atlas_tensor])

    pheno_model = keras.models.Model([pheno_tmp_model.input, atlas_input],
                                     atlas_tensor)

    # stack models
    inputs = pheno_model.inputs + [mn.inputs[1]]

    if use_stack:
        sm = nrn_utils.stack_models([pheno_model, mn], [[0]])
        neg_diffflow_out = sm.get_layer('neg_diffflow').get_output_at(-1)
        diffflow_out = mn.get_layer(smooth_pen_layer).get_output_at(-1)
        warped_src = sm.get_layer('warped_src').get_output_at(-1)
        warped_tgt = sm.get_layer('warped_tgt').get_output_at(-1)

    else:
        assert bidir
        assert smooth_pen_layer == 'diffflow'
        warped_src, warped_tgt, _, diffflow_out, neg_diffflow_out = mn(
            pheno_model.outputs + [mn.inputs[1]])
        sm = keras.models.Model(inputs, [warped_src, warped_tgt])

    if do_mean_layer:
        mean_layer = nrn_layers.MeanStream(name='mean_stream',
                                           cap=mean_cap)(neg_diffflow_out)
        outputs = [warped_src, warped_tgt, mean_layer, diffflow_out]
    else:
        outputs = [warped_src, warped_tgt, diffflow_out]

    model = keras.models.Model(inputs, outputs, name=name)
    if ret_vm:
        return model, mn
    else:
        return model