def evaluate(self, pred_dir, gt_dir):
        img_cnt = 0
        for filename in os.listdir(pred_dir):
            print(filename)

            pred_path = os.path.join(pred_dir, filename)
            gt_path = os.path.join(gt_dir, filename)
            predmap = ImageHelper.img2np(
                ImageHelper.read_image(pred_path, tool='pil', mode='P'))
            gtmap = ImageHelper.img2np(
                ImageHelper.read_image(gt_path, tool='pil', mode='P'))

            if "pascal_context" in gt_dir or "ade" in gt_dir or "coco_stuff" in gt_dir:
                predmap = self.relabel(predmap)
                gtmap = self.relabel(gtmap)

            if "coco_stuff" in gt_dir:
                gtmap[gtmap == 0] = 255

            self.seg_running_score.update(predmap[np.newaxis, :, :],
                                          gtmap[np.newaxis, :, :])
            img_cnt += 1

        Log.info('Evaluate {} images'.format(img_cnt))
        Log.info('mIOU: {}'.format(self.seg_running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}'.format(
            self.seg_running_score.get_pixel_acc()))
Exemplo n.º 2
0
    def __getitem__(self, index):
        img = None
        valid = True
        while img is None:
            try:
                img = ImageHelper.read_image(self.item_list[index][0],
                                             tool=self.configer.get('data', 'image_tool'),
                                             mode=self.configer.get('data', 'input_mode'))
                assert isinstance(img, np.ndarray) or isinstance(img, Image.Image)
            except:
                Log.warn('Invalid image path: {}'.format(self.item_list[index][0]))
                img = None
                valid = False
                index = (index + 1) % len(self.item_list)

        ori_img_size = ImageHelper.get_size(img)
        if self.aug_transform is not None:
            img = self.aug_transform(img)

        border_hw = ImageHelper.get_size(img)[::-1]
        if self.img_transform is not None:
            img = self.img_transform(img)

        meta = dict(
            valid=valid,
            ori_img_size=ori_img_size,
            border_hw=border_hw,
            img_path=self.item_list[index][0],
            filename=self.item_list[index][1],
            label=self.item_list[index][2]
        )
        return dict(
            img=DataContainer(img, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True)
        )
Exemplo n.º 3
0
    def _get_batch_per_gpu(self, cur_index):
        img = ImageHelper.read_image(
            self.img_list[cur_index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        labelmap = ImageHelper.read_image(self.label_list[cur_index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        img_size = self.size_list[cur_index]
        img_out = [img]
        label_out = [labelmap]
        for i in range(self.configer.get('train', 'batch_per_gpu') - 1):
            while True:
                cur_index = (cur_index + random.randint(
                    1,
                    len(self.img_list) - 1)) % len(self.img_list)
                now_img_size = self.size_list[cur_index]
                now_mark = 0 if now_img_size[0] > now_img_size[1] else 1
                mark = 0 if img_size[0] > img_size[1] else 1
                if now_mark == mark:
                    img = ImageHelper.read_image(
                        self.img_list[cur_index],
                        tool=self.configer.get('data', 'image_tool'),
                        mode=self.configer.get('data', 'input_mode'))
                    img_out.append(img)
                    labelmap = ImageHelper.read_image(
                        self.label_list[cur_index],
                        tool=self.configer.get('data', 'image_tool'),
                        mode='P')
                    label_out.append(labelmap)
                    break

        return img_out, label_out
Exemplo n.º 4
0
 def __init__(self, test_dir=None, aug_transform=None, img_transform=None, configer=None):
     super(TestDefaultDataset, self).__init__()
     self.configer = configer
     self.aug_transform=aug_transform
     self.img_transform = img_transform
     self.item_list = [(os.path.abspath(os.path.join(test_dir, filename)), filename)
                       for filename in FileHelper.list_dir(test_dir) if ImageHelper.is_img(filename)]
Exemplo n.º 5
0
    def __getitem__(self, index):
        img = None
        valid = True
        while img is None:
            try:
                img = ImageHelper.read_image(
                    self.img_list[index],
                    tool=self.configer.get('data', 'image_tool'),
                    mode=self.configer.get('data', 'input_mode'))
                assert isinstance(img, np.ndarray) or isinstance(
                    img, Image.Image)
            except:
                Log.warn('Invalid image path: {}'.format(self.img_list[index]))
                img = None
                valid = False
                index = (index + 1) % len(self.img_list)

        label = torch.from_numpy(np.array(self.label_list[index]))
        if self.aug_transform is not None:
            img = self.aug_transform(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(valid=valid,
                    img=DataContainer(img, stack=True),
                    label=DataContainer(label, stack=True))
Exemplo n.º 6
0
    def mscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        crop_size = self.configer.get('test', 'crop_size')
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

            if self.configer.get('data', 'image_tool') == 'cv2':
                mirror_image = cv2.flip(ori_image, 1)
            else:
                mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image, border_hw = self._get_blob(mirror_image, scale=1.0)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = results[:border_hw[0], :border_hw[1]]
            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits
Exemplo n.º 7
0
    def __read_file(self, data_dir, dataset):
        img_list = list()
        mlabel_list = list()
        img_dict = dict()
        all_img_list = []
        with open(self.configer.get('data.{}_label_path'.format(dataset)),
                  'r') as file_stream:
            all_img_list += file_stream.readlines()

        if dataset == 'train' and self.configer.get('data.include_val',
                                                    default=False):
            with open(self.configer.get('data.val_label_path'),
                      'r') as file_stream:
                all_img_list += file_stream.readlines()

        for line_cnt in range(len(all_img_list)):
            line_items = all_img_list[line_cnt].strip().split()
            if len(line_items) == 0:
                continue

            path = line_items[0]
            if not os.path.exists(os.path.join(
                    data_dir, path)) or not ImageHelper.is_img(path):
                Log.warn('Invalid Image Path: {}'.format(
                    os.path.join(data_dir, path)))
                continue

            img_list.append(os.path.join(data_dir, path))
            mlabel_list.append([int(item) for item in line_items[1:]])

        assert len(img_list) > 0
        Log.info('Length of {} imgs is {}...'.format(dataset, len(img_list)))
        return img_list, mlabel_list
Exemplo n.º 8
0
    def __list_dirs(self, root_dir, dataset):
        img_list = list()
        label_list = list()
        size_list = list()
        image_dir = os.path.join(root_dir, dataset, 'image')
        label_dir = os.path.join(root_dir, dataset, 'label')
        img_extension = os.listdir(image_dir)[0].split('.')[-1]

        for file_name in os.listdir(label_dir):
            image_name = '.'.join(file_name.split('.')[:-1])
            img_path = os.path.join(image_dir,
                                    '{}.{}'.format(image_name, img_extension))
            label_path = os.path.join(label_dir, file_name)
            if not os.path.exists(label_path) or not os.path.exists(img_path):
                Log.error('Label Path: {} not exists.'.format(label_path))
                continue

            img_list.append(img_path)
            label_list.append(label_path)
            img = ImageHelper.read_image(
                img_path,
                tool=self.configer.get('data', 'image_tool'),
                mode=self.configer.get('data', 'input_mode'))
            size_list.append(ImageHelper.get_size(img))

        if dataset == 'train' and self.configer.get('data', 'include_val'):
            image_dir = os.path.join(root_dir, 'val/image')
            label_dir = os.path.join(root_dir, 'val/label')
            for file_name in os.listdir(label_dir):
                image_name = '.'.join(file_name.split('.')[:-1])
                img_path = os.path.join(
                    image_dir, '{}.{}'.format(image_name, img_extension))
                label_path = os.path.join(label_dir, file_name)
                if not os.path.exists(label_path) or not os.path.exists(
                        img_path):
                    Log.error('Label Path: {} not exists.'.format(label_path))
                    continue

                img_list.append(img_path)
                label_list.append(label_path)
                img = ImageHelper.read_image(
                    img_path,
                    tool=self.configer.get('data', 'image_tool'),
                    mode=self.configer.get('data', 'input_mode'))
                size_list.append(ImageHelper.get_size(img))

        return img_list, label_list, size_list
Exemplo n.º 9
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        if self.configer.exists('data', 'label_list'):
            labelmap = self._encode_label(labelmap)

        if self.configer.exists('data', 'reduce_zero_label'):
            labelmap = self._reduce_zero_label(labelmap)

        ori_target = ImageHelper.tonp(labelmap)
        ori_target[ori_target == 255] = -1

        if self.torch_img_transform is not None:
            img = Image.fromarray(img)
            img = self.torch_img_transform(img)
            img = np.array(img).astype(np.uint8)

        if self.aug_transform is not None:
            img, labelmap = self.aug_transform(img, labelmap=labelmap)

        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)

        meta = dict(ori_img_size=img_size,
                    border_size=border_size,
                    ori_target=ori_target)
        return dict(
            img=DataContainer(img, stack=self.is_stack),
            labelmap=DataContainer(labelmap, stack=self.is_stack),
            meta=DataContainer(meta, stack=False, cpu_only=True),
            name=DataContainer(self.name_list[index],
                               stack=False,
                               cpu_only=True),
        )
Exemplo n.º 10
0
 def ss_test(self, ori_image):
     ori_width, ori_height = ImageHelper.get_size(ori_image)
     total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
     image, border_hw = self._get_blob(ori_image, scale=1.0)
     results = self._predict(image)
     results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                          (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
     total_logits += results
     return total_logits
Exemplo n.º 11
0
 def __getitem__(self, index):
     img = ImageHelper.read_image(
         self.img_list[index],
         tool=self.configer.get('data', 'image_tool'),
         mode=self.configer.get('data', 'input_mode'))
     img_size = ImageHelper.get_size(img)
     if self.img_transform is not None:
         img = self.img_transform(img)
     meta = dict(
         ori_img_size=img_size,
         border_size=img_size,
     )
     return dict(
         img=DataContainer(img, stack=self.is_stack),
         meta=DataContainer(meta, stack=False, cpu_only=True),
         name=DataContainer(self.name_list[index],
                            stack=False,
                            cpu_only=True),
     )
Exemplo n.º 12
0
    def _reduce_zero_label(self, labelmap):
        if not self.configer.get('data', 'reduce_zero_label'):
            return labelmap

        labelmap = np.array(labelmap)
        encoded_labelmap = labelmap - 1
        if self.configer.get('data', 'image_tool') == 'pil':
            encoded_labelmap = ImageHelper.np2img(
                encoded_labelmap.astype(np.uint8))

        return encoded_labelmap
    def _mp_target(self, inp):
        filename, pred_dir, gt_dir = inp
        print(filename)

        pred_path = os.path.join(pred_dir, filename)
        gt_path = os.path.join(gt_dir, filename)
        try:
            predmap = self._encode_label(
                ImageHelper.img2np(
                    ImageHelper.read_image(pred_path, tool='pil', mode='P')))
            gtmap = self._encode_label(
                ImageHelper.img2np(
                    ImageHelper.read_image(gt_path, tool='pil', mode='P')))
        except Exception as e:
            print(e)
            return 0.

        if "pascal_context" in gt_dir or "ADE" in gt_dir:
            predmap = self.relabel(predmap)
            gtmap = self.relabel(gtmap)

        return self.seg_running_score.hist(predmap[np.newaxis, :, :],
                                           gtmap[np.newaxis, :, :])
Exemplo n.º 14
0
    def _encode_label(self, labelmap):
        labelmap = np.array(labelmap)

        shape = labelmap.shape
        encoded_labelmap = np.ones(shape=(shape[0], shape[1]),
                                   dtype=np.float32) * 255
        for i in range(len(self.configer.get('data', 'label_list'))):
            class_id = self.configer.get('data', 'label_list')[i]
            encoded_labelmap[labelmap == class_id] = i

        if self.configer.get('data', 'image_tool') == 'pil':
            encoded_labelmap = ImageHelper.np2img(
                encoded_labelmap.astype(np.uint8))

        return encoded_labelmap
Exemplo n.º 15
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        edgemap = ImageHelper.read_image(self.edge_list[index],
                                         tool=self.configer.get(
                                             'data', 'image_tool'),
                                         mode='P')

        edgemap[edgemap == 255] = 1
        edgemap = cv2.resize(edgemap, (labelmap.shape[-1], labelmap.shape[-2]),
                             interpolation=cv2.INTER_NEAREST)

        if self.configer.exists('data', 'label_list'):
            labelmap = self._encode_label(labelmap)

        if self.configer.exists('data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label') == 'True':
            labelmap = self._reduce_zero_label(labelmap)

        ori_target = ImageHelper.tonp(labelmap)
        ori_target[ori_target == 255] = -1

        if self.aug_transform is not None:
            img, labelmap, edgemap = self.aug_transform(img,
                                                        labelmap=labelmap,
                                                        maskmap=edgemap)

        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)
            edgemap = self.label_transform(edgemap)

        meta = dict(ori_img_size=img_size,
                    border_size=border_size,
                    ori_target=ori_target)
        return dict(
            img=DataContainer(img, stack=True),
            labelmap=DataContainer(labelmap, stack=True),
            maskmap=DataContainer(edgemap, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True),
            name=DataContainer(self.name_list[index],
                               stack=False,
                               cpu_only=True),
        )
Exemplo n.º 16
0
    def __read_list(self, data_dir, list_path):
        item_list = []
        with open(list_path, 'r') as fr:
            for line in fr.readlines():
                filename = line.strip().split()[0]
                label = None if len(line.strip().split()) == 1 else line.strip().split()[1]
                img_path = os.path.join(data_dir, filename)
                if not os.path.exists(img_path) or not ImageHelper.is_img(img_path):
                    Log.error('Image Path: {} is Invalid.'.format(img_path))
                    exit(1)

                item_list.append((img_path, filename, label))

        Log.info('There are {} images..'.format(len(item_list)))
        return item_list
Exemplo n.º 17
0
    def load_boundary(self, fn):
        if fn.endswith('mat'):
            mat = io.loadmat(fn)
            if 'depth' in mat:
                dist_map, _ = self._load_maps(fn, None)
                boundary_map = DTOffsetHelper.distance_to_mask_label(
                    dist_map, np.zeros_like(dist_map)).astype(np.float32)
            else:
                boundary_map = mat['mat'].transpose(1, 2, 0)
        else:
            boundary_map = ImageHelper.read_image(fn,
                                                  tool=self.configer.get(
                                                      'data', 'image_tool'),
                                                  mode='P')
            boundary_map = boundary_map.astype(np.float32) / 255

        return boundary_map
Exemplo n.º 18
0
    def __read_file(self, root_dir, dataset, label_path):
        img_list = list()
        mlabel_list = list()
        
        with open(label_path, 'r') as file_stream:
            for line in file_stream.readlines():
                line_items = line.rstrip().split()
                path = line_items[0]
                if not os.path.exists(os.path.join(root_dir, path)) or not ImageHelper.is_img(path):
                    Log.warn('Invalid Image Path: {}'.format(os.path.join(root_dir, path)))
                    continue

                img_list.append(os.path.join(root_dir, path))
                mlabel_list.append([int(item) for item in line_items[1:]])

        assert len(img_list) > 0
        Log.info('Length of {} imgs is {}...'.format(dataset, len(img_list)))
        return img_list, mlabel_list
Exemplo n.º 19
0
    def __test_img(self, image_path, label_path, vis_path, raw_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))
        total_logits = None
        if self.configer.get('test', 'mode') == 'ss_test':
            total_logits = self.ss_test(ori_image)

        elif self.configer.get('test', 'mode') == 'sscrop_test':
            total_logits = self.sscrop_test(ori_image)

        elif self.configer.get('test', 'mode') == 'ms_test':
            total_logits = self.ms_test(ori_image)

        elif self.configer.get('test', 'mode') == 'mscrop_test':
            total_logits = self.mscrop_test(ori_image)

        else:
            Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode')))
            exit(1)

        label_map = np.argmax(total_logits, axis=-1)
        label_img = np.array(label_map, dtype=np.uint8)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode'))
        image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr)
        ImageHelper.save(image_canvas, save_path=vis_path)
        ImageHelper.save(ori_image, save_path=raw_path)

        if self.configer.exists('data', 'label_list'):
            label_img = self.__relabel(label_img)

        if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
            label_img = label_img + 1
            label_img = label_img.astype(np.uint8)

        label_img = Image.fromarray(label_img, 'P')
        Log.info('Label Path: {}'.format(label_path))
        ImageHelper.save(label_img, label_path)
Exemplo n.º 20
0
    def __read_and_split_file(self, root_dir, dataset, label_path):
        img_list = list()
        mlabel_list = list()
        select_interval = int(1 / self.configer.get('data', 'val_ratio'))
        img_dict = dict()
        with open(label_path, 'r') as file_stream:
            for line in file_stream.readlines():
                label = line.strip().split()[1]
                if int(label) in img_dict:
                    img_dict[int(label)].append(line)
                else:
                    img_dict[int(label)] = [line]

        all_img_list = []
        for i in sorted(img_dict.keys()):
            all_img_list += img_dict[i]

        for line_cnt in range(len(all_img_list)):
            if line_cnt % select_interval == 0 and dataset == 'train' and not self.configer.get('data', 'include_val'):
                continue

            if line_cnt % select_interval != 0 and dataset == 'val':
                continue

            line_items = all_img_list[line_cnt].strip().split()
            path = line_items[0]
            if not os.path.exists(os.path.join(root_dir, path)) or not ImageHelper.is_img(path):
                Log.warn('Invalid Image Path: {}'.format(os.path.join(root_dir, path)))
                continue

            img_list.append(os.path.join(root_dir, path))
            mlabel_list.append([int(item) for item in line_items[1:]])

        assert len(img_list) > 0
        Log.info('Length of {} imgs is {} after split trainval...'.format(dataset, len(img_list)))
        return img_list, mlabel_list
Exemplo n.º 21
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        if self.configer.exists('data', 'label_list'):
            labelmap = self._encode_label(labelmap)
        distance_map, angle_map = self._load_maps(self.offset_list[index],
                                                  labelmap)

        if self.configer.exists('data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label') == True:
            labelmap = self._reduce_zero_label(labelmap)

        ori_target = ImageHelper.tonp(labelmap).astype(np.int)
        ori_target[ori_target == 255] = -1
        ori_distance_map = np.array(distance_map)
        ori_angle_map = np.array(angle_map)

        if self.aug_transform is not None:
            img, labelmap, distance_map, angle_map = self.aug_transform(
                img,
                labelmap=labelmap,
                distance_map=distance_map,
                angle_map=angle_map)

        old_img = img
        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)
            distance_map = torch.from_numpy(distance_map)
            angle_map = torch.from_numpy(angle_map)

        if set(self.configer.get('val_trans', 'trans_seq')) & set(
            ['random_crop', 'crop']):
            ori_target = labelmap.numpy()
            ori_distance_map = distance_map.numpy()
            ori_angle_map = angle_map.numpy()
            img_size = ori_target.shape[:2][::-1]

        meta = dict(ori_img_size=img_size,
                    border_size=border_size,
                    ori_target=ori_target,
                    ori_distance_map=ori_distance_map,
                    ori_angle_map=ori_angle_map,
                    basename=os.path.basename(self.label_list[index]))

        return dict(
            img=DataContainer(img, stack=self.is_stack),
            labelmap=DataContainer(labelmap, stack=self.is_stack),
            distance_map=DataContainer(distance_map, stack=self.is_stack),
            angle_map=DataContainer(angle_map, stack=self.is_stack),
            meta=DataContainer(meta, stack=False, cpu_only=True),
            name=DataContainer(self.name_list[index],
                               stack=False,
                               cpu_only=True),
        )
Exemplo n.º 22
0
    def test(self, img_path=None, output_dir=None, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        if img_path is not None:
            input_path = img_path
        else:
            input_path = self.configer.get('input_image')

        input_image = cv2.imread(input_path)

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')), ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        
        ori_img = input_image.copy()

        h, w, _ = input_image.shape
        ori_img_size = [w, h]

        # print(img.shape)
        input_image = aug_val_transform(input_image)
        input_image = input_image[0]
            
        h, w, _ = input_image.shape
        border_size = [w, h]

        input_image = transform(input_image)
        # print(img)
        # print(img.shape)

        # inputs = data_dict['img']
        # names = data_dict['name']
        # metas = data_dict['meta']
        
        # print(inputs)

        with torch.no_grad():
            # Forward pass.
            outputs = self.ss_test([input_image])

            if isinstance(outputs, torch.Tensor):
                outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                n = outputs.shape[0]
            else:
                outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs]
                n = len(outputs)

            logits = cv2.resize(outputs[0],
                                tuple(ori_img_size), interpolation=cv2.INTER_CUBIC)
            label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8)
            if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
                label_img = label_img + 1
                label_img = label_img.astype(np.uint8)
            if self.configer.exists('data', 'label_list'):
                label_img_ = self.__relabel(label_img)
            else:
                label_img_ = label_img
            label_img_ = Image.fromarray(label_img_, 'P')

            input_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
            if output_dir is None:
                label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name))
            else:
                label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name))
            FileHelper.make_dirs(label_path, is_file=True)
            # print(f"{label_path}")
            ImageHelper.save(label_img_, label_path)

        self.batch_time.update(time.time() - start_time)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
Exemplo n.º 23
0
    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        if self.configer.get('dataset') in ['cityscapes', 'gta5']:
            colors = get_cityscapes_colors()
        elif self.configer.get('dataset') == 'ade20k':
            colors = get_ade_colors()
        elif self.configer.get('dataset') == 'lip':
            colors = get_lip_colors()
        elif self.configer.get('dataset') == 'pascal_context':
            colors = get_pascal_context_colors()
        elif self.configer.get('dataset') == 'pascal_voc':
            colors = get_pascal_voc_colors()
        elif self.configer.get('dataset') == 'coco_stuff':
            colors = get_cocostuff_colors()
        else:
            raise RuntimeError("Unsupport colors")

        save_prob = False
        if self.configer.get('test', 'save_prob'):
            save_prob = self.configer.get('test', 'save_prob')

            def softmax(X, axis=0):
                max_prob = np.max(X, axis=axis, keepdims=True)
                X -= max_prob
                X = np.exp(X)
                sum_prob = np.sum(X, axis=axis, keepdims=True)
                X /= sum_prob
                return X

        for j, data_dict in enumerate(self.test_loader):
            inputs = data_dict['img']
            names = data_dict['name']
            metas = data_dict['meta']

            if 'val' in self.save_dir and os.environ.get('save_gt_label'):
                labels = data_dict['labelmap']

            with torch.no_grad():
                # Forward pass.
                if self.configer.exists('data',
                                        'use_offset') and self.configer.get(
                                            'data', 'use_offset') == 'offline':
                    offset_h_maps = data_dict['offsetmap_h']
                    offset_w_maps = data_dict['offsetmap_w']
                    outputs = self.offset_test(inputs, offset_h_maps,
                                               offset_w_maps)
                elif self.configer.get('test', 'mode') == 'ss_test':
                    outputs = self.ss_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test':
                    outputs = self.ms_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test_depth':
                    outputs = self.ms_test_depth(inputs, names)
                elif self.configer.get('test', 'mode') == 'sscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.sscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'mscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.mscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'crf_ss_test':
                    outputs = self.ss_test(inputs)
                    outputs = self.dense_crf_process(inputs, outputs)

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                for k in range(n):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    logits = cv2.resize(
                        outputs[k][:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_CUBIC)

                    # save the logits map
                    if self.configer.get('test', 'save_prob'):
                        prob_path = os.path.join(self.save_dir, "prob/",
                                                 '{}.npy'.format(names[k]))
                        FileHelper.make_dirs(prob_path, is_file=True)
                        np.save(prob_path, softmax(logits, axis=-1))

                    label_img = np.asarray(np.argmax(logits, axis=-1),
                                           dtype=np.uint8)
                    if self.configer.exists(
                            'data', 'reduce_zero_label') and self.configer.get(
                                'data', 'reduce_zero_label'):
                        label_img = label_img + 1
                        label_img = label_img.astype(np.uint8)
                    if self.configer.exists('data', 'label_list'):
                        label_img_ = self.__relabel(label_img)
                    else:
                        label_img_ = label_img
                    label_img_ = Image.fromarray(label_img_, 'P')
                    Log.info('{:4d}/{:4d} label map generated'.format(
                        image_id, self.test_size))
                    label_path = os.path.join(self.save_dir, "label/",
                                              '{}.png'.format(names[k]))
                    FileHelper.make_dirs(label_path, is_file=True)
                    ImageHelper.save(label_img_, label_path)

                    # colorize the label-map
                    if os.environ.get('save_gt_label'):
                        if self.configer.exists(
                                'data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label'):
                            label_img = labels[k] + 1
                            label_img = np.asarray(label_img, dtype=np.uint8)
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "gt_vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)
                    else:
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))
Exemplo n.º 24
0
    def make_input(self,
                   image=None,
                   input_size=None,
                   min_side_length=None,
                   max_side_length=None,
                   scale=None):
        if input_size is not None and min_side_length is None and max_side_length is None:
            if input_size[0] == -1 and input_size[1] == -1:
                in_width, in_height = ImageHelper.get_size(image)

            elif input_size[0] != -1 and input_size[1] != -1:
                in_width, in_height = input_size

            elif input_size[0] == -1 and input_size[1] != -1:
                width, height = ImageHelper.get_size(image)
                scale_ratio = input_size[1] / height
                w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
                in_width, in_height = int(round(width * w_scale_ratio)), int(
                    round(height * h_scale_ratio))

            else:
                assert input_size[0] != -1 and input_size[1] == -1
                width, height = ImageHelper.get_size(image)
                scale_ratio = input_size[0] / width
                w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
                in_width, in_height = int(round(width * w_scale_ratio)), int(
                    round(height * h_scale_ratio))

        elif input_size is None and min_side_length is not None and max_side_length is None:
            width, height = ImageHelper.get_size(image)
            scale_ratio = min_side_length / min(width, height)
            w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
            in_width, in_height = int(round(width * w_scale_ratio)), int(
                round(height * h_scale_ratio))

        elif input_size is None and min_side_length is None and max_side_length is not None:
            width, height = ImageHelper.get_size(image)
            scale_ratio = max_side_length / max(width, height)
            w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
            in_width, in_height = int(round(width * w_scale_ratio)), int(
                round(height * h_scale_ratio))

        elif input_size is None and min_side_length is not None and max_side_length is not None:
            width, height = ImageHelper.get_size(image)
            scale_ratio = min_side_length / min(width, height)
            bound_scale_ratio = max_side_length / max(width, height)
            scale_ratio = min(scale_ratio, bound_scale_ratio)
            w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
            in_width, in_height = int(round(width * w_scale_ratio)), int(
                round(height * h_scale_ratio))

        else:
            in_width, in_height = ImageHelper.get_size(image)

        image = ImageHelper.resize(
            image, (int(in_width * scale), int(in_height * scale)),
            interpolation='cubic')
        img_tensor = ToTensor()(image)
        img_tensor = Normalize(div_value=self.configer.get(
            'normalize', 'div_value'),
                               mean=self.configer.get('normalize', 'mean'),
                               std=self.configer.get('normalize',
                                                     'std'))(img_tensor)
        img_tensor = img_tensor.unsqueeze(0).to(
            torch.device(
                'cpu' if self.configer.get('gpu') is None else 'cuda'))

        return img_tensor
Exemplo n.º 25
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        offsetmap_h = self._load_mat(self.offset_h_list[index])
        offsetmap_w = self._load_mat(self.offset_w_list[index])

        if os.environ.get('train_no_offset') and self.dataset == 'train':
            offsetmap_h = np.zeros_like(offsetmap_h)
            offsetmap_w = np.zeros_like(offsetmap_w)

        if self.configer.exists('data', 'label_list'):
            labelmap = self._encode_label(labelmap)

        if self.configer.exists('data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label') == True:
            labelmap = self._reduce_zero_label(labelmap)

        # Log.info('use dataset {}'.format(self.configer.get('dataset')))
        ori_target = ImageHelper.tonp(labelmap).astype(np.int)
        ori_target[ori_target == 255] = -1
        ori_offset_h = np.array(offsetmap_h)
        ori_offset_w = np.array(offsetmap_w)

        if self.aug_transform is not None:
            img, labelmap, offsetmap_h, offsetmap_w = self.aug_transform(
                img,
                labelmap=labelmap,
                offset_h_map=offsetmap_h,
                offset_w_map=offsetmap_w)

        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)
            offsetmap_h = torch.from_numpy(np.array(offsetmap_h)).long()
            offsetmap_w = torch.from_numpy(np.array(offsetmap_w)).long()

        meta = dict(
            ori_img_size=img_size,
            border_size=border_size,
            ori_target=ori_target,
            ori_offset_h=ori_offset_h,
            ori_offset_w=ori_offset_w,
        )

        return dict(
            img=DataContainer(img, stack=self.is_stack),
            labelmap=DataContainer(labelmap, stack=self.is_stack),
            offsetmap_h=DataContainer(offsetmap_h, stack=self.is_stack),
            offsetmap_w=DataContainer(offsetmap_w, stack=self.is_stack),
            meta=DataContainer(meta, stack=False, cpu_only=True),
            name=DataContainer(self.name_list[index],
                               stack=False,
                               cpu_only=True),
        )