def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.seg_visualizer = SegVisualizer(configer) self.seg_parser = SegParser(configer) self.seg_model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.seg_net = None self._init_model()
class FCNSegmentorTest(object): def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.seg_visualizer = SegVisualizer(configer) self.seg_parser = SegParser(configer) self.seg_model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.test_loader = TestDataLoader(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.seg_net = None self._init_model() def _init_model(self): self.seg_net = self.seg_model_manager.semantic_segmentor() self.seg_net = RunnerHelper.load_net(self, self.seg_net) self.seg_net.eval() def test(self, test_dir, out_dir): for i, data_dict in enumerate( self.test_loader.get_testloader(test_dir=test_dir)): total_logits = None if self.configer.get('test', 'mode') == 'ss_test': total_logits = self.ss_test(data_dict) elif self.configer.get('test', 'mode') == 'sscrop_test': total_logits = self.sscrop_test(data_dict) elif self.configer.get('test', 'mode') == 'ms_test': total_logits = self.ms_test(data_dict) elif self.configer.get('test', 'mode') == 'mscrop_test': total_logits = self.mscrop_test(data_dict) else: Log.error('Invalid test mode:{}'.format( self.configer.get('test', 'mode'))) exit(1) meta_list = DCHelper.tolist(data_dict['meta']) img_list = DCHelper.tolist(data_dict['img']) for i in range(len(meta_list)): filename = meta_list[i]['img_path'].split('/')[-1].split( '.')[0] label_map = np.argmax(total_logits[i], axis=-1) label_img = np.array(label_map, dtype=np.uint8) ori_img_bgr = self.blob_helper.tensor2bgr(img_list[i][0]) ori_img_bgr = ImageHelper.resize( ori_img_bgr, target_size=meta_list[i]['ori_img_size'], interpolation='linear') image_canvas = self.seg_parser.colorize( label_img, image_canvas=ori_img_bgr) ImageHelper.save(image_canvas, save_path=os.path.join( out_dir, 'vis/{}.png'.format(filename))) if self.configer.exists('data', 'label_list'): label_img = self.__relabel(label_img) if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) label_img = Image.fromarray(label_img, 'P') label_path = os.path.join(out_dir, 'label/{}.png'.format(filename)) Log.info('Label Path: {}'.format(label_path)) ImageHelper.save(label_img, label_path) def ss_test(self, in_data_dict): data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0) results = self._predict(data_dict) return results def sscrop_test(self, in_data_dict): data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0) crop_size = self.configer.get('test', 'crop_size') if any(image.size()[3] < crop_size[0] or image.size()[2] < crop_size[1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, crop_size) return results def mscrop_test(self, in_data_dict): total_logits = [ np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0], self.configer.get('data', 'num_classes')), np.float32) for meta in DCHelper.tolist(in_data_dict['meta']) ] for scale in self.configer.get('test', 'scale_search'): data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale) crop_size = self.configer.get('test', 'crop_size') if any(image.size()[3] < crop_size[0] or image.size()[2] < crop_size[1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, crop_size) for i in range(len(total_logits)): total_logits[i] += results[i] for scale in self.configer.get('test', 'scale_search'): data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale, flip=True) crop_size = self.configer.get('test', 'crop_size') if any(image.size()[3] < crop_size[0] or image.size()[2] < crop_size[1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, crop_size) for i in range(len(total_logits)): total_logits[i] += results[i][:, ::-1] return total_logits def ms_test(self, in_data_dict): total_logits = [ np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0], self.configer.get('data', 'num_classes')), np.float32) for meta in DCHelper.tolist(in_data_dict['meta']) ] for scale in self.configer.get('test', 'scale_search'): data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale) results = self._predict(data_dict) for i in range(len(total_logits)): total_logits[i] += results[i] for scale in self.configer.get('test', 'scale_search'): data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale, flip=True) results = self._predict(data_dict) for i in range(len(total_logits)): total_logits[i] += results[i][:, ::-1] return total_logits def _crop_predict(self, data_dict, crop_size): split_batch = list() height_starts_list = list() width_starts_list = list() hw_list = list() for image in DCHelper.tolist(data_dict['img']): height, width = image.size()[2:] hw_list.append([height, width]) np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() height_starts = self._decide_intersection(height, crop_size[1]) width_starts = self._decide_intersection(width, crop_size[0]) split_crops = [] for height in height_starts: for width in width_starts: image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]] split_crops.append(image_crop[np.newaxis, :]) height_starts_list.append(height_starts) width_starts_list.append(width_starts) split_crops = np.concatenate( split_crops, axis=0) # (n, crop_image_size, crop_image_size, 3) inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device) split_batch.append(inputs) out_list = list() with torch.no_grad(): results = self.seg_net.forward( DCHelper.todc(split_batch, stack=True, samples_per_gpu=1)) for res in results: out_list.append(res[-1].permute(0, 2, 3, 1).cpu().numpy()) total_logits = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] count_predictions = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] for i in range(len(height_starts_list)): index = 0 for height in height_starts_list[i]: for width in width_starts_list[i]: total_logits[i][height:height + crop_size[1], width:width + crop_size[0]] += out_list[i][index] count_predictions[i][height:height + crop_size[1], width:width + crop_size[0]] += 1 index += 1 for i in range(len(total_logits)): total_logits[i] /= count_predictions[i] for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize( total_logits[i][:meta['border_hw'][0], :meta['border_hw'][1]], (meta['ori_img_size'][0], meta['ori_img_size'][1]), interpolation=cv2.INTER_CUBIC) return total_logits def _decide_intersection(self, total_length, crop_length): stride = int(crop_length * self.configer.get( 'test', 'crop_stride_ratio')) # set the stride as the paper do times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride * i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting def _predict(self, data_dict): with torch.no_grad(): total_logits = list() results = self.seg_net.forward(data_dict['img']) for res in results: total_logits.append(res[-1].squeeze(0).permute( 1, 2, 0).cpu().numpy()) for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize( total_logits[i] [:meta['border_hw'][0], :meta['border_hw'][1]], (meta['ori_img_size'][0], meta['ori_img_size'][1]), interpolation=cv2.INTER_CUBIC) return total_logits def __relabel(self, label_map): height, width = label_map.shape label_dst = np.zeros((height, width), dtype=np.uint8) for i in range(self.configer.get('data', 'num_classes')): label_dst[label_map == i] = self.configer.get( 'data', 'label_list')[i] label_dst = np.array(label_dst, dtype=np.uint8) return label_dst def debug(self, vis_dir): count = 0 for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()): inputs = data_dict['img'] targets = data_dict['labelmap'] for j in range(inputs.size(0)): count = count + 1 if count > 20: exit(1) image_bgr = self.blob_helper.tensor2bgr(inputs[j]) label_map = targets[j].numpy() image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr) cv2.imwrite( os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)), image_canvas) cv2.imshow('main', image_canvas) cv2.waitKey()
class FCNSegmentorTest(object): def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.seg_visualizer = SegVisualizer(configer) self.seg_parser = SegParser(configer) self.seg_model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.seg_net = None self._init_model() def _init_model(self): self.seg_net = self.seg_model_manager.semantic_segmentor() self.seg_net = RunnerHelper.load_net(self, self.seg_net) self.seg_net.eval() def _get_blob(self, ori_image, scale=None): assert scale is not None image = None if self.configer.exists('test', 'input_size'): image = self.blob_helper.make_input(image=ori_image, input_size=self.configer.get( 'test', 'input_size'), scale=scale) elif self.configer.exists( 'test', 'min_side_length') and not self.configer.exists( 'test', 'max_side_length'): image = self.blob_helper.make_input( image=ori_image, min_side_length=self.configer.get('test', 'min_side_length'), scale=scale) elif not self.configer.exists( 'test', 'min_side_length') and self.configer.exists( 'test', 'max_side_length'): image = self.blob_helper.make_input( image=ori_image, max_side_length=self.configer.get('test', 'max_side_length'), scale=scale) elif self.configer.exists('test', 'min_side_length') and self.configer.exists( 'test', 'max_side_length'): image = self.blob_helper.make_input( image=ori_image, min_side_length=self.configer.get('test', 'min_side_length'), max_side_length=self.configer.get('test', 'max_side_length'), scale=scale) else: Log.error('Test setting error') exit(1) b, c, h, w = image.size() border_hw = [h, w] if self.configer.exists('test', 'fit_stride'): stride = self.configer.get('test', 'fit_stride') pad_w = 0 if (w % stride == 0) else stride - (w % stride) # right pad_h = 0 if (h % stride == 0) else stride - (h % stride) # down expand_image = torch.zeros( (b, c, h + pad_h, w + pad_w)).to(image.device) expand_image[:, :, 0:h, 0:w] = image image = expand_image return image, border_hw def test_img(self, image_path, label_path, vis_path, raw_path): Log.info('Image Path: {}'.format(image_path)) ori_image = ImageHelper.read_image( image_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) total_logits = None if self.configer.get('test', 'mode') == 'ss_test': total_logits = self.ss_test(ori_image) elif self.configer.get('test', 'mode') == 'sscrop_test': total_logits = self.sscrop_test(ori_image) elif self.configer.get('test', 'mode') == 'ms_test': total_logits = self.ms_test(ori_image) elif self.configer.get('test', 'mode') == 'mscrop_test': total_logits = self.mscrop_test(ori_image) else: Log.error('Invalid test mode:{}'.format( self.configer.get('test', 'mode'))) exit(1) label_map = np.argmax(total_logits, axis=-1) label_img = np.array(label_map, dtype=np.uint8) ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get( 'data', 'input_mode')) image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr) ImageHelper.save(image_canvas, save_path=vis_path) ImageHelper.save(ori_image, save_path=raw_path) if self.configer.exists('data', 'label_list'): label_img = self.__relabel(label_img) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) label_img = Image.fromarray(label_img, 'P') Log.info('Label Path: {}'.format(label_path)) ImageHelper.save(label_img, label_path) def ss_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) total_logits = np.zeros( (ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) image, border_hw = self._get_blob(ori_image, scale=1.0) results = self._predict(image) results = cv2.resize(results[:border_hw[0], :border_hw[1]], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits def sscrop_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) total_logits = np.zeros( (ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) image, _ = self._get_blob(ori_image, scale=1.0) crop_size = self.configer.get('test', 'crop_size') if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]: results = self._crop_predict(image, crop_size) else: results = self._predict(image) results = cv2.resize(results, (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits def mscrop_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) total_logits = np.zeros( (ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) for scale in self.configer.get('test', 'scale_search'): image, _ = self._get_blob(ori_image, scale=scale) crop_size = self.configer.get('test', 'crop_size') if image.size()[3] > crop_size[0] and image.size( )[2] > crop_size[1]: results = self._crop_predict(image, crop_size) else: results = self._predict(image) results = cv2.resize(results, (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits def ms_test(self, ori_image): ori_width, ori_height = ImageHelper.get_size(ori_image) total_logits = np.zeros( (ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32) for scale in self.configer.get('test', 'scale_search'): image, border_hw = self._get_blob(ori_image, scale=scale) results = self._predict(image) results = cv2.resize(results[:border_hw[0], :border_hw[1]], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results if self.configer.get('data', 'image_tool') == 'cv2': mirror_image = cv2.flip(ori_image, 1) else: mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT) image, border_hw = self._get_blob(mirror_image, scale=1.0) results = self._predict(image) results = results[:border_hw[0], :border_hw[1]] results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC) total_logits += results return total_logits def _crop_predict(self, image, crop_size): height, width = image.size()[2:] np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() height_starts = self._decide_intersection(height, crop_size[1]) width_starts = self._decide_intersection(width, crop_size[0]) split_crops = [] for height in height_starts: for width in width_starts: image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]] split_crops.append(image_crop[np.newaxis, :]) split_crops = np.concatenate( split_crops, axis=0) # (n, crop_image_size, crop_image_size, 3) inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device) with torch.no_grad(): results = self.seg_net.forward(inputs) results = results[-1].permute(0, 2, 3, 1).cpu().numpy() reassemble = np.zeros( (np_image.shape[0], np_image.shape[1], results.shape[-1]), np.float32) index = 0 for height in height_starts: for width in width_starts: reassemble[height:height + crop_size[1], width:width + crop_size[0]] += results[index] index += 1 return reassemble def _decide_intersection(self, total_length, crop_length): stride = int(crop_length * self.configer.get( 'test', 'crop_stride_ratio')) # set the stride as the paper do times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride * i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting def _predict(self, inputs): with torch.no_grad(): results = self.seg_net.forward(inputs) results = results[-1].squeeze(0).permute(1, 2, 0).cpu().numpy() return results def __relabel(self, label_map): height, width = label_map.shape label_dst = np.zeros((height, width), dtype=np.uint8) for i in range(self.configer.get('data', 'num_classes')): label_dst[label_map == i] = self.configer.get( 'data', 'label_list')[i] label_dst = np.array(label_dst, dtype=np.uint8) return label_dst def debug(self, vis_dir): count = 0 for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()): inputs = data_dict['img'] targets = data_dict['labelmap'] for j in range(inputs.size(0)): count = count + 1 if count > 20: exit(1) image_bgr = self.blob_helper.tensor2bgr(inputs[j]) label_map = targets[j].numpy() image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr) cv2.imwrite( os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)), image_canvas) cv2.imshow('main', image_canvas) cv2.waitKey()