Пример #1
0
def main_VIBE(video_dict_or_file, model):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    image_folder, num_frames, img_shape = video_to_images(video_dict_or_file,
                                                          return_info=True)

    print(f'Input video number of frames {num_frames}')

    # ========= Run tracking ========= #
    bbox_scale = 1.1
    # run multi object tracker
    mot = MPT(device=device, output_format='dict', yolo_img_size=256)
    tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    maximum = 0
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]
        elif tracking_results[person_id]['bbox'][:, 3].max() > maximum:
            bigbbox = person_id

    try:
        tracking_results = tracking_results[bigbbox]
    except UnboundLocalError:
        return None

    # ========= Define VIBE model ========= #
    # model = VIBE_Demo(
    #     seqlen=16,
    #     n_layers=2,
    #     hidden_size=1024,
    #     add_linear=True,
    #     use_residual=True,
    # ).to(device)

    # ========= Load pretrained weights ========= #
    # pretrained_file = download_ckpt(use_3dpw=False)
    # ckpt = torch.load(pretrained_file, map_location=device)
    # ckpt = ckpt['gen_state_dict']
    # model.load_state_dict(ckpt, strict=False)
    model.eval()
    # print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    # ========= Run VIBE on each person ========= #
    joints2d = None

    bboxes = tracking_results['bbox']

    frames = tracking_results['frames']

    dataset = Inference(
        image_folder=image_folder,
        frames=frames,
        bboxes=bboxes,
        joints2d=joints2d,
        scale=bbox_scale,
    )

    dataloader = DataLoader(dataset)

    with torch.no_grad():

        pred_joints3d = []

        for batch in dataloader:
            batch = batch.unsqueeze(0)
            batch = batch.to(device)

            batch_size, seqlen = batch.shape[:2]
            output = model(batch)[-1]

            pred_joints3d.append(output['kp_3d'].reshape(
                batch_size * seqlen, -1, 3))

        pred_joints3d = torch.cat(pred_joints3d, dim=0)

        del batch

        vibe_results = pred_joints3d[:, :25, :].transpose(1, 2)

    del model

    shutil.rmtree(image_folder)

    return vibe_results
Пример #2
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    video_file = args.vid_file

    # ========= [Optional] download the youtube video ========= #
    if video_file.startswith('https://www.youtube.com'):
        print(f'Donwloading YouTube video \"{video_file}\"')
        video_file = download_youtube_clip(video_file, '/tmp')

        if video_file is None:
            exit('Youtube url is not valid!')

        print(f'YouTube Video has been downloaded to {video_file}...')

    if not os.path.isfile(video_file):
        exit(f'Input video \"{video_file}\" does not exist!')

    output_path = os.path.join(
        args.output_folder,
        os.path.basename(video_file).replace('.mp4', ''))
    os.makedirs(output_path, exist_ok=True)

    image_folder, num_frames, img_shape = video_to_images(video_file,
                                                          return_info=True)

    print(f'Input video number of frames {num_frames}')
    orig_height, orig_width = img_shape[:2]

    total_time = time.time()

    # ========= Run tracking ========= #
    bbox_scale = 1.1
    if args.tracking_method == 'pose':
        if not os.path.isabs(video_file):
            video_file = os.path.join(os.getcwd(), video_file)
        tracking_results = run_posetracker(video_file,
                                           staf_folder=args.staf_dir,
                                           display=args.display)
    else:
        # run multi object tracker
        mot = MPT(
            device=device,
            batch_size=args.tracker_batch_size,
            display=args.display,
            detector_type=args.detector,
            output_format='dict',
            yolo_img_size=args.yolo_img_size,
        )
        tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]

    # ========= Define VIBE model ========= #
    model = VIBE_Demo(
        seqlen=16,
        n_layers=2,
        hidden_size=1024,
        add_linear=True,
        use_residual=True,
    ).to(device)

    # ========= Load pretrained weights ========= #
    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file)
    print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    # ========= Run VIBE on each person ========= #
    print(f'Running VIBE on each tracklet...')
    vibe_time = time.time()
    vibe_results = {}
    for person_id in tqdm(list(tracking_results.keys())):
        bboxes = joints2d = None

        if args.tracking_method == 'bbox':
            bboxes = tracking_results[person_id]['bbox']
        elif args.tracking_method == 'pose':
            joints2d = tracking_results[person_id]['joints2d']

        frames = tracking_results[person_id]['frames']

        dataset = Inference(
            image_folder=image_folder,
            frames=frames,
            bboxes=bboxes,
            joints2d=joints2d,
            scale=bbox_scale,
        )

        bboxes = dataset.bboxes
        frames = dataset.frames
        has_keypoints = True if joints2d is not None else False

        dataloader = DataLoader(dataset,
                                batch_size=args.vibe_batch_size,
                                num_workers=16)

        with torch.no_grad():

            pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

            for batch in dataloader:
                if has_keypoints:
                    batch, nj2d = batch
                    norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                batch = batch.unsqueeze(0)
                batch = batch.to(device)

                batch_size, seqlen = batch.shape[:2]
                output = model(batch)[-1]

                pred_cam.append(output['theta'][:, :, :3].reshape(
                    batch_size * seqlen, -1))
                pred_verts.append(output['verts'].reshape(
                    batch_size * seqlen, -1, 3))
                pred_pose.append(output['theta'][:, :, 3:75].reshape(
                    batch_size * seqlen, -1))
                pred_betas.append(output['theta'][:, :, 75:].reshape(
                    batch_size * seqlen, -1))
                pred_joints3d.append(output['kp_3d'].reshape(
                    batch_size * seqlen, -1, 3))

            pred_cam = torch.cat(pred_cam, dim=0)
            pred_verts = torch.cat(pred_verts, dim=0)
            pred_pose = torch.cat(pred_pose, dim=0)
            pred_betas = torch.cat(pred_betas, dim=0)
            pred_joints3d = torch.cat(pred_joints3d, dim=0)

            del batch

        # ========= [Optional] run Temporal SMPLify to refine the results ========= #
        if args.run_smplify and args.tracking_method == 'pose':
            norm_joints2d = np.concatenate(norm_joints2d, axis=0)
            norm_joints2d = convert_kps(norm_joints2d, src='staf', dst='spin')
            norm_joints2d = torch.from_numpy(norm_joints2d).float().to(device)

            # Run Temporal SMPLify
            update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
            new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
                pred_rotmat=pred_pose,
                pred_betas=pred_betas,
                pred_cam=pred_cam,
                j2d=norm_joints2d,
                device=device,
                batch_size=norm_joints2d.shape[0],
                pose2aa=False,
            )

            # update the parameters after refinement
            print(
                f'Update ratio after Temporal SMPLify: {update.sum()} / {norm_joints2d.shape[0]}'
            )
            pred_verts = pred_verts.cpu()
            pred_cam = pred_cam.cpu()
            pred_pose = pred_pose.cpu()
            pred_betas = pred_betas.cpu()
            pred_joints3d = pred_joints3d.cpu()
            pred_verts[update] = new_opt_vertices[update]
            pred_cam[update] = new_opt_cam[update]
            pred_pose[update] = new_opt_pose[update]
            pred_betas[update] = new_opt_betas[update]
            pred_joints3d[update] = new_opt_joints3d[update]

        elif args.run_smplify and args.tracking_method == 'bbox':
            print(
                '[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!'
            )
            print('[WARNING] Continuing without running Temporal SMPLify!..')

        # ========= Save results to a pickle file ========= #
        pred_cam = pred_cam.cpu().numpy()
        pred_verts = pred_verts.cpu().numpy()
        pred_pose = pred_pose.cpu().numpy()
        pred_betas = pred_betas.cpu().numpy()
        pred_joints3d = pred_joints3d.cpu().numpy()

        # Runs 1 Euro Filter to smooth out the results
        if args.smooth:
            min_cutoff = args.smooth_min_cutoff  # 0.004
            beta = args.smooth_beta  # 1.5
            print(
                f'Running smoothing on person {person_id}, min_cutoff: {min_cutoff}, beta: {beta}'
            )
            pred_verts, pred_pose, pred_joints3d = smooth_pose(
                pred_pose, pred_betas, min_cutoff=min_cutoff, beta=beta)

        orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                bbox=bboxes,
                                                img_width=orig_width,
                                                img_height=orig_height)

        output_dict = {
            'pred_cam': pred_cam,
            'orig_cam': orig_cam,
            'verts': pred_verts,
            'pose': pred_pose,
            'betas': pred_betas,
            'joints3d': pred_joints3d,
            'joints2d': joints2d,
            'bboxes': bboxes,
            'frame_ids': frames,
        }

        vibe_results[person_id] = output_dict

    del model

    end = time.time()
    fps = num_frames / (end - vibe_time)

    print(f'VIBE FPS: {fps:.2f}')
    total_time = time.time() - total_time
    print(
        f'Total time spent: {total_time:.2f} seconds (including model loading time).'
    )
    print(
        f'Total FPS (including model loading time): {num_frames / total_time:.2f}.'
    )

    print(
        f'Saving output results to \"{os.path.join(output_path, "vibe_output.pkl")}\".'
    )

    joblib.dump(vibe_results, os.path.join(output_path, "vibe_output.pkl"))

    if not args.no_render:
        # ========= Render results as a single video ========= #
        renderer = Renderer(resolution=(orig_width, orig_height),
                            orig_img=True,
                            wireframe=args.wireframe)

        output_img_folder = f'{image_folder}_output'
        os.makedirs(output_img_folder, exist_ok=True)

        print(f'Rendering output video, writing frames to {output_img_folder}')

        # prepare results for rendering
        frame_results = prepare_rendering_results(vibe_results, num_frames)
        mesh_color = {
            k: colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0)
            for k in vibe_results.keys()
        }

        image_file_names = sorted([
            os.path.join(image_folder, x) for x in os.listdir(image_folder)
            if x.endswith('.png') or x.endswith('.jpg')
        ])

        for frame_idx in tqdm(range(len(image_file_names))):
            img_fname = image_file_names[frame_idx]
            img = cv2.imread(img_fname)

            if args.sideview:
                side_img = np.zeros_like(img)

            for person_id, person_data in frame_results[frame_idx].items():
                frame_verts = person_data['verts']
                frame_cam = person_data['cam']

                mc = mesh_color[person_id]

                mesh_filename = None

                if args.save_obj:
                    mesh_folder = os.path.join(output_path, 'meshes',
                                               f'{person_id:04d}')
                    os.makedirs(mesh_folder, exist_ok=True)
                    mesh_filename = os.path.join(mesh_folder,
                                                 f'{frame_idx:06d}.obj')

                img = renderer.render(
                    img,
                    frame_verts,
                    cam=frame_cam,
                    color=mc,
                    mesh_filename=mesh_filename,
                )

                if args.sideview:
                    side_img = renderer.render(
                        side_img,
                        frame_verts,
                        cam=frame_cam,
                        color=mc,
                        angle=270,
                        axis=[0, 1, 0],
                    )

            if args.sideview:
                img = np.concatenate([img, side_img], axis=1)

            cv2.imwrite(
                os.path.join(output_img_folder, f'{frame_idx:06d}.png'), img)

            if args.display:
                cv2.imshow('Video', img)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

        if args.display:
            cv2.destroyAllWindows()

        # ========= Save rendered video ========= #
        vid_name = os.path.basename(video_file)
        save_name = f'{vid_name.replace(".mp4", "")}_vibe_result.mp4'
        save_name = os.path.join(output_path, save_name)
        print(f'Saving result video to {save_name}')
        images_to_video(img_folder=output_img_folder,
                        output_vid_file=save_name)
        shutil.rmtree(output_img_folder)

    shutil.rmtree(image_folder)
    print('================= END =================')
Пример #3
0
def estimate_pose(pose,
                  save_pkl=False,
                  image_folder=None,
                  output_path=None,
                  tracking_method='bbox',
                  vibe_batch_size=225,
                  tracker_batch_size=12,
                  mesh_out=False,
                  run_smplify=False,
                  render=False,
                  wireframe=False,
                  sideview=False,
                  display=False,
                  save_obj=False,
                  gpu_id=0,
                  output_folder='MEVA_outputs',
                  detector='yolo',
                  yolo_img_size=416,
                  exp='train_meva_2',
                  cfg='train_meva_2',
                  num_workers=None):

    #return_dir = os.getcwd()
    #os.chdir('MEVA')
    if not image_folder:
        image_folder = osp.join(PSYPOSE_DATA_DIR, pose.vid_name)

    video_file = pose.vid_path

    # setting minimum number of frames to reflect minimum track length to half a second
    MIN_NUM_FRAMES = 25
    #MIN_NUM_FRAMES = round(pose.fps/2)

    if torch.cuda.is_available():
        torch.cuda.set_device(gpu_id)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if not os.path.isfile(video_file):
        exit(f'Input video \"{video_file}\" does not exist!')

    filename = os.path.splitext(os.path.basename(video_file))[0]
    #output_path = os.path.join(output_folder, filename)
    #os.makedirs(output_path, exist_ok=True)

    image_folder, num_frames, img_shape = video_to_images(
        video_file, img_folder=image_folder, return_info=True)

    print(f'Input video number of frames {num_frames}')
    orig_height, orig_width = img_shape[:2]

    total_time = time.time()

    # ========= Run tracking ========= #

    #print("\n")

    # run multi object tracker
    mot = MPT(
        device=device,
        batch_size=tracker_batch_size,
        display=display,
        detector_type=detector,
        output_format='dict',
        yolo_img_size=yolo_img_size,
    )
    tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]

    # print('Track lengths: /n')
    # for person_id in list(tracking_results.keys()):
    #     print(str(tracking_results[person_id]['frames'].shape[0]))

    # ========= MEVA Model ========= #
    pretrained_file = PSYPOSE_DATA_DIR.joinpath("meva_data", "results", "meva",
                                                "train_meva_2",
                                                "model_best.pth.tar")

    config_file = osp.join(dir_name, "meva", "cfg", f"{cfg}.yml")
    cfg = update_cfg(config_file)
    model = MEVA_demo(
        n_layers=cfg.MODEL.TGRU.NUM_LAYERS,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        seqlen=cfg.DATASET.SEQLEN,
        hidden_size=cfg.MODEL.TGRU.HIDDEN_SIZE,
        add_linear=cfg.MODEL.TGRU.ADD_LINEAR,
        bidirectional=cfg.MODEL.TGRU.BIDIRECTIONAL,
        use_residual=cfg.MODEL.TGRU.RESIDUAL,
        cfg=cfg.VAE_CFG,
    ).to(device)

    ckpt = torch.load(pretrained_file, map_location=device)
    # print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt)
    model.eval()
    print(f'\nLoaded pretrained weights from \"{pretrained_file}\"')
    # ========= MEVA Model ========= #

    # ========= Run MEVA on each person ========= #
    bbox_scale = 1.2
    print('\nRunning MEVA on each tracklet...\n', flush=True)
    vibe_time = time.time()
    meva_results = {}
    for person_id in tqdm(list(tracking_results.keys())):
        bboxes = joints2d = None

        bboxes = tracking_results[person_id]['bbox']
        frames = tracking_results[person_id]['frames']
        #    if len(frames) < 90:
        #        print(f"!!!tracklet < 90 frames: {len(frames)} frames")
        #        continue

        dataset = Inference(
            image_folder=image_folder,
            frames=frames,
            bboxes=bboxes,
            scale=bbox_scale,
        )

        bboxes = dataset.bboxes
        frames = dataset.frames

        if num_workers == None:
            num_workers = 16

        dataloader = DataLoader(dataset,
                                batch_size=vibe_batch_size,
                                num_workers=num_workers,
                                shuffle=False)

        with torch.no_grad():

            pred_cam, pred_pose, pred_betas, pred_joints3d = [], [], [], []
            data_chunks = dataset.iter_data()

            for idx in range(len(data_chunks)):
                batch = data_chunks[idx]
                batch_image = batch['batch'].unsqueeze(0)
                cl = batch['cl']
                batch_image = batch_image.to(device)

                batch_size, seqlen = batch_image.shape[:2]
                output = model(batch_image)[-1]

                pred_cam.append(output['theta'][0, cl[0]:cl[1], :3])
                pred_pose.append(output['theta'][0, cl[0]:cl[1], 3:75])
                pred_betas.append(output['theta'][0, cl[0]:cl[1], 75:])
                pred_joints3d.append(output['kp_3d'][0, cl[0]:cl[1]])

            pred_cam = torch.cat(pred_cam, dim=0)
            pred_pose = torch.cat(pred_pose, dim=0)
            pred_betas = torch.cat(pred_betas, dim=0)
            pred_joints3d = torch.cat(pred_joints3d, dim=0)

            del batch_image

        # ========= Save results to a pickle file ========= #
        pred_cam = pred_cam.cpu().numpy()
        pred_pose = pred_pose.cpu().numpy()
        pred_betas = pred_betas.cpu().numpy()
        pred_joints3d = pred_joints3d.cpu().numpy()

        orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                bbox=bboxes,
                                                img_width=orig_width,
                                                img_height=orig_height)

        output_dict = {
            'pred_cam': pred_cam,
            'orig_cam': orig_cam,
            'pose': pred_pose,
            'betas': pred_betas,
            'joints3d': pred_joints3d,
            'bboxes': bboxes,
            'frame_ids': frames,
        }

        meva_results[person_id] = output_dict

    del model

    end = time.time()
    fps = num_frames / (end - vibe_time)

    print(f'VIBE FPS: {fps:.2f}')
    total_time = time.time() - total_time
    print(
        f'Total time spent: {total_time:.2f} seconds (including model loading time).'
    )
    print(
        f'Total FPS (including model loading time): {num_frames / total_time:.2f}.'
    )

    # if save_pkl:
    #     print(f'Saving output results to \"{os.path.join(output_path, "meva_output.pkl")}\".')

    #     joblib.dump(meva_results, os.path.join(output_path, "meva_output.pkl"))

    # meva_results = joblib.load(os.path.join(output_path, "meva_output.pkl"))

    #if render_preview or not len(meva_results) == 0:
    if render:
        # ========= Render results as a single video ========= #
        renderer = Renderer(resolution=(orig_width, orig_height),
                            orig_img=True,
                            wireframe=wireframe)

        output_img_folder = f'{image_folder}_output'
        os.makedirs(output_img_folder, exist_ok=True)

        print(f'Rendering output video, writing frames to {output_img_folder}')

        # prepare results for rendering
        frame_results = prepare_rendering_results(meva_results, num_frames)
        mesh_color = {
            k: colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0)
            for k in meva_results.keys()
        }

        image_file_names = sorted([
            os.path.join(image_folder, x) for x in os.listdir(image_folder)
            if x.endswith('.png') or x.endswith('.jpg')
        ])

        for frame_idx in tqdm(range(len(image_file_names))):
            img_fname = image_file_names[frame_idx]
            img = cv2.imread(img_fname)
            # img = np.zeros(img.shape)

            if sideview:
                side_img = np.zeros_like(img)

            for person_id, person_data in frame_results[frame_idx].items():
                frame_verts = person_data['verts']
                frame_cam = person_data['cam']

                mc = mesh_color[person_id]

                mesh_filename = None

                if save_obj:
                    mesh_folder = os.path.join(output_path, 'meshes',
                                               f'{person_id:04d}')
                    os.makedirs(mesh_folder, exist_ok=True)
                    mesh_filename = os.path.join(mesh_folder,
                                                 f'{frame_idx:06d}.obj')

                img = renderer.render(
                    img,
                    frame_verts,
                    cam=frame_cam,
                    color=mc,
                    mesh_filename=mesh_filename,
                )

                frame_cam = np.array([0.5, 1., 0, 0])
                if sideview:
                    side_img = renderer.render(
                        side_img,
                        frame_verts,
                        cam=frame_cam,
                        color=mc,
                        mesh_filename=mesh_filename,
                        # angle=270,
                        # axis=[0,1,0],
                    )

            if sideview:
                img = np.concatenate([img, side_img], axis=1)

            cv2.imwrite(
                os.path.join(output_img_folder, f'{frame_idx:06d}.png'), img)

            if display:
                cv2.imshow('Video', img)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

        if display:
            cv2.destroyAllWindows()

        # ========= Save rendered video ========= #
        vid_name = os.path.basename(video_file)
        save_name = f'{vid_name.replace(".mp4", "")}_meva_result.mp4'
        save_name = os.path.join(output_path, save_name)
        print(f'Saving result video to {save_name}')
        images_to_video(img_folder=output_img_folder,
                        output_vid_file=save_name)
        shutil.rmtree(output_img_folder)

    def clean_image_folder():
        if osp.exists(image_folder) and osp.isdir(image_folder):
            shutil.rmtree(image_folder)

    atexit.register(clean_image_folder)
    shutil.rmtree(image_folder)
    #os.chdir(return_dir)

    print('========FINISHED POSE ESTIMATION========')
    return meva_results
Пример #4
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    SEQ_LENGTH = args.sequence_length
    MIN_NUM_FRAMES = 1  # Don't change this
    TRACKER_BATCH_SIZE = MIN_NUM_FRAMES
    images_to_eval = []
    yolo_img_size = args.yolo_img_size

    image_folder = 'live_rendered_images'
    output_path = args.output_folder
    os.makedirs(image_folder, exist_ok=True)
    os.makedirs(output_path, exist_ok=True)
    os.makedirs('live_imgs', exist_ok=True)

    model = VIBE_Demo(seqlen=SEQ_LENGTH,
                      n_layers=2,
                      hidden_size=1024,
                      add_linear=True,
                      use_residual=True,
                      live_inference=True).to(device)

    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file)
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    mot = MPT(
        device=device,
        batch_size=TRACKER_BATCH_SIZE,
        display=False,
        detector_type=args.detector,
        output_format='dict',
        yolo_img_size=yolo_img_size,
    )

    # An asynchronous camera implementation to run cv2 camera in background while model is running
    cap = AsyncCamera(0, display=args.live_display)

    bbox_scale = 1.1

    i = 0
    bbox_lis, frame_lis, images_lis, joints2d_lis = [], [], [], []
    pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

    while (True):
        # If q is pressed cap.stop will turn True
        if (cap.stop):
            break

        ret, captured_frames = cap.read()
        if (not ret):
            continue
        if (len(captured_frames) < MIN_NUM_FRAMES):
            continue

        images = get_images_from_captures(captured_frames, MIN_NUM_FRAMES)

        cap.del_frame_lis()

        orig_height, orig_width = images[0].shape[:2]
        orig_dim = (orig_height, orig_width)

        saveToDir(images)
        if args.tracking_method == 'pose':
            images_to_video('./live_imgs', './live_imgs/pose_video.mp4')
            tracking_results = run_posetracker('live_imgs/pose_video.mp4',
                                               staf_folder=args.staf_dir,
                                               display=args.display)
        else:
            tracking_results = mot('./live_imgs')

        if args.live_display:
            cap.set_display_image(images[-1])

        if (len(tracking_results.keys()) == 0):
            print('Unable to detect any person')

        for image in images:
            images_lis.append(image)

        if len(tracking_results.keys()) != 0:

            person_id = (list)(tracking_results.keys())[0]
            print(person_id)
            frames = tracking_results[person_id]['frames']

            bboxes, joints2d = None, None

            if args.tracking_method == 'pose':
                joints2d = tracking_results[person_id]['joints2d']
                if (joints2d_lis == []):
                    joints2d_lis = joints2d
                else:
                    joints2d_lis = np.vstack([joints2d_lis, joints2d])
            else:
                bboxes = tracking_results[person_id]['bbox']
                if (bbox_lis == []):
                    bbox_lis = bboxes
                else:
                    bbox_lis = np.vstack([bbox_lis, bboxes])

            for x in (1 + i + frames - MIN_NUM_FRAMES):
                frame_lis.append(x)

            dataset = LiveInference(
                images=images_lis[-SEQ_LENGTH:],
                frames=frame_lis[-SEQ_LENGTH:],
                bboxes=bbox_lis[-SEQ_LENGTH:],
                joints2d=joints2d_lis[-SEQ_LENGTH:]
                if joints2d is not None else None,
                scale=bbox_scale,
            )

            bboxes = dataset.bboxes

            if args.tracking_method == 'pose':
                if (bbox_lis == []):
                    bbox_lis = bboxes
                else:
                    bbox_lis = np.vstack([bbox_lis, bboxes[-1:]])

            cap.set_bounding_box(bbox_lis[-1])

            has_keypoints = True if joints2d is not None else False
            norm_joints2d = []

            with torch.no_grad():

                # A manual implementation for getting data since dataloader is slow for few inputs
                tup = [
                    dataset.__getitem__(x) for x in range(dataset.__len__())
                ]

                if has_keypoints:
                    for j, batch in enumerate(tup):
                        tup[j], nj2d = batch
                        norm_joints2d.append(nj2d[:21, :].reshape(-1, 21, 3))

                for j, x in enumerate(tup):
                    tup[j] = x.unsqueeze(0)

                tup = tuple(tup)
                batch = torch.cat((tup), 0)

                batch = batch.unsqueeze(0)
                batch = batch.to(device)

                batch_size, seqlen = batch.shape[:2]

                # Send only latest image to hmr for faster inferencing
                output = model(batch[:, -1:, :, :, :])[-1]

                pred_cam.append(
                    output['theta'][:, -MIN_NUM_FRAMES:, :3].reshape(
                        batch_size * MIN_NUM_FRAMES, -1))
                pred_verts.append(
                    output['verts'][:, -MIN_NUM_FRAMES:, ].reshape(
                        batch_size * MIN_NUM_FRAMES, -1, 3))
                pred_pose.append(
                    output['theta'][:, -MIN_NUM_FRAMES:, ][:, :, 3:75].reshape(
                        batch_size * MIN_NUM_FRAMES, -1))
                pred_betas.append(
                    output['theta'][:, -MIN_NUM_FRAMES:, ][:, :, 75:].reshape(
                        batch_size * MIN_NUM_FRAMES, -1))
                pred_joints3d.append(
                    output['kp_3d'][:, -MIN_NUM_FRAMES:, ].reshape(
                        batch_size * MIN_NUM_FRAMES, -1, 3))

                del batch

            pred_verts[-MIN_NUM_FRAMES:], pred_cam[
                -MIN_NUM_FRAMES:], pred_pose[-MIN_NUM_FRAMES:], pred_betas[
                    -MIN_NUM_FRAMES:], pred_joints3d[
                        -MIN_NUM_FRAMES:], norm_joints2d[
                            -MIN_NUM_FRAMES:] = temporal_simplify(
                                pred_verts[-MIN_NUM_FRAMES:],
                                pred_cam[-MIN_NUM_FRAMES:],
                                pred_pose[-MIN_NUM_FRAMES:],
                                pred_betas[-MIN_NUM_FRAMES:],
                                pred_joints3d[-MIN_NUM_FRAMES:],
                                norm_joints2d[-MIN_NUM_FRAMES:], device, args)

            get_vibe_results(
                pred_cam[-MIN_NUM_FRAMES:], pred_verts[-MIN_NUM_FRAMES:],
                pred_pose[-MIN_NUM_FRAMES:], pred_betas[-MIN_NUM_FRAMES:],
                pred_joints3d[-MIN_NUM_FRAMES:],
                joints2d_lis[-MIN_NUM_FRAMES:], bbox_lis[-MIN_NUM_FRAMES:],
                frame_lis[-MIN_NUM_FRAMES], orig_dim, 0)

        images = []
        i = i + 1

        if (i == args.max_frames):
            break

    del model

    vibe_results = get_vibe_results(pred_cam, pred_verts, pred_pose,
                                    pred_betas, pred_joints3d, joints2d_lis,
                                    bbox_lis, frame_lis, orig_dim, 0)

    if not args.no_render:
        for i, image in enumerate(images_lis):
            cv2.imwrite(f'{image_folder}/{(i):06d}.jpg', image)
        print(frame_lis)
        render(orig_dim, frame_lis, vibe_results, image_folder, output_path,
               len(images_lis), args)

    shutil.rmtree('live_imgs')
    print('================= END =================')
Пример #5
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    dict = {
        'bridge': 1,
        'childs': 2,
        'downwarddog': 3,
        'mountain': 4,
        'plank': 5,
        'seatedforwardbend': 6,
        'tree': 7,
        'trianglepose': 8,
        'warrior1': 9,
        'warrior2': 10
    }
    dir_path = '/home/ubuntu/PoseEstimation/VIBE/InputData/input_training_data/'
    output_folder = '/home/ubuntu/PoseEstimation/VIBE/OutputData/training_data/'
    output_f3d_file = '/home/ubuntu/PoseEstimation/VIBE/OutputData/train_output_joints3d.csv'
    output_pose_file = '/home/ubuntu/PoseEstimation/VIBE/OutputData/train_output_pose.csv'
    joints3D_csv = open(output_f3d_file, 'a')
    pose_csv = open(output_pose_file, 'a')

    # ========= Define VIBE model ========= #
    model = VIBE_Demo(
        seqlen=16,
        n_layers=2,
        hidden_size=1024,
        add_linear=True,
        use_residual=True,
    ).to(device)

    # ========= Load pretrained weights ========= #
    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file)
    print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    for root, dirs, files in os.walk(dir_path):
        for file in files:
            if not file.startswith('NO_'):
                video_file = root + '/' + file
                video_label = dict[os.path.basename(root)]
                if not os.path.isfile(video_file):
                    exit(f'Input video \"{video_file}\" does not exist!')

                image_folder, num_frames, img_shape = video_to_images(
                    video_file, return_info=True)

                print(f'Input video number of frames {num_frames}')
                orig_height, orig_width = img_shape[:2]

                total_time = time.time()

                # ========= Run tracking ========= #
                bbox_scale = 1.1
                if args.tracking_method == 'pose':
                    if not os.path.isabs(video_file):
                        video_file = os.path.join(os.getcwd(), video_file)
                    tracking_results = run_posetracker(
                        video_file,
                        staf_folder=args.staf_dir,
                        display=args.display)
                else:
                    # run multi object tracker
                    mot = MPT(
                        device=device,
                        batch_size=args.tracker_batch_size,
                        display=args.display,
                        detector_type=args.detector,
                        output_format='dict',
                        yolo_img_size=args.yolo_img_size,
                    )
                    tracking_results = mot(image_folder)

                # remove tracklets if num_frames is less than MIN_NUM_FRAMES
                for person_id in list(tracking_results.keys()):
                    if tracking_results[person_id]['frames'].shape[
                            0] < MIN_NUM_FRAMES:
                        del tracking_results[person_id]

                # ========= Run VIBE on each person ========= #
                print(f'Running VIBE on each tracklet...')
                vibe_time = time.time()
                vibe_results = {}
                for person_id in tqdm(list(tracking_results.keys())):
                    bboxes = joints2d = None

                    if args.tracking_method == 'bbox':
                        bboxes = tracking_results[person_id]['bbox']
                    elif args.tracking_method == 'pose':
                        joints2d = tracking_results[person_id]['joints2d']

                    frames = tracking_results[person_id]['frames']

                    dataset = Inference(
                        image_folder=image_folder,
                        frames=frames,
                        bboxes=bboxes,
                        joints2d=joints2d,
                        scale=bbox_scale,
                    )

                    bboxes = dataset.bboxes
                    frames = dataset.frames
                    has_keypoints = True if joints2d is not None else False

                    dataloader = DataLoader(dataset,
                                            batch_size=args.vibe_batch_size,
                                            num_workers=16)

                    with torch.no_grad():

                        pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

                        for batch in dataloader:
                            if has_keypoints:
                                batch, nj2d = batch
                                norm_joints2d.append(nj2d.numpy().reshape(
                                    -1, 21, 3))

                            batch = batch.unsqueeze(0)
                            batch = batch.to(device)

                            batch_size, seqlen = batch.shape[:2]
                            output = model(batch)[-1]

                            pred_cam.append(output['theta'][:, :, :3].reshape(
                                batch_size * seqlen, -1))
                            pred_verts.append(output['verts'].reshape(
                                batch_size * seqlen, -1, 3))
                            pred_pose.append(
                                output['theta'][:, :, 3:75].reshape(
                                    batch_size * seqlen, -1))
                            pred_betas.append(
                                output['theta'][:, :, 75:].reshape(
                                    batch_size * seqlen, -1))
                            pred_joints3d.append(output['kp_3d'].reshape(
                                batch_size * seqlen, -1, 3))

                        pred_cam = torch.cat(pred_cam, dim=0)
                        pred_verts = torch.cat(pred_verts, dim=0)
                        pred_pose = torch.cat(pred_pose, dim=0)
                        pred_betas = torch.cat(pred_betas, dim=0)
                        pred_joints3d = torch.cat(pred_joints3d, dim=0)

                        del batch

                    # ========= [Optional] run Temporal SMPLify to refine the results ========= #
                    if args.run_smplify and args.tracking_method == 'pose':
                        norm_joints2d = np.concatenate(norm_joints2d, axis=0)
                        norm_joints2d = convert_kps(norm_joints2d,
                                                    src='staf',
                                                    dst='spin')
                        norm_joints2d = torch.from_numpy(
                            norm_joints2d).float().to(device)

                        # Run Temporal SMPLify
                        update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
                        new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
                            pred_rotmat=pred_pose,
                            pred_betas=pred_betas,
                            pred_cam=pred_cam,
                            j2d=norm_joints2d,
                            device=device,
                            batch_size=norm_joints2d.shape[0],
                            pose2aa=False,
                        )

                        # update the parameters after refinement
                        print(
                            f'Update ratio after Temporal SMPLify: {update.sum()} / {norm_joints2d.shape[0]}'
                        )
                        pred_verts = pred_verts.cpu()
                        pred_cam = pred_cam.cpu()
                        pred_pose = pred_pose.cpu()
                        pred_betas = pred_betas.cpu()
                        pred_joints3d = pred_joints3d.cpu()
                        pred_verts[update] = new_opt_vertices[update]
                        pred_cam[update] = new_opt_cam[update]
                        pred_pose[update] = new_opt_pose[update]
                        pred_betas[update] = new_opt_betas[update]
                        pred_joints3d[update] = new_opt_joints3d[update]

                    elif args.run_smplify and args.tracking_method == 'bbox':
                        print(
                            '[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!'
                        )
                        print(
                            '[WARNING] Continuing without running Temporal SMPLify!..'
                        )

                    # ========= Save results to a pickle file ========= #
                    pred_cam = pred_cam.cpu().numpy()
                    pred_verts = pred_verts.cpu().numpy()
                    pred_pose = pred_pose.cpu().numpy()
                    pred_betas = pred_betas.cpu().numpy()
                    pred_joints3d = pred_joints3d.cpu().numpy()

                    # Runs 1 Euro Filter to smooth out the results
                    if args.smooth:
                        min_cutoff = args.smooth_min_cutoff  # 0.004
                        beta = args.smooth_beta  # 1.5
                        print(
                            f'Running smoothing on person {person_id}, min_cutoff: {min_cutoff}, beta: {beta}'
                        )
                        pred_verts, pred_pose, pred_joints3d = smooth_pose(
                            pred_pose,
                            pred_betas,
                            min_cutoff=min_cutoff,
                            beta=beta)

                    orig_cam = convert_crop_cam_to_orig_img(
                        cam=pred_cam,
                        bbox=bboxes,
                        img_width=orig_width,
                        img_height=orig_height)

                    output_dict = {
                        'pred_cam': pred_cam,
                        'orig_cam': orig_cam,
                        'verts': pred_verts,
                        'pose': pred_pose,
                        'betas': pred_betas,
                        'joints3d': pred_joints3d,
                        'joints2d': joints2d,
                        'bboxes': bboxes,
                        'frame_ids': frames,
                    }

                    for i in range(len(output_dict['joints3d'])):
                        flat_arr = output_dict['joints3d'][i].flatten()
                        len_N = len(flat_arr)
                        np.savetxt(joints3D_csv,
                                   [np.append(flat_arr, [video_label])],
                                   delimiter=',',
                                   fmt=' '.join(['%f'] * len_N + ['%i']))

                    for i in range(len(output_dict['pose'])):
                        pose_arr = output_dict['pose'][i].flatten()
                        len_M = len(pose_arr)
                        np.savetxt(pose_csv,
                                   [np.append(pose_arr, [video_label])],
                                   delimiter=',',
                                   fmt=' '.join(['%f'] * len_M + ['%i']))

                end = time.time()
                fps = num_frames / (end - vibe_time)

                print(f'VIBE FPS: {fps:.2f}')
                total_time = time.time() - total_time
                print(
                    f'Total time spent: {total_time:.2f} seconds (including model loading time).'
                )
                print(
                    f'Total FPS (including model loading time): {num_frames / total_time:.2f}.'
                )
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    map_vals = {
        'bridge': 1,
        'childs': 2,
        'downwarddog': 3,
        'mountain': 4,
        'plank': 5,
        'seatedforwardbend': 6,
        'tree': 7,
        'trianglepose': 8,
        'warrior1': 9,
        'warrior2': 10
    }

    inverse_map = {
        1: 'bridge',
        2: 'childs',
        3: 'downwarddog',
        4: 'mountain',
        5: 'plank',
        6: 'seatedforwardbend',
        7: 'tree',
        8: 'trianglepose',
        9: 'warrior1',
        10: 'warrior2'
    }

    video_file = args.vid_file
    # ========= [Optional] download the youtube video ========= #
    if video_file.startswith('https://www.youtube.com'):
        print(f'Donwloading YouTube video \"{video_file}\"')
        video_file = download_youtube_clip(video_file, '/tmp')

        if video_file is None:
            exit('Youtube url is not valid!')

        print(f'YouTube Video has been downloaded to {video_file}...')

    if not os.path.isfile(video_file):
        exit(f'Input video \"{video_file}\" does not exist!')

    dir_path = '/home/ubuntu/PoseEstimation/VIBE/InputData/input_test_set/'
    output_folder = '/home/ubuntu/PoseEstimation/VIBE/OutputData/'

    # ========= Define VIBE model ========= #
    model = VIBE_Demo(
        seqlen=16,
        n_layers=2,
        hidden_size=1024,
        add_linear=True,
        use_residual=True,
    ).to(device)

    # ========= Load Classification Model ========= #
    classification_model = pickle.load(
        open('view_classification_model.pkl', 'rb'))

    # ========= Load pretrained weights ========= #
    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file)
    #print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    #print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    image_folder, num_frames, img_shape = video_to_images(video_file,
                                                          return_info=True)

    print(f'Input video number of frames {num_frames}')
    orig_height, orig_width = img_shape[:2]

    total_time = time.time()

    # ========= Run tracking ========= #
    bbox_scale = 1.1
    if args.tracking_method == 'pose':
        if not os.path.isabs(video_file):
            video_file = os.path.join(os.getcwd(), video_file)
        tracking_results = run_posetracker(video_file,
                                           staf_folder=args.staf_dir,
                                           display=args.display)
    else:
        # run multi object tracker
        mot = MPT(
            device=device,
            batch_size=args.tracker_batch_size,
            display=args.display,
            detector_type=args.detector,
            output_format='dict',
            yolo_img_size=args.yolo_img_size,
        )
        tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]

    # ========= Run VIBE on each person ========= #
    #print(f'Running VIBE on each tracklet...')
    vibe_time = time.time()
    vibe_results = {}
    for person_id in list(tracking_results.keys()):
        bboxes = joints2d = None

        if args.tracking_method == 'bbox':
            bboxes = tracking_results[person_id]['bbox']
        elif args.tracking_method == 'pose':
            joints2d = tracking_results[person_id]['joints2d']

        frames = tracking_results[person_id]['frames']

        dataset = Inference(
            image_folder=image_folder,
            frames=frames,
            bboxes=bboxes,
            joints2d=joints2d,
            scale=bbox_scale,
        )

        bboxes = dataset.bboxes
        frames = dataset.frames
        has_keypoints = True if joints2d is not None else False

        dataloader = DataLoader(dataset,
                                batch_size=args.vibe_batch_size,
                                num_workers=16)

        with torch.no_grad():

            pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

            for batch in dataloader:
                if has_keypoints:
                    batch, nj2d = batch
                    norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                batch = batch.unsqueeze(0)
                batch = batch.to(device)

                batch_size, seqlen = batch.shape[:2]
                output = model(batch)[-1]

                pred_cam.append(output['theta'][:, :, :3].reshape(
                    batch_size * seqlen, -1))
                pred_verts.append(output['verts'].reshape(
                    batch_size * seqlen, -1, 3))
                pred_pose.append(output['theta'][:, :, 3:75].reshape(
                    batch_size * seqlen, -1))
                pred_betas.append(output['theta'][:, :, 75:].reshape(
                    batch_size * seqlen, -1))
                pred_joints3d.append(output['kp_3d'].reshape(
                    batch_size * seqlen, -1, 3))

            pred_cam = torch.cat(pred_cam, dim=0)
            pred_verts = torch.cat(pred_verts, dim=0)
            pred_pose = torch.cat(pred_pose, dim=0)
            pred_betas = torch.cat(pred_betas, dim=0)
            pred_joints3d = torch.cat(pred_joints3d, dim=0)

            del batch

        # ========= [Optional] run Temporal SMPLify to refine the results ========= #
        if args.run_smplify and args.tracking_method == 'pose':
            norm_joints2d = np.concatenate(norm_joints2d, axis=0)
            norm_joints2d = convert_kps(norm_joints2d, src='staf', dst='spin')
            norm_joints2d = torch.from_numpy(norm_joints2d).float().to(device)

            # Run Temporal SMPLify
            update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
            new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
                pred_rotmat=pred_pose,
                pred_betas=pred_betas,
                pred_cam=pred_cam,
                j2d=norm_joints2d,
                device=device,
                batch_size=norm_joints2d.shape[0],
                pose2aa=False,
            )

            # update the parameters after refinement
            print(
                f'Update ratio after Temporal SMPLify: {update.sum()} / {norm_joints2d.shape[0]}'
            )
            pred_verts = pred_verts.cpu()
            pred_cam = pred_cam.cpu()
            pred_pose = pred_pose.cpu()
            pred_betas = pred_betas.cpu()
            pred_joints3d = pred_joints3d.cpu()
            pred_verts[update] = new_opt_vertices[update]
            pred_cam[update] = new_opt_cam[update]
            pred_pose[update] = new_opt_pose[update]
            pred_betas[update] = new_opt_betas[update]
            pred_joints3d[update] = new_opt_joints3d[update]

        elif args.run_smplify and args.tracking_method == 'bbox':
            print(
                '[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!'
            )
            print('[WARNING] Continuing without running Temporal SMPLify!..')

        # ========= Save results to a pickle file ========= #
        pred_cam = pred_cam.cpu().numpy()
        pred_verts = pred_verts.cpu().numpy()
        pred_pose = pred_pose.cpu().numpy()
        pred_betas = pred_betas.cpu().numpy()
        pred_joints3d = pred_joints3d.cpu().numpy()

        # Runs 1 Euro Filter to smooth out the results
        if args.smooth:
            min_cutoff = args.smooth_min_cutoff  # 0.004
            beta = args.smooth_beta  # 1.5
            print(
                f'Running smoothing on person {person_id}, min_cutoff: {min_cutoff}, beta: {beta}'
            )
            pred_verts, pred_pose, pred_joints3d = smooth_pose(
                pred_pose, pred_betas, min_cutoff=min_cutoff, beta=beta)

        orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                bbox=bboxes,
                                                img_width=orig_width,
                                                img_height=orig_height)

        output_dict = {
            'pred_cam': pred_cam,
            'orig_cam': orig_cam,
            'verts': pred_verts,
            'pose': pred_pose,
            'betas': pred_betas,
            'joints3d': pred_joints3d,
            'joints2d': joints2d,
            'bboxes': bboxes,
            'frame_ids': frames,
        }
        # ========= Extract 3D joint feature for each frame ========= #
        list_val = []
        for i in range(len(output_dict['joints3d'])):
            list_val.append(output_dict['joints3d'][i].flatten().reshape(
                1, -1))

        input_df = pd.DataFrame(np.concatenate(list_val))
        input_df = input_df.round(2)
        predicted_classes = classification_model.predict_classes(input_df)
        output_df = pd.DataFrame(predicted_classes)
        # ========= Printing all possible poses detected for the video ========= #
        total_frames = len(output_df)
        print(
            '\nPrinting probabilities for yoga poses predicted in different frames.'
        )
        for i, v in output_df.value_counts().items():
            val = round((v / total_frames) * 100, 2)
            print('Probability of the yoga pose being ' +
                  inverse_map[i[0]].capitalize() + " is: " + str(val))
        print('\nThe yoga pose in the given video is: ' +
              inverse_map[output_df[0].value_counts().idxmax()].capitalize())
Пример #7
0
def main(args):
    torch.cuda.set_device(args.gpu_id)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    print(f'Loading video list {args.video_list}')
    video_list = [l.strip() for l in open(args.video_list, 'r').readlines()]
    if len(video_list) < 1:
        print('No files were found in video list')
        return

    print('Loading VIBE model')
    # ========= Define VIBE model ========= #
    model = VIBE_Demo(
        seqlen=16,
        n_layers=2,
        hidden_size=1024,
        add_linear=True,
        use_residual=True,
    ).to(device)

    # ========= Load VIBE pretrained weights ========= #
    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file)
    print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    num_videos = len(video_list)
    print(f'Processing {num_videos} videos.')
    for video_idx, video_file in enumerate(video_list, start=1):
        if not osp.isfile(video_file):
            print(
                f'Input video \"{video_file}\" does not exist! Moving on to next file.'
            )
            continue

        filename = osp.splitext(osp.basename(video_file))[0]
        output_path = osp.join(args.output_folder, filename)
        os.makedirs(output_path, exist_ok=True)

        image_folder, num_frames, img_shape = video_to_images(video_file,
                                                              return_info=True)

        print(f'[{video_idx}/{num_videos}] Processing {num_frames} frames')
        orig_height, orig_width = img_shape[:2]

        # ========= Run tracking ========= #
        bbox_scale = 1.1
        if args.tracking_method == 'pose':
            if not osp.isabs(video_file):
                video_file = osp.join(os.getcwd(), video_file)
            tracking_results = run_posetracker(video_file,
                                               staf_folder=args.staf_dir,
                                               display=args.display)
        else:
            # run multi object tracker
            mot = MPT(
                device=device,
                batch_size=args.tracker_batch_size,
                display=args.display,
                detector_type=args.detector,
                output_format='dict',
                yolo_img_size=args.yolo_img_size,
            )
            tracking_results = mot(image_folder)

        # remove tracklets if num_frames is less than MIN_NUM_FRAMES
        for person_id in list(tracking_results.keys()):
            if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
                del tracking_results[person_id]

        # ========= Run VIBE on each person ========= #
        print(f'Running VIBE on each tracklet...')
        vibe_results = {}
        for person_id in tqdm(list(tracking_results.keys())):
            bboxes = joints2d = None

            if args.tracking_method == 'bbox':
                bboxes = tracking_results[person_id]['bbox']
            elif args.tracking_method == 'pose':
                joints2d = tracking_results[person_id]['joints2d']

            frames = tracking_results[person_id]['frames']

            dataset = Inference(
                image_folder=image_folder,
                frames=frames,
                bboxes=bboxes,
                joints2d=joints2d,
                scale=bbox_scale,
            )

            bboxes = dataset.bboxes
            frames = dataset.frames
            has_keypoints = True if joints2d is not None else False

            dataloader = DataLoader(dataset,
                                    batch_size=args.vibe_batch_size,
                                    num_workers=16)

            with torch.no_grad():

                pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

                for batch in dataloader:
                    if has_keypoints:
                        batch, nj2d = batch
                        norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                    batch = batch.unsqueeze(0)
                    batch = batch.to(device)

                    batch_size, seqlen = batch.shape[:2]
                    output = model(batch)[-1]

                    pred_cam.append(output['theta'][:, :, :3].reshape(
                        batch_size * seqlen, -1))
                    pred_verts.append(output['verts'].reshape(
                        batch_size * seqlen, -1, 3))
                    pred_pose.append(output['theta'][:, :, 3:75].reshape(
                        batch_size * seqlen, -1))
                    pred_betas.append(output['theta'][:, :, 75:].reshape(
                        batch_size * seqlen, -1))
                    pred_joints3d.append(output['kp_3d'].reshape(
                        batch_size * seqlen, -1, 3))

                pred_cam = torch.cat(pred_cam, dim=0)
                pred_verts = torch.cat(pred_verts, dim=0)
                pred_pose = torch.cat(pred_pose, dim=0)
                pred_betas = torch.cat(pred_betas, dim=0)
                pred_joints3d = torch.cat(pred_joints3d, dim=0)

                del batch

            # ========= [Optional] run Temporal SMPLify to refine the results ========= #
            if args.run_smplify and args.tracking_method == 'pose':
                norm_joints2d = np.concatenate(norm_joints2d, axis=0)
                norm_joints2d = convert_kps(norm_joints2d,
                                            src='staf',
                                            dst='spin')
                norm_joints2d = torch.from_numpy(norm_joints2d).float().to(
                    device)

                # Run Temporal SMPLify
                update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
                new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
                    pred_rotmat=pred_pose,
                    pred_betas=pred_betas,
                    pred_cam=pred_cam,
                    j2d=norm_joints2d,
                    device=device,
                    batch_size=norm_joints2d.shape[0],
                    pose2aa=False,
                )

                # update the parameters after refinement
                print(
                    f'Update ratio after Temporal SMPLify: {update.sum()} / {norm_joints2d.shape[0]}'
                )
                pred_verts = pred_verts.cpu()
                pred_cam = pred_cam.cpu()
                pred_pose = pred_pose.cpu()
                pred_betas = pred_betas.cpu()
                pred_joints3d = pred_joints3d.cpu()
                pred_verts[update] = new_opt_vertices[update]
                pred_cam[update] = new_opt_cam[update]
                pred_pose[update] = new_opt_pose[update]
                pred_betas[update] = new_opt_betas[update]
                pred_joints3d[update] = new_opt_joints3d[update]

            elif args.run_smplify and args.tracking_method == 'bbox':
                print(
                    '[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!'
                )
                print(
                    '[WARNING] Continuing without running Temporal SMPLify!..')

            # ========= Save results to a pickle file ========= #
            pred_cam = pred_cam.cpu().numpy()
            pred_verts = pred_verts.cpu().numpy()
            pred_pose = pred_pose.cpu().numpy()
            pred_betas = pred_betas.cpu().numpy()
            pred_joints3d = pred_joints3d.cpu().numpy()

            orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                    bbox=bboxes,
                                                    img_width=orig_width,
                                                    img_height=orig_height)

            output_dict = {
                'pred_cam': pred_cam,
                'orig_cam': orig_cam,
                'verts': pred_verts,
                'pose': pred_pose,
                'betas': pred_betas,
                'joints3d': pred_joints3d,
                'joints2d': joints2d,
                'bboxes': bboxes,
                'frame_ids': frames,
            }

            vibe_results[person_id] = output_dict

        # Clean-up the temporal folder
        shutil.rmtree(image_folder)

        # Save the outputs to joblib pkl file. File is loaded through joblib.load(pkl_path)
        output_pkl_path = osp.join(args.output_folder, f'{filename}.pkl')
        print(f'Saving output results to \"{output_pkl_path}\".')
        joblib.dump(vibe_results, output_pkl_path)

    # Clean-up after processing
    del model

    print('================= END =================')