Пример #1
0
sys.path.append('../neuron')
import neuron.layers as nrn_layers

sys.path.append('../voxelmorph')
import src.losses as vm_losses

# load validation or test dataset
do_final_test = True
ds_key = 'mri-csts2-test'
label_mapping = main.voxelmorph_labels

eval_data_params = main.named_data_params[ds_key]
eval_data_params['load_vols'] = True

eval_ds = mri_loader.MRIDataset(eval_data_params)
_ = eval_ds.load_dataset()

for f in eval_ds.files_labeled_valid:
    print(f)

#gpu_ids = [3]
# set gpu id and tf settings
#os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in gpu_ids])
#config = tf.ConfigProto()
#config.gpu_options.allow_growth = True
#K.tensorflow_backend.set_session(tf.Session(config=config))

# load trained segmenters
model_files = [
    ### PUT YOUR TRAINED .h5 MODEL FILES HERE ####
Пример #2
0
    def __init__(self,
                 data_params,
                 arch_params,
                 debug=False,
                 prompt_delete=True):
        self.logger = None

        self.arch_params = arch_params
        self.data_params = data_params

        # i.e. are we segmenting/augmenting slices or volumes
        self.n_seg_dims = arch_params['n_seg_dims']
        self.n_aug_dims = arch_params['n_aug_dims']

        self.pred_img_shape = data_params['img_shape'][:self.n_seg_dims] + (
            1, )
        self.aug_img_shape = data_params['img_shape'][:self.n_aug_dims] + (1, )

        self.display_slice_idx = 112

        self.logger = None
        self.profiler_logger = None
        self.do_profile = True
        self.profiled_iters = 0

        self.epoch_count = 0
        self.batch_count = 0
        if 'pretrain_l2' in self.arch_params.keys():
            self.loss_fn = keras_metrics.mean_squared_error
            self.loss_name = 'l2'
        else:
            self.loss_fn = keras_metrics.categorical_crossentropy
            self.loss_name = 'CE'

        # warp the onehot representation of labels instead of labels themselves
        if 'warpoh' not in self.arch_params.keys():
            self.arch_params['warpoh'] = False

        self.n_aug = None  # do augmentation through the generator by default

        self.aug_tm = False  # flow and appearance transform models
        self.aug_sas = False  # flow model for single-atlas segmentation
        self.aug_rand = False  # random flow and multiplicative intensity

        if 'aug_rand' in data_params.keys() and data_params['aug_rand']:
            self.aug_rand = True
            if 'n_flow_aug' in data_params.keys(
            ) and data_params['n_flow_aug'] is not None:
                self.n_aug = data_params['n_flow_aug']
        elif 'aug_randmult' in data_params.keys(
        ) and data_params['aug_randmult']:
            self.aug_randmult = True

        if data_params['aug_tm']:
            self.aug_tm = True
            self.aug_sas = False
            if 'n_tm_aug' in data_params.keys(
            ) and data_params['n_tm_aug'] is not None:
                self.n_aug = data_params['n_tm_aug']
        elif data_params['aug_sas']:
            self.aug_sas = True
            self.aug_tm = False
            if 'n_sas_aug' in data_params.keys(
            ) and data_params['n_sas_aug'] is not None:
                self.n_aug = data_params['n_sas_aug']

        # come up with a short name for our flow and color models so we can put them in this model name
        if self.arch_params['tm_flow_model'] is not None and self.arch_params['tm_color_model'] is not None \
                and self.aug_tm:
            if 'epoch' in self.arch_params['tm_flow_model']:
                flow_epoch = re.search(
                    '(?<=_epoch)[0-9]*',
                    self.arch_params['tm_flow_model']).group(0)
            else:
                flow_epoch = int(
                    int(
                        re.search('(?<=_iter)[0-9]*',
                                  self.arch_params['tm_flow_model']).group(0))
                    / 100)
            color_epoch = re.search(
                '(?<=_epoch)[0-9]*',
                self.arch_params['tm_color_model']).group(0)

            # only include the color model in the name if we are doing both flow and color aug
            self.aug_model_name = 'tmflow-e{}-colore{}'.format(
                flow_epoch, color_epoch)
        elif self.arch_params['tm_flow_model'] is not None:
            self.aug_model_name = 'tmflow-{}'.format(
                os.path.basename(
                    self.arch_params['tm_flow_model'].split('/models/')[0]))

        # do augmentation through generator, or pre-augment training set
        if 'aug_in_gen' not in data_params.keys():
            self.data_params['aug_in_gen'] = False
        if self.data_params['aug_in_gen']:
            self.n_aug = None

        # let dataset loader figure out short name
        self.data_params['n_dims'] = self.n_aug_dims
        self.dataset = mri_loader.MRIDataset(self.data_params, self.logger)
        self.dataset_name = self.dataset.create_display_name()

        # automatic early stopping based on validation loss
        if 'patience' in arch_params.keys():
            validation_losses_buff_len = arch_params['patience']
        else:
            validation_losses_buff_len = 10
        super(SegmenterTrainer,
              self).__init__(data_params,
                             arch_params,
                             prompt_delete_existing=prompt_delete)

        self.validation_losses_buffer = [np.nan] * validation_losses_buff_len

        # keep track of all ids the network sees as a sanity check
        self.all_ul_ids = []
        self.all_train_ids = []
Пример #3
0
    def __init__(self, data_params, arch_params):
        self.data_params = data_params
        self.arch_params = arch_params

        # if we are profiling our model, only do it for a few iterations
        # since there is some overhead that will slow us down
        self.do_profile = True
        self.profiled_iters = 0

        self.epoch_count = 0

        self.img_shape = data_params['img_shape']
        self.n_chans = data_params['img_shape'][-1]
        self.n_dims = len(self.img_shape) - 1

        # name our source domain according to our dataset parameters
        self.logger = None

        # initialize our dataset
        self.dataset = mri_loader.MRIDataset(self.data_params, self.logger)

        if 'use_aux_reg' not in arch_params.keys():
            self.arch_params['use_aux_reg'] = None

        # enc/dec architecture
        # parse params for flow portion of network
        if 'flow' in self.arch_params['model_arch']:
            self.transform_reg_name = self.arch_params['transform_reg_flow']

            if 'grad_l2' in self.transform_reg_name:
                self.transform_reg_fn = my_metrics.gradient_loss_l2(n_dims=self.n_dims)
                self.transform_reg_wt = self.arch_params['transform_reg_lambda_flow']
            else:
                self.transform_reg_fn = None
                self.transform_reg_wt = 0.

            self.recon_loss_name = self.arch_params['recon_loss_Iw']
            if self.recon_loss_name is None:  # still have this output node, but don't weight it
                self.recon_loss_fn = keras_metrics.mean_squared_error
                self.recon_loss_wt = 0
            elif 'cc_vm' in self.recon_loss_name:
                self.cc_loss_weight = self.arch_params['cc_loss_weight']
                self.cc_win_size_Iw = self.arch_params['cc_win_size_Iw']
                self.recon_loss_fn = my_metrics.NCC().loss
                self.recon_loss_wt = self.cc_loss_weight

        # parse params for color portion of network
        if 'color' in self.arch_params['model_arch']:
            self.recon_loss_name = self.arch_params['recon_loss_I']
            self.transform_reg_name = self.arch_params['transform_reg_color']

            if 'seg-l2' in self.transform_reg_name:
                self.transform_reg_wt = self.arch_params['transform_reg_lambda_color']
                self.transform_reg_fn = utils.SpatialSegmentSmoothness(
                    n_dims=self.n_dims,
                    n_chans=self.n_chans,
                ).compute_loss
            else:
                self.transform_reg_fn = None
                self.transform_reg_wt = 0.

            if self.recon_loss_name is None:  # still have this output node, but don't weight it
                self.recon_loss_fn = keras_metrics.mean_squared_error
                self.recon_loss_wt = 0
            elif 'l2' in self.recon_loss_name:
                self.recon_loss_fn = keras_metrics.mean_squared_error

                # set a constant weight for reconstruction
                self.recon_loss_wt = self.arch_params['recon_loss_wt']

        if 'latest_epoch' in arch_params.keys():
            self.latest_epoch = arch_params['latest_epoch']
        else:
            self.latest_epoch = 0

        super(TransformModelTrainer, self).__init__(
            data_params=self.data_params, arch_params=self.arch_params,
            prompt_delete_existing=True, prompt_update_name=True)