예제 #1
0
    def _create_augmentation_models(self, indexing='xy'):
        # TODO: put this in a param somewhere

        self.vol_warp_model = networks.warp_model(
            img_shape=self.aug_img_shape,
            interp_mode='linear',
            indexing=indexing,
        )

        # segmentation warpers that take in a flow field
        if 'warpoh' in self.arch_params.keys() and self.arch_params['warpoh']:
            self.seg_warp_model = networks.warp_model(
                img_shape=self.aug_img_shape[:-1] + (self.n_labels, ),  # (1,),
                interp_mode='linear',
                indexing=indexing,
            )
        else:
            self.seg_warp_model = networks.warp_model(
                img_shape=self.aug_img_shape[:-1] + (1, ),
                interp_mode='nearest',
                indexing=indexing)

        if self.aug_rand:
            if self.data_params['aug_params'][
                    'randflow_type'] == 'ronneberger':
                self.flow_rand_aug_model = networks.randflow_ronneberger_model(
                    img_shape=self.aug_img_shape,
                    model=None,
                    interp_mode='linear',
                    model_name='randflow_ronneberger_model',
                    flow_sigma=self.data_params['aug_params']['flow_sigma'],
                    blur_sigma=self.data_params['aug_params']['blur_sigma'])
                self.logger.debug('Random flow Ronneberger augmentation model')
            else:
                self.flow_rand_aug_model = networks.randflow_model(
                    img_shape=self.aug_img_shape,
                    model=None,
                    interp_mode='linear',
                    model_name='randflow_model',
                    flow_sigma=self.data_params['aug_params']['flow_sigma'],
                    flow_amp=self.data_params['aug_params']['flow_amp'],
                    blur_sigma=self.data_params['aug_params']['blur_sigma'],
                    indexing=indexing,
                )
                self.logger.debug('Random flow augmentation model')

            self.flow_rand_aug_model.summary(print_fn=self.logger.debug)

        if self.aug_tm or self.aug_sas:
            self.flow_aug_model = load_model(
                self.arch_params['tm_flow_model'],
                custom_objects={
                    'SpatialTransformer':
                    functools.partial(nrn_layers.SpatialTransformer,
                                      indexing=indexing)
                },
                compile=False)

            if self.arch_params['tm_color_model'] is not None and self.aug_tm:
                self.flow_bck_aug_model = load_model(
                    self.arch_params['tm_flow_bck_model'],
                    custom_objects={
                        'SpatialTransformer':
                        functools.partial(nrn_layers.SpatialTransformer,
                                          indexing=indexing)
                    },
                    compile=False)
                self.color_aug_model = load_model(
                    self.arch_params['tm_color_model'],
                    custom_objects={
                        'SpatialTransformer':
                        functools.partial(nrn_layers.SpatialTransformer,
                                          indexing=indexing)
                    },
                    compile=False)

                if 'l2-tgt' in self.arch_params['tm_color_model']:
                    # if the color model transforms to the target space by default, create
                    # a wrapper that gets the source space
                    self.color_aug_model = Model(
                        inputs=self.color_aug_model.inputs[:-1],
                        outputs=[
                            self.color_aug_model.outputs[0],
                            self.color_aug_model.get_layer(
                                'add_color_delta').output,
                            self.color_aug_model.outputs[2]
                        ],
                        name='color_model_wrapper')
예제 #2
0
def eval_seg_sas_from_gen(sas_model,
                          atlas_vol,
                          atlas_labels,
                          eval_gen,
                          label_mapping,
                          n_eval_examples,
                          batch_size,
                          logger=None):
    '''
    Evaluates a single-atlas segmentation method on a bunch of evaluation volumes.
    :param sas_model: spatial transform model used for SAS. Can be voxelmorph.
    :param atlas_vol: atlas volume
    :param atlas_labels: atlas segmentations
    :param eval_gen: generator that yields vols_valid, segs_valid batches
    :param label_mapping: list of label ids that will appear in segs, ordered by how they map to channels
    :param n_eval_examples: total number of examples to evaluate
    :param batch_size: batch size to use in evaluation
    :param logger: python logger if we want to log messages
    :return:
    '''
    img_shape = atlas_vol.shape[1:]

    seg_warp_model = networks.warp_model(
        img_shape=img_shape,
        interp_mode='nearest',
        indexing='xy',
    )

    from keras.models import Model
    from keras.layers import Input, Activation
    from keras.optimizers import Adam
    n_labels = len(label_mapping)

    warped_in = Input(img_shape[0:-1] + (n_labels, ))
    warped = Activation('softmax')(warped_in)

    ce_model = Model(inputs=[warped_in], outputs=[warped], name='ce_model')
    ce_model.compile(loss='categorical_crossentropy', optimizer=Adam(0.0001))

    # test metrics: categorical cross-entropy and dice
    dice_per_label = np.zeros((n_eval_examples, len(label_mapping)))
    cces = np.zeros((n_eval_examples, ))
    accs = np.zeros((n_eval_examples, ))
    all_ids = []
    for bi in range(n_eval_examples):
        if logger is not None:
            logger.debug('Testing on subject {} of {}'.format(
                bi, n_eval_examples))
        else:
            print('Testing on subject {} of {}'.format(bi, n_eval_examples))
        X, Y, _, ids = next(eval_gen)
        Y_oh = labels_to_onehot(Y, label_mapping=label_mapping)

        warped, warp = sas_model.predict([atlas_vol, X])

        # warp our source models according to the predicted flow field. get rid of channels
        if Y.shape[-1] == 1:
            Y = Y[..., 0]
        preds_batch = seg_warp_model.predict(
            [atlas_labels[..., np.newaxis], warp])[..., 0]
        preds_oh = labels_to_onehot(preds_batch, label_mapping=label_mapping)

        cce = np.mean(ce_model.evaluate(preds_oh, Y_oh, verbose=False))
        subject_dice_per_label = medipy_metrics.dice(Y,
                                                     preds_batch,
                                                     labels=label_mapping)

        nonbkgmap = (Y > 0)
        acc = np.sum(((Y == preds_batch) *
                      nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float)
        print(acc)
        dice_per_label[bi] = subject_dice_per_label
        cces[bi] = cce
        accs[bi] = acc
        all_ids += ids

    if logger is not None:
        logger.debug('Dice per label: {}, {}'.format(label_mapping,
                                                     dice_per_label))
        logger.debug('Mean dice (no bkg): {}'.format(
            np.mean(dice_per_label[:, 1:])))
        logger.debug('Mean CE: {}'.format(np.mean(cces)))
        logger.debug('Mean accuracy: {}'.format(np.mean(accs)))
    else:
        print('Dice per label: {}, {}'.format(label_mapping, dice_per_label))
        print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:])))
        print('Mean CE: {}'.format(np.mean(cces)))
        print('Mean accuracy: {}'.format(np.mean(accs)))
    return cces, dice_per_label, accs, all_ids