Exemplo n.º 1
0
 def __init__(self, conf):
     self.model = None
     self.conf = conf
     self.loader = None
     if hasattr(self.conf,
                'dataset_name') and len(self.conf.dataset_name) > 0:
         self.loader = loader_factory.init_loader(self.conf.dataset_name)
Exemplo n.º 2
0
 def __init__(self, conf, model):
     self.conf = conf
     self.model = model
     self.loader = loader_factory.init_loader(self.conf.dataset_name)
     self.epoch = 0
     self.models_folder = self.conf.folder + '/models'
     self.train_data = None
     self.valid_data = None
     self.train_folder = None
Exemplo n.º 3
0
    def test(self):
        """
        Evaluate a model on the test data.
        """
        log.info('Evaluating model on test data')
        folder = os.path.join(self.conf.folder,
                              'test_results_%s' % self.conf.dataset_name)
        if not os.path.exists(folder):
            os.makedirs(folder)

        test_loader = loader_factory.init_loader(self.conf.dataset_name)
        test_data = test_loader.load_labelled_data(self.conf.split, 'test')

        synth = []
        im_dice = {}
        samples = os.path.join(folder, 'samples')
        if not os.path.exists(samples):
            os.makedirs(samples)

        f = open(os.path.join(folder, 'results.csv'), 'w')
        f.writelines('Vol, Dice\n')

        for vol_i in test_data.volumes():
            vol_folder = os.path.join(samples, 'vol_%s' % str(vol_i))
            if not os.path.exists(vol_folder):
                os.makedirs(vol_folder)

            vol_image = test_data.get_volume_image(vol_i)
            vol_mask = test_data.get_volume_mask(vol_i)
            assert vol_image.shape[0] > 0 and vol_image.shape == vol_mask.shape
            pred, _ = self.sdnet.Decomposer.predict(vol_image)

            synth.append(pred)
            im_dice[vol_i] = costs.dice(vol_mask, pred)
            f.writelines('%s, %.3f\n' % (str(vol_i), im_dice[vol_i]))

            for i in range(vol_image.shape[0]):
                im = np.concatenate([
                    vol_image[i, :, :, 0], pred[i, :, :, 0], vol_mask[i, :, :,
                                                                      0]
                ],
                                    axis=1)
                scipy.misc.imsave(
                    os.path.join(vol_folder,
                                 'test_vol%d_sl%d.png' % (vol_i, i)), im)

        print('Dice score: %.3f' % np.mean(list(im_dice.values())))
        f.close()
Exemplo n.º 4
0
    def __init__(self, sdnet, conf):
        self.sdnet = sdnet
        self.conf = conf
        self.loader = loader_factory.init_loader(self.conf.dataset_name)

        # Data iterators
        self.gen_X_L = None  # labelled data: (image, mask) pairs
        self.gen_X_U = None  # unlabelled data
        self.other_masks = None  # real masks to use for discriminator training

        self.fake_image_pool = []
        self.fake_mask_pool = []
        self.batch = 0
        self.epoch = 0

        if not os.path.exists(self.conf.folder):
            os.makedirs(self.conf.folder)
Exemplo n.º 5
0
    def __init__(self, conf):
        """
        SDNet constructor
        :param conf: configuration object
        """
        super(SDNet, self).__init__()
        self.other_masks = None
        self.conf = conf
        self.loader = loader_factory.init_loader(self.conf.dataset_name)

        self.D_model = None  # Discriminator trainer
        self.G_model = None  # Unsupervised generator trainer
        self.G_supervised_model = None  # Supervised generator trainer
        self.Decomposer = None  # Decomposer
        self.Reconstructor = None  # Reconstructor
        self.ImageDiscriminator = None  # Image discriminator
        self.MaskDiscriminator = None  # Mask discriminator
    def test_modality(self, modality, modality_index):
        """
        Evaluate model on a given modality
        :param modality: the modality to load
        """
        test_loader = loader_factory.init_loader(self.conf.test_dataset)
        test_loader.modalities = self.conf.modality
        test_data   = test_loader.load_all_modalities_concatenated(self.conf.split, 'test', self.conf.image_downsample)
        test_data.crop(self.conf.input_shape[:2])  # crop data to input shape

        for type in ['simple', 'def', 'max']:
            folder = self.make_test_folder(modality, suffix=type)
            self.test_modality_type(folder, modality_index, type, test_loader, test_data)

        test_data.randomise_pairs(length=2, seed=self.conf.seed)
        for type in ['simple', 'def', 'max']:
            folder = self.make_test_folder(modality, suffix=type + '_rand')
            self.test_modality_type(folder, modality_index, type, test_loader, test_data)
Exemplo n.º 7
0
    log.info('---- Setting up experiment at ' + config.folder + '----')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run SDNet')
    parser.add_argument('--epochs', help='Number of epochs to train', type=int)
    parser.add_argument('--dataset', help='Dataset to use', choices=['acdc'], required=True)
    parser.add_argument('--description', help='Experiment description')
    parser.add_argument('--split', help='Split for Cross Validation', type=int, required=True)
    parser.add_argument('--test', help='Test', type=bool)
    parser.add_argument('--ul_mix', help='Percentage of unlabelled data to mix', type=float, required=True)
    parser.add_argument('--l_mix', help='Percentage of labelled data to mix', type=float, required=True)
    args = parser.parse_args()

    # Create configuration object from parameters
    loader = loader_factory.init_loader(args.dataset)
    data = loader.load_labelled_data(args.split, 'training')

    folder = 'sdnet_%s_ul_%.3f_l_%.3f_split%d' % (args.dataset, args.ul_mix, args.l_mix, args.split)
    conf = Configuration(folder, data.size(), data.shape()[1:])
    del data

    conf.description = args.description if args.description else ''
    if args.epochs:
        conf.epochs = args.epochs
    conf.dataset_name = args.dataset
    conf.ul_mix = args.ul_mix
    conf.l_mix = args.l_mix
    conf.split = args.split
    conf.save()
Exemplo n.º 8
0
    def test_modality(self, folder, modality, group, save_figs=True):
        test_loader = loader_factory.init_loader(self.conf.test_dataset)
        test_data = test_loader.load_labelled_data(
            self.conf.split,
            group,
            modality=modality,
            downsample=self.conf.image_downsample)

        anatomy_segmentor = self.model.get_anatomy_segmentor()
        pathology_segmentator = self.model.get_pathology_encoder()

        synth = []
        im_dice_anato, im_false_negative_anato = {}, {}
        im_dice_patho, im_false_negative_patho = {}, {}

        sep_dice_list_anato, sep_false_negative_list_anato = [], []
        sep_dice_list_patho, sep_false_negative_list_patho = [], []
        anato_mask_num = len(test_data.anato_mask_names)
        patho_mask_num = len(test_data.patho_mask_names)
        for ii in range(anato_mask_num):
            sep_dice_list_anato.append([])
            sep_false_negative_list_anato.append([])
        for ii in range(patho_mask_num):
            sep_dice_list_patho.append([])
            sep_false_negative_list_patho.append([])

        f = open(os.path.join(folder, 'results.csv'), 'w')
        for vol_i in test_data.volumes():
            vol_image = test_data.get_images(vol_i)
            vol_anato_mask = test_data.get_anato_masks(vol_i)
            vol_patho_mask = test_data.get_patho_masks(vol_i)
            vol_slice = test_data.get_slice(vol_i)
            assert vol_image.shape[
                0] > 0 and vol_image.shape[:
                                           -1] == vol_anato_mask.shape[:
                                                                       -1] and vol_image.shape[:
                                                                                               -1] == vol_patho_mask.shape[:
                                                                                                                           -1]
            anato_pred = anatomy_segmentor.predict(vol_image)
            patho_pred = pathology_segmentator.predict(vol_image)
            pred = [anato_pred, patho_pred]
            synth.append(pred)

            model_type = 'sdnet'

            im_dice_anato[vol_i], sep_dice_anato \
                = dice(vol_anato_mask, pred[0])
            im_false_negative_anato[vol_i], sep_false_negative_anato \
                = calculate_false_negative(vol_anato_mask, pred[0])

            im_dice_patho[vol_i], sep_dice_patho \
                = dice(vol_patho_mask, pred[1])
            im_false_negative_patho[vol_i], sep_false_negative_patho \
                = calculate_false_negative(vol_patho_mask, pred[1])

            # harric added to specify dice scores across different masks
            assert anato_mask_num == len(
                sep_dice_anato), 'Incorrect mask num !'
            assert patho_mask_num == len(
                sep_dice_patho), 'Incorrect mask num !'
            for ii in range(anato_mask_num):
                sep_dice_list_anato[ii].append(sep_dice_anato[ii])
                sep_false_negative_list_anato[ii].append(
                    sep_false_negative_anato[ii])
            for ii in range(patho_mask_num):
                sep_dice_list_patho[ii].append(sep_dice_patho[ii])
                sep_false_negative_list_patho[ii].append(
                    sep_false_negative_patho[ii])

            # harric added to specify dice scores across different masks
            s = 'Volume:%s, AnatomyDice:%.3f, AnatomyFN:%.3f, ' \
                + 'PathologyDice:%.3f, PathologyFN:%.3f, ' \
                + ', '.join(['%s, %.3f, %.3f, '] * len(test_data.anato_mask_names)) \
                + ', '.join(['%s, %.3f, %.3f, '] * len(test_data.patho_mask_names)) \
                + '\n'
            d = (str(vol_i), im_dice_anato[vol_i],
                 im_false_negative_anato[vol_i])
            d += (im_dice_patho[vol_i], im_false_negative_patho[vol_i])
            for info_travesal in range(anato_mask_num):
                d += (test_data.anato_mask_names[info_travesal],
                      sep_dice_anato[info_travesal],
                      sep_false_negative_anato[info_travesal])
            for info_travesal in range(patho_mask_num):
                d += (test_data.patho_mask_names[info_travesal],
                      sep_dice_patho[info_travesal],
                      sep_false_negative_patho[info_travesal])
            f.writelines(s % d)

            if save_figs:
                for i in range(vol_image.shape[0]):
                    d, m, mm = vol_image[i], vol_anato_mask[i], vol_patho_mask[
                        i]
                    # d, m, mm = vol_image[10], vol_anato_mask[10], vol_patho_mask[10]
                    s = vol_slice[i]
                    im1 = save_segmentation(pred[0][i, :, :, :], d, m)
                    im2 = save_segmentation(pred[1][i, :, :, :], d, mm)

                    if im1.shape[1] > im2.shape[1]:
                        im2 = np.concatenate([
                            im2,
                            np.zeros(shape=(im2.shape[0],
                                            im1.shape[1] - im2.shape[1]),
                                     dtype=im2.dtype)
                        ],
                                             axis=1)
                    elif im1.shape[1] < im2.shape[1]:
                        im1 = np.concatenate([
                            im1,
                            np.zeros(shape=(im1.shape[0],
                                            im2.shape[1] - im1.shape[1]),
                                     dtype=im1.dtype)
                        ],
                                             axis=1)

                    im = np.concatenate([im1, im2], axis=0)
                    imsave(
                        os.path.join(
                            folder,
                            "vol%s_slice%s" % (str(vol_i), s) + '.png'), im)

        # harric added to specify dice scores across different masks
        print_info = group + ', AnatomyDice:%.3f, AnatoFN:%.3f, PathoDice:%.3f, PathoFN:%.3f,' % \
                     (np.mean(list(im_dice_anato.values())),
                      np.mean(list(im_false_negative_anato.values())),
                      np.mean(list(im_dice_patho.values())),
                      np.mean(list(im_false_negative_patho.values())))
        for ii in range(anato_mask_num):
            print_info += '%s, %.3f, %.3f,' % \
                          (test_data.anato_mask_names[ii],
                           np.mean(sep_dice_list_anato[ii]),
                           np.mean(sep_false_negative_list_anato[ii]))
        for ii in range(patho_mask_num):
            print_info += '%s, %.3f, %.3f' % \
                          (test_data.patho_mask_names[ii],
                           np.mean(sep_dice_list_patho[ii]),
                           np.mean(sep_false_negative_list_patho[ii]))
        print(print_info)
        f.write(print_info)
        f.close()
        return np.mean(list(im_dice_patho.values()))