def _generate_slices_batchs_from_vols(self, vols, segs, ids, vol_gen=None, randomize=False, batch_size=16, vol_batch_size=1, convert_onehot=False, ): if vol_gen is None: vol_gen = batch_utils.gen_batch(vols, [segs, ids], randomize=randomize, batch_size=vol_batch_size) while True: X, Y, ids_batch = next(vol_gen) n_slices = X.shape[-2] slice_idxs = np.random.choice(n_slices, self.batch_size, replace=True) X_slices = np.reshape( np.transpose(X[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape) if Y is not None and self.arch_params['warpoh']: # we used the onehot warper in these cases Y = np.reshape(np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape[:-1] + (self.n_labels,)) else: Y = np.reshape(np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape[:-1] + (1,)) if Y is not None and convert_onehot: Y = classification_utils.labels_to_onehot(Y, label_mapping=self.label_mapping) yield X_slices, Y, ids_batch
def gen_slices_batch(self, files_list, batch_size=1, randomize=True, convert_onehot=False, label_mapping=None): if randomize: np.random.shuffle(files_list) n_files = len(files_list) slice_idxs = np.linspace(0, batch_size, batch_size, endpoint=False, dtype=int) file_idx = 0 while True: start = time.time() X, Y = _load_vol_and_seg( files_list[file_idx], do_mask_vol=True, ) # TODO: add scaling here since we no longer do it per example if self.profiler_logger is not None: self.profiler_logger.info('Loading vol took {}'.format(time.time() - start)) start = time.time() # slice the z-axis by default n_slices = X.shape[-2] if randomize: slice_idxs = np.random.choice(n_slices, batch_size, replace=True) else: # if we are going through the dataset sequentially, make the last batch smaller slice_idxs += batch_size slice_idxs[slice_idxs > n_slices - 1] = [] file_idx += 1 X = np.transpose(X[:, :, slice_idxs], (2, 0, 1, 3)) Y = np.transpose(Y[:, :, slice_idxs], (2, 0, 1)) if self.profiler_logger is not None: self.profiler_logger.info('Slicing took {}'.format(time.time() - start)) if randomize: # randomize the shuffled files every batch file_idx += 1 if file_idx > n_files - 1: file_idx = 0 elif not randomize and max(slice_idxs) >= n_slices - 1: # if we reached the end of teh slices of the previous file, start the slices again slice_idxs = np.linspace(0, batch_size, batch_size, endpoint=False, dtype=int) if convert_onehot: start = time.time() Y = classification_utils.labels_to_onehot(Y, label_mapping=label_mapping) if self.profiler_logger is not None: self.profiler_logger.info('Converting slices onehot took {}'.format(time.time() - start)) yield X, Y
def _generate_augmented_batch( self, source_gen, use_single_atlas=False, aug_by=None, return_transforms=False, convert_onehot=False, do_slice_z=False, ): if use_single_atlas: # single atlas # single atlas, dont bother with generator self.logger.debug('Single atlas, not using source generator for augmenter') source_X, source_segs, source_contours, source_ids = next(source_gen) else: self.logger.debug('Multiple atlases, sampling source vols from generator') while True: if not use_single_atlas: # randomly select an atlas from our source volume generator source_X, source_segs, source_contours, source_ids = next(source_gen) # keep track of which unlabeled subjects we are using in training ul_ids = [] if len(aug_by) > 1: # multiple augmentation schemes. let's flip a coin do_aug_by_idx = np.random.rand(1)[0] * float(len(aug_by) - 1) aug_batch_by = aug_by[int(round(do_aug_by_idx))] else: aug_batch_by = aug_by[0] color_delta = None start = time.time() if aug_batch_by == 'rand': X_aug, flow = self.flow_rand_aug_model.predict(source_X) # color augmentation by additive or multiplicative factor. randomize the factor and then tile to correct shape if 'offset_amp' in self.data_params['aug_params']: X_aug += np.tile( (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params'][ 'offset_amp'], (1,) + X_aug.shape[1:]) X_aug = np.clip(X_aug, 0., 1.) if 'mult_amp' in self.data_params['aug_params']: X_aug *= np.tile( 1. + (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params'][ 'mult_amp'], (1,) + X_aug.shape[1:]) X_aug = np.clip(X_aug, 0., 1.) # Y_to_aug = classification_utils.labels_to_onehot(Y_to_aug, label_mapping=self.label_mapping) Y_aug = self.seg_warp_model.predict([source_segs, flow]) elif aug_batch_by == 'tm': if self.arch_params['do_coupled_sampling']: # use the same target for flow and color X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw) X_colortgt = X_flowtgt ul_ids += ul_flow_ids else: X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw) X_colortgt, _, _, ul_color_ids = next(self.unlabeled_gen_raw) ul_ids += ul_flow_ids + ul_color_ids if self.do_profile: self.profiler_logger.info('Sampling aug tgt took {}'.format(time.time() - st)) self.aug_target = X_flowtgt #if 'l2-tgt' in self.arch_params['tm_color_model']: X_colortgt_src, _ = self.flow_bck_aug_model.predict([X_colortgt, source_X]) color_delta, colored_vol, _ = self.color_aug_model.predict([source_X, X_colortgt_src, source_contours]) self.aug_colored = colored_vol _, flow = self.flow_aug_model.predict([source_X, X_flowtgt]) X_aug = self.vol_warp_model.predict([colored_vol, flow]) if self.do_profile: self.profiler_logger.info('Running color and flow aug took {}'.format(time.time() - st)) st = time.time() # now warp the labels to match Y_aug = self.seg_warp_model.predict([source_segs, flow]) if self.do_profile: self.profiler_logger.info('Warping labels took {}'.format(time.time() - st)) else: # no aug X_aug = source_X Y_aug = source_segs if self.do_profile: self.profiler_logger.info('Augmenting input batch took {}'.format(time.time() - start)) if do_slice_z: # get a random slice in the z dimension start = time.time() n_total_slices = source_X.shape[-2] # always take random slices in the z dimension slice_idxs = np.random.choice(n_total_slices, self.batch_size, replace=True) X_to_aug_slices = np.reshape( np.transpose( # roll z-slices into batch source_X[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + tuple(self.pred_img_shape)) X_aug = np.reshape( np.transpose( X_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape) if self.aug_target is not None: self.aug_target = np.reshape( np.transpose(self.aug_target[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape) if self.aug_colored is not None: # not relevant if we arent using a color transform model self.aug_colored = np.reshape( np.transpose(self.aug_colored[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape) if (aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']: # we used the onehot warper in these cases Y_to_aug_slices = np.reshape( np.transpose( source_segs[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_segs_oh_shape) Y_aug = np.reshape(np.transpose(Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_segs_oh_shape) else: Y_to_aug_slices = np.reshape(np.transpose( source_segs[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_segs_shape) Y_aug = np.reshape(np.transpose( Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1,) + self.pred_segs_shape) if self.do_profile: self.profiler_logger.info('Slicing batch took {}'.format(time.time() - start)) if return_transforms: st = time.time() flow = np.reshape( np.transpose( flow[:, :, :, slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape[:-1] + (self.n_aug_dims,)) if self.do_profile: self.profiler_logger.info('Slicing flow took {}'.format(time.time() - st)) if color_delta is not None: color_delta = np.reshape( np.transpose( color_delta[:, :, :, slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape) else: # no slicing, no aug? X_to_aug_slices = source_X Y_to_aug_slices = source_segs if convert_onehot and not ((aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']): # if we don't have onehot segs already, convert them after slicing start = time.time() Y_to_aug_slices = classification_utils.labels_to_onehot( Y_to_aug_slices, label_mapping=self.label_mapping) Y_aug = classification_utils.labels_to_onehot( Y_aug, label_mapping=self.label_mapping) if self.do_profile: self.profiler_logger.info('Converting onehot took {}'.format(time.time() - start)) if return_transforms: yield X_to_aug_slices, Y_to_aug_slices, flow, color_delta, X_aug, Y_aug, ul_ids else: yield X_to_aug_slices, Y_to_aug_slices, X_aug, Y_aug, ul_ids
def create_generators(self, batch_size): self.batch_size = batch_size # generator for labeled training examples that we wish to augment self.source_gen = self.dataset.gen_vols_batch( dataset_splits=['labeled_train'], batch_size=1, load_segs=True, # always load segmentations since this is our labeled set load_contours=self.aug_tm, # only load contours if we need them as aux data for our appearance model randomize=True, return_ids=True, ) # actually more like a target generator self.unlabeled_gen_raw = self.dataset.gen_vols_batch( dataset_splits=['labeled_train', 'unlabeled_train'], batch_size=1, load_segs=False, load_contours=False, randomize=True, return_ids=True, ) self._create_augmentation_models() # simply append augmented examples to training set # NOTE: we can only do this if we are not synthesizing that many examples (i.e. <= 1000) self.aug_train_gen = None if self.n_aug is not None: self._create_augmented_examples() self.logger.debug('Augmented classifier training set: vols {}, segs {}'.format( self.X_labeled_train.shape, self.segs_labeled_train.shape)) elif self.aug_tm or self.aug_rand: # augmentation by transform model or random flow+intensity requires synthesizing # a lot of examples, so we'll just do this per-batch aug_by = [] if self.aug_tm: aug_by += ['tm'] if self.aug_rand: aug_by += ['rand'] # these need to be done in the generator self.aug_train_gen = self._generate_augmented_batch( source_gen=self.source_gen, use_single_atlas=self.X_labeled_train.shape[0] == 1, aug_by=aug_by, convert_onehot=True, return_transforms=True, do_slice_z=(self.n_pred_dims == 2) # if we are predicting on slices, then get slices ) # generates slices from the training volumes self.train_gen = self._generate_slices_batchs_from_vols( self.X_labeled_train, self.segs_labeled_train, self.ids_labeled_train, vol_gen=None, convert_onehot=True, batch_size=self.batch_size, randomize=True ) # load each subject, then evaluate each slice self.eval_valid_gen = self.dataset.gen_vols_batch( ['labeled_valid'], batch_size=1, randomize=False, convert_onehot=False, load_segs=True, label_mapping=self.label_mapping, return_ids=True, ) # we will compute validation losses on all validation volumes # just pick some random slices to display for the validation set later rand_subjs = np.random.choice(self.X_labeled_valid.shape[0], batch_size) rand_slices = np.random.choice(self.aug_img_shape[2], batch_size, replace=False) self.X_valid_batch = np.zeros((batch_size,) + self.pred_img_shape) self.Y_valid_batch = np.zeros((batch_size,) + self.pred_segs_shape) for ei in range(batch_size): self.X_valid_batch[ei] = self.X_labeled_valid[rand_subjs[ei], :, :, rand_slices[ei]] self.Y_valid_batch[ei] = self.segs_labeled_valid[rand_subjs[ei], :, :, rand_slices[ei]] self.Y_valid_batch = classification_utils.labels_to_onehot(self.Y_valid_batch, label_mapping=self.label_mapping)
def gen_batch(ims_data, labels_data, batch_size, randomize=False, pad_or_crop_to_size=None, normalize_tanh=False, convert_onehot=False, labels_to_onehot_mapping=None, aug_model=None, aug_params=None, yield_aug_params=False, yield_idxs=False, random_seed=None): ''' :param ims_data: list of images, or an image. If a single image, it will be automatically converted to a list :param labels_data: list of other data (e.g. labels) that do not require image normalization or augmentation, but might need to be converted to onehot :param batch_size: :param randomize: bool to randomize indices per batch :param pad_or_crop_to_size: pad or crop each image in ims_data to the specified size. Default pad value is 0 :param normalize_tanh: normalize image to range [-1, 1], good for synthesis with a tanh activation :param aug_params: :param convert_onehot: convert labels to a onehot representation using the mapping below :param labels_to_onehot_mapping: list of labels e.g. [0, 3, 5] indicating the mapping of label values to channel indices :param yield_aug_params: include the random augmentation params used on the batch in the return values :param yield_idxs: include the indices that comprise this batch in the return values :param random_seed: :return: ''' if random_seed: np.random.seed(random_seed) # make sure everything is a list if not isinstance(ims_data, list): ims_data = [ims_data] if not isinstance(normalize_tanh, list): normalize_tanh = [normalize_tanh] * len(ims_data) else: assert len(normalize_tanh) == len(ims_data) if aug_params is not None: if not isinstance(aug_params, list): aug_params = [aug_params] * len(ims_data) else: assert len(aug_params) == len(ims_data) out_aug_params = aug_params[:] if pad_or_crop_to_size is not None: if not isinstance(pad_or_crop_to_size, list): pad_or_crop_to_size = [pad_or_crop_to_size] * len(ims_data) else: assert len(pad_or_crop_to_size) == len(ims_data) # if we have labels that we want to generate from, # put everything into a list for consistency # (useful if we have labels and aux data) if labels_data is not None: if not isinstance(labels_data, list): labels_data = [labels_data] # each entry should correspond to an entry in labels_data if not isinstance(convert_onehot, list): convert_onehot = [convert_onehot] * len(labels_data) else: assert len(convert_onehot) == len(labels_data) idxs = [-1] n_ims = ims_data[0].shape[0] h = ims_data[0].shape[1] w = ims_data[0].shape[2] if pad_or_crop_to_size is not None: # pad each image and then re-concatenate ims_data = [np.concatenate([ image_utils.pad_or_crop_to_shape(x, pad_or_crop_to_size)[np.newaxis] for x in im_data], axis=0) for im_data in ims_data] while True: if randomize: idxs = np.random.choice(n_ims, batch_size, replace=True) else: idxs = np.linspace(idxs[-1] + 1, idxs[-1] + 1 + batch_size - 1, batch_size, dtype=int) restart_idxs = False while np.any(idxs >= n_ims): idxs[np.where(idxs >= n_ims)] = idxs[np.where(idxs >= n_ims)] - n_ims restart_idxs = True ims_batches = [] for i, im_data in enumerate(ims_data): X_batch = im_data[idxs] if not X_batch.dtype == np.float32 and not X_batch.dtype == np.float64: X_batch = X_batch.astype(np.float32) / 255. if normalize_tanh[i]: X_batch = image_utils.normalize(X_batch) if aug_params is not None and aug_params[i] is not None: if aug_model is not None: # use the gpu aug model instead T, _ = aug_utils.aug_params_to_transform_matrices( batch_size=X_batch.shape[0], add_last_row=True, **aug_params[i] ) X_batch = aug_model.predict([X_batch, T]) out_aug_params[i] = T else: X_batch, out_aug_params[i] = aug_utils.aug_im_batch(X_batch, **aug_params[i]) ims_batches.append(X_batch) if labels_data is not None: labels_batches = [] for li, Y in enumerate(labels_data): if Y is None: Y_batch = None else: if convert_onehot[li]: Y_batch = classification_utils.labels_to_onehot( Y[idxs], label_mapping=labels_to_onehot_mapping) else: if isinstance(Y, np.ndarray): Y_batch = Y[idxs] else: # in case it's a list Y_batch = [Y[idx] for idx in idxs] labels_batches.append(Y_batch) else: labels_batches = None if not randomize and restart_idxs: idxs[-1] = -1 if yield_aug_params and yield_idxs: yield tuple(ims_batches) + tuple(labels_batches) + (out_aug_params, idxs) elif yield_aug_params: yield tuple(ims_batches) + tuple(labels_batches) + (out_aug_params, ) elif yield_idxs: yield tuple(ims_batches) + tuple(labels_batches) + (idxs, ) else: yield tuple(ims_batches) + tuple(labels_batches)
def gen_batch(ims_data, labels_data, batch_size, randomize=False, pad_or_crop_to_size=None, normalize_tanh=False, convert_onehot=False, labels_to_onehot_mapping=None, aug_model=None, aug_params=None, yield_aug_params=False, yield_idxs=False, random_seed=None): ''' :param ims_data: list of images, or an image. If a single image, it will be automatically converted to a list :param labels_data: list of other data (e.g. labels) that do not require image normalization or augmentation, but might need to be converted to onehot :param batch_size: :param randomize: bool to randomize indices per batch :param pad_or_crop_to_size: pad or crop each image in ims_data to the specified size. Default pad value is 0 :param normalize_tanh: normalize image to range [-1, 1], good for synthesis with a tanh activation :param aug_params: :param convert_onehot: convert labels to a onehot representation using the mapping below :param labels_to_onehot_mapping: list of labels e.g. [0, 3, 5] indicating the mapping of label values to channel indices :param yield_aug_params: include the random augmentation params used on the batch in the return values :param yield_idxs: include the indices that comprise this batch in the return values :param random_seed: :return: ''' if random_seed: np.random.seed(random_seed) # make sure everything is a list if not isinstance(ims_data, list): ims_data = [ims_data] if not isinstance(normalize_tanh, list): normalize_tanh = [normalize_tanh] * len(ims_data) else: assert len(normalize_tanh) == len(ims_data) if aug_params is not None: if not isinstance(aug_params, list): aug_params = [aug_params] * len(ims_data) else: assert len(aug_params) == len(ims_data) out_aug_params = aug_params[:] if pad_or_crop_to_size is not None: if not isinstance(pad_or_crop_to_size, list): pad_or_crop_to_size = [pad_or_crop_to_size] * len(ims_data) else: assert len(pad_or_crop_to_size) == len(ims_data) # if we have labels that we want to generate from, # put everything into a list for consistency # (useful if we have labels and aux data) if labels_data is not None: if not isinstance(labels_data, list): labels_data = [labels_data] # each entry should correspond to an entry in labels_data if not isinstance(convert_onehot, list): convert_onehot = [convert_onehot] * len(labels_data) else: assert len(convert_onehot) == len(labels_data) idxs = [-1] n_ims = ims_data[0].shape[0] h = ims_data[0].shape[1] w = ims_data[0].shape[2] if pad_or_crop_to_size is not None: # pad each image and then re-concatenate ims_data = [ np.concatenate([ image_utils.pad_or_crop_to_shape( x, pad_or_crop_to_size)[np.newaxis] for x in im_data ], axis=0) for im_data in ims_data ] while True: if randomize: idxs = np.random.choice(n_ims, batch_size, replace=True) else: idxs = np.linspace(idxs[-1] + 1, idxs[-1] + 1 + batch_size - 1, batch_size, dtype=int) restart_idxs = False while np.any(idxs >= n_ims): idxs[np.where( idxs >= n_ims)] = idxs[np.where(idxs >= n_ims)] - n_ims restart_idxs = True ims_batches = [] for i, im_data in enumerate(ims_data): X_batch = im_data[idxs] if not X_batch.dtype == np.float32 and not X_batch.dtype == np.float64: X_batch = X_batch.astype(np.float32) / 255. if normalize_tanh[i]: X_batch = image_utils.normalize(X_batch) if aug_params is not None and aug_params[i] is not None: if aug_model is not None: # use the gpu aug model instead T, _ = aug_utils.aug_params_to_transform_matrices( batch_size=X_batch.shape[0], add_last_row=True, **aug_params[i]) X_batch = aug_model.predict([X_batch, T]) out_aug_params[i] = T else: X_batch, out_aug_params[i] = aug_utils.aug_im_batch( X_batch, **aug_params[i]) ims_batches.append(X_batch) if labels_data is not None: labels_batches = [] for li, Y in enumerate(labels_data): if Y is None: Y_batch = None else: if convert_onehot[li]: Y_batch = classification_utils.labels_to_onehot( Y[idxs], label_mapping=labels_to_onehot_mapping) else: if isinstance(Y, np.ndarray): Y_batch = Y[idxs] else: # in case it's a list Y_batch = [Y[idx] for idx in idxs] labels_batches.append(Y_batch) else: labels_batches = None if not randomize and restart_idxs: idxs[-1] = -1 if yield_aug_params and yield_idxs: yield tuple(ims_batches) + tuple(labels_batches) + (out_aug_params, idxs) elif yield_aug_params: yield tuple(ims_batches) + tuple(labels_batches) + ( out_aug_params, ) elif yield_idxs: yield tuple(ims_batches) + tuple(labels_batches) + (idxs, ) elif labels_data is not None: yield tuple(ims_batches) + tuple(labels_batches) else: yield tuple(ims_batches) + (None, )
def eval_seg_sas_from_gen(sas_model, atlas_vol, atlas_labels, eval_gen, label_mapping, n_eval_examples, batch_size, logger=None): ''' Evaluates a single-atlas segmentation method on a bunch of evaluation volumes. :param sas_model: spatial transform model used for SAS. Can be voxelmorph. :param atlas_vol: atlas volume :param atlas_labels: atlas segmentations :param eval_gen: generator that yields vols_valid, segs_valid batches :param label_mapping: list of label ids that will appear in segs, ordered by how they map to channels :param n_eval_examples: total number of examples to evaluate :param batch_size: batch size to use in evaluation :param logger: python logger if we want to log messages :return: ''' img_shape = atlas_vol.shape[1:] seg_warp_model = networks.warp_model( img_shape=img_shape, interp_mode='nearest', indexing='xy', ) from keras.models import Model from keras.layers import Input, Activation from keras.optimizers import Adam n_labels = len(label_mapping) warped_in = Input(img_shape[0:-1] + (n_labels, )) warped = Activation('softmax')(warped_in) ce_model = Model(inputs=[warped_in], outputs=[warped], name='ce_model') ce_model.compile(loss='categorical_crossentropy', optimizer=Adam(0.0001)) # test metrics: categorical cross-entropy and dice dice_per_label = np.zeros((n_eval_examples, len(label_mapping))) cces = np.zeros((n_eval_examples, )) accs = np.zeros((n_eval_examples, )) all_ids = [] for bi in range(n_eval_examples): if logger is not None: logger.debug('Testing on subject {} of {}'.format( bi, n_eval_examples)) else: print('Testing on subject {} of {}'.format(bi, n_eval_examples)) X, Y, _, ids = next(eval_gen) Y_oh = classification_utils.labels_to_onehot( Y, label_mapping=label_mapping) warped, warp = sas_model.predict([atlas_vol, X]) # warp our source models according to the predicted flow field. get rid of channels if Y.shape[-1] == 1: Y = Y[..., 0] preds_batch = seg_warp_model.predict( [atlas_labels[..., np.newaxis], warp])[..., 0] preds_oh = classification_utils.labels_to_onehot( preds_batch, label_mapping=label_mapping) cce = np.mean(ce_model.evaluate(preds_oh, Y_oh, verbose=False)) subject_dice_per_label = medipy_metrics.dice(Y, preds_batch, labels=label_mapping) nonbkgmap = (Y > 0) acc = np.sum(((Y == preds_batch) * nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float) print(acc) dice_per_label[bi] = subject_dice_per_label cces[bi] = cce accs[bi] = acc all_ids += ids if logger is not None: logger.debug('Dice per label: {}, {}'.format(label_mapping, dice_per_label)) logger.debug('Mean dice (no bkg): {}'.format( np.mean(dice_per_label[:, 1:]))) logger.debug('Mean CE: {}'.format(np.mean(cces))) logger.debug('Mean accuracy: {}'.format(np.mean(accs))) else: print('Dice per label: {}, {}'.format(label_mapping, dice_per_label)) print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:]))) print('Mean CE: {}'.format(np.mean(cces))) print('Mean accuracy: {}'.format(np.mean(accs))) return cces, dice_per_label, accs, all_ids
def eval_seg_from_gen(segmenter_model, eval_gen, label_mapping, n_eval_examples, batch_size, logger=None): ''' Evaluates accuracy of a segmentation CNN :param segmenter_model: keras model for segmenter :param eval_gen: genrator that yields vols_valid, segs_valid :param label_mapping: list of label ids, ordered by how they map to channels :param n_eval_examples: total number of volumes to evaluate :param batch_size: batch size (number of slices per batch) :param logger: python logger (optional) :return: ''' # test metrics: categorical cross-entropy and dice dice_per_label = np.zeros((n_eval_examples, len(label_mapping))) cces = np.zeros((n_eval_examples, )) accs = np.zeros((n_eval_examples, )) all_ids = [] for bi in range(n_eval_examples): if logger is not None: logger.debug('Testing on subject {} of {}'.format( bi, n_eval_examples)) else: print('Testing on subject {} of {}'.format(bi, n_eval_examples)) X, Y, _, ids = next(eval_gen) Y_oh = classification_utils.labels_to_onehot( Y, label_mapping=label_mapping) preds_batch, cce = segment_vol_by_slice( segmenter_model, X, label_mapping=label_mapping, batch_size=batch_size, Y_oh=Y_oh, compute_cce=True, ) subject_dice_per_label = medipy_metrics.dice(Y, preds_batch, labels=label_mapping) # only consider pixels where the gt label is not bkg (if we count bkg, accuracy will be very high) nonbkgmap = (Y > 0) acc = np.sum(((Y == preds_batch) * nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float) print(acc) dice_per_label[bi] = subject_dice_per_label cces[bi] = cce accs[bi] = acc all_ids += ids if logger is not None: logger.debug('Dice per label: {}, {}'.format( label_mapping, np.mean(dice_per_label, axis=0).tolist())) logger.debug('Mean dice (no bkg): {}'.format( np.mean(dice_per_label[:, 1:]))) logger.debug('Mean CE: {}'.format(np.mean(cces))) logger.debug('Mean accuracy: {}'.format(np.mean(accs))) else: print('Dice per label: {}, {}'.format( label_mapping, np.mean(dice_per_label, axis=0).tolist())) print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:]))) print('Mean CE: {}'.format(np.mean(cces))) print('Mean accuracy: {}'.format(np.mean(accs))) return cces, dice_per_label, accs, all_ids
def gen_vols_batch(self, dataset_splits=['labeled_train'], batch_size=1, randomize=True, load_segs=False, load_contours=False, convert_onehot=False, label_mapping=None, return_ids=False): if not isinstance(dataset_splits, list): dataset_splits = [dataset_splits] X_all = [] Y_all = [] contours_all = [] files_list = [] for ds in dataset_splits: if ds == 'labeled_train': X_all.append(self.vols_labeled_train) Y_all.append(self.segs_labeled_train) contours_all.append(self.contours_labeled_train) files_list += self.files_labeled_train elif ds == 'labeled_valid': X_all.append(self.vols_labeled_valid) Y_all.append(self.segs_labeled_valid) contours_all.append(self.contours_labeled_valid) files_list += self.files_labeled_valid elif ds == 'unlabeled_train': X_all.append(self.vols_unlabeled_train) Y_all.append(self.segs_unlabeled_train) contours_all.append(self.contours_unlabeled_train) files_list += self.files_unlabeled_train elif ds == 'labeled_test': if self.logger is not None: self.logger.debug('LOOKING FOR FINAL TEST SET') else: print('LOOKING FOR FINAL TEST SET') X_all.append(self.vols_labeled_test) Y_all.append(self.segs_labeled_test) contours_all.append(self.contours_labeled_test) files_list += self.files_labeled_test n_files = len(files_list) n_loaded_vols = np.sum([x.shape[0] for x in X_all]) # if all of the vols are loaded, so we can sample from vols instead of loading from file if n_loaded_vols == n_files: load_from_files = False X_all = np.concatenate(X_all, axis=0) if load_segs and len(Y_all) > 0: Y_all = np.concatenate(Y_all, axis=0) else: Y_all = None if load_contours and len(contours_all) > 0: contours_all = np.concatenate(contours_all, axis=0) else: contours_all = None else: load_from_files = True if load_from_files: self._print('Sampling size {} batches from {} files!'.format( batch_size, n_files)) else: self._print('Sampling size {} batches from {} volumes!'.format( batch_size, n_files)) if randomize: idxs = np.random.choice(n_files, batch_size, replace=True) else: idxs = np.linspace(0, min(n_files, batch_size), batch_size, endpoint=False, dtype=int) while True: start = time.time() if not load_from_files: # if vols are pre-loaded, simply sample them X = X_all[idxs] if load_segs and Y_all is not None: Y = Y_all[idxs] else: Y = None if load_contours and contours_all is not None: contours = contours_all[idxs] else: contours = None batch_files = [files_list[i] for i in idxs] else: X = [None] * batch_size if load_segs: Y = [None] * batch_size else: Y = None if load_contours: contours = [None] * batch_size else: contours = None batch_files = [] # load from files as we go for i, idx in enumerate(idxs.tolist()): x, y, curr_contours = _load_vol_and_seg( files_list[idx], load_seg=load_segs, load_contours=load_contours, do_mask_vol=True, keep_labels=self.label_mapping, ) batch_files.append(files_list[idx]) X[i] = x[np.newaxis, ...] if load_segs: Y[i] = y[np.newaxis, ...] if load_contours: contours[i] = curr_contours[np.newaxis] if self.profiler_logger is not None: self.profiler_logger.info( 'Loading vol took {}'.format(time.time() - start)) # if we loaded these as lists, turn them into ndarrays if isinstance(X, list): X = np.concatenate(X, axis=0) if load_segs and isinstance(Y, list): Y = np.concatenate(Y, axis=0) if load_contours and isinstance(contours, list): contours = np.concatenate(contours, axis=0) # pick idxs for the next batch if randomize: idxs = np.random.choice(n_files, batch_size, replace=True) else: idxs += batch_size idxs[idxs > n_files - 1] -= n_files if load_segs and convert_onehot: start = time.time() Y = classification_utils.labels_to_onehot( Y, label_mapping=label_mapping) if self.profiler_logger is not None: self.profiler_logger.info( 'Converting vol onehot took {}'.format(time.time() - start)) elif load_segs and not convert_onehot and not Y.shape[ -1] == 1: # make sure we have a channels dim Y = Y[..., np.newaxis] if not return_ids: yield X, Y, contours else: yield X, Y, contours, batch_files
def _generate_source_target_pairs(self, batch_size, source_vol_gen=None, target_vol_gen=None, return_ids=False): if self.X_source_train.shape[0] == 1: # single atlas, no need to sample from generator X_source = self.X_source_train segs_source = self.segs_source_train contours_source = self.contours_source_train id_source = [self.source_train_files[0]] # create source aux input here in case we need it later if self.arch_params['do_aux_reg'] is not None: if 'segs_oh' in self.arch_params['do_aux_reg']: # first channel will be segs in label form Y_source_onehot = classification_utils.labels_to_onehot( self.segs_source_train[..., [0]], label_mapping=self.label_mapping) source_aux_inputs = Y_source_onehot elif 'segs' in self.arch_params['do_aux_reg']: source_aux_inputs = self.segs_source_train else: source_aux_inputs = None if 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is not None: source_aux_inputs = np.concatenate([source_aux_inputs, self.contours_source_train], axis=-1) elif 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is None: source_aux_inputs = self.contours_source_train while True: if self.X_source_train.shape[0] > 1: X_source, segs_source, contours_source, id_source = next(source_vol_gen) # create source aux input here in case we need it later if self.arch_params['do_aux_reg'] is not None: if 'segs_oh' in self.arch_params['do_aux_reg']: # first channel will be segs in label form Y_source_onehot = classification_utils.labels_to_onehot( segs_source[..., [0]], label_mapping=self.label_mapping) source_aux_inputs = Y_source_onehot elif 'segs' in self.arch_params['do_aux_reg']: source_aux_inputs = segs_source else: source_aux_inputs = None if 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is not None: source_aux_inputs = np.concatenate([source_aux_inputs, contours_source], axis=-1) elif 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is None: source_aux_inputs = contours_source X_target, _, _, id_target = next(target_vol_gen) if 'color' in self.arch_params['model_arch']: if self.X_source_train.shape[0] > 1 and self.recon_loss_name == 'l2-src' \ or self.recon_loss_name == 'l2-tgt': # more than one atlas, so we need to back-warp depending on our atlas # OR, if we are computing reconstruction loss in the target space, # we still need to give the color model the src-space target X_target_srcspace = self.flow_bck_model.predict([X_target, X_source])[0] if self.arch_params['do_aux_reg'] is not None: inputs = [X_source, X_target_srcspace, source_aux_inputs] else: inputs = [X_source, X_target_srcspace] if self.recon_loss_name == 'l2-tgt': _, flow_batch = self.flow_fwd_model.predict([X_source, X_target]) # reconstruction loss in the target space inputs += [flow_batch] if 'bidir' in self.arch_params['model_arch']: # forward target, backward target, forward flow reg, backward flow reg targets = [X_target, X_source, X_target, X_source] else: targets = [X_target] * 3 # one dummy input at the end for the aux labels if not return_ids: yield inputs, targets else: yield inputs, targets, id_source, id_target