def evaluate(self, pred_dir, gt_dir): img_cnt = 0 for filename in os.listdir(pred_dir): print(filename) pred_path = os.path.join(pred_dir, filename) gt_path = os.path.join(gt_dir, filename) predmap = ImageHelper.img2np( ImageHelper.read_image(pred_path, tool='pil', mode='P')) gtmap = ImageHelper.img2np( ImageHelper.read_image(gt_path, tool='pil', mode='P')) if "pascal_context" in gt_dir or "ade" in gt_dir or "coco_stuff" in gt_dir: predmap = self.relabel(predmap) gtmap = self.relabel(gtmap) if "coco_stuff" in gt_dir: gtmap[gtmap == 0] = 255 self.seg_running_score.update(predmap[np.newaxis, :, :], gtmap[np.newaxis, :, :]) img_cnt += 1 Log.info('Evaluate {} images'.format(img_cnt)) Log.info('mIOU: {}'.format(self.seg_running_score.get_mean_iou())) Log.info('Pixel ACC: {}'.format( self.seg_running_score.get_pixel_acc()))
def __getitem__(self, index): img = None valid = True while img is None: try: img = ImageHelper.read_image(self.item_list[index][0], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) assert isinstance(img, np.ndarray) or isinstance(img, Image.Image) except: Log.warn('Invalid image path: {}'.format(self.item_list[index][0])) img = None valid = False index = (index + 1) % len(self.item_list) ori_img_size = ImageHelper.get_size(img) if self.aug_transform is not None: img = self.aug_transform(img) border_hw = ImageHelper.get_size(img)[::-1] if self.img_transform is not None: img = self.img_transform(img) meta = dict( valid=valid, ori_img_size=ori_img_size, border_hw=border_hw, img_path=self.item_list[index][0], filename=self.item_list[index][1], label=self.item_list[index][2] ) return dict( img=DataContainer(img, stack=True), meta=DataContainer(meta, stack=False, cpu_only=True) )
def _get_batch_per_gpu(self, cur_index): img = ImageHelper.read_image( self.img_list[cur_index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) labelmap = ImageHelper.read_image(self.label_list[cur_index], tool=self.configer.get( 'data', 'image_tool'), mode='P') img_size = self.size_list[cur_index] img_out = [img] label_out = [labelmap] for i in range(self.configer.get('train', 'batch_per_gpu') - 1): while True: cur_index = (cur_index + random.randint( 1, len(self.img_list) - 1)) % len(self.img_list) now_img_size = self.size_list[cur_index] now_mark = 0 if now_img_size[0] > now_img_size[1] else 1 mark = 0 if img_size[0] > img_size[1] else 1 if now_mark == mark: img = ImageHelper.read_image( self.img_list[cur_index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_out.append(img) labelmap = ImageHelper.read_image( self.label_list[cur_index], tool=self.configer.get('data', 'image_tool'), mode='P') label_out.append(labelmap) break return img_out, label_out
def __init__(self, test_dir=None, aug_transform=None, img_transform=None, configer=None): super(TestDefaultDataset, self).__init__() self.configer = configer self.aug_transform=aug_transform self.img_transform = img_transform self.item_list = [(os.path.abspath(os.path.join(test_dir, filename)), filename) for filename in FileHelper.list_dir(test_dir) if ImageHelper.is_img(filename)]
def __getitem__(self, index): img = None valid = True while img is None: try: img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) assert isinstance(img, np.ndarray) or isinstance( img, Image.Image) except: Log.warn('Invalid image path: {}'.format(self.img_list[index])) img = None valid = False index = (index + 1) % len(self.img_list) label = torch.from_numpy(np.array(self.label_list[index])) if self.aug_transform is not None: img = self.aug_transform(img) if self.img_transform is not None: img = self.img_transform(img) return dict(valid=valid, img=DataContainer(img, stack=True), label=DataContainer(label, stack=True))
def mscrop_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) crop_size = self.configer.get('test', 'crop_size') total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) for scale in self.configer.get('test', 'scale_search'): image, border_hw = self._get_blob(ori_image, scale=scale) if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]: results = self._crop_predict(image, crop_size) else: results = self._predict(image) results = cv2.resize(results[:border_hw[0], :border_hw[1]], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results if self.configer.get('data', 'image_tool') == 'cv2': mirror_image = cv2.flip(ori_image, 1) else: mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT) image, border_hw = self._get_blob(mirror_image, scale=1.0) if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]: results = self._crop_predict(image, crop_size) else: results = self._predict(image) results = results[:border_hw[0], :border_hw[1]] results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits
def __read_file(self, data_dir, dataset): img_list = list() mlabel_list = list() img_dict = dict() all_img_list = [] with open(self.configer.get('data.{}_label_path'.format(dataset)), 'r') as file_stream: all_img_list += file_stream.readlines() if dataset == 'train' and self.configer.get('data.include_val', default=False): with open(self.configer.get('data.val_label_path'), 'r') as file_stream: all_img_list += file_stream.readlines() for line_cnt in range(len(all_img_list)): line_items = all_img_list[line_cnt].strip().split() if len(line_items) == 0: continue path = line_items[0] if not os.path.exists(os.path.join( data_dir, path)) or not ImageHelper.is_img(path): Log.warn('Invalid Image Path: {}'.format( os.path.join(data_dir, path))) continue img_list.append(os.path.join(data_dir, path)) mlabel_list.append([int(item) for item in line_items[1:]]) assert len(img_list) > 0 Log.info('Length of {} imgs is {}...'.format(dataset, len(img_list))) return img_list, mlabel_list
def __list_dirs(self, root_dir, dataset): img_list = list() label_list = list() size_list = list() image_dir = os.path.join(root_dir, dataset, 'image') label_dir = os.path.join(root_dir, dataset, 'label') img_extension = os.listdir(image_dir)[0].split('.')[-1] for file_name in os.listdir(label_dir): image_name = '.'.join(file_name.split('.')[:-1]) img_path = os.path.join(image_dir, '{}.{}'.format(image_name, img_extension)) label_path = os.path.join(label_dir, file_name) if not os.path.exists(label_path) or not os.path.exists(img_path): Log.error('Label Path: {} not exists.'.format(label_path)) continue img_list.append(img_path) label_list.append(label_path) img = ImageHelper.read_image( img_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) size_list.append(ImageHelper.get_size(img)) if dataset == 'train' and self.configer.get('data', 'include_val'): image_dir = os.path.join(root_dir, 'val/image') label_dir = os.path.join(root_dir, 'val/label') for file_name in os.listdir(label_dir): image_name = '.'.join(file_name.split('.')[:-1]) img_path = os.path.join( image_dir, '{}.{}'.format(image_name, img_extension)) label_path = os.path.join(label_dir, file_name) if not os.path.exists(label_path) or not os.path.exists( img_path): Log.error('Label Path: {} not exists.'.format(label_path)) continue img_list.append(img_path) label_list.append(label_path) img = ImageHelper.read_image( img_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) size_list.append(ImageHelper.get_size(img)) return img_list, label_list, size_list
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_size = ImageHelper.get_size(img) labelmap = ImageHelper.read_image(self.label_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') if self.configer.exists('data', 'label_list'): labelmap = self._encode_label(labelmap) if self.configer.exists('data', 'reduce_zero_label'): labelmap = self._reduce_zero_label(labelmap) ori_target = ImageHelper.tonp(labelmap) ori_target[ori_target == 255] = -1 if self.torch_img_transform is not None: img = Image.fromarray(img) img = self.torch_img_transform(img) img = np.array(img).astype(np.uint8) if self.aug_transform is not None: img, labelmap = self.aug_transform(img, labelmap=labelmap) border_size = ImageHelper.get_size(img) if self.img_transform is not None: img = self.img_transform(img) if self.label_transform is not None: labelmap = self.label_transform(labelmap) meta = dict(ori_img_size=img_size, border_size=border_size, ori_target=ori_target) return dict( img=DataContainer(img, stack=self.is_stack), labelmap=DataContainer(labelmap, stack=self.is_stack), meta=DataContainer(meta, stack=False, cpu_only=True), name=DataContainer(self.name_list[index], stack=False, cpu_only=True), )
def ss_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) image, border_hw = self._get_blob(ori_image, scale=1.0) results = self._predict(image) results = cv2.resize(results[:border_hw[0], :border_hw[1]], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_size = ImageHelper.get_size(img) if self.img_transform is not None: img = self.img_transform(img) meta = dict( ori_img_size=img_size, border_size=img_size, ) return dict( img=DataContainer(img, stack=self.is_stack), meta=DataContainer(meta, stack=False, cpu_only=True), name=DataContainer(self.name_list[index], stack=False, cpu_only=True), )
def _reduce_zero_label(self, labelmap): if not self.configer.get('data', 'reduce_zero_label'): return labelmap labelmap = np.array(labelmap) encoded_labelmap = labelmap - 1 if self.configer.get('data', 'image_tool') == 'pil': encoded_labelmap = ImageHelper.np2img( encoded_labelmap.astype(np.uint8)) return encoded_labelmap
def _mp_target(self, inp): filename, pred_dir, gt_dir = inp print(filename) pred_path = os.path.join(pred_dir, filename) gt_path = os.path.join(gt_dir, filename) try: predmap = self._encode_label( ImageHelper.img2np( ImageHelper.read_image(pred_path, tool='pil', mode='P'))) gtmap = self._encode_label( ImageHelper.img2np( ImageHelper.read_image(gt_path, tool='pil', mode='P'))) except Exception as e: print(e) return 0. if "pascal_context" in gt_dir or "ADE" in gt_dir: predmap = self.relabel(predmap) gtmap = self.relabel(gtmap) return self.seg_running_score.hist(predmap[np.newaxis, :, :], gtmap[np.newaxis, :, :])
def _encode_label(self, labelmap): labelmap = np.array(labelmap) shape = labelmap.shape encoded_labelmap = np.ones(shape=(shape[0], shape[1]), dtype=np.float32) * 255 for i in range(len(self.configer.get('data', 'label_list'))): class_id = self.configer.get('data', 'label_list')[i] encoded_labelmap[labelmap == class_id] = i if self.configer.get('data', 'image_tool') == 'pil': encoded_labelmap = ImageHelper.np2img( encoded_labelmap.astype(np.uint8)) return encoded_labelmap
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_size = ImageHelper.get_size(img) labelmap = ImageHelper.read_image(self.label_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') edgemap = ImageHelper.read_image(self.edge_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') edgemap[edgemap == 255] = 1 edgemap = cv2.resize(edgemap, (labelmap.shape[-1], labelmap.shape[-2]), interpolation=cv2.INTER_NEAREST) if self.configer.exists('data', 'label_list'): labelmap = self._encode_label(labelmap) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label') == 'True': labelmap = self._reduce_zero_label(labelmap) ori_target = ImageHelper.tonp(labelmap) ori_target[ori_target == 255] = -1 if self.aug_transform is not None: img, labelmap, edgemap = self.aug_transform(img, labelmap=labelmap, maskmap=edgemap) border_size = ImageHelper.get_size(img) if self.img_transform is not None: img = self.img_transform(img) if self.label_transform is not None: labelmap = self.label_transform(labelmap) edgemap = self.label_transform(edgemap) meta = dict(ori_img_size=img_size, border_size=border_size, ori_target=ori_target) return dict( img=DataContainer(img, stack=True), labelmap=DataContainer(labelmap, stack=True), maskmap=DataContainer(edgemap, stack=True), meta=DataContainer(meta, stack=False, cpu_only=True), name=DataContainer(self.name_list[index], stack=False, cpu_only=True), )
def __read_list(self, data_dir, list_path): item_list = [] with open(list_path, 'r') as fr: for line in fr.readlines(): filename = line.strip().split()[0] label = None if len(line.strip().split()) == 1 else line.strip().split()[1] img_path = os.path.join(data_dir, filename) if not os.path.exists(img_path) or not ImageHelper.is_img(img_path): Log.error('Image Path: {} is Invalid.'.format(img_path)) exit(1) item_list.append((img_path, filename, label)) Log.info('There are {} images..'.format(len(item_list))) return item_list
def load_boundary(self, fn): if fn.endswith('mat'): mat = io.loadmat(fn) if 'depth' in mat: dist_map, _ = self._load_maps(fn, None) boundary_map = DTOffsetHelper.distance_to_mask_label( dist_map, np.zeros_like(dist_map)).astype(np.float32) else: boundary_map = mat['mat'].transpose(1, 2, 0) else: boundary_map = ImageHelper.read_image(fn, tool=self.configer.get( 'data', 'image_tool'), mode='P') boundary_map = boundary_map.astype(np.float32) / 255 return boundary_map
def __read_file(self, root_dir, dataset, label_path): img_list = list() mlabel_list = list() with open(label_path, 'r') as file_stream: for line in file_stream.readlines(): line_items = line.rstrip().split() path = line_items[0] if not os.path.exists(os.path.join(root_dir, path)) or not ImageHelper.is_img(path): Log.warn('Invalid Image Path: {}'.format(os.path.join(root_dir, path))) continue img_list.append(os.path.join(root_dir, path)) mlabel_list.append([int(item) for item in line_items[1:]]) assert len(img_list) > 0 Log.info('Length of {} imgs is {}...'.format(dataset, len(img_list))) return img_list, mlabel_list
def __test_img(self, image_path, label_path, vis_path, raw_path): Log.info('Image Path: {}'.format(image_path)) ori_image = ImageHelper.read_image(image_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) total_logits = None if self.configer.get('test', 'mode') == 'ss_test': total_logits = self.ss_test(ori_image) elif self.configer.get('test', 'mode') == 'sscrop_test': total_logits = self.sscrop_test(ori_image) elif self.configer.get('test', 'mode') == 'ms_test': total_logits = self.ms_test(ori_image) elif self.configer.get('test', 'mode') == 'mscrop_test': total_logits = self.mscrop_test(ori_image) else: Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode'))) exit(1) label_map = np.argmax(total_logits, axis=-1) label_img = np.array(label_map, dtype=np.uint8) ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode')) image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr) ImageHelper.save(image_canvas, save_path=vis_path) ImageHelper.save(ori_image, save_path=raw_path) if self.configer.exists('data', 'label_list'): label_img = self.__relabel(label_img) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) label_img = Image.fromarray(label_img, 'P') Log.info('Label Path: {}'.format(label_path)) ImageHelper.save(label_img, label_path)
def __read_and_split_file(self, root_dir, dataset, label_path): img_list = list() mlabel_list = list() select_interval = int(1 / self.configer.get('data', 'val_ratio')) img_dict = dict() with open(label_path, 'r') as file_stream: for line in file_stream.readlines(): label = line.strip().split()[1] if int(label) in img_dict: img_dict[int(label)].append(line) else: img_dict[int(label)] = [line] all_img_list = [] for i in sorted(img_dict.keys()): all_img_list += img_dict[i] for line_cnt in range(len(all_img_list)): if line_cnt % select_interval == 0 and dataset == 'train' and not self.configer.get('data', 'include_val'): continue if line_cnt % select_interval != 0 and dataset == 'val': continue line_items = all_img_list[line_cnt].strip().split() path = line_items[0] if not os.path.exists(os.path.join(root_dir, path)) or not ImageHelper.is_img(path): Log.warn('Invalid Image Path: {}'.format(os.path.join(root_dir, path))) continue img_list.append(os.path.join(root_dir, path)) mlabel_list.append([int(item) for item in line_items[1:]]) assert len(img_list) > 0 Log.info('Length of {} imgs is {} after split trainval...'.format(dataset, len(img_list))) return img_list, mlabel_list
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_size = ImageHelper.get_size(img) labelmap = ImageHelper.read_image(self.label_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') if self.configer.exists('data', 'label_list'): labelmap = self._encode_label(labelmap) distance_map, angle_map = self._load_maps(self.offset_list[index], labelmap) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label') == True: labelmap = self._reduce_zero_label(labelmap) ori_target = ImageHelper.tonp(labelmap).astype(np.int) ori_target[ori_target == 255] = -1 ori_distance_map = np.array(distance_map) ori_angle_map = np.array(angle_map) if self.aug_transform is not None: img, labelmap, distance_map, angle_map = self.aug_transform( img, labelmap=labelmap, distance_map=distance_map, angle_map=angle_map) old_img = img border_size = ImageHelper.get_size(img) if self.img_transform is not None: img = self.img_transform(img) if self.label_transform is not None: labelmap = self.label_transform(labelmap) distance_map = torch.from_numpy(distance_map) angle_map = torch.from_numpy(angle_map) if set(self.configer.get('val_trans', 'trans_seq')) & set( ['random_crop', 'crop']): ori_target = labelmap.numpy() ori_distance_map = distance_map.numpy() ori_angle_map = angle_map.numpy() img_size = ori_target.shape[:2][::-1] meta = dict(ori_img_size=img_size, border_size=border_size, ori_target=ori_target, ori_distance_map=ori_distance_map, ori_angle_map=ori_angle_map, basename=os.path.basename(self.label_list[index])) return dict( img=DataContainer(img, stack=self.is_stack), labelmap=DataContainer(labelmap, stack=self.is_stack), distance_map=DataContainer(distance_map, stack=self.is_stack), angle_map=DataContainer(angle_map, stack=self.is_stack), meta=DataContainer(meta, stack=False, cpu_only=True), name=DataContainer(self.name_list[index], stack=False, cpu_only=True), )
def test(self, img_path=None, output_dir=None, data_loader=None): """ Validation function during the train phase. """ print("test!!!") self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) colors = get_ade_colors() # Reader. if img_path is not None: input_path = img_path else: input_path = self.configer.get('input_image') input_image = cv2.imread(input_path) transform = trans.Compose([ trans.ToTensor(), trans.Normalize(div_value=self.configer.get('normalize', 'div_value'), mean=self.configer.get('normalize', 'mean'), std=self.configer.get('normalize', 'std')), ]) aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val') pre_vis_img = None pre_lines = None pre_target_img = None ori_img = input_image.copy() h, w, _ = input_image.shape ori_img_size = [w, h] # print(img.shape) input_image = aug_val_transform(input_image) input_image = input_image[0] h, w, _ = input_image.shape border_size = [w, h] input_image = transform(input_image) # print(img) # print(img.shape) # inputs = data_dict['img'] # names = data_dict['name'] # metas = data_dict['meta'] # print(inputs) with torch.no_grad(): # Forward pass. outputs = self.ss_test([input_image]) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs] n = len(outputs) logits = cv2.resize(outputs[0], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') input_name = '.'.join(os.path.basename(input_path).split('.')[:-1]) if output_dir is None: label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name)) else: label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name)) FileHelper.make_dirs(label_path, is_file=True) # print(f"{label_path}") ImageHelper.save(label_img_, label_path) self.batch_time.update(time.time() - start_time) # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
def test(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) if self.configer.get('dataset') in ['cityscapes', 'gta5']: colors = get_cityscapes_colors() elif self.configer.get('dataset') == 'ade20k': colors = get_ade_colors() elif self.configer.get('dataset') == 'lip': colors = get_lip_colors() elif self.configer.get('dataset') == 'pascal_context': colors = get_pascal_context_colors() elif self.configer.get('dataset') == 'pascal_voc': colors = get_pascal_voc_colors() elif self.configer.get('dataset') == 'coco_stuff': colors = get_cocostuff_colors() else: raise RuntimeError("Unsupport colors") save_prob = False if self.configer.get('test', 'save_prob'): save_prob = self.configer.get('test', 'save_prob') def softmax(X, axis=0): max_prob = np.max(X, axis=axis, keepdims=True) X -= max_prob X = np.exp(X) sum_prob = np.sum(X, axis=axis, keepdims=True) X /= sum_prob return X for j, data_dict in enumerate(self.test_loader): inputs = data_dict['img'] names = data_dict['name'] metas = data_dict['meta'] if 'val' in self.save_dir and os.environ.get('save_gt_label'): labels = data_dict['labelmap'] with torch.no_grad(): # Forward pass. if self.configer.exists('data', 'use_offset') and self.configer.get( 'data', 'use_offset') == 'offline': offset_h_maps = data_dict['offsetmap_h'] offset_w_maps = data_dict['offsetmap_w'] outputs = self.offset_test(inputs, offset_h_maps, offset_w_maps) elif self.configer.get('test', 'mode') == 'ss_test': outputs = self.ss_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test': outputs = self.ms_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test_depth': outputs = self.ms_test_depth(inputs, names) elif self.configer.get('test', 'mode') == 'sscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.sscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'mscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.mscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'crf_ss_test': outputs = self.ss_test(inputs) outputs = self.dense_crf_process(inputs, outputs) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [ output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs ] n = len(outputs) for k in range(n): image_id += 1 ori_img_size = metas[k]['ori_img_size'] border_size = metas[k]['border_size'] logits = cv2.resize( outputs[k][:border_size[1], :border_size[0]], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) # save the logits map if self.configer.get('test', 'save_prob'): prob_path = os.path.join(self.save_dir, "prob/", '{}.npy'.format(names[k])) FileHelper.make_dirs(prob_path, is_file=True) np.save(prob_path, softmax(logits, axis=-1)) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') Log.info('{:4d}/{:4d} label map generated'.format( image_id, self.test_size)) label_path = os.path.join(self.save_dir, "label/", '{}.png'.format(names[k])) FileHelper.make_dirs(label_path, is_file=True) ImageHelper.save(label_img_, label_path) # colorize the label-map if os.environ.get('save_gt_label'): if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = labels[k] + 1 label_img = np.asarray(label_img, dtype=np.uint8) color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "gt_vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) else: color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) self.batch_time.update(time.time() - start_time) start_time = time.time() # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time))
def make_input(self, image=None, input_size=None, min_side_length=None, max_side_length=None, scale=None): if input_size is not None and min_side_length is None and max_side_length is None: if input_size[0] == -1 and input_size[1] == -1: in_width, in_height = ImageHelper.get_size(image) elif input_size[0] != -1 and input_size[1] != -1: in_width, in_height = input_size elif input_size[0] == -1 and input_size[1] != -1: width, height = ImageHelper.get_size(image) scale_ratio = input_size[1] / height w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio in_width, in_height = int(round(width * w_scale_ratio)), int( round(height * h_scale_ratio)) else: assert input_size[0] != -1 and input_size[1] == -1 width, height = ImageHelper.get_size(image) scale_ratio = input_size[0] / width w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio in_width, in_height = int(round(width * w_scale_ratio)), int( round(height * h_scale_ratio)) elif input_size is None and min_side_length is not None and max_side_length is None: width, height = ImageHelper.get_size(image) scale_ratio = min_side_length / min(width, height) w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio in_width, in_height = int(round(width * w_scale_ratio)), int( round(height * h_scale_ratio)) elif input_size is None and min_side_length is None and max_side_length is not None: width, height = ImageHelper.get_size(image) scale_ratio = max_side_length / max(width, height) w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio in_width, in_height = int(round(width * w_scale_ratio)), int( round(height * h_scale_ratio)) elif input_size is None and min_side_length is not None and max_side_length is not None: width, height = ImageHelper.get_size(image) scale_ratio = min_side_length / min(width, height) bound_scale_ratio = max_side_length / max(width, height) scale_ratio = min(scale_ratio, bound_scale_ratio) w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio in_width, in_height = int(round(width * w_scale_ratio)), int( round(height * h_scale_ratio)) else: in_width, in_height = ImageHelper.get_size(image) image = ImageHelper.resize( image, (int(in_width * scale), int(in_height * scale)), interpolation='cubic') img_tensor = ToTensor()(image) img_tensor = Normalize(div_value=self.configer.get( 'normalize', 'div_value'), mean=self.configer.get('normalize', 'mean'), std=self.configer.get('normalize', 'std'))(img_tensor) img_tensor = img_tensor.unsqueeze(0).to( torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda')) return img_tensor
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) img_size = ImageHelper.get_size(img) labelmap = ImageHelper.read_image(self.label_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') offsetmap_h = self._load_mat(self.offset_h_list[index]) offsetmap_w = self._load_mat(self.offset_w_list[index]) if os.environ.get('train_no_offset') and self.dataset == 'train': offsetmap_h = np.zeros_like(offsetmap_h) offsetmap_w = np.zeros_like(offsetmap_w) if self.configer.exists('data', 'label_list'): labelmap = self._encode_label(labelmap) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label') == True: labelmap = self._reduce_zero_label(labelmap) # Log.info('use dataset {}'.format(self.configer.get('dataset'))) ori_target = ImageHelper.tonp(labelmap).astype(np.int) ori_target[ori_target == 255] = -1 ori_offset_h = np.array(offsetmap_h) ori_offset_w = np.array(offsetmap_w) if self.aug_transform is not None: img, labelmap, offsetmap_h, offsetmap_w = self.aug_transform( img, labelmap=labelmap, offset_h_map=offsetmap_h, offset_w_map=offsetmap_w) border_size = ImageHelper.get_size(img) if self.img_transform is not None: img = self.img_transform(img) if self.label_transform is not None: labelmap = self.label_transform(labelmap) offsetmap_h = torch.from_numpy(np.array(offsetmap_h)).long() offsetmap_w = torch.from_numpy(np.array(offsetmap_w)).long() meta = dict( ori_img_size=img_size, border_size=border_size, ori_target=ori_target, ori_offset_h=ori_offset_h, ori_offset_w=ori_offset_w, ) return dict( img=DataContainer(img, stack=self.is_stack), labelmap=DataContainer(labelmap, stack=self.is_stack), offsetmap_h=DataContainer(offsetmap_h, stack=self.is_stack), offsetmap_w=DataContainer(offsetmap_w, stack=self.is_stack), meta=DataContainer(meta, stack=False, cpu_only=True), name=DataContainer(self.name_list[index], stack=False, cpu_only=True), )