Beispiel #1
0
    def _generate_slices_batchs_from_vols(
        self,
        vols,
        segs,
        ids,
        vol_gen=None,
        randomize=False,
        batch_size=16,
        vol_batch_size=1,
        convert_onehot=False,
    ):
        if vol_gen is None:
            vol_gen = utils.gen_batch(vols, [segs, ids],
                                      randomize=randomize,
                                      batch_size=vol_batch_size)

        while True:
            X, Y, ids_batch = next(vol_gen)

            n_slices = X.shape[-2]
            slice_idxs = np.random.choice(n_slices,
                                          self.batch_size,
                                          replace=True)

            X_slices = np.reshape(
                np.transpose(X[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                (-1, ) + self.pred_img_shape)

            if Y is not None and self.arch_params['warpoh']:
                # we used the onehot warper in these cases
                Y = np.reshape(
                    np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                    (-1, ) + self.pred_img_shape[:-1] + (self.n_labels, ))
            else:
                Y = np.reshape(
                    np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                    (-1, ) + self.pred_img_shape[:-1] + (1, ))

            if Y is not None and convert_onehot:
                Y = utils.labels_to_onehot(Y, label_mapping=self.label_mapping)

            yield X_slices, Y, ids_batch
Beispiel #2
0
    def _generate_augmented_batch(
        self,
        source_gen,
        use_single_atlas=False,
        aug_by=None,
        return_transforms=False,
        convert_onehot=False,
        do_slice_z=False,
    ):

        if use_single_atlas:  # single atlas
            # single atlas, dont bother with generator
            self.logger.debug(
                'Single atlas, not using source generator for augmenter')
            X_source, source_segs, source_contours, source_ids = next(
                source_gen)
        else:
            self.logger.debug(
                'Multiple atlases, sampling source vols from generator')

        while True:
            if not use_single_atlas:
                # randomly select an atlas from our source volume generator
                X_source, source_segs, source_contours, source_ids = next(
                    source_gen)

            # keep track of which unlabeled subjects we are using in training
            ul_ids = []

            if len(aug_by) > 1:
                # multiple augmentation schemes. let's flip a coin
                do_aug_by_idx = np.random.rand(1)[0] * float(len(aug_by) - 1)
                aug_batch_by = aug_by[int(round(do_aug_by_idx))]
            else:
                aug_batch_by = aug_by[0]

            color_delta = None
            start = time.time()
            if aug_batch_by == 'rand':
                X_aug, flow = self.flow_rand_aug_model.predict(X_source)

                # color augmentation by additive or multiplicative factor. randomize the factor and then tile to correct shape
                if 'offset_amp' in self.data_params['aug_params']:
                    X_aug += np.tile(
                        (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) *
                        self.data_params['aug_params']['offset_amp'],
                        (1, ) + X_aug.shape[1:])
                    X_aug = np.clip(X_aug, 0., 1.)
                if 'mult_amp' in self.data_params['aug_params']:
                    X_aug *= np.tile(
                        1. +
                        (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) *
                        self.data_params['aug_params']['mult_amp'],
                        (1, ) + X_aug.shape[1:])
                    X_aug = np.clip(X_aug, 0., 1.)

                # Y_to_aug = classification_utils.labels_to_onehot(Y_to_aug, label_mapping=self.label_mapping)
                Y_aug = self.seg_warp_model.predict([source_segs, flow])
            elif aug_batch_by == 'tm':
                if self.arch_params['do_coupled_sampling']:
                    # use the same target for flow and color
                    X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw)
                    X_colortgt = X_flowtgt
                    ul_ids += ul_flow_ids
                else:
                    X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw)
                    X_colortgt, _, _, ul_color_ids = next(
                        self.unlabeled_gen_raw)
                    ul_ids += ul_flow_ids + ul_color_ids

                    if self.do_profile:
                        self.profiler_logger.info(
                            'Sampling aug tgt took {}'.format(time.time() -
                                                              st))

                self.aug_target = X_flowtgt

                # compute forward flow, which we will use for the spatial transformation
                _, flow = self.flow_aug_model.predict([X_source, X_flowtgt])

                # warp color target back to the atlas space so that we can compute the color transformation
                X_colortgt_src, _ = self.flow_bck_aug_model.predict(
                    [X_colortgt, X_source])
                colored_vol, color_delta, _ = self.color_aug_model.predict(
                    [X_source, X_colortgt_src, source_contours, flow])

                self.aug_colored = colored_vol

                # color the source volume, and then apply the spatial transformation
                X_aug = self.vol_warp_model.predict([colored_vol, flow])

                if self.do_profile:
                    self.profiler_logger.info(
                        'Running color and flow aug took {}'.format(
                            time.time() - st))

                st = time.time()
                # now warp the labels to match
                Y_aug = self.seg_warp_model.predict([source_segs, flow])
                if self.do_profile:
                    self.profiler_logger.info(
                        'Warping labels took {}'.format(time.time() - st))
            else:
                # no aug
                X_aug = X_source
                Y_aug = source_segs

            if self.do_profile:
                self.profiler_logger.info(
                    'Augmenting input batch took {}'.format(time.time() -
                                                            start))

            if do_slice_z:
                # get a random slice in the z dimension
                start = time.time()
                n_total_slices = X_source.shape[-2]

                # always take random slices in the z dimension
                slice_idxs = np.random.choice(n_total_slices,
                                              self.batch_size,
                                              replace=True)

                X_to_aug_slices = np.reshape(
                    np.transpose(  # roll z-slices into batch
                        X_source[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                    (-1, ) + tuple(self.pred_img_shape))

                X_aug = np.reshape(
                    np.transpose(X_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                    (-1, ) + self.pred_img_shape)

                if self.aug_target is not None:
                    self.aug_target = np.reshape(
                        np.transpose(self.aug_target[..., slice_idxs, :],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_img_shape)

                    if self.aug_colored is not None:
                        # not relevant if we arent using a color transform model
                        self.aug_colored = np.reshape(
                            np.transpose(self.aug_colored[..., slice_idxs, :],
                                         (0, 3, 1, 2, 4)),
                            (-1, ) + self.pred_img_shape)

                if (aug_by == 'rand'
                        or aug_by == 'tm') and self.arch_params['warpoh']:
                    # we used the onehot warper in these cases
                    Y_to_aug_slices = np.reshape(
                        np.transpose(source_segs[:, :, :, slice_idxs],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_segs_oh_shape)

                    Y_aug = np.reshape(
                        np.transpose(Y_aug[:, :, :, slice_idxs],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_segs_oh_shape)
                else:
                    Y_to_aug_slices = np.reshape(
                        np.transpose(source_segs[:, :, :, slice_idxs],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_segs_shape)
                    Y_aug = np.reshape(
                        np.transpose(Y_aug[:, :, :, slice_idxs],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_segs_shape)

                if self.do_profile:
                    self.profiler_logger.info(
                        'Slicing batch took {}'.format(time.time() - start))

                if return_transforms:
                    st = time.time()
                    flow = np.reshape(
                        np.transpose(flow[:, :, :, slice_idxs, :],
                                     (0, 3, 1, 2, 4)), (-1, ) +
                        self.pred_img_shape[:-1] + (self.n_aug_dims, ))
                    if self.do_profile:
                        self.profiler_logger.info(
                            'Slicing flow took {}'.format(time.time() - st))

                if color_delta is not None:
                    color_delta = np.reshape(
                        np.transpose(color_delta[:, :, :, slice_idxs, :],
                                     (0, 3, 1, 2, 4)),
                        (-1, ) + self.pred_img_shape)
            else:
                # no slicing, no aug?
                X_to_aug_slices = X_source
                Y_to_aug_slices = source_segs

            if convert_onehot and not ((aug_by == 'rand' or aug_by == 'tm')
                                       and self.arch_params['warpoh']):
                # if we don't have onehot segs already, convert them after slicing
                start = time.time()
                Y_to_aug_slices = utils.labels_to_onehot(
                    Y_to_aug_slices, label_mapping=self.label_mapping)
                Y_aug = utils.labels_to_onehot(
                    Y_aug, label_mapping=self.label_mapping)
                if self.do_profile:
                    self.profiler_logger.info(
                        'Converting onehot took {}'.format(time.time() -
                                                           start))

            if return_transforms:
                yield X_to_aug_slices, Y_to_aug_slices, flow, color_delta, X_aug, Y_aug, ul_ids
            else:
                yield X_to_aug_slices, Y_to_aug_slices, X_aug, Y_aug, ul_ids
Beispiel #3
0
    def create_generators(self, batch_size):
        self.batch_size = batch_size

        # generator for labeled training examples that we wish to augment
        self.source_gen = self.dataset.gen_vols_batch(
            dataset_splits=['labeled_train'],
            batch_size=1,
            load_segs=
            True,  # always load segmentations since this is our labeled set
            load_contours=self.
            aug_tm,  # only load contours if we need them as aux data for our appearance model
            randomize=True,
            return_ids=True,
        )

        # actually more like a target generator
        self.unlabeled_gen_raw = self.dataset.gen_vols_batch(
            dataset_splits=['labeled_train', 'unlabeled_train'],
            batch_size=1,
            load_segs=False,
            load_contours=False,
            randomize=True,
            return_ids=True,
        )

        self._create_augmentation_models()

        # simply append augmented examples to training set
        # NOTE: we can only do this if we are not synthesizing that many examples (i.e. <= 1000)
        self.aug_train_gen = None
        if self.n_aug is not None:
            self._create_augmented_examples()
            self.logger.debug(
                'Augmented classifier training set: vols {}, segs {}'.format(
                    self.X_labeled_train.shape, self.segs_labeled_train.shape))
        elif self.aug_tm or self.aug_rand:
            # augmentation by transform model or random flow+intensity requires synthesizing
            # a lot of examples, so we'll just do this per-batch
            aug_by = []
            if self.aug_tm:
                aug_by += ['tm']
            if self.aug_rand:
                aug_by += ['rand']

            # these need to be done in the generator
            self.aug_train_gen = self._generate_augmented_batch(
                source_gen=self.source_gen,
                use_single_atlas=self.X_labeled_train.shape[0] == 1,
                aug_by=aug_by,
                convert_onehot=True,
                return_transforms=True,
                do_slice_z=(
                    self.n_seg_dims == 2
                )  # if we are predicting on slices, then get slices
            )

        # generates slices from the training volumes
        self.train_gen = self._generate_slices_batchs_from_vols(
            self.X_labeled_train,
            self.segs_labeled_train,
            self.ids_labeled_train,
            vol_gen=None,
            convert_onehot=True,
            batch_size=self.batch_size,
            randomize=True)

        # load each subject, then evaluate each slice
        self.eval_valid_gen = self.dataset.gen_vols_batch(
            ['labeled_valid'],
            batch_size=1,
            randomize=False,
            convert_onehot=False,
            load_segs=True,
            label_mapping=self.label_mapping,
            return_ids=True,
        )

        # we will compute validation losses on all validation volumes
        # just pick some random slices to display for the validation set later
        rand_subjs = np.random.choice(self.X_labeled_valid.shape[0],
                                      batch_size)
        rand_slices = np.random.choice(self.aug_img_shape[2],
                                       batch_size,
                                       replace=False)

        self.X_valid_batch = np.zeros((batch_size, ) + self.pred_img_shape)
        self.Y_valid_batch = np.zeros((batch_size, ) + self.pred_segs_shape)
        for ei in range(batch_size):
            self.X_valid_batch[ei] = self.X_labeled_valid[rand_subjs[ei], :, :,
                                                          rand_slices[ei]]
            self.Y_valid_batch[ei] = self.segs_labeled_valid[
                rand_subjs[ei], :, :, rand_slices[ei]]

        self.Y_valid_batch = utils.labels_to_onehot(
            self.Y_valid_batch, label_mapping=self.label_mapping)
Beispiel #4
0
    def gen_vols_batch(self,
                       dataset_splits=['labeled_train'],
                       batch_size=1,
                       randomize=True,
                       load_segs=False,
                       load_contours=False,
                       convert_onehot=False,
                       label_mapping=None,
                       return_ids=False):

        if not isinstance(dataset_splits, list):
            dataset_splits = [dataset_splits]

        X_all = []
        Y_all = []
        contours_all = []
        files_list = []
        for ds in dataset_splits:
            if ds == 'labeled_train':
                X_all.append(self.vols_labeled_train)
                Y_all.append(self.segs_labeled_train)
                contours_all.append(self.contours_labeled_train)
                files_list += self.files_labeled_train
            elif ds == 'labeled_valid':
                X_all.append(self.vols_labeled_valid)
                Y_all.append(self.segs_labeled_valid)
                contours_all.append(self.contours_labeled_valid)
                files_list += self.files_labeled_valid
            elif ds == 'unlabeled_train':
                X_all.append(self.vols_unlabeled_train)
                Y_all.append(self.segs_unlabeled_train)
                contours_all.append(self.contours_unlabeled_train)
                files_list += self.files_unlabeled_train
            elif ds == 'labeled_test':
                if self.logger is not None:
                    self.logger.debug('LOOKING FOR FINAL TEST SET')
                else:
                    print('LOOKING FOR FINAL TEST SET')
                X_all.append(self.vols_labeled_test)
                Y_all.append(self.segs_labeled_test)
                contours_all.append(self.contours_labeled_test)
                files_list += self.files_labeled_test

        n_files = len(files_list)
        n_loaded_vols = np.sum([x.shape[0] for x in X_all])
        # if all of the vols are loaded, so we can sample from vols instead of loading from file
        if n_loaded_vols == n_files:
            load_from_files = False

            X_all = np.concatenate(X_all, axis=0)
            if load_segs and len(Y_all) > 0:
                Y_all = np.concatenate(Y_all, axis=0)
            else:
                Y_all = None

            if load_contours and len(contours_all) > 0:
                contours_all = np.concatenate(contours_all, axis=0)
            else:
                contours_all = None
        else:
            load_from_files = True

        if load_from_files:
            self._print('Sampling size {} batches from {} files!'.format(
                batch_size, n_files))
        else:
            self._print('Sampling size {} batches from {} volumes!'.format(
                batch_size, n_files))

        if randomize:
            idxs = np.random.choice(n_files, batch_size, replace=True)
        else:
            idxs = np.linspace(0,
                               min(n_files, batch_size),
                               batch_size,
                               endpoint=False,
                               dtype=int)

        while True:
            start = time.time()

            if not load_from_files:
                # if vols are pre-loaded, simply sample them
                X = X_all[idxs]

                if load_segs and Y_all is not None:
                    Y = Y_all[idxs]
                else:
                    Y = None

                if load_contours and contours_all is not None:
                    contours = contours_all[idxs]
                else:
                    contours = None

                batch_files = [files_list[i] for i in idxs]
            else:
                X = [None] * batch_size

                if load_segs:
                    Y = [None] * batch_size
                else:
                    Y = None

                if load_contours:
                    contours = [None] * batch_size
                else:
                    contours = None

                batch_files = []
                # load from files as we go
                for i, idx in enumerate(idxs.tolist()):
                    x, y, curr_contours = load_vol_and_seg(
                        files_list[idx],
                        load_seg=load_segs,
                        load_contours=load_contours,
                        do_mask_vol=True,
                        keep_labels=self.label_mapping,
                    )
                    batch_files.append(files_list[idx])
                    X[i] = x[np.newaxis, ...]

                    if load_segs:
                        Y[i] = y[np.newaxis, ...]

                    if load_contours:
                        contours[i] = curr_contours[np.newaxis]

            if self.profiler_logger is not None:
                self.profiler_logger.info(
                    'Loading vol took {}'.format(time.time() - start))

            # if we loaded these as lists, turn them into ndarrays
            if isinstance(X, list):
                X = np.concatenate(X, axis=0)

            if load_segs and isinstance(Y, list):
                Y = np.concatenate(Y, axis=0)

            if load_contours and isinstance(contours, list):
                contours = np.concatenate(contours, axis=0)

            # pick idxs for the next batch
            if randomize:
                idxs = np.random.choice(n_files, batch_size, replace=True)
            else:
                idxs += batch_size
                idxs[idxs > n_files - 1] -= n_files

            if load_segs and convert_onehot:
                start = time.time()
                Y = utils.labels_to_onehot(Y, label_mapping=label_mapping)
                if self.profiler_logger is not None:
                    self.profiler_logger.info(
                        'Converting vol onehot took {}'.format(time.time() -
                                                               start))
            elif load_segs and not convert_onehot and not Y.shape[
                    -1] == 1:  # make sure we have a channels dim
                Y = Y[..., np.newaxis]

            if not return_ids:
                yield X, Y, contours
            else:
                yield X, Y, contours, batch_files