Beispiel #1
0
    def _generate_slices_batchs_from_vols(self, vols, segs, ids, vol_gen=None,
                                          randomize=False, batch_size=16,
                                          vol_batch_size=1,
                                          convert_onehot=False,
                                          ):
        if vol_gen is None:
            vol_gen = batch_utils.gen_batch(vols, [segs, ids],
                                            randomize=randomize,
                                            batch_size=vol_batch_size)

        while True:
            X, Y, ids_batch = next(vol_gen)

            n_slices = X.shape[-2]
            slice_idxs = np.random.choice(n_slices, self.batch_size, replace=True)
            X_slices = np.reshape(
                np.transpose(X[:, :, :, slice_idxs],
                             (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape)

            if Y is not None and self.arch_params['warpoh']:
                # we used the onehot warper in these cases
                Y = np.reshape(np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                               (-1,) + self.pred_img_shape[:-1] + (self.n_labels,))
            else:
                Y = np.reshape(np.transpose(Y[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                               (-1,) + self.pred_img_shape[:-1] + (1,))

            if Y is not None and convert_onehot:
                Y = classification_utils.labels_to_onehot(Y, label_mapping=self.label_mapping)

            yield X_slices, Y, ids_batch
Beispiel #2
0
    def gen_slices_batch(self, files_list, batch_size=1, randomize=True,
                       convert_onehot=False,
                       label_mapping=None):

        if randomize:
            np.random.shuffle(files_list)

        n_files = len(files_list)

        slice_idxs = np.linspace(0, batch_size, batch_size, endpoint=False, dtype=int)
        file_idx = 0
        while True:
            start = time.time()
            X, Y = _load_vol_and_seg(
                files_list[file_idx], do_mask_vol=True,
            )
            # TODO: add scaling here since we no longer do it per example

            if self.profiler_logger is not None:
                self.profiler_logger.info('Loading vol took {}'.format(time.time() - start))

            start = time.time()
            # slice the z-axis by default
            n_slices = X.shape[-2]
            if randomize:
                slice_idxs = np.random.choice(n_slices, batch_size, replace=True)
            else:
                # if we are going through the dataset sequentially, make the last batch smaller
                slice_idxs += batch_size
                slice_idxs[slice_idxs > n_slices - 1] = []
                file_idx += 1

            X = np.transpose(X[:, :, slice_idxs], (2, 0, 1, 3))
            Y = np.transpose(Y[:, :, slice_idxs], (2, 0, 1))
            if self.profiler_logger is not None:
                self.profiler_logger.info('Slicing took {}'.format(time.time() - start))

            if randomize:
                # randomize the shuffled files every batch
                file_idx += 1
                if file_idx > n_files - 1:
                    file_idx = 0
            elif not randomize and max(slice_idxs) >= n_slices - 1:
                # if we reached the end of teh slices of the previous file, start the slices again
                slice_idxs = np.linspace(0, batch_size, batch_size, endpoint=False, dtype=int)

            if convert_onehot:
                start = time.time()
                Y = classification_utils.labels_to_onehot(Y, label_mapping=label_mapping)
                if self.profiler_logger is not None:
                    self.profiler_logger.info('Converting slices onehot took {}'.format(time.time() - start))

            yield X, Y
Beispiel #3
0
    def _generate_augmented_batch(
            self,
            source_gen,
            use_single_atlas=False,
            aug_by=None,
            return_transforms=False,
            convert_onehot=False,
            do_slice_z=False,
    ):

        if use_single_atlas: # single atlas
            # single atlas, dont bother with generator
            self.logger.debug('Single atlas, not using source generator for augmenter')
            source_X, source_segs, source_contours, source_ids = next(source_gen)
        else:
            self.logger.debug('Multiple atlases, sampling source vols from generator')

        while True:
            if not use_single_atlas:
                # randomly select an atlas from our source volume generator
                source_X, source_segs, source_contours, source_ids = next(source_gen)

            # keep track of which unlabeled subjects we are using in training
            ul_ids = []

            if len(aug_by) > 1:
                # multiple augmentation schemes. let's flip a coin
                do_aug_by_idx = np.random.rand(1)[0] * float(len(aug_by) - 1)
                aug_batch_by = aug_by[int(round(do_aug_by_idx))]
            else:
                aug_batch_by = aug_by[0]

            color_delta = None
            start = time.time()
            if aug_batch_by == 'rand':
                X_aug, flow = self.flow_rand_aug_model.predict(source_X)

                # color augmentation by additive or multiplicative factor. randomize the factor and then tile to correct shape
                if 'offset_amp' in self.data_params['aug_params']:
                    X_aug += np.tile(
                        (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params'][
                            'offset_amp'],
                        (1,) + X_aug.shape[1:])
                    X_aug = np.clip(X_aug, 0., 1.)
                if 'mult_amp' in self.data_params['aug_params']:
                    X_aug *= np.tile(
                        1. + (np.random.rand(X_aug.shape[0], 1, 1, 1) * 2. - 1.) * self.data_params['aug_params'][
                            'mult_amp'],
                        (1,) + X_aug.shape[1:])
                    X_aug = np.clip(X_aug, 0., 1.)

                # Y_to_aug = classification_utils.labels_to_onehot(Y_to_aug, label_mapping=self.label_mapping)
                Y_aug = self.seg_warp_model.predict([source_segs, flow])
            elif aug_batch_by == 'tm':
                if self.arch_params['do_coupled_sampling']:
                    # use the same target for flow and color
                    X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw)
                    X_colortgt = X_flowtgt
                    ul_ids += ul_flow_ids
                else:
                    X_flowtgt, _, _, ul_flow_ids = next(self.unlabeled_gen_raw)
                    X_colortgt, _, _, ul_color_ids = next(self.unlabeled_gen_raw)
                    ul_ids += ul_flow_ids + ul_color_ids

                    if self.do_profile:
                        self.profiler_logger.info('Sampling aug tgt took {}'.format(time.time() - st))

                self.aug_target = X_flowtgt

                #if 'l2-tgt' in self.arch_params['tm_color_model']:
                X_colortgt_src, _ = self.flow_bck_aug_model.predict([X_colortgt, source_X])
                color_delta, colored_vol, _ = self.color_aug_model.predict([source_X, X_colortgt_src, source_contours])
                self.aug_colored = colored_vol

                _, flow = self.flow_aug_model.predict([source_X, X_flowtgt])
                X_aug = self.vol_warp_model.predict([colored_vol, flow])

                if self.do_profile:
                    self.profiler_logger.info('Running color and flow aug took {}'.format(time.time() - st))

                st = time.time()
                # now warp the labels to match
                Y_aug = self.seg_warp_model.predict([source_segs, flow])
                if self.do_profile:
                    self.profiler_logger.info('Warping labels took {}'.format(time.time() - st))
            else:
                # no aug
                X_aug = source_X
                Y_aug = source_segs

            if self.do_profile:
                self.profiler_logger.info('Augmenting input batch took {}'.format(time.time() - start))

            if do_slice_z:
                # get a random slice in the z dimension
                start = time.time()
                n_total_slices = source_X.shape[-2]

                # always take random slices in the z dimension
                slice_idxs = np.random.choice(n_total_slices, self.batch_size, replace=True)

                X_to_aug_slices = np.reshape(
                    np.transpose(  # roll z-slices into batch
                        source_X[:, :, :, slice_idxs],
                        (0, 3, 1, 2, 4)),
                    (-1,) + tuple(self.pred_img_shape))

                X_aug = np.reshape(
                    np.transpose(
                        X_aug[:, :, :, slice_idxs],
                        (0, 3, 1, 2, 4)),
                    (-1,) + self.pred_img_shape)

                if self.aug_target is not None:
                    self.aug_target = np.reshape(
                        np.transpose(self.aug_target[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape)

                    if self.aug_colored is not None:
                        # not relevant if we arent using a color transform model
                        self.aug_colored = np.reshape(
                            np.transpose(self.aug_colored[..., slice_idxs, :], (0, 3, 1, 2, 4)), (-1,) + self.pred_img_shape)

                if (aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']:
                    # we used the onehot warper in these cases
                    Y_to_aug_slices = np.reshape(
                        np.transpose(
                            source_segs[:, :, :, slice_idxs],
                            (0, 3, 1, 2, 4)),
                        (-1,) + self.pred_segs_oh_shape)

                    Y_aug = np.reshape(np.transpose(Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                                       (-1,) + self.pred_segs_oh_shape)
                else:
                    Y_to_aug_slices = np.reshape(np.transpose(
                        source_segs[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                     (-1,) + self.pred_segs_shape)
                    Y_aug = np.reshape(np.transpose(
                        Y_aug[:, :, :, slice_idxs], (0, 3, 1, 2, 4)),
                        (-1,) + self.pred_segs_shape)

                if self.do_profile:
                    self.profiler_logger.info('Slicing batch took {}'.format(time.time() - start))

                if return_transforms:
                    st = time.time()
                    flow = np.reshape(
                        np.transpose(
                            flow[:, :, :, slice_idxs, :],
                            (0, 3, 1, 2, 4)),
                        (-1,) + self.pred_img_shape[:-1] + (self.n_aug_dims,))
                    if self.do_profile:
                        self.profiler_logger.info('Slicing flow took {}'.format(time.time() - st))

                if color_delta is not None:
                    color_delta = np.reshape(
                        np.transpose(
                            color_delta[:, :, :, slice_idxs, :],
                            (0, 3, 1, 2, 4)),
                        (-1,) + self.pred_img_shape)
            else:
                # no slicing, no aug?
                X_to_aug_slices = source_X
                Y_to_aug_slices = source_segs

            if convert_onehot and not ((aug_by == 'rand' or aug_by == 'tm') and self.arch_params['warpoh']):
                # if we don't have onehot segs already, convert them after slicing
                start = time.time()
                Y_to_aug_slices = classification_utils.labels_to_onehot(
                    Y_to_aug_slices, label_mapping=self.label_mapping)
                Y_aug = classification_utils.labels_to_onehot(
                    Y_aug, label_mapping=self.label_mapping)
                if self.do_profile:
                    self.profiler_logger.info('Converting onehot took {}'.format(time.time() - start))

            if return_transforms:
                yield X_to_aug_slices, Y_to_aug_slices, flow, color_delta, X_aug, Y_aug, ul_ids
            else:
                yield X_to_aug_slices, Y_to_aug_slices, X_aug, Y_aug, ul_ids
Beispiel #4
0
    def create_generators(self, batch_size):
        self.batch_size = batch_size

        # generator for labeled training examples that we wish to augment
        self.source_gen = self.dataset.gen_vols_batch(
            dataset_splits=['labeled_train'], batch_size=1,
            load_segs=True, # always load segmentations since this is our labeled set
            load_contours=self.aug_tm, # only load contours if we need them as aux data for our appearance model
            randomize=True,
            return_ids=True,
        )

        # actually more like a target generator
        self.unlabeled_gen_raw = self.dataset.gen_vols_batch(
            dataset_splits=['labeled_train', 'unlabeled_train'], batch_size=1,
            load_segs=False, load_contours=False,
            randomize=True,
            return_ids=True,
        )

        self._create_augmentation_models()

        # simply append augmented examples to training set
        # NOTE: we can only do this if we are not synthesizing that many examples (i.e. <= 1000)
        self.aug_train_gen = None
        if self.n_aug is not None:
            self._create_augmented_examples()
            self.logger.debug('Augmented classifier training set: vols {}, segs {}'.format(
                self.X_labeled_train.shape, self.segs_labeled_train.shape))
        elif self.aug_tm or self.aug_rand:
            # augmentation by transform model or random flow+intensity requires synthesizing
            # a lot of examples, so we'll just do this per-batch
            aug_by = []
            if self.aug_tm:
                aug_by += ['tm']
            if self.aug_rand:
                aug_by += ['rand']

            # these need to be done in the generator
            self.aug_train_gen = self._generate_augmented_batch(
                    source_gen=self.source_gen, use_single_atlas=self.X_labeled_train.shape[0] == 1,
                    aug_by=aug_by,
                    convert_onehot=True, return_transforms=True,
                    do_slice_z=(self.n_pred_dims == 2)  # if we are predicting on slices, then get slices
            )

        # generates slices from the training volumes
        self.train_gen = self._generate_slices_batchs_from_vols(
            self.X_labeled_train, self.segs_labeled_train, self.ids_labeled_train,
            vol_gen=None,
            convert_onehot=True,
            batch_size=self.batch_size, randomize=True
        )

        # load each subject, then evaluate each slice
        self.eval_valid_gen = self.dataset.gen_vols_batch(
            ['labeled_valid'],
            batch_size=1, randomize=False,
            convert_onehot=False,
            load_segs=True,
            label_mapping=self.label_mapping,
            return_ids=True,
        )

        # we will compute validation losses on all validation volumes
        # just pick some random slices to display for the validation set later
        rand_subjs = np.random.choice(self.X_labeled_valid.shape[0], batch_size)
        rand_slices = np.random.choice(self.aug_img_shape[2], batch_size, replace=False)

        self.X_valid_batch = np.zeros((batch_size,) + self.pred_img_shape)
        self.Y_valid_batch = np.zeros((batch_size,) + self.pred_segs_shape)
        for ei in range(batch_size):
            self.X_valid_batch[ei] = self.X_labeled_valid[rand_subjs[ei], :, :, rand_slices[ei]]
            self.Y_valid_batch[ei] = self.segs_labeled_valid[rand_subjs[ei], :, :, rand_slices[ei]]

        self.Y_valid_batch = classification_utils.labels_to_onehot(self.Y_valid_batch, label_mapping=self.label_mapping)
Beispiel #5
0
def gen_batch(ims_data, labels_data,
              batch_size, randomize=False,
              pad_or_crop_to_size=None, normalize_tanh=False,
              convert_onehot=False, labels_to_onehot_mapping=None,
              aug_model=None, aug_params=None,
              yield_aug_params=False, yield_idxs=False,
              random_seed=None):
    '''

    :param ims_data: list of images, or an image.
    If a single image, it will be automatically converted to a list

    :param labels_data: list of other data (e.g. labels) that do not require
    image normalization or augmentation, but might need to be converted to onehot

    :param batch_size:
    :param randomize: bool to randomize indices per batch

    :param pad_or_crop_to_size: pad or crop each image in ims_data to the specified size. Default pad value is 0
    :param normalize_tanh: normalize image to range [-1, 1], good for synthesis with a tanh activation
    :param aug_params:

    :param convert_onehot: convert labels to a onehot representation using the mapping below
    :param labels_to_onehot_mapping: list of labels e.g. [0, 3, 5] indicating the mapping of label values to channel indices

    :param yield_aug_params: include the random augmentation params used on the batch in the return values
    :param yield_idxs: include the indices that comprise this batch in the return values
    :param random_seed:
    :return:
    '''
    if random_seed:
        np.random.seed(random_seed)

    # make sure everything is a list
    if not isinstance(ims_data, list):
        ims_data = [ims_data]

    if not isinstance(normalize_tanh, list):
        normalize_tanh = [normalize_tanh] * len(ims_data)
    else:
        assert len(normalize_tanh) == len(ims_data)

    if aug_params is not None:
        if not isinstance(aug_params, list):
            aug_params = [aug_params] * len(ims_data)
        else:
            assert len(aug_params) == len(ims_data)
        out_aug_params = aug_params[:]

    if pad_or_crop_to_size is not None:
        if not isinstance(pad_or_crop_to_size, list):
            pad_or_crop_to_size = [pad_or_crop_to_size] * len(ims_data)
        else:
            assert len(pad_or_crop_to_size) == len(ims_data)


    # if we have labels that we want to generate from,
    # put everything into a list for consistency
    # (useful if we have labels and aux data)
    if labels_data is not None:
        if not isinstance(labels_data, list):
            labels_data = [labels_data]

        # each entry should correspond to an entry in labels_data
        if not isinstance(convert_onehot, list):
            convert_onehot = [convert_onehot] * len(labels_data)
        else:
            assert len(convert_onehot) == len(labels_data)


    idxs = [-1]

    n_ims = ims_data[0].shape[0]
    h = ims_data[0].shape[1]
    w = ims_data[0].shape[2]

    if pad_or_crop_to_size is not None:
        # pad each image and then re-concatenate
        ims_data = [np.concatenate([
            image_utils.pad_or_crop_to_shape(x, pad_or_crop_to_size)[np.newaxis]
            for x in im_data], axis=0) for im_data in ims_data]

    while True:
        if randomize:
            idxs = np.random.choice(n_ims, batch_size, replace=True)
        else:
            idxs = np.linspace(idxs[-1] + 1, idxs[-1] + 1 + batch_size - 1, batch_size, dtype=int)
            restart_idxs = False
            while np.any(idxs >= n_ims):
                idxs[np.where(idxs >= n_ims)] = idxs[np.where(idxs >= n_ims)] - n_ims
                restart_idxs = True

        ims_batches = []
        for i, im_data in enumerate(ims_data):
            X_batch = im_data[idxs]

            if not X_batch.dtype == np.float32 and not X_batch.dtype == np.float64:
                X_batch = X_batch.astype(np.float32) / 255.

            if normalize_tanh[i]:
                X_batch = image_utils.normalize(X_batch)

            if aug_params is not None and aug_params[i] is not None:
                if aug_model is not None:
                    # use the gpu aug model instead
                    T, _ = aug_utils.aug_params_to_transform_matrices(
                        batch_size=X_batch.shape[0], add_last_row=True,
                        **aug_params[i]
                    )
                    X_batch = aug_model.predict([X_batch, T])
                    out_aug_params[i] = T
                else:
                    X_batch, out_aug_params[i] = aug_utils.aug_im_batch(X_batch, **aug_params[i])
            ims_batches.append(X_batch)

        if labels_data is not None:
            labels_batches = []
            for li, Y in enumerate(labels_data):
                if Y is None:
                    Y_batch = None
                else:
                    if convert_onehot[li]:
                        Y_batch = classification_utils.labels_to_onehot(
                            Y[idxs],
                            label_mapping=labels_to_onehot_mapping)
                    else:
                        if isinstance(Y, np.ndarray):
                            Y_batch = Y[idxs]
                        else: # in case it's a list
                            Y_batch = [Y[idx] for idx in idxs]
                labels_batches.append(Y_batch)
        else:
            labels_batches = None

        if not randomize and restart_idxs:
            idxs[-1] = -1

        if yield_aug_params and yield_idxs:
            yield tuple(ims_batches) +  tuple(labels_batches) + (out_aug_params, idxs)
        elif yield_aug_params:
            yield tuple(ims_batches) + tuple(labels_batches) + (out_aug_params, )
        elif yield_idxs:
            yield tuple(ims_batches) + tuple(labels_batches) + (idxs, )
        else:
            yield tuple(ims_batches) + tuple(labels_batches)
Beispiel #6
0
def gen_batch(ims_data,
              labels_data,
              batch_size,
              randomize=False,
              pad_or_crop_to_size=None,
              normalize_tanh=False,
              convert_onehot=False,
              labels_to_onehot_mapping=None,
              aug_model=None,
              aug_params=None,
              yield_aug_params=False,
              yield_idxs=False,
              random_seed=None):
    '''

    :param ims_data: list of images, or an image.
    If a single image, it will be automatically converted to a list

    :param labels_data: list of other data (e.g. labels) that do not require
    image normalization or augmentation, but might need to be converted to onehot

    :param batch_size:
    :param randomize: bool to randomize indices per batch

    :param pad_or_crop_to_size: pad or crop each image in ims_data to the specified size. Default pad value is 0
    :param normalize_tanh: normalize image to range [-1, 1], good for synthesis with a tanh activation
    :param aug_params:

    :param convert_onehot: convert labels to a onehot representation using the mapping below
    :param labels_to_onehot_mapping: list of labels e.g. [0, 3, 5] indicating the mapping of label values to channel indices

    :param yield_aug_params: include the random augmentation params used on the batch in the return values
    :param yield_idxs: include the indices that comprise this batch in the return values
    :param random_seed:
    :return:
    '''
    if random_seed:
        np.random.seed(random_seed)

    # make sure everything is a list
    if not isinstance(ims_data, list):
        ims_data = [ims_data]

    if not isinstance(normalize_tanh, list):
        normalize_tanh = [normalize_tanh] * len(ims_data)
    else:
        assert len(normalize_tanh) == len(ims_data)

    if aug_params is not None:
        if not isinstance(aug_params, list):
            aug_params = [aug_params] * len(ims_data)
        else:
            assert len(aug_params) == len(ims_data)
        out_aug_params = aug_params[:]

    if pad_or_crop_to_size is not None:
        if not isinstance(pad_or_crop_to_size, list):
            pad_or_crop_to_size = [pad_or_crop_to_size] * len(ims_data)
        else:
            assert len(pad_or_crop_to_size) == len(ims_data)

    # if we have labels that we want to generate from,
    # put everything into a list for consistency
    # (useful if we have labels and aux data)
    if labels_data is not None:
        if not isinstance(labels_data, list):
            labels_data = [labels_data]

        # each entry should correspond to an entry in labels_data
        if not isinstance(convert_onehot, list):
            convert_onehot = [convert_onehot] * len(labels_data)
        else:
            assert len(convert_onehot) == len(labels_data)

    idxs = [-1]

    n_ims = ims_data[0].shape[0]
    h = ims_data[0].shape[1]
    w = ims_data[0].shape[2]

    if pad_or_crop_to_size is not None:
        # pad each image and then re-concatenate
        ims_data = [
            np.concatenate([
                image_utils.pad_or_crop_to_shape(
                    x, pad_or_crop_to_size)[np.newaxis] for x in im_data
            ],
                           axis=0) for im_data in ims_data
        ]

    while True:
        if randomize:
            idxs = np.random.choice(n_ims, batch_size, replace=True)
        else:
            idxs = np.linspace(idxs[-1] + 1,
                               idxs[-1] + 1 + batch_size - 1,
                               batch_size,
                               dtype=int)
            restart_idxs = False
            while np.any(idxs >= n_ims):
                idxs[np.where(
                    idxs >= n_ims)] = idxs[np.where(idxs >= n_ims)] - n_ims
                restart_idxs = True

        ims_batches = []
        for i, im_data in enumerate(ims_data):
            X_batch = im_data[idxs]

            if not X_batch.dtype == np.float32 and not X_batch.dtype == np.float64:
                X_batch = X_batch.astype(np.float32) / 255.

            if normalize_tanh[i]:
                X_batch = image_utils.normalize(X_batch)

            if aug_params is not None and aug_params[i] is not None:
                if aug_model is not None:
                    # use the gpu aug model instead
                    T, _ = aug_utils.aug_params_to_transform_matrices(
                        batch_size=X_batch.shape[0],
                        add_last_row=True,
                        **aug_params[i])
                    X_batch = aug_model.predict([X_batch, T])
                    out_aug_params[i] = T
                else:
                    X_batch, out_aug_params[i] = aug_utils.aug_im_batch(
                        X_batch, **aug_params[i])
            ims_batches.append(X_batch)

        if labels_data is not None:
            labels_batches = []
            for li, Y in enumerate(labels_data):
                if Y is None:
                    Y_batch = None
                else:
                    if convert_onehot[li]:
                        Y_batch = classification_utils.labels_to_onehot(
                            Y[idxs], label_mapping=labels_to_onehot_mapping)
                    else:
                        if isinstance(Y, np.ndarray):
                            Y_batch = Y[idxs]
                        else:  # in case it's a list
                            Y_batch = [Y[idx] for idx in idxs]
                labels_batches.append(Y_batch)
        else:
            labels_batches = None

        if not randomize and restart_idxs:
            idxs[-1] = -1

        if yield_aug_params and yield_idxs:
            yield tuple(ims_batches) + tuple(labels_batches) + (out_aug_params,
                                                                idxs)
        elif yield_aug_params:
            yield tuple(ims_batches) + tuple(labels_batches) + (
                out_aug_params, )
        elif yield_idxs:
            yield tuple(ims_batches) + tuple(labels_batches) + (idxs, )
        elif labels_data is not None:
            yield tuple(ims_batches) + tuple(labels_batches)
        else:
            yield tuple(ims_batches) + (None, )
Beispiel #7
0
def eval_seg_sas_from_gen(sas_model,
                          atlas_vol,
                          atlas_labels,
                          eval_gen,
                          label_mapping,
                          n_eval_examples,
                          batch_size,
                          logger=None):
    '''
    Evaluates a single-atlas segmentation method on a bunch of evaluation volumes.
    :param sas_model: spatial transform model used for SAS. Can be voxelmorph.
    :param atlas_vol: atlas volume
    :param atlas_labels: atlas segmentations
    :param eval_gen: generator that yields vols_valid, segs_valid batches
    :param label_mapping: list of label ids that will appear in segs, ordered by how they map to channels
    :param n_eval_examples: total number of examples to evaluate
    :param batch_size: batch size to use in evaluation
    :param logger: python logger if we want to log messages
    :return:
    '''
    img_shape = atlas_vol.shape[1:]

    seg_warp_model = networks.warp_model(
        img_shape=img_shape,
        interp_mode='nearest',
        indexing='xy',
    )

    from keras.models import Model
    from keras.layers import Input, Activation
    from keras.optimizers import Adam
    n_labels = len(label_mapping)

    warped_in = Input(img_shape[0:-1] + (n_labels, ))
    warped = Activation('softmax')(warped_in)

    ce_model = Model(inputs=[warped_in], outputs=[warped], name='ce_model')
    ce_model.compile(loss='categorical_crossentropy', optimizer=Adam(0.0001))

    # test metrics: categorical cross-entropy and dice
    dice_per_label = np.zeros((n_eval_examples, len(label_mapping)))
    cces = np.zeros((n_eval_examples, ))
    accs = np.zeros((n_eval_examples, ))
    all_ids = []
    for bi in range(n_eval_examples):
        if logger is not None:
            logger.debug('Testing on subject {} of {}'.format(
                bi, n_eval_examples))
        else:
            print('Testing on subject {} of {}'.format(bi, n_eval_examples))
        X, Y, _, ids = next(eval_gen)
        Y_oh = classification_utils.labels_to_onehot(
            Y, label_mapping=label_mapping)

        warped, warp = sas_model.predict([atlas_vol, X])

        # warp our source models according to the predicted flow field. get rid of channels
        if Y.shape[-1] == 1:
            Y = Y[..., 0]
        preds_batch = seg_warp_model.predict(
            [atlas_labels[..., np.newaxis], warp])[..., 0]
        preds_oh = classification_utils.labels_to_onehot(
            preds_batch, label_mapping=label_mapping)

        cce = np.mean(ce_model.evaluate(preds_oh, Y_oh, verbose=False))
        subject_dice_per_label = medipy_metrics.dice(Y,
                                                     preds_batch,
                                                     labels=label_mapping)

        nonbkgmap = (Y > 0)
        acc = np.sum(((Y == preds_batch) *
                      nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float)
        print(acc)
        dice_per_label[bi] = subject_dice_per_label
        cces[bi] = cce
        accs[bi] = acc
        all_ids += ids

    if logger is not None:
        logger.debug('Dice per label: {}, {}'.format(label_mapping,
                                                     dice_per_label))
        logger.debug('Mean dice (no bkg): {}'.format(
            np.mean(dice_per_label[:, 1:])))
        logger.debug('Mean CE: {}'.format(np.mean(cces)))
        logger.debug('Mean accuracy: {}'.format(np.mean(accs)))
    else:
        print('Dice per label: {}, {}'.format(label_mapping, dice_per_label))
        print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:])))
        print('Mean CE: {}'.format(np.mean(cces)))
        print('Mean accuracy: {}'.format(np.mean(accs)))
    return cces, dice_per_label, accs, all_ids
Beispiel #8
0
def eval_seg_from_gen(segmenter_model,
                      eval_gen,
                      label_mapping,
                      n_eval_examples,
                      batch_size,
                      logger=None):
    '''
    Evaluates accuracy of a segmentation CNN
    :param segmenter_model: keras model for segmenter
    :param eval_gen: genrator that yields vols_valid, segs_valid
    :param label_mapping: list of label ids, ordered by how they map to channels
    :param n_eval_examples: total number of volumes to evaluate
    :param batch_size: batch size (number of slices per batch)
    :param logger: python logger (optional)
    :return:
    '''

    # test metrics: categorical cross-entropy and dice
    dice_per_label = np.zeros((n_eval_examples, len(label_mapping)))
    cces = np.zeros((n_eval_examples, ))
    accs = np.zeros((n_eval_examples, ))
    all_ids = []
    for bi in range(n_eval_examples):
        if logger is not None:
            logger.debug('Testing on subject {} of {}'.format(
                bi, n_eval_examples))
        else:
            print('Testing on subject {} of {}'.format(bi, n_eval_examples))
        X, Y, _, ids = next(eval_gen)
        Y_oh = classification_utils.labels_to_onehot(
            Y, label_mapping=label_mapping)
        preds_batch, cce = segment_vol_by_slice(
            segmenter_model,
            X,
            label_mapping=label_mapping,
            batch_size=batch_size,
            Y_oh=Y_oh,
            compute_cce=True,
        )
        subject_dice_per_label = medipy_metrics.dice(Y,
                                                     preds_batch,
                                                     labels=label_mapping)

        # only consider pixels where the gt label is not bkg (if we count bkg, accuracy will be very high)
        nonbkgmap = (Y > 0)

        acc = np.sum(((Y == preds_batch) *
                      nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float)

        print(acc)
        dice_per_label[bi] = subject_dice_per_label
        cces[bi] = cce
        accs[bi] = acc
        all_ids += ids

    if logger is not None:
        logger.debug('Dice per label: {}, {}'.format(
            label_mapping,
            np.mean(dice_per_label, axis=0).tolist()))
        logger.debug('Mean dice (no bkg): {}'.format(
            np.mean(dice_per_label[:, 1:])))
        logger.debug('Mean CE: {}'.format(np.mean(cces)))
        logger.debug('Mean accuracy: {}'.format(np.mean(accs)))
    else:
        print('Dice per label: {}, {}'.format(
            label_mapping,
            np.mean(dice_per_label, axis=0).tolist()))
        print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:])))
        print('Mean CE: {}'.format(np.mean(cces)))
        print('Mean accuracy: {}'.format(np.mean(accs)))
    return cces, dice_per_label, accs, all_ids
Beispiel #9
0
    def gen_vols_batch(self,
                       dataset_splits=['labeled_train'],
                       batch_size=1,
                       randomize=True,
                       load_segs=False,
                       load_contours=False,
                       convert_onehot=False,
                       label_mapping=None,
                       return_ids=False):

        if not isinstance(dataset_splits, list):
            dataset_splits = [dataset_splits]

        X_all = []
        Y_all = []
        contours_all = []
        files_list = []
        for ds in dataset_splits:
            if ds == 'labeled_train':
                X_all.append(self.vols_labeled_train)
                Y_all.append(self.segs_labeled_train)
                contours_all.append(self.contours_labeled_train)
                files_list += self.files_labeled_train
            elif ds == 'labeled_valid':
                X_all.append(self.vols_labeled_valid)
                Y_all.append(self.segs_labeled_valid)
                contours_all.append(self.contours_labeled_valid)
                files_list += self.files_labeled_valid
            elif ds == 'unlabeled_train':
                X_all.append(self.vols_unlabeled_train)
                Y_all.append(self.segs_unlabeled_train)
                contours_all.append(self.contours_unlabeled_train)
                files_list += self.files_unlabeled_train
            elif ds == 'labeled_test':
                if self.logger is not None:
                    self.logger.debug('LOOKING FOR FINAL TEST SET')
                else:
                    print('LOOKING FOR FINAL TEST SET')
                X_all.append(self.vols_labeled_test)
                Y_all.append(self.segs_labeled_test)
                contours_all.append(self.contours_labeled_test)
                files_list += self.files_labeled_test

        n_files = len(files_list)
        n_loaded_vols = np.sum([x.shape[0] for x in X_all])
        # if all of the vols are loaded, so we can sample from vols instead of loading from file
        if n_loaded_vols == n_files:
            load_from_files = False

            X_all = np.concatenate(X_all, axis=0)
            if load_segs and len(Y_all) > 0:
                Y_all = np.concatenate(Y_all, axis=0)
            else:
                Y_all = None

            if load_contours and len(contours_all) > 0:
                contours_all = np.concatenate(contours_all, axis=0)
            else:
                contours_all = None
        else:
            load_from_files = True

        if load_from_files:
            self._print('Sampling size {} batches from {} files!'.format(
                batch_size, n_files))
        else:
            self._print('Sampling size {} batches from {} volumes!'.format(
                batch_size, n_files))

        if randomize:
            idxs = np.random.choice(n_files, batch_size, replace=True)
        else:
            idxs = np.linspace(0,
                               min(n_files, batch_size),
                               batch_size,
                               endpoint=False,
                               dtype=int)

        while True:
            start = time.time()

            if not load_from_files:
                # if vols are pre-loaded, simply sample them
                X = X_all[idxs]

                if load_segs and Y_all is not None:
                    Y = Y_all[idxs]
                else:
                    Y = None

                if load_contours and contours_all is not None:
                    contours = contours_all[idxs]
                else:
                    contours = None

                batch_files = [files_list[i] for i in idxs]
            else:
                X = [None] * batch_size

                if load_segs:
                    Y = [None] * batch_size
                else:
                    Y = None

                if load_contours:
                    contours = [None] * batch_size
                else:
                    contours = None

                batch_files = []
                # load from files as we go
                for i, idx in enumerate(idxs.tolist()):
                    x, y, curr_contours = _load_vol_and_seg(
                        files_list[idx],
                        load_seg=load_segs,
                        load_contours=load_contours,
                        do_mask_vol=True,
                        keep_labels=self.label_mapping,
                    )
                    batch_files.append(files_list[idx])
                    X[i] = x[np.newaxis, ...]

                    if load_segs:
                        Y[i] = y[np.newaxis, ...]

                    if load_contours:
                        contours[i] = curr_contours[np.newaxis]

            if self.profiler_logger is not None:
                self.profiler_logger.info(
                    'Loading vol took {}'.format(time.time() - start))

            # if we loaded these as lists, turn them into ndarrays
            if isinstance(X, list):
                X = np.concatenate(X, axis=0)

            if load_segs and isinstance(Y, list):
                Y = np.concatenate(Y, axis=0)

            if load_contours and isinstance(contours, list):
                contours = np.concatenate(contours, axis=0)

            # pick idxs for the next batch
            if randomize:
                idxs = np.random.choice(n_files, batch_size, replace=True)
            else:
                idxs += batch_size
                idxs[idxs > n_files - 1] -= n_files

            if load_segs and convert_onehot:
                start = time.time()
                Y = classification_utils.labels_to_onehot(
                    Y, label_mapping=label_mapping)
                if self.profiler_logger is not None:
                    self.profiler_logger.info(
                        'Converting vol onehot took {}'.format(time.time() -
                                                               start))
            elif load_segs and not convert_onehot and not Y.shape[
                    -1] == 1:  # make sure we have a channels dim
                Y = Y[..., np.newaxis]

            if not return_ids:
                yield X, Y, contours
            else:
                yield X, Y, contours, batch_files
Beispiel #10
0
    def _generate_source_target_pairs(self, batch_size, source_vol_gen=None, target_vol_gen=None, return_ids=False):
        if self.X_source_train.shape[0] == 1:
            # single atlas, no need to sample from generator
            X_source = self.X_source_train
            segs_source = self.segs_source_train
            contours_source = self.contours_source_train
            id_source = [self.source_train_files[0]]

            # create source aux input here in case we need it later
            if self.arch_params['do_aux_reg'] is not None:
                if 'segs_oh' in self.arch_params['do_aux_reg']:
                    # first channel will be segs in label form
                    Y_source_onehot = classification_utils.labels_to_onehot(
                        self.segs_source_train[..., [0]], label_mapping=self.label_mapping)
                    source_aux_inputs = Y_source_onehot
                elif 'segs' in self.arch_params['do_aux_reg']:
                    source_aux_inputs = self.segs_source_train
                else:
                    source_aux_inputs = None

                if 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is not None:
                    source_aux_inputs = np.concatenate([source_aux_inputs, self.contours_source_train], axis=-1)
                elif 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is None:
                    source_aux_inputs = self.contours_source_train

        while True:
            if self.X_source_train.shape[0] > 1:
                X_source, segs_source, contours_source, id_source = next(source_vol_gen)
                # create source aux input here in case we need it later
                if self.arch_params['do_aux_reg'] is not None:
                    if 'segs_oh' in self.arch_params['do_aux_reg']:
                        # first channel will be segs in label form
                        Y_source_onehot = classification_utils.labels_to_onehot(
                            segs_source[..., [0]], label_mapping=self.label_mapping)
                        source_aux_inputs = Y_source_onehot
                    elif 'segs' in self.arch_params['do_aux_reg']:
                        source_aux_inputs = segs_source
                    else:
                        source_aux_inputs = None

                    if 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is not None:
                        source_aux_inputs = np.concatenate([source_aux_inputs, contours_source], axis=-1)
                    elif 'contours' in self.arch_params['do_aux_reg'] and source_aux_inputs is None:
                        source_aux_inputs = contours_source


            X_target, _, _, id_target = next(target_vol_gen)
            if 'color' in self.arch_params['model_arch']:
                if self.X_source_train.shape[0] > 1 and self.recon_loss_name == 'l2-src' \
                        or self.recon_loss_name == 'l2-tgt':
                    # more than one atlas, so we need to back-warp depending on our atlas
                    # OR, if we are computing reconstruction loss in the target space,
                    # we still need to give the color model the src-space target
                    X_target_srcspace = self.flow_bck_model.predict([X_target, X_source])[0]


            if self.arch_params['do_aux_reg'] is not None:
                inputs = [X_source, X_target_srcspace, source_aux_inputs]
            else:
                inputs = [X_source, X_target_srcspace]

            if self.recon_loss_name == 'l2-tgt':
                _, flow_batch = self.flow_fwd_model.predict([X_source, X_target])
                # reconstruction loss in the target space
                inputs += [flow_batch]

            if 'bidir' in self.arch_params['model_arch']:
                # forward target, backward target, forward flow reg, backward flow reg
                targets = [X_target, X_source, X_target, X_source]
            else:
                targets = [X_target] * 3 # one dummy input at the end for the aux labels

            if not return_ids:
                yield inputs, targets
            else:
                yield inputs, targets, id_source, id_target