Ejemplo n.º 1
0
    def _make_results_im(self, data_gen, max_batch_size=32):
        inputs, targets, ids_source, ids_target = next(data_gen)
        preds = self.transform_model.predict(inputs)

        input_im_batches = inputs[:2]
        labels = [
            [os.path.basename(ids) for ids in ids_source],
            [os.path.basename(idt) for idt in ids_target]]
        do_normalize = [False, False]

        if self.arch_params['use_aux_reg'] is not None and 'contours' in self.arch_params['use_aux_reg']:
            input_im_batches += [inputs[2][..., [-1]]]
            labels += ['aux_contours']
            do_normalize += [True]

        if 'bidir' in self.arch_params['model_arch']:
            # fwd flow, fwd transformed im
            input_im_batches += [preds[2], preds[0]]
        else:
            # spatial and appearance transforom models both output [transformed, transform, ...]
            input_im_batches += [preds[1], preds[0]]
        labels += ['transform', 'transformed']

        # if we are learning a color transform, normalize it for display purposes
        do_normalize += ['color' in self.arch_params['model_arch'], False]

        # batch_size = inputs_im.shape[0]
        batch_size = self.batch_size
        display_batch_size = min(max_batch_size, batch_size)
        zeros_batch = np.zeros((batch_size,) + self.img_shape)

        if display_batch_size < batch_size:
            input_im_batches = [batch[:display_batch_size] for batch in input_im_batches]

        if do_normalize is None:
            do_normalize = [False] * len(input_im_batches)

        if self.n_dims == 2:
            out_im = np.concatenate([
                utils.label_ims(
                    batch, labels[i],
                    normalize=do_normalize[i]
                ) for i, batch in enumerate(input_im_batches)
            ], axis=1)
        else:
            # pick a slice that is somewhat in the middle
            slice_idx = np.random.choice(
                range(int(round(self.img_shape[-2] * 0.25)), int(round(self.img_shape[-2] * 0.75))),
                1, replace=False)

            out_im = np.concatenate([
                utils.label_ims(
                    batch[:, :, :, slice_idx[0]], labels[i],
                    normalize=do_normalize[i]
                ) for i, batch in enumerate(input_im_batches)
            ], axis=1)

        return out_im
Ejemplo n.º 2
0
    def _make_results_im(self,
                         input_im_batches,
                         labels,
                         overlay_on_ims=None,
                         do_normalize=None,
                         is_seg=None,
                         max_batch_size=32):
        # batch_size = inputs_im.shape[0]
        batch_size = self.batch_size
        display_batch_size = min(max_batch_size, batch_size)
        zeros_batch = np.zeros((batch_size, ) + self.pred_img_shape)

        if do_normalize is None:
            do_normalize = [False] * len(input_im_batches)
        if is_seg is None:
            is_seg = [False] * len(input_im_batches)

        if display_batch_size < batch_size:
            input_im_batches = [
                batch[:display_batch_size] for batch in input_im_batches
            ]
            overlay_on_ims = [
                im[:display_batch_size] if im is not None else None
                for im in overlay_on_ims
            ]

        show_label_idx = 12  # cerebral wm
        out_im = np.concatenate([
            utils.label_ims(batch, labels[i]) if not is_seg[i] else
            np.concatenate([  # we want two images here: overlay and a single label
                utils.label_ims(np.transpose(
                        utils.overlay_segs_on_ims_batch(
                            ims=np.transpose(overlay_on_ims[i], (1, 2, 3, 0)),
                            segs=np.transpose(
                                utils.onehot_to_labels(
                                    batch, label_mapping=self.label_mapping), (1, 2, 0)),
                            include_labels=self.label_mapping,
                            draw_contours=True,
                        ),
                        (3, 0, 1, 2)), []),
                utils.label_ims(batch[..., [show_label_idx]],
                                    'label {}'.format(self.label_mapping[show_label_idx]), normalize=True)], axis=1) \
            for i, batch in enumerate(input_im_batches) if batch is not None
        ], axis=1)

        return out_im
Ejemplo n.º 3
0
    def _create_augmented_examples(self):
        preview_augmented_examples = True

        if self.aug_sas:
            aug_name = 'SAS'
            # just label a bunch of examples using our SAS model, and then append them to the training set
            X_source = self.X_labeled_train
            Y_source = self.segs_labeled_train

            unlabeled_labeler_gen = self.dataset.gen_vols_batch(
                dataset_splits=['unlabeled_train'],
                batch_size=1,
                randomize=False,
                return_ids=True)

            X_target = np.zeros((self.n_aug, ) + self.aug_img_shape)
            X_train_aug = np.zeros((self.n_aug, ) + self.aug_img_shape)
            Y_train_aug = np.zeros((self.n_aug, ) + self.aug_img_shape[:-1] +
                                   (1, ))
            ids_train_aug = [
            ]  #['sas_aug_{}'.format(i) for i in range(self.n_aug)]
            for i in range(self.n_aug):
                self.logger.debug(
                    'Pseudo-labeling UL example {} of {} using SAS!'.format(
                        i, self.n_aug))
                X_unlabeled, _, _, ul_ids = next(unlabeled_labeler_gen)

                # warp labeled example to unlabeled example
                X_aug, flow = self.flow_aug_model.predict(
                    [X_source, X_unlabeled])

                # warp labeled segs similarly
                Y_aug = self.seg_warp_model.predict([Y_source, flow])

                X_target[i] = X_unlabeled
                X_train_aug[
                    i] = X_unlabeled  # when using SAS, we use the predicted segmentations to "label" the original target volume
                Y_train_aug[i] = Y_aug
                ids_train_aug += ['sas_{}'.format(ul_id) for ul_id in ul_ids]

            self.X_labeled_train = np.concatenate(
                [self.X_labeled_train, X_train_aug], axis=0)
            self.segs_labeled_train = np.concatenate(
                [self.segs_labeled_train, Y_train_aug], axis=0)
            self.ids_labeled_train += ids_train_aug
            self.logger.debug(
                'Added {} {}-augmented batches to training set!'.format(
                    len(X_train_aug), aug_name))

            if preview_augmented_examples:
                print_batch_size = 10
                show_slice_idx = 112
                n_aug_batches = int(
                    np.ceil(X_train_aug.shape[0] / float(print_batch_size)))
                aug_out_im = []

                for bi in range(min(20, n_aug_batches)):
                    X_target_batch = X_target[
                        bi *
                        print_batch_size:min(X_train_aug.shape[0], (bi + 1) *
                                             print_batch_size), ...,
                        show_slice_idx, :]
                    X_aug_batch = X_train_aug[
                        bi *
                        print_batch_size:min(X_train_aug.shape[0], (bi + 1) *
                                             print_batch_size), ...,
                        show_slice_idx, :]
                    Y_aug_batch = Y_train_aug[
                        bi *
                        print_batch_size:min(X_train_aug.shape[0], (bi + 1) *
                                             print_batch_size), ...,
                        show_slice_idx, :]

                    aug_im = utils.concatenate_with_pad([
                        utils.label_ims(
                            np.tile(X_source[..., show_slice_idx, :],
                                    (X_target_batch.shape[0], ) + (1, ) *
                                    (len(X_source.shape) - 2))),
                        utils.label_ims(X_target_batch),
                        utils.label_ims(X_aug_batch),
                        utils.label_ims(
                            utils.overlay_segs_on_ims_batch(
                                ims=X_aug_batch,
                                segs=Y_aug_batch,
                                include_labels=self.label_mapping,
                                draw_contours=True,
                                subjects_axis=0)),
                    ],
                                                        axis=1)
                    aug_out_im.append(aug_im)
                aug_out_im = np.concatenate(aug_out_im, axis=0)
                cv2.imwrite(
                    os.path.join(self.exp_dir,
                                 'aug_{}_examples.jpg'.format(aug_name)),
                    aug_out_im)
Ejemplo n.º 4
0
    def _create_augmented_examples(self):
        preview_augmented_examples = True

        if self.aug_sas:
            aug_name = 'SAS'
            # just label a bunch of examples using our SAS model, and then append them to the training set
            source_X = self.X_labeled_train
            source_Y = self.segs_labeled_train

            unlabeled_labeler_gen = self.dataset.gen_vols_batch(
                dataset_splits=['unlabeled_train'],
                batch_size=1,
                randomize=False,
                return_ids=True)

            X_train_aug = np.zeros((self.n_aug, ) + self.aug_img_shape)
            Y_train_aug = np.zeros((self.n_aug, ) + self.aug_img_shape[:-1] +
                                   (1, ))
            ids_train_aug = [
            ]  #['sas_aug_{}'.format(i) for i in range(self.n_aug)]
            for i in range(self.n_aug):
                self.logger.debug(
                    'Pseudo-labeling UL example {} of {} using SAS!'.format(
                        i, self.n_aug))
                unlabeled_X, _, _, ul_ids = next(unlabeled_labeler_gen)

                # warp labeled example to unlabeled example
                X_aug, flow = self.flow_aug_model.predict(
                    [source_X, unlabeled_X])
                # warp labeled segs similarly
                Y_aug = self.seg_warp_model.predict([source_Y, flow])

                X_train_aug[i] = unlabeled_X
                Y_train_aug[i] = Y_aug
                ids_train_aug += ['sas_{}'.format(ul_id) for ul_id in ul_ids]

        elif self.aug_tm and self.data_params['n_tm_aug'] is not None \
                and self.data_params['n_tm_aug'] <= 100:

            # TODO: this section is not well tested since it is no longer really used, we do even coupled
            # augmentation in the generator
            aug_name = 'tm'
            source_train_gen_tmaug = self._generate_augmented_batch(
                self.source_gen,
                aug_by='tm',
                use_single_atlas=self.X_labeled_train.shape[0] == 1)

            # augment and append to training set
            n_aug_batches = int(
                np.ceil(self.data_params['n_tm_aug'] / float(self.batch_size)))

            X_preaug = [None] * n_aug_batches
            Y_preaug = [None] * n_aug_batches
            X_train_aug = [None] * n_aug_batches
            Y_train_aug = [None] * n_aug_batches
            ids_train_aug = []

            for i in range(n_aug_batches):
                # get source examples to perform augmentation on
                X_preaug[i], Y_preaug[i], X_train_aug[i], Y_train_aug[
                    i], aug_ids_batch = next(source_train_gen_tmaug)
                ids_train_aug += aug_ids_batch

            self.X_tm_aug = np.concatenate(
                X_train_aug, axis=0)[:self.data_params['n_tm_aug']]
            self.Y_tm_aug = np.concatenate(
                Y_train_aug, axis=0)[:self.data_params['n_tm_aug']]

        self.X_labeled_train = np.concatenate(
            [self.X_labeled_train, X_train_aug], axis=0)
        self.segs_labeled_train = np.concatenate(
            [self.segs_labeled_train, Y_train_aug], axis=0)
        self.ids_labeled_train += ids_train_aug
        self.logger.debug(
            'Added {} {}-augmented batches to training set!'.format(
                len(X_train_aug), aug_name))

        if preview_augmented_examples:
            print_batch_size = 10
            show_slice_idx = 112
            n_aug_batches = int(
                np.ceil(X_train_aug.shape[0] / float(print_batch_size)))
            aug_out_im = []
            for bi in range(min(20, n_aug_batches)):
                aug_im = utils.concatenate_with_pad([
                    utils.label_ims(
                        X_train_aug[bi * print_batch_size:min(
                            X_train_aug.shape[0], (bi + 1) * print_batch_size),
                                    ..., show_slice_idx, :], []),
                    utils.label_ims(
                        Y_train_aug[bi * print_batch_size:min(
                            X_train_aug.shape[0], (bi + 1) * print_batch_size),
                                    ..., show_slice_idx, :], []),
                ],
                                                    axis=1)
                aug_out_im.append(aug_im)
            aug_out_im = np.concatenate(aug_out_im, axis=0)
            cv2.imwrite(
                os.path.join(self.exp_dir,
                             'aug_{}_examples.jpg'.format(aug_name)),
                aug_out_im)