Ejemplo n.º 1
0
    def __init__(self,
                 seqlen,
                 batch_size=64,
                 n_layers=1,
                 hidden_size=2048,
                 pretrained='data/vibe_data/spin_model_checkpoint.pth.tar',
                 add_linear=False,
                 bidirectional=False,
                 attention=False,
                 attention_cfg=None,
                 use_residual=True,
                 disable_temporal=False):

        super(VIBE_Demo, self).__init__()

        self.seqlen = seqlen
        self.batch_size = batch_size
        self.disable_temporal = disable_temporal

        if attention:
            cfg = attention_cfg
            self.encoder = TemporalEncoderWAttention(
                hidden_size=hidden_size,
                bidirectional=bidirectional,
                add_linear=add_linear,
                attention_size=cfg.SIZE,
                attention_layers=cfg.LAYERS,
                attention_dropout=cfg.DROPOUT,
                use_residual=use_residual,
            )
        else:
            self.encoder = TemporalEncoder(
                n_layers=n_layers,
                hidden_size=hidden_size,
                bidirectional=bidirectional,
                add_linear=add_linear,
                use_residual=use_residual,
            )

        self.hmr = hmr()
        if torch.cuda.is_available():
            checkpoint = torch.load(pretrained)
        else:
            checkpoint = torch.load(pretrained,
                                    map_location=torch.device('cpu'))

        self.hmr.load_state_dict(checkpoint['model'], strict=False)

        # regressor can predict cam, pose and shape params in an iterative way
        self.regressor = Regressor()

        if pretrained and os.path.isfile(pretrained):
            if torch.cuda.is_available():
                pretrained_dict = torch.load(pretrained)['model']
            else:
                pretrained_dict = torch.load(
                    pretrained, map_location=torch.device('cpu'))['model']

            self.regressor.load_state_dict(pretrained_dict, strict=False)
            print(f'=> loaded pretrained model from \'{pretrained}\'')
Ejemplo n.º 2
0
    def __init__(
            self,
            seqlen,
            batch_size=64,
            n_layers=1,
            hidden_size=2048,
            add_linear=False,
            bidirectional=False,
            use_residual=True,
            pretrained=osp.join(VIBE_DATA_DIR,
                                'spin_model_checkpoint.pth.tar'),
    ):

        super(e2e_VIBE, self).__init__()
        self.seqlen = seqlen
        self.batch_size = batch_size

        self.encoder = TemporalEncoder(
            n_layers=n_layers,
            hidden_size=hidden_size,
            bidirectional=bidirectional,
            add_linear=add_linear,
            use_residual=use_residual,
        )
        self.hmr = hmr()
        self.regressor = Regressor()
Ejemplo n.º 3
0
 def __init__(
         self,
         vibe,
         cfg = "zen_vis_5", 
         pretrained=osp.join(VIBE_DATA_DIR, 'spin_model_checkpoint.pth.tar'),
 ):
     super().__init__(vibe, cfg)
     self.hmr = hmr()
     checkpoint = torch.load(pretrained)
     self.hmr.load_state_dict(checkpoint['model'], strict=False)
Ejemplo n.º 4
0
Archivo: vibe.py Proyecto: L4zyy/CVIBE
    def __init__(
            self,
            seqlen,
            batch_size=64,
            n_layers=1,
            hidden_size=2048,
            add_linear=False,
            bidirectional=False,
            use_residual=True,
            pretrained=osp.join(VIBE_DATA_DIR, 'spin_model_checkpoint.pth.tar'),
            temporal_type='gru',
            n_head=0,
    ):

        super(SA_VIBE_Demo, self).__init__()

        self.seqlen = seqlen
        self.batch_size = batch_size

        self.encoder = TemporalEncoder(
            n_layers=n_layers,
            hidden_size=hidden_size,
            bidirectional=bidirectional,
            add_linear=add_linear,
            use_residual=use_residual,
            temporal_type=temporal_type,
            seqlen=self.seqlen,
            n_head=n_head,
        )

        self.hmr = hmr()
        checkpoint = torch.load(pretrained)
        self.hmr.load_state_dict(checkpoint['model'], strict=False)

        # regressor can predict cam, pose and shape params in an iterative way
        self.regressor = Regressor()

        if pretrained and os.path.isfile(pretrained):
            pretrained_dict = torch.load(pretrained)['model']

            self.regressor.load_state_dict(pretrained_dict, strict=False)
            print(f'=> loaded pretrained model from \'{pretrained}\'')
Ejemplo n.º 5
0
    def forward(self, input, J_regressor=None):
        self.hmr_model = hmr()
        checkpoint = torch.load(self.pretrained)
        self.hmr_model.load_state_dict(checkpoint['model'], strict=False)
        self.hmr_model.to(input.device)

        # input size NTF
        batch_size, seqlen, nc, h, w = input.shape
        feature = self.hmr_model.feature_extractor(input.reshape(-1, nc, h, w))
        input = feature.reshape(batch_size, seqlen, -1)
        batch_size, seqlen = input.shape[:2]

        feature = self.feat_encoder(input)
        motion_z = self.motion_encoder(feature).mean(dim=1)
        vae_init_pose = self.vae_init_mlp(motion_z)
        X_r = self.vae_model.decode(vae_init_pose[None, :, :], motion_z)
        X_r = X_r.permute(1, 0, 2)[:, :seqlen, :]

        feature = feature.reshape(-1, feature.size(-1))
        # init_pose = X_r.reshape(-1, X_r.shape[-1])

        init_pose = rotmat_to_6d(convert_orth_6d_to_mat(X_r)).reshape(
            -1, X_r.shape[-1])

        smpl_output = self.regressor(feature,
                                     J_regressor=J_regressor,
                                     init_pose=init_pose)

        for s in smpl_output:
            s['theta'] = s['theta'].reshape(batch_size, seqlen, -1)
            s['verts'] = s['verts'].reshape(batch_size, seqlen, -1, 3)
            s['kp_2d'] = s['kp_2d'].reshape(batch_size, seqlen, -1, 2)
            s['kp_3d'] = s['kp_3d'].reshape(batch_size, seqlen, -1, 3)
            s['rotmat'] = s['rotmat'].reshape(batch_size, seqlen, -1, 3, 3)

        return smpl_output
Ejemplo n.º 6
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    """ Prepare input video (images) """
    video_file = args.vid_file
    if video_file.startswith('https://www.youtube.com'):
        print(f"Donwloading YouTube video \'{video_file}\'")
        video_file = download_youtube_clip(video_file, '/tmp')
        if video_file is None:
            exit('Youtube url is not valid!')
        print(f"YouTube Video has been downloaded to {video_file}...")

    if not os.path.isfile(video_file):
        exit(f"Input video \'{video_file}\' does not exist!")

    output_path = osp.join('./output/demo_output',
                           os.path.basename(video_file).replace('.mp4', ''))
    Path(output_path).mkdir(parents=True, exist_ok=True)
    image_folder, num_frames, img_shape = video_to_images(video_file,
                                                          return_info=True)

    print(f"Input video number of frames {num_frames}\n")
    orig_height, orig_width = img_shape[:2]
    """ Run tracking """
    total_time = time.time()
    bbox_scale = 1.2
    # run multi object tracker
    mot = MPT(
        device=device,
        batch_size=args.tracker_batch_size,
        display=args.display,
        detector_type=args.detector,
        output_format='dict',
        yolo_img_size=args.yolo_img_size,
    )
    tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]
    """ Get TCMR model """
    seq_len = 16
    model = TCMR(seqlen=seq_len, n_layers=2, hidden_size=1024).to(device)

    # Load pretrained weights
    pretrained_file = args.model
    ckpt = torch.load(pretrained_file)
    print(f"Load pretrained weights from \'{pretrained_file}\'")
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)

    # Change mesh gender
    gender = args.gender  # 'neutral', 'male', 'female'
    model.regressor.smpl = SMPL(SMPL_MODEL_DIR,
                                batch_size=64,
                                create_transl=False,
                                gender=gender).cuda()

    model.eval()

    # Get feature_extractor
    from lib.models.spin import hmr
    hmr = hmr().to(device)
    checkpoint = torch.load(
        osp.join(BASE_DATA_DIR, 'spin_model_checkpoint.pth.tar'))
    hmr.load_state_dict(checkpoint['model'], strict=False)
    hmr.eval()
    """ Run TCMR on each person """
    print("\nRunning TCMR on each person tracklet...")
    tcmr_time = time.time()
    tcmr_results = {}
    for person_id in tqdm(list(tracking_results.keys())):
        bboxes = joints2d = None
        bboxes = tracking_results[person_id]['bbox']
        frames = tracking_results[person_id]['frames']

        # Prepare static image features
        dataset = CropDataset(
            image_folder=image_folder,
            frames=frames,
            bboxes=bboxes,
            joints2d=joints2d,
            scale=bbox_scale,
        )

        bboxes = dataset.bboxes
        frames = dataset.frames
        has_keypoints = True if joints2d is not None else False

        crop_dataloader = DataLoader(dataset, batch_size=256, num_workers=16)

        with torch.no_grad():
            feature_list = []
            for i, batch in enumerate(crop_dataloader):
                if has_keypoints:
                    batch, nj2d = batch
                    norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                batch = batch.to(device)
                feature = hmr.feature_extractor(batch.reshape(-1, 3, 224, 224))
                feature_list.append(feature.cpu())

            del batch

            feature_list = torch.cat(feature_list, dim=0)

        # Encode temporal features and estimate 3D human mesh
        dataset = FeatureDataset(
            image_folder=image_folder,
            frames=frames,
            seq_len=seq_len,
        )
        dataset.feature_list = feature_list

        dataloader = DataLoader(dataset, batch_size=64, num_workers=32)
        with torch.no_grad():
            pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

            for i, batch in enumerate(dataloader):
                if has_keypoints:
                    batch, nj2d = batch
                    norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                batch = batch.to(device)
                output = model(batch)[0][-1]

                pred_cam.append(output['theta'][:, :3])
                pred_verts.append(output['verts'])
                pred_pose.append(output['theta'][:, 3:75])
                pred_betas.append(output['theta'][:, 75:])
                pred_joints3d.append(output['kp_3d'])

            pred_cam = torch.cat(pred_cam, dim=0)
            pred_verts = torch.cat(pred_verts, dim=0)
            pred_pose = torch.cat(pred_pose, dim=0)
            pred_betas = torch.cat(pred_betas, dim=0)
            pred_joints3d = torch.cat(pred_joints3d, dim=0)

            del batch

        # ========= Save results to a pickle file ========= #
        pred_cam = pred_cam.cpu().numpy()
        pred_verts = pred_verts.cpu().numpy()
        pred_pose = pred_pose.cpu().numpy()
        pred_betas = pred_betas.cpu().numpy()
        pred_joints3d = pred_joints3d.cpu().numpy()

        bboxes[:, 2:] = bboxes[:, 2:] * 1.2
        if args.render_plain:
            pred_cam[:, 0], pred_cam[:, 1:] = 1, 0  # np.array([[1, 0, 0]])
        orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                bbox=bboxes,
                                                img_width=orig_width,
                                                img_height=orig_height)

        output_dict = {
            'pred_cam': pred_cam,
            'orig_cam': orig_cam,
            'verts': pred_verts,
            'pose': pred_pose,
            'betas': pred_betas,
            'joints3d': pred_joints3d,
            'joints2d': joints2d,
            'bboxes': bboxes,
            'frame_ids': frames,
        }

        tcmr_results[person_id] = output_dict

    del model

    end = time.time()
    fps = num_frames / (end - tcmr_time)
    print(f'TCMR FPS: {fps:.2f}')
    total_time = time.time() - total_time
    print(
        f'Total time spent: {total_time:.2f} seconds (including model loading time).'
    )
    print(
        f'Total FPS (including model loading time): {num_frames / total_time:.2f}.'
    )

    if args.save_pkl:
        print(
            f"Saving output results to \'{os.path.join(output_path, 'tcmr_output.pkl')}\'."
        )
        joblib.dump(tcmr_results, os.path.join(output_path, "tcmr_output.pkl"))
    """ Render results as a single video """
    renderer = Renderer(resolution=(orig_width, orig_height),
                        orig_img=True,
                        wireframe=args.wireframe)

    output_img_folder = f'{image_folder}_output'
    input_img_folder = f'{image_folder}_input'
    os.makedirs(output_img_folder, exist_ok=True)
    os.makedirs(input_img_folder, exist_ok=True)

    print(f"\nRendering output video, writing frames to {output_img_folder}")
    # prepare results for rendering
    frame_results = prepare_rendering_results(tcmr_results, num_frames)
    mesh_color = {
        k: colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0)
        for k in tcmr_results.keys()
    }

    image_file_names = sorted([
        os.path.join(image_folder, x) for x in os.listdir(image_folder)
        if x.endswith('.png') or x.endswith('.jpg')
    ])

    for frame_idx in tqdm(range(len(image_file_names))):
        img_fname = image_file_names[frame_idx]
        img = cv2.imread(img_fname)
        input_img = img.copy()
        if args.render_plain:
            img[:] = 0

        if args.sideview:
            side_img = np.zeros_like(img)

        for person_id, person_data in frame_results[frame_idx].items():
            frame_verts = person_data['verts']
            frame_cam = person_data['cam']

            mesh_filename = None
            if args.save_obj:
                mesh_folder = os.path.join(output_path, 'meshes',
                                           f'{person_id:04d}')
                Path(mesh_folder).mkdir(parents=True, exist_ok=True)
                mesh_filename = os.path.join(mesh_folder,
                                             f'{frame_idx:06d}.obj')

            mc = mesh_color[person_id]

            img = renderer.render(
                img,
                frame_verts,
                cam=frame_cam,
                color=mc,
                mesh_filename=mesh_filename,
            )
            if args.sideview:
                side_img = renderer.render(
                    side_img,
                    frame_verts,
                    cam=frame_cam,
                    color=mc,
                    angle=270,
                    axis=[0, 1, 0],
                )

        if args.sideview:
            img = np.concatenate([img, side_img], axis=1)

        # save output frames
        cv2.imwrite(os.path.join(output_img_folder, f'{frame_idx:06d}.jpg'),
                    img)
        cv2.imwrite(os.path.join(input_img_folder, f'{frame_idx:06d}.jpg'),
                    input_img)

        if args.display:
            cv2.imshow('Video', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if args.display:
        cv2.destroyAllWindows()
    """ Save rendered video """
    vid_name = os.path.basename(video_file)
    save_name = f'tcmr_{vid_name.replace(".mp4", "")}_output.mp4'
    save_path = os.path.join(output_path, save_name)

    images_to_video(img_folder=output_img_folder, output_vid_file=save_path)
    images_to_video(img_folder=input_img_folder,
                    output_vid_file=os.path.join(output_path, vid_name))
    print(f"Saving result video to {os.path.abspath(save_path)}")
    shutil.rmtree(output_img_folder)
    shutil.rmtree(input_img_folder)
    shutil.rmtree(image_folder)