Example #1
0
    def visualize_pred(self, dataset_reader):
        """
        Predict segmentation of images random selected from dataset_reader.
        """
        valid_images, _ = dataset_reader.get_random_batch(
            self.flags.batch_size)

        feed_dict = {
            self.image: valid_images,
            self.keep_probability: 1.0,
            self.phase_train: False
        }
        reconst_image, pred = self.sess.run(
            [self.reconstruct_image, self.pred_annotation],
            feed_dict=feed_dict)
        pred = np.squeeze(pred, axis=3)
        pred = utils.batch_colorize_ndarray(pred, 0, self.flags.num_class,
                                            self.flags.cmap)[:, :, :, :3]

        for itr in range(self.flags.batch_size):
            print('itr', itr)
            utils.save_image(valid_images[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="inp_" + str(itr))
            utils.save_image(reconst_image[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="recon_" + str(itr))
            utils.save_image(pred[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="pred_" + str(itr))
            print("Saved image: %d" % itr)

        return valid_images, pred
Example #2
0
    def plot_segmentation_under_test_dir(self):

        image_pattern = os.path.join(self.flags.test_dir, '*')
        image_lst = glob(image_pattern)
        data = []
        if not image_lst:
            print('No files found')
        else:
            test_images = np.stack([
                misc.imresize(imageio.imread(file),
                              [self.flags.image_size, self.flags.image_size],
                              interp='bilinear') for file in image_lst
            ])
            test_preds = self.predict_segmentation(test_images)
            colorized_test_preds = utils.batch_colorize_ndarray(
                test_preds, 0, self.flags.num_class,
                self.flags.cmap)[:, :, :, :3]
            for i, (imag,
                    pred) in enumerate(zip(test_images, colorized_test_preds)):
                fig, axes = plt.subplots(1, 2)
                axes[0].imshow(imag)
                axes[1].imshow(pred)
                axes[0].axis('off')
                axes[1].axis('off')
                filename = os.path.join(self.flags.logs_dir,
                                        'Figure_%d.png' % i)
                plt.savefig(filename, dpi=300, format="png", transparent=False)
            # plt.show()
        return test_images, test_preds
Example #3
0
    def plot_segmentation_on_test(self, test_dataset_reader):

        if not test_dataset_reader:
            print('No test data found')
        else:
            test_images = test_dataset_reader.images
            test_gts = test_dataset_reader.annotations
            test_preds = self.predict_segmentation(test_images)
            colorized_test_preds = utils.batch_colorize_ndarray(test_preds,
                                    0, self.flags.num_class, self.flags.cmap)[:,:,:,:3]
            for i, (imag, pred, gt) in enumerate(zip(test_images, colorized_test_preds, test_gts)):
                fig, axes = plt.subplots(1,3)
                axes[0].imshow(imag)
                axes[1].imshow(pred)
                axes[2].imshow(gts)
                axes[0].axis('off')
                axes[1].axis('off')
                axes[2].axis('off')
                axes[0].set_title('IMG')
                axes[1].set_title('PRED')
                axes[2].set_title('GT')
                filename = os.path.join(self.flags.logs_dir, 'Figure_%d.png'%i)
                plt.savefig(filename, dpi=300, format="png", transparent=False)
            # plt.show()
        return test_images, test_preds
Example #4
0
    def visaulize_pred(self, dataset_reader):
        """
        Predict segmentation of images random selected from dataset_reader.
        """

        valid_images, valid_annotations = dataset_reader.get_random_batch(
            self.flags.batch_size)
        feed_dict = {
            self.image: valid_images,
            self.keep_probability: 1.0,
            self.phase_train: False
        }

        pred = self.sess.run([self.pred_annotation], feed_dict=feed_dict)
        pred = np.squeeze(pred, axis=3)
        pred = utils.batch_colorize_ndarray(pred, 0, self.flags.num_class,
                                            self.flags.cmap)[:, :, :, :3]
        valid_annotations = np.argmax(valid_annotations, axis=-1)
        valid_annotations = utils.batch_colorize_ndarray(
            valid_annotations, 0, self.flags.num_class,
            self.flags.cmap)[:, :, :, :3]

        for itr in range(self.flags.batch_size):
            utils.save_image(valid_images[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="inp_" + str(5 + itr),
                             mean=1)
            utils.save_image(valid_annotations[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="gt_" + str(5 + itr),
                             mean=1)
            utils.save_image(pred[itr].astype(np.uint8),
                             self.flags.logs_dir,
                             name="pred_" + str(5 + itr),
                             mean=1)
            print("Saved image: %d" % itr)

        return valid_images, pred