def _generate_slices_batchs_from_vols( self, vols, segs, ids, vol_gen=None, randomize=False, batch_size=16, vol_batch_size=1, convert_onehot=False, ): if vol_gen is None: vol_gen = utils.gen_batch(vols, [segs, ids], randomize=randomize, batch_size=vol_batch_size) while True: X, Y, ids_batch = next(vol_gen) n_slices = X.shape[-2] slice_idxs = np.random.choice(n_slices, self.batch_size, replace=True) X_slices = np.reshape( np.transpose(X[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape) if Y is not None and self.arch_params['warpoh']: # we used the onehot warper in these cases Y = np.reshape( np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape[:-1] + (self.n_labels, )) else: Y = np.reshape( np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape[:-1] + (1, )) if Y is not None and convert_onehot: Y = utils.labels_to_onehot(Y, label_mapping=self.label_mapping) yield X_slices, Y, ids_batch
def _generate_augmented_batch( self, source_gen, use_single_atlas=False, aug_by=None, return_transforms=False, convert_onehot=False, do_slice_z=False, ): if use_single_atlas: # single atlas # single atlas, dont bother with generator self.logger.debug( 'Single atlas, not using source generator for augmenter') X_source, source_segs, source_contours, source_ids = next( source_gen) else: self.logger.debug( 'Multiple atlases, sampling source vols from generator') while True: if not use_single_atlas: # randomly select an atlas from our source volume generator X_source, source_segs, source_contours, source_ids = next( source_gen) # keep track of which unlabeled subjects we are using in training ul_ids = [] if len(aug_by) > 1: # multiple augmentation schemes. let's flip a coin do_aug_by_idx = np.random.rand(1)[0] * float(len(aug_by) - 1) aug_batch_by = aug_by[int(round(do_aug_by_idx))] else: aug_batch_by = aug_by[0] color_delta = None start = time.time() if aug_batch_by == 'rand': X_aug, flow = self.flow_rand_aug_model.predict(X_source) # color augmentation by additive or multiplicative factor. randomize the factor and then tile to correct shape if 'offset_amp' in self.data_params['aug_params']: X_aug += np.tile( (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params']['offset_amp'], (1, ) + X_aug.shape[1:]) X_aug = np.clip(X_aug, 0., 1.) if 'mult_amp' in self.data_params['aug_params']: X_aug *= np.tile( 1. + (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params']['mult_amp'], (1, ) + X_aug.shape[1:]) X_aug = np.clip(X_aug, 0., 1.) # Y_to_aug = classification_utils.labels_to_onehot(Y_to_aug, label_mapping=self.label_mapping) Y_aug = self.seg_warp_model.predict([source_segs, flow]) elif aug_batch_by == 'tm': if self.arch_params['do_coupled_sampling']: # use the same target for flow and color X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw) X_colortgt = X_flowtgt ul_ids += ul_flow_ids else: X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw) X_colortgt, _, _, ul_color_ids = next( self.unlabeled_gen_raw) ul_ids += ul_flow_ids + ul_color_ids if self.do_profile: self.profiler_logger.info( 'Sampling aug tgt took {}'.format(time.time() - st)) self.aug_target = X_flowtgt # compute forward flow, which we will use for the spatial transformation _, flow = self.flow_aug_model.predict([X_source, X_flowtgt]) # warp color target back to the atlas space so that we can compute the color transformation X_colortgt_src, _ = self.flow_bck_aug_model.predict( [X_colortgt, X_source]) colored_vol, color_delta, _ = self.color_aug_model.predict( [X_source, X_colortgt_src, source_contours, flow]) self.aug_colored = colored_vol # color the source volume, and then apply the spatial transformation X_aug = self.vol_warp_model.predict([colored_vol, flow]) if self.do_profile: self.profiler_logger.info( 'Running color and flow aug took {}'.format( time.time() - st)) st = time.time() # now warp the labels to match Y_aug = self.seg_warp_model.predict([source_segs, flow]) if self.do_profile: self.profiler_logger.info( 'Warping labels took {}'.format(time.time() - st)) else: # no aug X_aug = X_source Y_aug = source_segs if self.do_profile: self.profiler_logger.info( 'Augmenting input batch took {}'.format(time.time() - start)) if do_slice_z: # get a random slice in the z dimension start = time.time() n_total_slices = X_source.shape[-2] # always take random slices in the z dimension slice_idxs = np.random.choice(n_total_slices, self.batch_size, replace=True) X_to_aug_slices = np.reshape( np.transpose( # roll z-slices into batch X_source[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + tuple(self.pred_img_shape)) X_aug = np.reshape( np.transpose(X_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape) if self.aug_target is not None: self.aug_target = np.reshape( np.transpose(self.aug_target[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape) if self.aug_colored is not None: # not relevant if we arent using a color transform model self.aug_colored = np.reshape( np.transpose(self.aug_colored[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape) if (aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']: # we used the onehot warper in these cases Y_to_aug_slices = np.reshape( np.transpose(source_segs[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_segs_oh_shape) Y_aug = np.reshape( np.transpose(Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_segs_oh_shape) else: Y_to_aug_slices = np.reshape( np.transpose(source_segs[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_segs_shape) Y_aug = np.reshape( np.transpose(Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)), (-1, ) + self.pred_segs_shape) if self.do_profile: self.profiler_logger.info( 'Slicing batch took {}'.format(time.time() - start)) if return_transforms: st = time.time() flow = np.reshape( np.transpose(flow[:, :, :, slice_idxs, :], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape[:-1] + (self.n_aug_dims, )) if self.do_profile: self.profiler_logger.info( 'Slicing flow took {}'.format(time.time() - st)) if color_delta is not None: color_delta = np.reshape( np.transpose(color_delta[:, :, :, slice_idxs, :], (0, 3, 1, 2, 4)), (-1, ) + self.pred_img_shape) else: # no slicing, no aug? X_to_aug_slices = X_source Y_to_aug_slices = source_segs if convert_onehot and not ((aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']): # if we don't have onehot segs already, convert them after slicing start = time.time() Y_to_aug_slices = utils.labels_to_onehot( Y_to_aug_slices, label_mapping=self.label_mapping) Y_aug = utils.labels_to_onehot( Y_aug, label_mapping=self.label_mapping) if self.do_profile: self.profiler_logger.info( 'Converting onehot took {}'.format(time.time() - start)) if return_transforms: yield X_to_aug_slices, Y_to_aug_slices, flow, color_delta, X_aug, Y_aug, ul_ids else: yield X_to_aug_slices, Y_to_aug_slices, X_aug, Y_aug, ul_ids
def create_generators(self, batch_size): self.batch_size = batch_size # generator for labeled training examples that we wish to augment self.source_gen = self.dataset.gen_vols_batch( dataset_splits=['labeled_train'], batch_size=1, load_segs= True, # always load segmentations since this is our labeled set load_contours=self. aug_tm, # only load contours if we need them as aux data for our appearance model randomize=True, return_ids=True, ) # actually more like a target generator self.unlabeled_gen_raw = self.dataset.gen_vols_batch( dataset_splits=['labeled_train', 'unlabeled_train'], batch_size=1, load_segs=False, load_contours=False, randomize=True, return_ids=True, ) self._create_augmentation_models() # simply append augmented examples to training set # NOTE: we can only do this if we are not synthesizing that many examples (i.e. <= 1000) self.aug_train_gen = None if self.n_aug is not None: self._create_augmented_examples() self.logger.debug( 'Augmented classifier training set: vols {}, segs {}'.format( self.X_labeled_train.shape, self.segs_labeled_train.shape)) elif self.aug_tm or self.aug_rand: # augmentation by transform model or random flow+intensity requires synthesizing # a lot of examples, so we'll just do this per-batch aug_by = [] if self.aug_tm: aug_by += ['tm'] if self.aug_rand: aug_by += ['rand'] # these need to be done in the generator self.aug_train_gen = self._generate_augmented_batch( source_gen=self.source_gen, use_single_atlas=self.X_labeled_train.shape[0] == 1, aug_by=aug_by, convert_onehot=True, return_transforms=True, do_slice_z=( self.n_seg_dims == 2 ) # if we are predicting on slices, then get slices ) # generates slices from the training volumes self.train_gen = self._generate_slices_batchs_from_vols( self.X_labeled_train, self.segs_labeled_train, self.ids_labeled_train, vol_gen=None, convert_onehot=True, batch_size=self.batch_size, randomize=True) # load each subject, then evaluate each slice self.eval_valid_gen = self.dataset.gen_vols_batch( ['labeled_valid'], batch_size=1, randomize=False, convert_onehot=False, load_segs=True, label_mapping=self.label_mapping, return_ids=True, ) # we will compute validation losses on all validation volumes # just pick some random slices to display for the validation set later rand_subjs = np.random.choice(self.X_labeled_valid.shape[0], batch_size) rand_slices = np.random.choice(self.aug_img_shape[2], batch_size, replace=False) self.X_valid_batch = np.zeros((batch_size, ) + self.pred_img_shape) self.Y_valid_batch = np.zeros((batch_size, ) + self.pred_segs_shape) for ei in range(batch_size): self.X_valid_batch[ei] = self.X_labeled_valid[rand_subjs[ei], :, :, rand_slices[ei]] self.Y_valid_batch[ei] = self.segs_labeled_valid[ rand_subjs[ei], :, :, rand_slices[ei]] self.Y_valid_batch = utils.labels_to_onehot( self.Y_valid_batch, label_mapping=self.label_mapping)
def gen_vols_batch(self, dataset_splits=['labeled_train'], batch_size=1, randomize=True, load_segs=False, load_contours=False, convert_onehot=False, label_mapping=None, return_ids=False): if not isinstance(dataset_splits, list): dataset_splits = [dataset_splits] X_all = [] Y_all = [] contours_all = [] files_list = [] for ds in dataset_splits: if ds == 'labeled_train': X_all.append(self.vols_labeled_train) Y_all.append(self.segs_labeled_train) contours_all.append(self.contours_labeled_train) files_list += self.files_labeled_train elif ds == 'labeled_valid': X_all.append(self.vols_labeled_valid) Y_all.append(self.segs_labeled_valid) contours_all.append(self.contours_labeled_valid) files_list += self.files_labeled_valid elif ds == 'unlabeled_train': X_all.append(self.vols_unlabeled_train) Y_all.append(self.segs_unlabeled_train) contours_all.append(self.contours_unlabeled_train) files_list += self.files_unlabeled_train elif ds == 'labeled_test': if self.logger is not None: self.logger.debug('LOOKING FOR FINAL TEST SET') else: print('LOOKING FOR FINAL TEST SET') X_all.append(self.vols_labeled_test) Y_all.append(self.segs_labeled_test) contours_all.append(self.contours_labeled_test) files_list += self.files_labeled_test n_files = len(files_list) n_loaded_vols = np.sum([x.shape[0] for x in X_all]) # if all of the vols are loaded, so we can sample from vols instead of loading from file if n_loaded_vols == n_files: load_from_files = False X_all = np.concatenate(X_all, axis=0) if load_segs and len(Y_all) > 0: Y_all = np.concatenate(Y_all, axis=0) else: Y_all = None if load_contours and len(contours_all) > 0: contours_all = np.concatenate(contours_all, axis=0) else: contours_all = None else: load_from_files = True if load_from_files: self._print('Sampling size {} batches from {} files!'.format( batch_size, n_files)) else: self._print('Sampling size {} batches from {} volumes!'.format( batch_size, n_files)) if randomize: idxs = np.random.choice(n_files, batch_size, replace=True) else: idxs = np.linspace(0, min(n_files, batch_size), batch_size, endpoint=False, dtype=int) while True: start = time.time() if not load_from_files: # if vols are pre-loaded, simply sample them X = X_all[idxs] if load_segs and Y_all is not None: Y = Y_all[idxs] else: Y = None if load_contours and contours_all is not None: contours = contours_all[idxs] else: contours = None batch_files = [files_list[i] for i in idxs] else: X = [None] * batch_size if load_segs: Y = [None] * batch_size else: Y = None if load_contours: contours = [None] * batch_size else: contours = None batch_files = [] # load from files as we go for i, idx in enumerate(idxs.tolist()): x, y, curr_contours = load_vol_and_seg( files_list[idx], load_seg=load_segs, load_contours=load_contours, do_mask_vol=True, keep_labels=self.label_mapping, ) batch_files.append(files_list[idx]) X[i] = x[np.newaxis, ...] if load_segs: Y[i] = y[np.newaxis, ...] if load_contours: contours[i] = curr_contours[np.newaxis] if self.profiler_logger is not None: self.profiler_logger.info( 'Loading vol took {}'.format(time.time() - start)) # if we loaded these as lists, turn them into ndarrays if isinstance(X, list): X = np.concatenate(X, axis=0) if load_segs and isinstance(Y, list): Y = np.concatenate(Y, axis=0) if load_contours and isinstance(contours, list): contours = np.concatenate(contours, axis=0) # pick idxs for the next batch if randomize: idxs = np.random.choice(n_files, batch_size, replace=True) else: idxs += batch_size idxs[idxs > n_files - 1] -= n_files if load_segs and convert_onehot: start = time.time() Y = utils.labels_to_onehot(Y, label_mapping=label_mapping) if self.profiler_logger is not None: self.profiler_logger.info( 'Converting vol onehot took {}'.format(time.time() - start)) elif load_segs and not convert_onehot and not Y.shape[ -1] == 1: # make sure we have a channels dim Y = Y[..., np.newaxis] if not return_ids: yield X, Y, contours else: yield X, Y, contours, batch_files