示例#1
0
                                    nf_dec,
                                    int_steps=4,
                                    full_size=False,
                                    use_miccai_int=False)

image_sigma = 0.02
prior_lambda = 10
flow_vol_shape = vxm_model.outputs[-1].shape[1:-1]
loss_class = losses.Miccai2018(image_sigma,
                               prior_lambda,
                               flow_vol_shape=flow_vol_shape)

model_losses = [loss_class.recon_loss, loss_class.kl_loss]
loss_weights = [1, 1]

vxm_model = networks.cvpr2018_net(vol_shape, nf_enc, nf_dec, full_size=False)
model_losses = ['mse', losses.Grad('l2').loss]
# usually, we have to balance the two losses by a hyper-parameter.
lambda_param = 0.01
loss_weights = [10, lambda_param]

lr = 1e-4
decay_rate = lr / 10
momentum = 0.8
vxm_model.compile(optimizer=keras.optimizers.SGD(lr=lr,
                                                 momentum=momentum,
                                                 decay=decay_rate,
                                                 nesterov=False),
                  loss=model_losses,
                  loss_weights=loss_weights)
示例#2
0
    def _create_flow_model(self):
        # parse the flow architecture name to create the correct model
        if 'flow_fwd' in self.arch_params['model_arch'] \
                or 'flow_bck' in self.arch_params['model_arch']:
            # train a fwd model only
            nf_enc = [16, 32, 32, 32]
            nf_dec = [32, 32, 32, 32, 32, 16, 16]

            self.transform_model = networks.cvpr2018_net(
                vol_size=(160, 192, 224),
                enc_nf=nf_enc,
                dec_nf=nf_dec,
                indexing='xy'
            )
            self.transform_model.name = self.arch_params['model_arch']

            self.models = [self.transform_model]
        elif 'bidir_separate' in self.arch_params['model_arch']:
            # train a fwd model and back model
            nf_enc = [16, 32, 32, 32]
            nf_dec = [32, 32, 32, 32, 32, 16, 16]

            self.flow_bck_model = networks.cvpr2018_net(
                vol_size=(160, 192, 224),
                enc_nf=nf_enc,
                dec_nf=nf_dec,
                indexing='xy'
            )
            self.flow_bck_model.name = 'vm_bidir_bck_model'
            self.flow_models = [self.flow_bck_model]

            # vm2 model
            self.flow_fwd_model = networks.cvpr2018_net(
                vol_size=(160, 192, 224),
                enc_nf=nf_enc,
                dec_nf=nf_dec,
                indexing='xy'
            )
            self.flow_fwd_model.name = 'vm_bidir_fwd_model'

            self.transform_model = networks.bidir_wrapper(
                img_shape=self.img_shape,
                fwd_model=self.flow_fwd_model,
                bck_model=self.flow_bck_model,
            )

            self.models += [self.flow_fwd_model, self.flow_bck_model, self.transform_model]
        else:
            raise NotImplementedError('Only separate bidirectional spatial transform models are implemented in this version!')

        if 'init_weights_from' in self.arch_params.keys():
            from keras.models import load_model
            # this is not the right indexing, but it doesnt matter since we are only loading conv weights
            init_weights_from_models = [
                load_model(
                    m,
                    custom_objects={
                        'SpatialTransformer': nrn_layers.SpatialTransformer
                    },
                    compile=False
                    ) if m is not None else None for m in self.arch_params['init_weights_from']
            ]

            for mi, m in enumerate(self.models):
                # nothing to load from for this model, skip it
                if mi >= len(init_weights_from_models) or init_weights_from_models[mi] is None:
                    continue

                for li, l in enumerate(m.layers):
                    if li >= len(init_weights_from_models[mi].layers):
                        break

                    # TODO: this assumes matching layer nums, roughly...
                    init_from_layer = init_weights_from_models[mi].layers[li]
                    if 'conv' in l.name.lower()	and 'conv' in init_from_layer.name.lower():
                        our_weights = l.get_weights()
                        init_from_weights = init_from_layer.get_weights()

                        if np.all(our_weights[0].shape == init_from_weights[0].shape):
                            m.layers[li].set_weights(init_from_weights)
                            self.logger.debug('Copying weights from {} layer {} to {} layer {}'.format(
                                init_weights_from_models[mi].name,
                                init_from_layer.name,
                                m.name,
                                l.name))
                        else:
                            self.logger.debug('Unable to copy weights from {} layer {} to {} layer {}, shapes {} and {}'.format(
                                init_weights_from_models[mi].name,
                                init_from_layer.name,
                                m.name,
                                l.name,
                                our_weights[0].shape,
                                init_from_weights[0].shape
                            ))
            #self.flow_fwd_model, self.flow_bck_model = self.models[:2]
            if 'bidir_separate' in self.arch_params['model_arch']:
                # recreate wrapper?
                self.transform_model = networks.bidir_wrapper(
                    img_shape=self.img_shape,
                    fwd_model=self.models[0],
                    bck_model=self.models[1],
                )
                self.models[-1] = self.transform_model