Beispiel #1
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        pred = self.sliding_eval(img, config.eval_crop_size,
                                 config.eval_stride_rate, device)
        # pred = self.whole_eval(img, None, None, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            trans_labels = [
                7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
                31, 32, 33
            ]
            label = np.zeros(pred.shape)
            ids = np.unique(pred)
            for id in ids:
                label[np.where(pred == id)] = trans_labels[id]
            fn = name.split('_')
            fn = fn[0] + '_' + fn[1] + '_' + fn[2] + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), label)
            logger.info('Save the image ' + fn)

        return results_dict
Beispiel #2
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        pred = self.sliding_eval(img,
                                 config.eval_crop_size,
                                 config.eval_stride_rate,
                                 device=device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Beispiel #3
0
    def func_per_iteration(self, data, device, iter=None):
        if self.config is not None: config = self.config
        img = data['data']
        label = data['label']
        name = data['fn']

        if len(config.eval_scale_array) == 1:
            pred = self.whole_eval(img, None, device)
        else:
            pred = self.sliding_eval(img, config.eval_crop_size,
                                     config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        # tensorboard logger does not fit multiprocess
        if self.logger is not None and iter is not None:
            colors = self.dataset.get_class_colors()
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            self.logger.add_image(
                'vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter)

        print("self.show_prediction = ", self.show_prediction)
        if self.show_image or self.show_prediction:
            colors = self.dataset.get_class_colors()
            image = img
            clean = np.zeros(label.shape)
            if self.show_image:
                comp_img = show_img(colors, config.background, image, clean,
                                    label, pred)
            else:
                comp_img = show_prediction(colors, config.background, image,
                                           pred)
            cv2.imwrite(
                os.path.join(os.path.realpath('.'), self.config.save, "eval",
                             name + ".vis.png"), comp_img[:, :, ::-1])
            # cv2.imwrite(name + ".png", comp_img[:,:,::-1])

        return results_dict
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        hha = data['hha_img']
        name = data['fn']
        pred = self.sliding_eval_rgbdepth(img, hha, config.eval_crop_size,
                                          config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            ensure_dir(self.save_path)
            ensure_dir(self.save_path + '_color')

            fn = name + '.png'

            'save colored result'
            result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
            class_colors = get_class_colors()
            palette_list = list(np.array(class_colors).flat)
            if len(palette_list) < 768:
                palette_list += [0] * (768 - len(palette_list))
            result_img.putpalette(palette_list)
            result_img.save(os.path.join(self.save_path + '_color', fn))

            'save raw result'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Beispiel #5
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        img = cv2.resize(img, (config.image_width, config.image_height),
                         interpolation=cv2.INTER_LINEAR)
        label = cv2.resize(label,
                           (config.image_width // config.gt_down_sampling,
                            config.image_height // config.gt_down_sampling),
                           interpolation=cv2.INTER_NEAREST)

        pred = self.whole_eval(img,
                               (config.image_height // config.gt_down_sampling,
                                config.image_width // config.gt_down_sampling),
                               config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict