class FCNSegmentor(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = DictAverageMeter() self.val_losses = DictAverageMeter() self.seg_running_score = SegRunningScore(configer) self.seg_visualizer = SegVisualizer(configer) self.seg_model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.seg_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self.runner_state = dict() self._init_model() def _init_model(self): self.seg_net = self.seg_model_manager.get_seg_model() self.seg_net = RunnerHelper.load_net(self, self.seg_net) self.optimizer, self.scheduler = Trainer.init( self._get_parameters(), self.configer.get('solver')) self.train_loader = self.seg_data_loader.get_trainloader() self.val_loader = self.seg_data_loader.get_valloader() self.loss = self.seg_model_manager.get_seg_loss() def _get_parameters(self): lr_1 = [] lr_10 = [] params_dict = dict(self.seg_net.named_parameters()) for key, value in params_dict.items(): if 'backbone' not in key: lr_10.append(value) else: lr_1.append(value) params = [{ 'params': lr_1, 'lr': self.configer.get('solver', 'lr')['base_lr'] }, { 'params': lr_10, 'lr': self.configer.get('solver', 'lr')['base_lr'] * 1.0 }] return params def train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() start_time = time.time() # Adjust the learning rate after every epoch. for i, data_dict in enumerate(self.train_loader): Trainer.update(self, warm_list=(0, ), solver_dict=self.configer.get('solver')) self.data_time.update(time.time() - start_time) # Forward pass. data_dict = RunnerHelper.to_device(self, data_dict) out = self.seg_net(data_dict) # Compute the loss of the train batch & backward. loss_dict = self.loss(out) loss = loss_dict['loss'] self.train_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['img'].size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.configer.get( 'solver', 'display_iter') == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {4}\tLoss = {3}\n'.format( self.runner_state['epoch'], self.runner_state['iters'], self.configer.get('solver', 'display_iter'), self.train_losses.info(), RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.runner_state['iters'] % self.configer.get('solver.save_iters') == 0 \ and self.configer.get('local_rank') == 0: RunnerHelper.save_net(self, self.seg_net) if self.configer.get('solver', 'lr')['metric'] == 'iters' \ and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'): break # Check to val the current model. if self.runner_state['iters'] % self.configer.get('solver', 'test_interval') == 0 \ and not self.configer.get('network.distributed'): self.val() self.runner_state['epoch'] += 1 def val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): data_dict = RunnerHelper.to_device(self, data_dict) with torch.no_grad(): # Forward pass. out = self.seg_net(data_dict) loss_dict = self.loss(out) # Compute the loss of the val batch. out_dict, _ = RunnerHelper.gather(self, out) self.val_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['img'].size(0)) self._update_running_score(out_dict['out'], DCHelper.tolist(data_dict['meta'])) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['performance'] = self.seg_running_score.get_mean_iou( ) self.runner_state['val_loss'] = self.val_losses.avg['loss'] RunnerHelper.save_net( self, self.seg_net, performance=self.seg_running_score.get_mean_iou(), val_loss=self.val_losses.avg['loss']) # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss = {0}\n'.format(self.val_losses.info(), batch_time=self.batch_time)) Log.info('Mean IOU: {}\n'.format( self.seg_running_score.get_mean_iou())) Log.info('Pixel ACC: {}\n'.format( self.seg_running_score.get_pixel_acc())) self.batch_time.reset() self.val_losses.reset() self.seg_running_score.reset() self.seg_net.train() def _update_running_score(self, pred, metas): pred = pred.permute(0, 2, 3, 1) for i in range(pred.size(0)): border_size = metas[i]['border_wh'] ori_target = metas[i]['ori_target'] total_logits = cv2.resize( pred[i, :border_size[1], :border_size[0]].cpu().numpy(), tuple(metas[i]['ori_img_wh']), interpolation=cv2.INTER_CUBIC) labelmap = np.argmax(total_logits, axis=-1) self.seg_running_score.update(labelmap[None], ori_target[None])
class FCNSegmentorTest(object): def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.seg_visualizer = SegVisualizer(configer) self.seg_parser = SegParser(configer) self.seg_model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.test_loader = TestDataLoader(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.seg_net = None self._init_model() def _init_model(self): self.seg_net = self.seg_model_manager.get_seg_model() self.seg_net = RunnerHelper.load_net(self, self.seg_net) self.seg_net.eval() def test(self, test_dir, out_dir): for _, data_dict in enumerate( self.test_loader.get_testloader(test_dir=test_dir)): total_logits = None if self.configer.get('test', 'mode') == 'ss_test': total_logits = self.ss_test(data_dict) elif self.configer.get('test', 'mode') == 'sscrop_test': total_logits = self.sscrop_test(data_dict, params_dict=self.configer.get( 'test', 'sscrop_test')) elif self.configer.get('test', 'mode') == 'ms_test': total_logits = self.ms_test(data_dict, params_dict=self.configer.get( 'test', 'ms_test')) elif self.configer.get('test', 'mode') == 'mscrop_test': total_logits = self.mscrop_test(data_dict, params_dict=self.configer.get( 'test', 'mscrop_test')) else: Log.error('Invalid test mode:{}'.format( self.configer.get('test', 'mode'))) exit(1) meta_list = DCHelper.tolist(data_dict['meta']) for i in range(len(meta_list)): label_map = np.argmax(total_logits[i], axis=-1) label_img = np.array(label_map, dtype=np.uint8) ori_img_bgr = ImageHelper.read_image(meta_list[i]['img_path'], tool='cv2', mode='BGR') image_canvas = self.seg_parser.colorize( label_img, image_canvas=ori_img_bgr) ImageHelper.save(image_canvas, save_path=os.path.join( out_dir, 'vis/{}.png'.format( meta_list[i]['filename']))) if self.configer.get('data.label_list', default=None) is not None: label_img = self.__relabel(label_img) if self.configer.get('data.reduce_zero_label', default=False): label_img = label_img + 1 label_img = label_img.astype(np.uint8) label_img = Image.fromarray(label_img, 'P') label_path = os.path.join( out_dir, 'label/{}.png'.format(meta_list[i]['filename'])) Log.info('Label Path: {}'.format(label_path)) ImageHelper.save(label_img, label_path) def ss_test(self, in_data_dict): data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0) results = self._predict(data_dict) return results def ms_test(self, in_data_dict, params_dict): total_logits = [ np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0], self.configer.get('data', 'num_classes')), np.float32) for meta in DCHelper.tolist(in_data_dict['meta']) ] for scale in params_dict['scale_search']: data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale) results = self._predict(data_dict) for i in range(len(total_logits)): total_logits[i] += results[i] for scale in params_dict['scale_search']: data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale, flip=True) results = self._predict(data_dict) for i in range(len(total_logits)): total_logits[i] += results[i][:, ::-1] return total_logits def sscrop_test(self, in_data_dict, params_dict): data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0) if any(image.size()[2] < params_dict['crop_size'][0] or image.size()[1] < params_dict['crop_size'][1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, params_dict['crop_size'], params_dict['crop_stride_ratio']) return results def mscrop_test(self, in_data_dict, params_dict): total_logits = [ np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0], self.configer.get('data', 'num_classes')), np.float32) for meta in DCHelper.tolist(in_data_dict['meta']) ] for scale in params_dict['scale_search']: data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale) if any(image.size()[2] < params_dict['crop_size'][0] or image.size()[1] < params_dict['crop_size'][1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, params_dict['crop_size'], params_dict['crop_stride_ratio']) for i in range(len(total_logits)): total_logits[i] += results[i] for scale in params_dict['scale_search']: data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale, flip=True) if any(image.size()[2] < params_dict['crop_size'][0] or image.size()[1] < params_dict['crop_size'][1] for image in DCHelper.tolist(data_dict['img'])): results = self._predict(data_dict) else: results = self._crop_predict(data_dict, params_dict['crop_size'], params_dict['crop_stride_ratio']) for i in range(len(total_logits)): total_logits[i] += results[i][:, ::-1] return total_logits def _crop_predict(self, data_dict, crop_size, crop_stride_ratio): split_batch = list() height_starts_list = list() width_starts_list = list() hw_list = list() for image in DCHelper.tolist(data_dict['img']): height, width = image.size()[1:] hw_list.append([height, width]) np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() height_starts = self._decide_intersection(height, crop_size[1], crop_stride_ratio) width_starts = self._decide_intersection(width, crop_size[0], crop_stride_ratio) split_crops = [] for height in height_starts: for width in width_starts: image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]] split_crops.append(image_crop[np.newaxis, :]) height_starts_list.append(height_starts) width_starts_list.append(width_starts) split_crops = np.concatenate( split_crops, axis=0) # (n, crop_image_size, crop_image_size, 3) inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device) split_batch.extend(list(inputs)) out_list = list() with torch.no_grad(): results = self.seg_net( dict(img=DCHelper.todc( split_batch, stack=True, samples_per_gpu=True))) for res in results: out_list.append(res['out'].permute(0, 2, 3, 1).cpu().numpy()) total_logits = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] count_predictions = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] for i in range(len(height_starts_list)): index = 0 for height in height_starts_list[i]: for width in width_starts_list[i]: total_logits[i][height:height + crop_size[1], width:width + crop_size[0]] += out_list[i][index] count_predictions[i][height:height + crop_size[1], width:width + crop_size[0]] += 1 index += 1 for i in range(len(total_logits)): total_logits[i] /= count_predictions[i] for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize( total_logits[i][:meta['border_hw'][0], :meta['border_hw'][1]], (meta['ori_img_size'][0], meta['ori_img_size'][1]), interpolation=cv2.INTER_CUBIC) return total_logits def _decide_intersection(self, total_length, crop_length, crop_stride_ratio): stride = int(crop_length * crop_stride_ratio) # set the stride as the paper do times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride * i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting def _predict(self, data_dict): with torch.no_grad(): total_logits = list() results = self.seg_net(data_dict) for res in results: total_logits.append(res['out'].squeeze(0).permute( 1, 2, 0).cpu().numpy()) for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize( total_logits[i] [:meta['border_hw'][0], :meta['border_hw'][1]], (meta['ori_img_size'][0], meta['ori_img_size'][1]), interpolation=cv2.INTER_CUBIC) return total_logits def __relabel(self, label_map): height, width = label_map.shape label_dst = np.zeros((height, width), dtype=np.uint8) for i in range(self.configer.get('data', 'num_classes')): label_dst[label_map == i] = self.configer.get( 'data', 'label_list')[i] label_dst = np.array(label_dst, dtype=np.uint8) return label_dst