def _create_augmentation_models(self, indexing='xy'): # TODO: put this in a param somewhere self.vol_warp_model = networks.warp_model( img_shape=self.aug_img_shape, interp_mode='linear', indexing=indexing, ) # segmentation warpers that take in a flow field if 'warpoh' in self.arch_params.keys() and self.arch_params['warpoh']: self.seg_warp_model = networks.warp_model( img_shape=self.aug_img_shape[:-1] + (self.n_labels, ), # (1,), interp_mode='linear', indexing=indexing, ) else: self.seg_warp_model = networks.warp_model( img_shape=self.aug_img_shape[:-1] + (1, ), interp_mode='nearest', indexing=indexing) if self.aug_rand: if self.data_params['aug_params'][ 'randflow_type'] == 'ronneberger': self.flow_rand_aug_model = networks.randflow_ronneberger_model( img_shape=self.aug_img_shape, model=None, interp_mode='linear', model_name='randflow_ronneberger_model', flow_sigma=self.data_params['aug_params']['flow_sigma'], blur_sigma=self.data_params['aug_params']['blur_sigma']) self.logger.debug('Random flow Ronneberger augmentation model') else: self.flow_rand_aug_model = networks.randflow_model( img_shape=self.aug_img_shape, model=None, interp_mode='linear', model_name='randflow_model', flow_sigma=self.data_params['aug_params']['flow_sigma'], flow_amp=self.data_params['aug_params']['flow_amp'], blur_sigma=self.data_params['aug_params']['blur_sigma'], indexing=indexing, ) self.logger.debug('Random flow augmentation model') self.flow_rand_aug_model.summary(print_fn=self.logger.debug) if self.aug_tm or self.aug_sas: self.flow_aug_model = load_model( self.arch_params['tm_flow_model'], custom_objects={ 'SpatialTransformer': functools.partial(nrn_layers.SpatialTransformer, indexing=indexing) }, compile=False) if self.arch_params['tm_color_model'] is not None and self.aug_tm: self.flow_bck_aug_model = load_model( self.arch_params['tm_flow_bck_model'], custom_objects={ 'SpatialTransformer': functools.partial(nrn_layers.SpatialTransformer, indexing=indexing) }, compile=False) self.color_aug_model = load_model( self.arch_params['tm_color_model'], custom_objects={ 'SpatialTransformer': functools.partial(nrn_layers.SpatialTransformer, indexing=indexing) }, compile=False) if 'l2-tgt' in self.arch_params['tm_color_model']: # if the color model transforms to the target space by default, create # a wrapper that gets the source space self.color_aug_model = Model( inputs=self.color_aug_model.inputs[:-1], outputs=[ self.color_aug_model.outputs[0], self.color_aug_model.get_layer( 'add_color_delta').output, self.color_aug_model.outputs[2] ], name='color_model_wrapper')
def eval_seg_sas_from_gen(sas_model, atlas_vol, atlas_labels, eval_gen, label_mapping, n_eval_examples, batch_size, logger=None): ''' Evaluates a single-atlas segmentation method on a bunch of evaluation volumes. :param sas_model: spatial transform model used for SAS. Can be voxelmorph. :param atlas_vol: atlas volume :param atlas_labels: atlas segmentations :param eval_gen: generator that yields vols_valid, segs_valid batches :param label_mapping: list of label ids that will appear in segs, ordered by how they map to channels :param n_eval_examples: total number of examples to evaluate :param batch_size: batch size to use in evaluation :param logger: python logger if we want to log messages :return: ''' img_shape = atlas_vol.shape[1:] seg_warp_model = networks.warp_model( img_shape=img_shape, interp_mode='nearest', indexing='xy', ) from keras.models import Model from keras.layers import Input, Activation from keras.optimizers import Adam n_labels = len(label_mapping) warped_in = Input(img_shape[0:-1] + (n_labels, )) warped = Activation('softmax')(warped_in) ce_model = Model(inputs=[warped_in], outputs=[warped], name='ce_model') ce_model.compile(loss='categorical_crossentropy', optimizer=Adam(0.0001)) # test metrics: categorical cross-entropy and dice dice_per_label = np.zeros((n_eval_examples, len(label_mapping))) cces = np.zeros((n_eval_examples, )) accs = np.zeros((n_eval_examples, )) all_ids = [] for bi in range(n_eval_examples): if logger is not None: logger.debug('Testing on subject {} of {}'.format( bi, n_eval_examples)) else: print('Testing on subject {} of {}'.format(bi, n_eval_examples)) X, Y, _, ids = next(eval_gen) Y_oh = labels_to_onehot(Y, label_mapping=label_mapping) warped, warp = sas_model.predict([atlas_vol, X]) # warp our source models according to the predicted flow field. get rid of channels if Y.shape[-1] == 1: Y = Y[..., 0] preds_batch = seg_warp_model.predict( [atlas_labels[..., np.newaxis], warp])[..., 0] preds_oh = labels_to_onehot(preds_batch, label_mapping=label_mapping) cce = np.mean(ce_model.evaluate(preds_oh, Y_oh, verbose=False)) subject_dice_per_label = medipy_metrics.dice(Y, preds_batch, labels=label_mapping) nonbkgmap = (Y > 0) acc = np.sum(((Y == preds_batch) * nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float) print(acc) dice_per_label[bi] = subject_dice_per_label cces[bi] = cce accs[bi] = acc all_ids += ids if logger is not None: logger.debug('Dice per label: {}, {}'.format(label_mapping, dice_per_label)) logger.debug('Mean dice (no bkg): {}'.format( np.mean(dice_per_label[:, 1:]))) logger.debug('Mean CE: {}'.format(np.mean(cces))) logger.debug('Mean accuracy: {}'.format(np.mean(accs))) else: print('Dice per label: {}, {}'.format(label_mapping, dice_per_label)) print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:]))) print('Mean CE: {}'.format(np.mean(cces))) print('Mean accuracy: {}'.format(np.mean(accs))) return cces, dice_per_label, accs, all_ids