def __test_img(self, image_path, label_path, vis_path, raw_path): Log.info('Image Path: {}'.format(image_path)) ori_image = ImageHelper.read_image(image_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) total_logits = None if self.configer.get('test', 'mode') == 'ss_test': total_logits = self.ss_test(ori_image) elif self.configer.get('test', 'mode') == 'sscrop_test': total_logits = self.sscrop_test(ori_image) elif self.configer.get('test', 'mode') == 'ms_test': total_logits = self.ms_test(ori_image) elif self.configer.get('test', 'mode') == 'mscrop_test': total_logits = self.mscrop_test(ori_image) else: Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode'))) exit(1) label_map = np.argmax(total_logits, axis=-1) label_img = np.array(label_map, dtype=np.uint8) ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode')) image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr) ImageHelper.save(image_canvas, save_path=vis_path) ImageHelper.save(ori_image, save_path=raw_path) if self.configer.exists('data', 'label_list'): label_img = self.__relabel(label_img) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) label_img = Image.fromarray(label_img, 'P') Log.info('Label Path: {}'.format(label_path)) ImageHelper.save(label_img, label_path)
def test(self, img_path=None, output_dir=None, data_loader=None): """ Validation function during the train phase. """ print("test!!!") self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) colors = get_ade_colors() # Reader. if img_path is not None: input_path = img_path else: input_path = self.configer.get('input_image') input_image = cv2.imread(input_path) transform = trans.Compose([ trans.ToTensor(), trans.Normalize(div_value=self.configer.get('normalize', 'div_value'), mean=self.configer.get('normalize', 'mean'), std=self.configer.get('normalize', 'std')), ]) aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val') pre_vis_img = None pre_lines = None pre_target_img = None ori_img = input_image.copy() h, w, _ = input_image.shape ori_img_size = [w, h] # print(img.shape) input_image = aug_val_transform(input_image) input_image = input_image[0] h, w, _ = input_image.shape border_size = [w, h] input_image = transform(input_image) # print(img) # print(img.shape) # inputs = data_dict['img'] # names = data_dict['name'] # metas = data_dict['meta'] # print(inputs) with torch.no_grad(): # Forward pass. outputs = self.ss_test([input_image]) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs] n = len(outputs) logits = cv2.resize(outputs[0], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') input_name = '.'.join(os.path.basename(input_path).split('.')[:-1]) if output_dir is None: label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name)) else: label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name)) FileHelper.make_dirs(label_path, is_file=True) # print(f"{label_path}") ImageHelper.save(label_img_, label_path) self.batch_time.update(time.time() - start_time) # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
def test(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) if self.configer.get('dataset') in ['cityscapes', 'gta5']: colors = get_cityscapes_colors() elif self.configer.get('dataset') == 'ade20k': colors = get_ade_colors() elif self.configer.get('dataset') == 'lip': colors = get_lip_colors() elif self.configer.get('dataset') == 'pascal_context': colors = get_pascal_context_colors() elif self.configer.get('dataset') == 'pascal_voc': colors = get_pascal_voc_colors() elif self.configer.get('dataset') == 'coco_stuff': colors = get_cocostuff_colors() else: raise RuntimeError("Unsupport colors") save_prob = False if self.configer.get('test', 'save_prob'): save_prob = self.configer.get('test', 'save_prob') def softmax(X, axis=0): max_prob = np.max(X, axis=axis, keepdims=True) X -= max_prob X = np.exp(X) sum_prob = np.sum(X, axis=axis, keepdims=True) X /= sum_prob return X for j, data_dict in enumerate(self.test_loader): inputs = data_dict['img'] names = data_dict['name'] metas = data_dict['meta'] if 'val' in self.save_dir and os.environ.get('save_gt_label'): labels = data_dict['labelmap'] with torch.no_grad(): # Forward pass. if self.configer.exists('data', 'use_offset') and self.configer.get( 'data', 'use_offset') == 'offline': offset_h_maps = data_dict['offsetmap_h'] offset_w_maps = data_dict['offsetmap_w'] outputs = self.offset_test(inputs, offset_h_maps, offset_w_maps) elif self.configer.get('test', 'mode') == 'ss_test': outputs = self.ss_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test': outputs = self.ms_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test_depth': outputs = self.ms_test_depth(inputs, names) elif self.configer.get('test', 'mode') == 'sscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.sscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'mscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.mscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'crf_ss_test': outputs = self.ss_test(inputs) outputs = self.dense_crf_process(inputs, outputs) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [ output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs ] n = len(outputs) for k in range(n): image_id += 1 ori_img_size = metas[k]['ori_img_size'] border_size = metas[k]['border_size'] logits = cv2.resize( outputs[k][:border_size[1], :border_size[0]], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) # save the logits map if self.configer.get('test', 'save_prob'): prob_path = os.path.join(self.save_dir, "prob/", '{}.npy'.format(names[k])) FileHelper.make_dirs(prob_path, is_file=True) np.save(prob_path, softmax(logits, axis=-1)) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') Log.info('{:4d}/{:4d} label map generated'.format( image_id, self.test_size)) label_path = os.path.join(self.save_dir, "label/", '{}.png'.format(names[k])) FileHelper.make_dirs(label_path, is_file=True) ImageHelper.save(label_img_, label_path) # colorize the label-map if os.environ.get('save_gt_label'): if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = labels[k] + 1 label_img = np.asarray(label_img, dtype=np.uint8) color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "gt_vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) else: color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) self.batch_time.update(time.time() - start_time) start_time = time.time() # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time))