def img_atlas_diff_model(vol_shape, nf_enc, nf_dec, atl_mult=1.0, bidir=True, smooth_pen_layer='diffflow', atl_int_steps=3, vel_resize=1 / 2, int_steps=3, mean_cap=100, atl_layer_name='atlas', **kwargs): # vm model mn = diff_net(vol_shape, nf_enc, nf_dec, int_steps=int_steps, bidir=bidir, vel_resize=vel_resize, **kwargs) # pre-warp model (atlas model) pw = atl_img_model(vol_shape, mult=atl_mult, src=mn.inputs[0], atl_layer_name=atl_layer_name) # Wait I'm confused.... # stack models sm = nrn_utils.stack_models([pw, mn], [[0]]) # note: sm.outputs might be out of order now # TODO: I'm not sure the mean layer is the right direction mean_layer = nrn_layers.MeanStream(name='mean_stream', cap=mean_cap)( sm.get_layer('neg_diffflow').get_output_at(-1)) outputs = [ sm.get_layer('warped_src').get_output_at(-1), sm.get_layer('warped_tgt').get_output_at(-1), mean_layer, mn.get_layer(smooth_pen_layer).get_output_at(-1) ] model = keras.models.Model(mn.inputs, outputs) return model
def cond_img_atlas_diff_model(vol_shape, nf_enc, nf_dec, atl_mult=1.0, bidir=True, smooth_pen_layer='diffflow', vel_resize=1 / 2, int_steps=5, nb_conv_features=32, cond_im_input_shape=[10, 12, 14, 1], cond_nb_levels=5, cond_conv_size=[3, 3, 3], use_stack=True, do_mean_layer=True, pheno_input_shape=[1], atlas_feats=1, name='cond_model', mean_cap=100, templcondsi=False, templcondsi_init=None, full_size=False, ret_vm=False, extra_conv_layers=0, **kwargs): # conv layer class Conv = getattr(KL, 'Conv%dD' % len(vol_shape)) # vm model. inputs: "atlas" (we will replace this) and mn = diff_net(vol_shape, nf_enc, nf_dec, int_steps=int_steps, bidir=bidir, src_feats=atlas_feats, full_size=full_size, vel_resize=vel_resize, ret_flows=(not use_stack), **kwargs) # pre-warp model (atlas model) pheno_input = KL.Input(pheno_input_shape, name='pheno_input') dense_tensor = KL.Dense(np.prod(cond_im_input_shape), activation='elu')(pheno_input) reshape_tensor = KL.Reshape(cond_im_input_shape)(dense_tensor) pheno_init_model = keras.models.Model(pheno_input, reshape_tensor) pheno_tmp_model = nrn_models.conv_dec(nb_conv_features, cond_im_input_shape, cond_nb_levels, cond_conv_size, nb_labels=nb_conv_features, final_pred_activation='linear', input_model=pheno_init_model, name='atlasmodel') last_tensor = pheno_tmp_model.output for i in range(extra_conv_layers): last_tensor = Conv(nb_conv_features, kernel_size=cond_conv_size, padding='same', name='atlas_ec_%d' % i)(last_tensor) pout = Conv(atlas_feats, kernel_size=3, padding='same', name='atlasmodel_c', kernel_initializer=RandomNormal(mean=0.0, stddev=1e-7), bias_initializer=RandomNormal(mean=0.0, stddev=1e-7))(last_tensor) atlas_input = KL.Input([*vol_shape, atlas_feats], name='atlas_input') if not templcondsi: atlas_tensor = KL.Add(name='atlas')([atlas_input, pout]) else: atlas_tensor = KL.Add(name='atlas_tmp')([atlas_input, pout]) # change first channel to be result from seg with another add layer tmp_layer = KL.Lambda(lambda x: K.softmax(x[..., 1:]))( atlas_tensor) # this is just tmp. Do not use me. cl = Conv(1, kernel_size=1, padding='same', use_bias=False, name='atlas_gen', kernel_initializer=RandomNormal(mean=0, stddev=1e-5)) ximg = cl(tmp_layer) if templcondsi_init is not None: w = cl.get_weights() w[0] = templcondsi_init.reshape(w[0].shape) cl.set_weights(w) atlas_tensor = KL.Lambda( lambda x: K.concatenate([x[0], x[1][..., 1:]]), name='atlas')([ximg, atlas_tensor]) pheno_model = keras.models.Model([pheno_tmp_model.input, atlas_input], atlas_tensor) # stack models inputs = pheno_model.inputs + [mn.inputs[1]] if use_stack: sm = nrn_utils.stack_models([pheno_model, mn], [[0]]) neg_diffflow_out = sm.get_layer('neg_diffflow').get_output_at(-1) diffflow_out = mn.get_layer(smooth_pen_layer).get_output_at(-1) warped_src = sm.get_layer('warped_src').get_output_at(-1) warped_tgt = sm.get_layer('warped_tgt').get_output_at(-1) else: assert bidir assert smooth_pen_layer == 'diffflow' warped_src, warped_tgt, _, diffflow_out, neg_diffflow_out = mn( pheno_model.outputs + [mn.inputs[1]]) sm = keras.models.Model(inputs, [warped_src, warped_tgt]) if do_mean_layer: mean_layer = nrn_layers.MeanStream(name='mean_stream', cap=mean_cap)(neg_diffflow_out) outputs = [warped_src, warped_tgt, mean_layer, diffflow_out] else: outputs = [warped_src, warped_tgt, diffflow_out] model = keras.models.Model(inputs, outputs, name=name) if ret_vm: return model, mn else: return model