Пример #1
0
def predict_on_frames(args):
    # Load model
    mesh = Mesh(device=device)
    # Our pretrained networks have 5 residual blocks with 256 channels.
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint, device=device)
    model.to(device)
    model.eval()

    image_paths = [os.path.join(args.in_folder, f) for f in sorted(os.listdir(args.in_folder))
                   if f.endswith('.png')]
    print('Predicting on all png images in {}'.format(args.in_folder))

    all_vertices = []
    all_vertices_smpl = []
    all_cams = []

    for image_path in image_paths:
        print("Image: ", image_path)
        # Preprocess input image and generate predictions
        img, norm_img = process_image(image_path, input_res=cfg.INPUT_RES)
        norm_img = norm_img.to(device)
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, pred_camera, _, _ = model(norm_img)

        pred_vertices = pred_vertices.cpu().numpy()
        pred_vertices_smpl = pred_vertices_smpl.cpu().numpy()
        pred_camera = pred_camera.cpu().numpy()

        all_vertices.append(pred_vertices)
        all_vertices_smpl.append(pred_vertices_smpl)
        all_cams.append(pred_camera)

    # Save predictions as pkl
    all_vertices = np.concatenate(all_vertices, axis=0)
    all_vertices_smpl = np.concatenate(all_vertices_smpl, axis=0)
    all_cams = np.concatenate(all_cams, axis=0)

    pred_dict = {'verts': all_vertices,
                 'verts_smpl': all_vertices_smpl,
                 'pred_cam': all_cams}
    if args.out_folder == 'dataset':
        out_folder = args.in_folder.replace('cropped_frames', 'cmr_results')
    else:
        out_folder = args.out_folder
    print('Saving to', os.path.join(out_folder, 'cmr_results.pkl'))
    os.makedirs(out_folder)
    for key in pred_dict.keys():
        print(pred_dict[key].shape)
    with open(os.path.join(out_folder, 'cmr_results.pkl'), 'wb') as f:
        pickle.dump(pred_dict, f)
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--num_workers', default=4, type=int, help='Number of processes for data loading')
    parser.add_argument('--path_correction', action='store_true')
    args = parser.parse_args()

    # Device
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Load model
    mesh = Mesh(device=device)
    # Our pretrained networks have 5 residual blocks with 256 channels.
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint, device=device)
    model.to(device)
    model.eval()

    # Setup evaluation dataset
    dataset_path = '/scratch/as2562/datasets/sports_videos_smpl/final_dataset'
    dataset = SportsVideosEvalDataset(dataset_path, img_wh=config.INPUT_RES,
                                      path_correction=args.path_correction)
    print("Eval examples found:", len(dataset))

    # Metrics
    metrics = ['pve', 'pve_scale_corrected', 'pve_pa', 'pve-t', 'pve-t_scale_corrected',
               'silhouette_iou', 'j2d_l2e']

    save_path = '/data/cvfs/as2562/GraphCMR/evaluations/sports_videos_final_dataset'
    if not os.path.exists(save_path):
        os.makedirs(save_path)