def __getitem__(self, idx):

        batchx = self.train_files[idx * self.batch_size:(idx + 1) *
                                  self.batch_size]
        batchy = self.train_outputs[idx * self.batch_size:(idx + 1) *
                                    self.batch_size]

        X_RGB = np.zeros((self.batch_size, 256, 256, 3))
        X_Per = np.zeros((self.batch_size, 32, 32, 1))
        Y = np.zeros((self.batch_size, 2))

        for i in range(self.batch_size):
            image_id = int(float(re.findall("\d+\.\d+", batchx[i])[0]))

            fnameRGB = batchx[i] + '.jpg'
            fnamePer = batchx[i] + '.pkl'

            if batchy[i] == self.label['malignant']:
                pathRGB = os.path.join(self.path_mal_train, fnameRGB)
                pathPer = os.path.join(self.path_mal_train_per, fnamePer)

            elif batchy[i] == self.label['benign']:
                pathRGB = os.path.join(self.path_ben_train, fnameRGB)
                pathPer = os.path.join(self.path_ben_train_per, fnamePer)

            img = skimage.io.imread(pathRGB)
            if img.shape == (1024, 1024, 3):
                img = img[::4, ::4, :]

            image_id = int(float(re.findall("\d+\.\d+", pathRGB)[0]))

            if image_id in self.config['stats'].keys():
                [src_mu, src_sigma] = self.stats[image_id]
                img_nmzd = htk_cnorm.reinhard(
                    img,
                    self.ref_mu_lab,
                    self.ref_std_lab,
                    src_mu=src_mu,
                    src_sigma=src_sigma).astype('float')
            else:
                print '#### stats for %d not present' % (image_id)
                img_nmzd = htk_cnorm.reinhard(img, self.ref_mu_lab,
                                              self.ref_std_lab).astype('float')

            imgRGB = preprocess_resnet(img_nmzd)

            with open(pathPer, 'rb') as f:
                img = pickle.load(f)

            imgPer = self.preprocess_persistence(img)

            X_RGB[i] = imgRGB
            X_Per[i] = imgPer
            Y[i] = to_categorical(batchy[i], num_classes=2)

        return ([X_RGB, X_Per], Y)
Esempio n. 2
0
    def preprocess(self, img_path):

        img = skimage.io.imread(img_path)
        if img.shape == (1024, 1024, 3):
            img = img[::4, ::4, :]

        image_id = int(float(re.findall("\d+\.\d+", img_path)[0]))

        if image_id in self.stats.keys():
            [src_mu, src_sigma] = self.stats[image_id]
            img_nmzd = htk_cnorm.reinhard(img,
                                          self.ref_mu_lab,
                                          self.ref_std_lab,
                                          src_mu=src_mu,
                                          src_sigma=src_sigma).astype('float')
        else:
            print '#### stats for %d not present' % (image_id)
            img_nmzd = htk_cnorm.reinhard(img, self.ref_mu_lab,
                                          self.ref_std_lab).astype('float')

        img = preprocess_resnet(img_nmzd)

        return img
def CombinedTestData(config):

    print 'Loading combined data'
    path_mal_test, _, files_malignant_test = next(
        os.walk(os.path.join(config.test_dir, 'malignant', 'rgb')))
    path_ben_test, _, files_benign_test = next(
        os.walk(os.path.join(config.test_dir, 'benign', 'rgb')))

    path_mal_test_per, _, _ = next(
        os.walk(
            os.path.join(config.test_dir, 'malignant', 'persistence_images')))
    path_ben_test_per, _, _ = next(
        os.walk(os.path.join(config.test_dir, 'benign', 'persistence_images')))

    batch_size = config.trainer.batch_size
    label = config.label

    mal_paths_test = glob.glob(os.path.join(path_mal_test_per, '*'))
    ben_paths_test = glob.glob(os.path.join(path_ben_test_per, '*'))

    mal_outputs_test = [label['malignant']] * len(mal_paths_test)
    ben_outputs_test = [label['benign']] * len(ben_paths_test)

    test_paths = mal_paths_test + ben_paths_test
    test_outputs = mal_outputs_test + ben_outputs_test

    test_files = [os.path.basename(elem) for elem in test_paths]
    test_files = [elem.replace('.pkl', '') for elem in test_files]

    ref_std_lab = (0.57506023, 0.10403329, 0.01364062)
    ref_mu_lab = (8.63234435, -0.11501964, 0.03868433)

    if os.path.isfile('configs/stats.pkl'):
        with open('configs/stats.pkl', 'rb') as f:
            stats = pickle.load(f)
        print 'Stats loaded'
        config['stats'] = stats
    else:
        print 'No stats file found (To obtain Mu and Sigma from original whole image).'

    len_test = len(test_outputs)

    X_RGB = np.zeros((len_test, 256, 256, 3))
    X_Per = np.zeros((len_test, 32, 32, 1))
    Y = [-1] * len_test

    for i in range(len_test):

        image_id = int(float(re.findall("\d+\.\d+", test_files[i])[0]))

        fnameRGB = test_files[i] + '.jpg'
        fnamePer = test_files[i] + '.pkl'

        if test_outputs[i] == config.label['malignant']:
            pathRGB = os.path.join(path_mal_test, fnameRGB)
            pathPer = os.path.join(path_mal_test_per, fnamePer)

        elif test_outputs[i] == config.label['benign']:
            pathRGB = os.path.join(path_ben_test, fnameRGB)
            pathPer = os.path.join(path_ben_test_per, fnamePer)

        img = skimage.io.imread(pathRGB)
        if img.shape == (1024, 1024, 3):
            img = img[::4, ::4, :]

        image_id = int(float(re.findall("\d+\.\d+", pathRGB)[0]))

        if image_id in stats.keys():
            [src_mu, src_sigma] = stats[image_id]
            img_nmzd = htk_cnorm.reinhard(img,
                                          ref_mu_lab,
                                          ref_std_lab,
                                          src_mu=src_mu,
                                          src_sigma=src_sigma).astype('float')
        else:
            print '#### stats for %d not present' % (image_id)
            img_nmzd = htk_cnorm.reinhard(img, ref_mu_lab,
                                          ref_std_lab).astype('float')

        imgRGB = preprocess_resnet(img_nmzd)

        with open(pathPer, 'rb') as f:
            img = pickle.load(f)
        img = img / config.trainer.percentile_factor
        img = np.array([img])
        imgPer = np.moveaxis(img, 0, 2)

        X_RGB[i] = imgRGB
        X_Per[i] = imgPer
        Y[i] = test_outputs[i]

    print 'RGB : ', X_RGB.shape
    print 'Per : ', X_Per.shape
    print 'len(Y) : ', len(Y)

    return [X_RGB, X_Per, Y]
Esempio n. 4
0
 def predict(self, x):
     if self.data_format == "channels_first":
         x = x.transpose(0, 3, 1, 2)
     x = preprocess_resnet(x.astype(K.floatx()))
     return self.model.predict(x, batch_size=self.batch_size)
Esempio n. 5
0
def RGBTestData(config):

    print 'loading RGB data'
    path_mal_test, _, files_malignant_test = next(
        os.walk(os.path.join(config.test_dir, 'malignant', 'rgb')))
    path_ben_test, _, files_benign_test = next(
        os.walk(os.path.join(config.test_dir, 'benign', 'rgb')))

    mal_paths_test = glob.glob(os.path.join(path_mal_test, '*'))
    ben_paths_test = glob.glob(os.path.join(path_ben_test, '*'))

    mal_outputs_test = [config.label.malignant] * len(mal_paths_test)
    ben_outputs_test = [config.label.benign] * len(ben_paths_test)

    test_paths = mal_paths_test + ben_paths_test
    test_outputs = mal_outputs_test + ben_outputs_test

    z = zip(test_paths, test_outputs)
    random.shuffle(z)
    test_paths, test_outputs = zip(*z)

    ref_std_lab = (0.57506023, 0.10403329, 0.01364062)
    ref_mu_lab = (8.63234435, -0.11501964, 0.03868433)

    if os.path.isfile('configs/stats.pkl'):
        with open('configs/stats.pkl', 'rb') as f:
            stats = pickle.load(f)
        print '###################  Stats loaded Test ####################'
        config['stats'] = stats
    else:
        print 'No stats file found (To obtain Mu and Sigma from original whole image).'

    len_test = len(test_outputs)

    X = np.zeros((len_test, 256, 256, 3))
    Y = [-1] * len_test

    for i in range(len_test):

        img = skimage.io.imread(test_paths[i])
        if img.shape == (1024, 1024, 3):
            img = img[::4, ::4, :]

        image_id = int(float(re.findall("\d+\.\d+", test_paths[i])[0]))

        if image_id in stats.keys():
            [src_mu, src_sigma] = stats[image_id]
            img_nmzd = htk_cnorm.reinhard(img,
                                          ref_mu_lab,
                                          ref_std_lab,
                                          src_mu=src_mu,
                                          src_sigma=src_sigma).astype('float')
        else:
            print '#### stats for %d not present' % (image_id)
            img_nmzd = htk_cnorm.reinhard(img, ref_mu_lab,
                                          ref_std_lab).astype('float')

        img = preprocess_resnet(img_nmzd)

        X[i] = img
        Y[i] = test_outputs[i]

    return (X, Y)