def get_pose_estimation_prediction(cfg, model, image, vis_thre, transforms): # size at scale 1.0 base_size, center, scale = get_multi_scale_size( image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0 ) parser = HeatmapRegParser(cfg) with torch.no_grad(): heatmap_fuse = 0 final_heatmaps = None final_kpts = None input_size = cfg.DATASET.INPUT_SIZE for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): #joints, mask do not use in demo mode joints = np.zeros((0, cfg.DATASET.NUM_JOINTS, 3)) mask = np.zeros((image.shape[0], image.shape[1])) image_resized, _, _, center, scale = resize_align_multi_scale( image, joints, mask, input_size, s, 1.0 ) image_resized = transforms(image_resized) image_resized = image_resized.unsqueeze(0).cuda() outputs, heatmaps, kpts = get_multi_stage_outputs( cfg, model, image_resized, cfg.TEST.FLIP_TEST ) final_heatmaps, final_kpts = aggregate_results( cfg, final_heatmaps, final_kpts, heatmaps, kpts ) for heatmap in final_heatmaps: heatmap_fuse += up_interpolate( heatmap, size=(base_size[1], base_size[0]), mode='bilinear' ) heatmap_fuse = heatmap_fuse/float(len(final_heatmaps)) # for only pred kpts grouped, scores = parser.parse( final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=False ) if len(scores) == 0: return [] results = get_final_preds( grouped, center, scale, [heatmap_fuse.size(-1), heatmap_fuse.size(-2)] ) final_results = [] for i in range(len(scores)): if scores[i] > vis_thre: final_results.append(results[i]) if len(final_results) == 0: return [] return final_results
def aggregate_results(cfg, heatmap_sum, poses, heatmap, posemap, scale): """ Get initial pose proposals and aggregate the results of all scale. Args: heatmap (Tensor): Heatmap at this scale (1, 1+num_joints, w, h) posemap (Tensor): Posemap at this scale (1, 2*num_joints, w, h) heatmap_sum (Tensor): Sum of the heatmaps (1, 1+num_joints, w, h) poses (List): Gather of the pose proposals [(num_people, num_joints, 3)] """ ratio = cfg.DATASET.INPUT_SIZE * 1.0 / cfg.DATASET.OUTPUT_SIZE reverse_scale = ratio / scale h, w = heatmap[0].size(-1), heatmap[0].size(-2) heatmap_sum += up_interpolate(heatmap, size=(int(reverse_scale * w), int(reverse_scale * h)), mode='bilinear') center_heatmap = heatmap[0, -1:] pose_ind, ctr_score = get_maximum_from_heatmap(cfg, center_heatmap) posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2) pose = reverse_scale * posemap[pose_ind] ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None] poses.append(torch.cat([pose, ctr_score], dim=2)) return heatmap_sum, poses
def main(): args = parse_args() update_config(cfg, args) check_config(cfg) logger, final_output_dir, tb_log_dir = create_logger( cfg, args.cfg, 'valid' ) logger.info(pprint.pformat(args)) logger.info(cfg) # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( cfg, is_train=False ) if cfg.TEST.MODEL_FILE: logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True) else: model_state_file = os.path.join( final_output_dir, 'model_best.pth.tar' ) logger.info('=> loading model from {}'.format(model_state_file)) model.load_state_dict(torch.load(model_state_file)) #dump_input = torch.rand( # (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE) #) #logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE)) model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() model.eval() data_loader, test_dataset = make_test_dataloader(cfg) if cfg.MODEL.NAME == 'pose_hourglass': transforms = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), ] ) else: transforms = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) parser = HeatmapRegParser(cfg) # for only kpts all_reg_preds = [] all_reg_scores = [] # for pred kpts and pred heat all_preds = [] all_scores = [] pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None for i, (images, joints, masks, areas) in enumerate(data_loader): assert 1 == images.size(0), 'Test batch size should be 1' image = images[0].cpu().numpy() joints = joints[0].cpu().numpy() mask = masks[0].cpu().numpy() area = areas[0].cpu().numpy() # size at scale 1.0 base_size, center, scale = get_multi_scale_size( image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0 ) with torch.no_grad(): heatmap_fuse = 0 final_heatmaps = None final_kpts = None input_size = cfg.DATASET.INPUT_SIZE for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): image_resized, joints_resized, _, center, scale = resize_align_multi_scale( image, joints, mask, input_size, s, 1.0 ) image_resized = transforms(image_resized) image_resized = image_resized.unsqueeze(0).cuda() outputs, heatmaps, kpts = get_multi_stage_outputs( cfg, model, image_resized, cfg.TEST.FLIP_TEST ) final_heatmaps, final_kpts = aggregate_results( cfg, final_heatmaps, final_kpts, heatmaps, kpts ) for heatmap in final_heatmaps: heatmap_fuse += up_interpolate( heatmap, size=(base_size[1], base_size[0]), mode='bilinear' ) heatmap_fuse = heatmap_fuse/float(len(final_heatmaps)) # for only pred kpts grouped, scores = parser.parse( final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=False ) if len(scores) == 0: all_reg_preds.append([]) all_reg_scores.append([]) else: final_results = get_final_preds( grouped, center, scale, [heatmap_fuse.size(-1),heatmap_fuse.size(-2)] ) if cfg.RESCORE.USE: scores = rescore_valid(cfg, final_results, scores) all_reg_preds.append(final_results) all_reg_scores.append(scores) # for pred kpts and pred heatmaps grouped, scores = parser.parse( final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=True ) if len(scores) == 0: all_preds.append([]) all_scores.append([]) if cfg.TEST.LOG_PROGRESS: pbar.update() continue final_results = get_final_preds( grouped, center, scale, [heatmap_fuse.size(-1),heatmap_fuse.size(-2)] ) if cfg.RESCORE.USE: scores = rescore_valid(cfg, final_results, scores) all_preds.append(final_results) all_scores.append(scores) if cfg.TEST.LOG_PROGRESS: pbar.update() sv_all_preds = [all_reg_preds, all_preds] sv_all_scores = [all_reg_scores, all_scores] sv_all_name = ['regression', 'final'] if cfg.TEST.LOG_PROGRESS: pbar.close() for i in range(len(sv_all_preds)): print('Testing '+sv_all_name[i]) preds = sv_all_preds[i] scores = sv_all_scores[i] name_values, _ = test_dataset.evaluate( cfg, preds, scores, final_output_dir, sv_all_name[i] ) if isinstance(name_values, list): for name_value in name_values: _print_name_value(logger, name_value, cfg.MODEL.NAME) else: _print_name_value(logger, name_values, cfg.MODEL.NAME)
def get_multi_stage_outputs(cfg, model, image, with_flip=False): num_joints = cfg.DATASET.NUM_JOINTS - 1 dataset = cfg.DATASET.DATASET heatmaps_avg = 0 num_heatmaps = 0 heatmaps = [] reg_kpts_list = [] # forward ########################################################################## if cfg.LOSS.HEATMAP_MIDDLE_LOSS: all_outputs, all_offsets, _ = model(image) else: all_outputs, all_offsets = model(image) ########################################################################## outputs = [get_one_stage_outputs(out) for out in all_outputs] offset = all_offsets[0][-1] h, w = offset.shape[2:] reg_kpts = get_reg_kpts(offset[0], num_joints) reg_kpts = reg_kpts.contiguous().view(h * w, 2 * num_joints).permute( 1, 0).contiguous().view(1, -1, h, w) reg_kpts_list.append(reg_kpts) if with_flip: if 'coco' in dataset: flip_index_heat = FLIP_CONFIG['COCO_WITH_CENTER'] \ if cfg.DATASET.WITH_CENTER else FLIP_CONFIG['COCO'] flip_index_offset = FLIP_CONFIG['COCO'] elif 'crowd_pose' in dataset: flip_index_heat = FLIP_CONFIG['CROWDPOSE_WITH_CENTER'] \ if cfg.DATASET.WITH_CENTER else FLIP_CONFIG['CROWDPOSE'] flip_index_offset = FLIP_CONFIG['CROWDPOSE'] else: raise ValueError( 'Please implement flip_index for new dataset: %s.' % dataset) new_image = torch.zeros_like(image) new_image_2x = torch.zeros_like(image) image = torch.flip(image, [3]) new_image[:, :, :, :-3] = image[:, :, :, 3:] new_image_2x[:, :, :, :-1] = image[:, :, :, 1:] ########################################################################## if cfg.LOSS.HEATMAP_MIDDLE_LOSS: all_outputs_flip, all_offsets_flip, _ = model(new_image) else: all_outputs_flip, all_offsets_flip = model(new_image) ########################################################################## outputs_flip = [get_one_stage_outputs(all_outputs_flip[0])] if len(cfg.DATASET.OUTPUT_SIZE) > 1: ########################################################################## if cfg.LOSS.HEATMAP_MIDDLE_LOSS: all_outputs_flip, _, _ = model(new_image_2x) else: all_outputs_flip, _ = model(new_image_2x) ########################################################################## outputs_flip.append(get_one_stage_outputs(all_outputs_flip[1])) offset_flip = all_offsets_flip[0][-1] reg_kpts_flip = get_reg_kpts(offset_flip[0], num_joints) reg_kpts_flip = reg_kpts_flip[:, flip_index_offset, :] reg_kpts_flip[:, :, 0] = w - reg_kpts_flip[:, :, 0] - 1 reg_kpts_flip = reg_kpts_flip.contiguous().view( h * w, 2 * num_joints).permute(1, 0).contiguous().view(1, -1, h, w) reg_kpts_list.append(torch.flip(reg_kpts_flip, [3])) else: outputs_flip = None for i, output in enumerate(outputs): if len(outputs) > 1 and i != len(outputs) - 1: output = up_interpolate(output, size=(outputs[-1].size(2), outputs[-1].size(3))) c = output.shape[1] if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]: num_heatmaps += 1 if num_heatmaps > 1: heatmaps_avg[:, :c] += output else: heatmaps_avg += output if num_heatmaps > 0: heatmaps_avg[:, :c] /= num_heatmaps heatmaps.append(heatmaps_avg) if with_flip: heatmaps_avg = 0 num_heatmaps = 0 for i in range(len(outputs_flip)): output = outputs_flip[i] if len(outputs_flip) > 1 and i != len(outputs_flip) - 1: output = up_interpolate(output, size=(outputs_flip[-1].size(2), outputs_flip[-1].size(3))) output = torch.flip(output, [3]) outputs.append(output) c = output.shape[1] if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]: num_heatmaps += 1 if 'coco' in dataset: flip_index_heat = FLIP_CONFIG['COCO_WITH_CENTER'] \ if c == num_joints+1 else FLIP_CONFIG['COCO'] elif 'crowd_pose' in dataset: flip_index_heat = FLIP_CONFIG['CROWDPOSE_WITH_CENTER'] \ if c == num_joints+1 else FLIP_CONFIG['CROWDPOSE'] else: raise ValueError( 'Please implement flip_index for new dataset: %s.' % dataset) if num_heatmaps > 1: heatmaps_avg[:, :c] += output[:, flip_index_heat, :, :] else: heatmaps_avg += \ output[:, flip_index_heat, :, :] if num_heatmaps > 0: heatmaps_avg[:, :c] /= num_heatmaps heatmaps.append(heatmaps_avg) return outputs, heatmaps, reg_kpts_list