Beispiel #1
0
    def fit(self, keypoints3d, dtype='coco', verbose=True):
        """Run fitting to optimize the SMPL parameters."""
        assert dtype == 'coco', 'only support coco format for now.'
        assert len(
            keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]'
        mapping_target = unify_joint_mappings(dataset=dtype)
        keypoints3d = keypoints3d[:, mapping_target, :]
        keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device)
        batch_size, njoints = keypoints3d.shape[0:2]

        # Init learnable smpl model
        smpl = SMPL(model_path=self.smpl_model_path,
                    gender=self.smpl_model_gender,
                    batch_size=batch_size).to(self.device)

        # Start fitting
        for step in range(self.niter):
            optimizer = self.get_optimizer(smpl, step, self.base_lr)

            output = smpl.forward()
            joints = output.joints[:, self.joints_mapping_smpl[:njoints], :]
            loss = self.metric(joints, keypoints3d)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if verbose and step % 10 == 0:
                logging.info(f'step {step:03d}; loss {loss.item():.3f};')

        # Return results
        return smpl, loss.item()
Beispiel #2
0
def compute_pve_neutral_pose_scale_corrected(predicted_smpl_shape,
                                             target_smpl_shape, gender):
    """
    Given predicted and target SMPL shape parameters, computes neutral-pose per-vertex error
    after scale-correction (to account for scale vs camera depth ambiguity).
    :param predicted_smpl_parameters: predicted SMPL shape parameters tensor with shape (1, 10)
    :param target_smpl_parameters: target SMPL shape parameters tensor with shape (1, 10)
    :param gender: gender of target
    """
    smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='male')
    smpl_female = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='female')

    # Get neutral pose vertices
    if gender == 'm':
        pred_smpl_neutral_pose_output = smpl_male(betas=predicted_smpl_shape)
        target_smpl_neutral_pose_output = smpl_male(betas=target_smpl_shape)
    elif gender == 'f':
        pred_smpl_neutral_pose_output = smpl_female(betas=predicted_smpl_shape)
        target_smpl_neutral_pose_output = smpl_female(betas=target_smpl_shape)

    pred_smpl_neutral_pose_vertices = pred_smpl_neutral_pose_output.vertices
    target_smpl_neutral_pose_vertices = target_smpl_neutral_pose_output.vertices

    # Rescale such that RMSD of predicted vertex mesh is the same as RMSD of target mesh.
    # This is done to combat scale vs camera depth ambiguity.
    pred_smpl_neutral_pose_vertices_rescale = scale_and_translation_transform_batch(
        pred_smpl_neutral_pose_vertices, target_smpl_neutral_pose_vertices)

    # Compute PVE-T-SC
    pve_neutral_pose_scale_corrected = np.linalg.norm(
        pred_smpl_neutral_pose_vertices_rescale -
        target_smpl_neutral_pose_vertices,
        axis=-1)  # (1, 6890)

    return pve_neutral_pose_scale_corrected
Beispiel #3
0
    def __init__(self, smpl_path, joints_regressor):
        super().__init__()

        assert has_smpl, 'Please install smplx to use SMPL.'

        self.smpl_neutral = SMPL_(
            model_path=smpl_path,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='neutral')

        self.smpl_male = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='male')

        self.smpl_female = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='female')

        joints_regressor = torch.tensor(
            np.load(joints_regressor), dtype=torch.float)[None, ...]
        self.register_buffer('joints_regressor', joints_regressor)

        self.num_verts = self.smpl_neutral.get_num_verts()
        self.num_joints = self.joints_regressor.shape[1]
Beispiel #4
0
def main(_):
  # Parsing data info.
  aist_dataset = AISTDataset(FLAGS.anno_dir)
  video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
  seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
  view_idx = AISTDataset.VIEWS.index(view)

  # Parsing keypoints.
  if FLAGS.mode == '2D':  # raw keypoints detection results.
    keypoints2d, _, _ = AISTDataset.load_keypoint2d(
        aist_dataset.keypoint2d_dir, seq_name)
    keypoints2d = keypoints2d[view_idx, :, :, 0:2]

  elif FLAGS.mode == '3D':  # 3D keypoints with temporal optimization.
    keypoints3d = AISTDataset.load_keypoint3d(
        aist_dataset.keypoint3d_dir, seq_name, use_optim=True)
    nframes, njoints, _ = keypoints3d.shape
    env_name = aist_dataset.mapping_seq2env[seq_name]
    cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
    keypoints2d = cgroup.project(keypoints3d)
    keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

  elif FLAGS.mode == 'SMPL':  # SMPL joints
    smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
        aist_dataset.motion_dir, seq_name)
    smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
    keypoints3d = smpl.forward(
        global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
        body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
        transl=torch.from_numpy(smpl_trans).float(),
        scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).joints.detach().numpy()

    nframes, njoints, _ = keypoints3d.shape
    env_name = aist_dataset.mapping_seq2env[seq_name]
    cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
    keypoints2d = cgroup.project(keypoints3d)
    keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

  # Visualize.
  os.makedirs(FLAGS.save_dir, exist_ok=True)
  save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4')
  plot_on_video(keypoints2d, video_path, save_path, fps=60)
Beispiel #5
0
    def fit(self, keypoints3d, dtype='coco', verbose=True):
        """Run fitting to optimize the SMPL parameters."""
        assert dtype == 'coco', 'only support coco format for now.'
        assert len(
            keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]'
        mapping_target = unify_joint_mappings(dataset=dtype)
        keypoints3d = keypoints3d[:, mapping_target, :]
        keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device)
        batch_size, njoints = keypoints3d.shape[0:2]

        # Init learnable smpl model
        smpl = SMPL(model_path=self.smpl_model_path,
                    gender=self.smpl_model_gender,
                    batch_size=batch_size).to(self.device)

        # Start fitting
        for step in range(self.niter):
            optimizer = self.get_optimizer(smpl, step, self.base_lr)

            output = smpl.forward()
            joints = output.joints[:, self.joints_mapping_smpl[:njoints], :]
            loss = self.metric(joints, keypoints3d)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if verbose and step % 10 == 0:
                logging.info(f'step {step:03d}; loss {loss.item():.3f};')

            if FLAGS.visualize:
                vertices = output.vertices[0].detach().cpu().numpy(
                )  # first frame
                mesh = trimesh.Trimesh(vertices, smpl.faces)
                mesh.visual.face_colors = [200, 200, 250, 100]
                pts = vedo.Points(keypoints3d[0].detach().cpu().numpy(),
                                  r=20)  # first frame
                vedo.show(mesh, pts, interactive=False)

        # Return results
        return smpl, loss.item()
def load_dance_data(dance_dir):
    print('---------- Loading pose keypoints ----------')
    aist_dataset = AISTDataset(dance_dir)
    seq_names = list(aist_dataset.mapping_seq2env.keys())
    print(seq_names)

    dances = {}

    for seq_name in tqdm(seq_names):
        print(f'Process -> {seq_name}')
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=args.smpl_dir, gender='MALE', batch_size=1)
        keypoints3d = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans / smpl_scaling).float(),
        ).joints.detach().numpy()[:, 0:24, :]
        nframes = keypoints3d.shape[0]
        dances[seq_name] = keypoints3d.reshape(nframes, -1).tolist()
        print(np.shape(dances[seq_name]))  # (nframes, 72)

    return dances
Beispiel #7
0
def get_verts(theta, smpl_path):
    device = 'cpu'
    smpl = SMPL(smpl_path, batch_size=1).to(device)

    pose, betas = theta[:, :72], theta[:, 72:]

    verts = []
    b_ = torch.split(betas, 500)
    p_ = torch.split(pose, 500)

    for b, p in zip(b_, p_):
        output = smpl(betas=b,
                      body_pose=p[:, 3:],
                      global_orient=p[:, :3],
                      pose2rot=True)
        verts.append(output.vertices.detach().cpu().numpy())

    verts = np.concatenate(verts, axis=0)
    del smpl
    return verts
Beispiel #8
0
    def __init__(self,
                 backbone,
                 mesh_head,
                 smpl,
                 disc=None,
                 loss_gan=None,
                 loss_mesh=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super().__init__()

        assert has_smpl, 'Please install smplx to use SMPL.'

        self.backbone = builder.build_backbone(backbone)
        self.mesh_head = builder.build_head(mesh_head)
        self.generator = torch.nn.Sequential(self.backbone, self.mesh_head)

        self.smpl = SMPL(
            model_path=smpl['smpl_path'],
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False)

        joints_regressor = torch.tensor(
            np.load(smpl['joints_regressor']), dtype=torch.float).unsqueeze(0)
        self.register_buffer('joints_regressor', joints_regressor)

        self.with_gan = disc is not None and loss_gan is not None
        if self.with_gan:
            self.discriminator = SMPLDiscriminator(**disc)
            self.loss_gan = builder.build_loss(loss_gan)
        self.disc_step_count = 0

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.loss_mesh = builder.build_loss(loss_mesh)
        self.init_weights(pretrained=pretrained)
Beispiel #9
0
def visEFT_singleSubject(renderer):
    inputDir = args.fit_dir
    imgDir = args.img_dir

    smplModelDir = args.smpl_dir
    smpl = SMPL(smplModelDir, batch_size=1, create_transl=False)

    # outputFolder = os.path.basename(inputDir) + '_dpOut'
    # outputFolder =os.path.join('/run/media/hjoo/disk/data/eftout/',outputFolder)
    
    eft_fileList  = listdir(inputDir)       #Check all fitting files
    print(">> Found {} files in the fitting folder {}".format(len(eft_fileList), inputDir))
    totalCnt =0
    erroneousCnt =0

    for idx, f in enumerate(sorted(eft_fileList)):
        
        #Load EFT data
        fileFullPath = join(inputDir, f)
        with open(fileFullPath,'rb') as f:
            eft_data = pickle.load(f)

        #Get raw image path
        imgFullPath = eft_data['imageName'][0]
        imgName = os.path.basename(imgFullPath)
        imgFullPath =os.path.join(imgDir, os.path.basename(imgFullPath) )
        assert os.path.exists(imgFullPath)
        rawImg = cv2.imread(imgFullPath)
        print(f'Input image: {imgFullPath}')

        #EFT data
        bbox_scale = eft_data['scale'][0]
        bbox_center = eft_data['center'][0]

        pred_camera = eft_data['pred_camera']
        pred_betas = torch.from_numpy(eft_data['pred_shape'])
        pred_pose_rotmat = torch.from_numpy(eft_data['pred_pose_rotmat'])        

        #COCO only. Annotation index
        print("COCO annotId: {}".format(eft_data['annotId']))

        #Obtain skeleton and smpl data
        smpl_output = smpl(betas=pred_betas, body_pose=pred_pose_rotmat[:,1:], global_orient=pred_pose_rotmat[:,0].unsqueeze(1), pose2rot=False )
        smpl_vertices = smpl_output.vertices.detach().cpu().numpy() 
        smpl_joints_3d = smpl_output.joints.detach().cpu().numpy() 

        #Crop image
        croppedImg, boxScale_o2n, bboxTopLeft = crop_bboxInfo(rawImg, bbox_center, bbox_scale, (BBOX_IMG_RES, BBOX_IMG_RES) )

        ########################
        # Visualize
        if False:
            #Compute 2D reprojection error
            # if not (data['loss_keypoints_2d']<0.0001 or data['loss_keypoints_2d']>0.001 :
            #     continue
            maxBeta = abs(torch.max( abs(pred_betas)).item())
            if eft_data['loss_keypoints_2d']>0.0005 or maxBeta>3:
                erroneousCnt +=1
            print(">>> loss2d: {}, maxBeta: {}".format( eft_data['loss_keypoints_2d'],maxBeta) )
        
        # Visualize 2D image
        if False:
            viewer2D.ImShow(rawImg, name='rawImg', waitTime=1)      #You should press any key 
            viewer2D.ImShow(croppedImg, name='croppedImg', waitTime=1)

        # Visualization Mesh
        if True:    
            b=0
            camParam_scale = pred_camera[b,0]
            camParam_trans = pred_camera[b,1:]
            pred_vert_vis = smpl_vertices[b]
            smpl_joints_3d_vis = smpl_joints_3d[b]

            if args.onbbox:
                pred_vert_vis = convert_smpl_to_bbox(pred_vert_vis, camParam_scale, camParam_trans)
                smpl_joints_3d_vis = convert_smpl_to_bbox(smpl_joints_3d_vis, camParam_scale, camParam_trans)
                renderer.setBackgroundTexture(croppedImg)
                renderer.setViewportSize(croppedImg.shape[1], croppedImg.shape[0])
            else:
                #Covert SMPL to BBox first
                pred_vert_vis = convert_smpl_to_bbox(pred_vert_vis, camParam_scale, camParam_trans)
                smpl_joints_3d_vis = convert_smpl_to_bbox(smpl_joints_3d_vis, camParam_scale, camParam_trans)

                #From cropped space to original
                pred_vert_vis = convert_bbox_to_oriIm(pred_vert_vis, boxScale_o2n, bboxTopLeft, rawImg.shape[1], rawImg.shape[0]) 
                smpl_joints_3d_vis = convert_bbox_to_oriIm(smpl_joints_3d_vis, boxScale_o2n, bboxTopLeft, rawImg.shape[1], rawImg.shape[0])
                renderer.setBackgroundTexture(rawImg)
                renderer.setViewportSize(rawImg.shape[1], rawImg.shape[0])

            pred_meshes = {'ver': pred_vert_vis, 'f': smpl.faces}
            v = pred_meshes['ver'] 
            f = pred_meshes['f']

            #Visualize in the original image space
            renderer.set_mesh(v,f)
            renderer.showBackground(True)
            renderer.setWorldCenterBySceneCenter()
            renderer.setCameraViewMode("cam")

            renderer.setViewportSize(rawImg.shape[1], rawImg.shape[0])
            renderer.display()
            renderImg = renderer.get_screen_color_ibgr()
            viewer2D.ImShow(renderImg,waitTime=1)

            # out_all_f = render.get_z_value()

        # Visualization Mesh on side view
        if True:
            # renderer.set_viewpoint()
            renderer.showBackground(False)
            renderer.setWorldCenterBySceneCenter()
            renderer.setCameraViewMode("side")

            renderer.setViewportSize(rawImg.shape[1], rawImg.shape[0])
            renderer.display()
            sideImg = renderer.get_screen_color_ibgr()        #Overwite on rawImg
            viewer2D.ImShow(sideImg,waitTime=1)
            
            sideImg = cv2.resize(sideImg, (renderImg.shape[1], renderImg.shape[0]) )

        saveImg = np.concatenate( (renderImg,sideImg), axis =1)
        viewer2D.ImShow(saveImg,waitTime=0)

        if True:    #Save the rendered image to files
            if os.path.exists(args.render_dir) == False:
                os.mkdir(args.render_dir)
            render_output_path = args.render_dir + '/render_{:08d}.jpg'.format(idx)
            print(f"Save to {render_output_path}")
            cv2.imwrite(render_output_path, saveImg)

    print("erroneous Num : {}/{} ({} percent)".format(erroneousCnt,totalCnt, float(erroneousCnt)*100/totalCnt))
Beispiel #10
0
def visEFT_multiSubjects(renderer):
    inputDir = args.fit_dir
    imgDir = args.img_dir

    smplModelDir = args.smpl_dir
    smpl = SMPL(smplModelDir, batch_size=1, create_transl=False)
    
    eft_fileList  = listdir(inputDir)       #Check all fitting files
    print(">> Found {} files in the fitting folder {}".format(len(eft_fileList), inputDir))

    #Aggregate all efl per image
    eft_perimage ={}
    for f in sorted(eft_fileList):
        #Load
        imageName = f[:f.rfind('_')]
        if imageName not in eft_perimage.keys():
            eft_perimage[imageName] =[]

        eft_perimage[imageName].append(f)


    for imgName in eft_perimage:
        eftFiles_perimage = eft_perimage[imgName]
        
        renderer.clear_mesh()

        for idx,f in enumerate(eftFiles_perimage):
            
            #Load EFT data
            fileFullPath = join(inputDir, f)
            with open(fileFullPath,'rb') as f:
                eft_data = pickle.load(f)

            #Get raw image path
            if idx==0:
                imgFullPath = eft_data['imageName'][0]
                imgFullPath =os.path.join(imgDir, os.path.basename(imgFullPath) )
                assert os.path.exists(imgFullPath)
                rawImg = cv2.imread(imgFullPath)
                print(f'Input image: {imgFullPath}')

            #EFT data
            bbox_scale = eft_data['scale'][0]
            bbox_center = eft_data['center'][0]

            pred_camera = eft_data['pred_camera']
            pred_betas = torch.from_numpy(eft_data['pred_shape'])
            pred_pose_rotmat = torch.from_numpy(eft_data['pred_pose_rotmat'])        

            #COCO only. Annotation index
            print("COCO annotId: {}".format(eft_data['annotId']))

            #Obtain skeleton and smpl data
            smpl_output = smpl(betas=pred_betas, body_pose=pred_pose_rotmat[:,1:], global_orient=pred_pose_rotmat[:,0].unsqueeze(1), pose2rot=False )
            smpl_vertices = smpl_output.vertices.detach().cpu().numpy() 
            smpl_joints_3d = smpl_output.joints.detach().cpu().numpy() 

            #Crop image
            croppedImg, boxScale_o2n, bboxTopLeft = crop_bboxInfo(rawImg.copy(), bbox_center, bbox_scale, (BBOX_IMG_RES, BBOX_IMG_RES) )

            ########################
            # Visualize
            # Visualize 2D image
            if False:
                viewer2D.ImShow(rawImg, name='rawImg', waitTime=1)      #You should press any key 
                viewer2D.ImShow(croppedImg, name='croppedImg', waitTime=0)

            # Visualization Mesh on raw images
            if True:    
                b=0
                camParam_scale = pred_camera[b,0]
                camParam_trans = pred_camera[b,1:]
                pred_vert_vis = smpl_vertices[b]
                smpl_joints_3d_vis = smpl_joints_3d[b]

                if False:#args.onbbox:      #Always in the original image
                    pred_vert_vis = convert_smpl_to_bbox(pred_vert_vis, camParam_scale, camParam_trans)
                    smpl_joints_3d_vis = convert_smpl_to_bbox(smpl_joints_3d_vis, camParam_scale, camParam_trans)
                    renderer.setBackgroundTexture(croppedImg)
                    renderer.setViewportSize(croppedImg.shape[1], croppedImg.shape[0])
                else:
                    #Covert SMPL to BBox first
                    pred_vert_vis = convert_smpl_to_bbox(pred_vert_vis, camParam_scale, camParam_trans)
                    smpl_joints_3d_vis = convert_smpl_to_bbox(smpl_joints_3d_vis, camParam_scale, camParam_trans)

                    #From cropped space to original
                    pred_vert_vis = convert_bbox_to_oriIm(pred_vert_vis, boxScale_o2n, bboxTopLeft, rawImg.shape[1], rawImg.shape[0]) 
                    smpl_joints_3d_vis = convert_bbox_to_oriIm(smpl_joints_3d_vis, boxScale_o2n, bboxTopLeft, rawImg.shape[1], rawImg.shape[0])
                    renderer.setBackgroundTexture(rawImg)
                    renderer.setViewportSize(rawImg.shape[1], rawImg.shape[0])

                pred_meshes = {'ver': pred_vert_vis, 'f': smpl.faces}
                v = pred_meshes['ver'] 
                f = pred_meshes['f']

                #Visualize in the original image spaceq
                # renderer.set_mesh(v,f)
                renderer.add_mesh(v,f)

        #Render Mesh on the camera view
        renderer.showBackground(True)
        renderer.setWorldCenterBySceneCenter()
        renderer.setCameraViewMode("cam")
        renderer.display()
        overlaid = renderer.get_screen_color_ibgr()        #Overwite on rawImg
        viewer2D.ImShow(overlaid,waitTime=1,name="overlaid")

        #Render Mesh on the rotating view
        renderer.showBackground(False)
        renderer.setWorldCenterBySceneCenter()
        renderer.setCameraViewMode("free")
        for i in range(90):
            renderer.setViewAngle(i*4,0)
            renderer.display()
            sideImg = renderer.get_screen_color_ibgr()        #Overwite on rawImg
            viewer2D.ImShow(sideImg,waitTime=1,name="otherviews")
            
        if True:    #Save the rendered image to files
            if os.path.exists(args.render_dir) == False:
                os.mkdir(args.render_dir)
            render_output_path = args.render_dir + '/render_{}.jpg'.format(imgName)
            print(f"Save to {render_output_path}")
            cv2.imwrite(render_output_path, rawImg)
Beispiel #11
0
 def __init__(self, focal_length=1000, height=512, width=512):
     self.renderer = pyrender.OffscreenRenderer(height, width)
     smpl = SMPL('data/smpl')
     self.faces = smpl.faces
     self.focal_length = focal_length
Beispiel #12
0
class SMPL(nn.Module):
    """SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned
    multi-person linear model''. This module is based on the smplx project
    (https://github.com/vchoutas/smplx).

    Args:
        smpl_path (str): The path to the folder where the model weights are
            stored.
        joints_regressor (str): The path to the file where the joints
            regressor weight are stored.
    """

    def __init__(self, smpl_path, joints_regressor):
        super().__init__()

        assert has_smpl, 'Please install smplx to use SMPL.'

        self.smpl_neutral = SMPL_(
            model_path=smpl_path,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='neutral')

        self.smpl_male = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='male')

        self.smpl_female = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='female')

        joints_regressor = torch.tensor(
            np.load(joints_regressor), dtype=torch.float)[None, ...]
        self.register_buffer('joints_regressor', joints_regressor)

        self.num_verts = self.smpl_neutral.get_num_verts()
        self.num_joints = self.joints_regressor.shape[1]

    def smpl_forward(self, model, **kwargs):
        """Apply a specific SMPL model with given model parameters.

        Note:
            B: batch size
            V: number of vertices
            K: number of joints

        Returns:
            outputs (dict): Dict with mesh vertices and joints.
                - vertices: Tensor([B, V, 3]), mesh vertices
                - joints: Tensor([B, K, 3]), 3d joints regressed
                    from mesh vertices.
        """

        betas = kwargs['betas']
        batch_size = betas.shape[0]
        device = betas.device
        output = {}
        if batch_size == 0:
            output['vertices'] = betas.new_zeros([0, self.num_verts, 3])
            output['joints'] = betas.new_zeros([0, self.num_joints, 3])
        else:
            smpl_out = model(**kwargs)
            output['vertices'] = smpl_out.vertices
            output['joints'] = torch.matmul(
                self.joints_regressor.to(device), output['vertices'])
        return output

    def get_faces(self):
        """Return mesh faces.

        Note:
            F: number of faces

        Returns:
            faces: np.ndarray([F, 3]), mesh faces
        """
        return self.smpl_neutral.faces

    def forward(self,
                betas,
                body_pose,
                global_orient,
                transl=None,
                gender=None):
        """Forward function.

        Note:
            B: batch size
            J: number of controllable joints of model, for smpl model J=23
            K: number of joints

        Args:
            betas: Tensor([B, 10]), human body shape parameters of SMPL model.
            body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
                parameters of SMPL model. It should be axis-angle vector
                ([B, J*3]) or rotation matrix ([B, J, 3, 3)].
            global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
                of human body. It should be axis-angle vector ([B, 3]) or
                rotation matrix ([B, 1, 3, 3)].
            transl: Tensor([B, 3]), global translation of human body.
            gender: Tensor([B]), gender parameters of human body. -1 for
                neutral, 0 for male , 1 for female.

        Returns:
            outputs (dict): Dict with mesh vertices and joints.
                - vertices: Tensor([B, V, 3]), mesh vertices
                - joints: Tensor([B, K, 3]), 3d joints regressed from
                    mesh vertices.
        """

        batch_size = betas.shape[0]
        pose2rot = True if body_pose.dim() == 2 else False
        if batch_size > 0 and gender is not None:
            output = {
                'vertices': betas.new_zeros([batch_size, self.num_verts, 3]),
                'joints': betas.new_zeros([batch_size, self.num_joints, 3])
            }

            mask = gender < 0
            _out = self.smpl_forward(
                self.smpl_neutral,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']

            mask = gender == 0
            _out = self.smpl_forward(
                self.smpl_male,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']

            mask = gender == 1
            _out = self.smpl_forward(
                self.smpl_male,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']
        else:
            return self.smpl_forward(
                self.smpl_neutral,
                betas=betas,
                body_pose=body_pose,
                global_orient=global_orient,
                transl=transl,
                pose2rot=pose2rot)

        return output
Beispiel #13
0
import torch
# import matplotlib
# matplotlib.use('MACOSX')
import matplotlib.pyplot as plt
from smplx import SMPL

from utils.renderer import Renderer
from utils.cam_utils import perspective_project_torch
from data.ssp3d_dataset import SSP3DDataset
import config

# SMPL models in torch
smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='male')
smpl_female = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='female')

# Pyrender renderer
renderer = Renderer(faces=smpl_male.faces, img_res=512)

# SSP-3D datset class
ssp3d_dataset = SSP3DDataset(config.SSP_3D_PATH)

indices_to_plot = [11, 60, 199]  # Visualising 3 examples from SSP-3D

for i in indices_to_plot:
    data = ssp3d_dataset.__getitem__(i)

    fname = data['fname']
    image = data['image']
    cropped_image = data['cropped_image']
    silhouette = data['silhouette']
    joints2D = data['joints2D']
Beispiel #14
0
def main():
    args = get_args()
    if args.aist:
        import vedo

    if not os.path.exists(args.json_dir):
        os.makedirs(args.json_dir)

    if args.aist:
        print ("test with AIST++ dataset!")
        music_data, dance_data, dance_names = load_data_aist(
            args.input_dir, interval=None, rotmat=args.rotmat)
    else:    
        music_data, dance_data, dance_names = load_data(
            args.input_dir, interval=None)

    device = torch.device('cuda' if args.cuda else 'cpu')

    test_loader = torch.utils.data.DataLoader(
        DanceDataset(music_data, dance_data),
        batch_size=args.batch_size,
        collate_fn=paired_collate_fn
    )

    generator = Generator(args.model, device)
    
    if args.aist and args.rotmat:
        from smplx import SMPL
        smpl = SMPL(model_path="/media/ruilongli/hd1/Data/smpl/", gender='MALE', batch_size=1)

    results = []
    random_id = 0  # np.random.randint(0, 1e4)
    for i, batch in enumerate(tqdm(test_loader, desc='Generating dance poses')):
        # Prepare data
        src_seq, src_pos, tgt_pose = map(lambda x: x.to(device), batch)
        pose_seq = generator.generate(src_seq[:, :1200], src_pos[:, :1200])  # first 20 secs
        results.append(pose_seq)

        if args.aist:
            np_dance = pose_seq[0].data.cpu().numpy()
            if args.rotmat:
                root = np_dance[:, :3]
                rotmat = np_dance[:, 3:].reshape([-1, 3, 3])
                rotmat = get_closest_rotmat(rotmat)
                smpl_poses = rotmat2aa(rotmat).reshape(-1, 24, 3)
                np_dance = smpl.forward(
                    global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
                    body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
                    transl=torch.from_numpy(root).float(),
                ).joints.detach().numpy()[:, 0:24, :]
            else:
                root = np_dance[:, :3]
                np_dance = np_dance + np.tile(root, (1, 24))
                np_dance[:, :3] = root
                np_dance = np_dance.reshape(np_dance.shape[0], -1, 3)
            print (np_dance.shape)
            # save
            save_path = os.path.join(args.json_dir, dance_names[i]+f"_{random_id:04d}")
            np.save(save_path, np_dance)
            # visualize
            for frame in np_dance:
                pts = vedo.Points(frame, r=20)
                vedo.show(pts, interactive=False)
                time.sleep(0.1)
            exit()

    if args.aist:
        pass

    else:
        # Visualize generated dance poses
        np_dances = []
        for i in range(len(results)):
            np_dance = results[i][0].data.cpu().numpy()
            root = np_dance[:, 2*8:2*9]
            np_dance = np_dance + np.tile(root, (1, 25))
            np_dance[:, 2*8:2*9] = root
            np_dances.append(np_dance)
        write2json(np_dances, dance_names, args)
        visualize(args)
Beispiel #15
0
def main(_):
    # Parsing data info.
    aist_dataset = AISTDataset(FLAGS.anno_dir)
    video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
    seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
    view_idx = AISTDataset.VIEWS.index(view)

    # Parsing keypoints.
    if FLAGS.mode == '2D':  # raw keypoints detection results.
        keypoints2d, _, _ = AISTDataset.load_keypoint2d(
            aist_dataset.keypoint2d_dir, seq_name)
        keypoints2d = keypoints2d[view_idx, :, :, 0:2]

    elif FLAGS.mode == '3D':  # 3D keypoints with temporal optimization.
        keypoints3d = AISTDataset.load_keypoint3d(aist_dataset.keypoint3d_dir,
                                                  seq_name,
                                                  use_optim=True)
        nframes, njoints, _ = keypoints3d.shape
        env_name = aist_dataset.mapping_seq2env[seq_name]
        cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir,
                                               env_name)
        keypoints2d = cgroup.project(keypoints3d)
        keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

    elif FLAGS.mode == 'SMPL':  # SMPL joints
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        keypoints3d = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans).float(),
            scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).joints.detach().numpy()

        nframes, njoints, _ = keypoints3d.shape
        env_name = aist_dataset.mapping_seq2env[seq_name]
        cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir,
                                               env_name)
        keypoints2d = cgroup.project(keypoints3d)
        keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

    elif FLAGS.mode == 'SMPLMesh':  # SMPL Mesh
        import trimesh  # install by `pip install trimesh`
        import vedo  # install by `pip install vedo`
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        vertices = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans).float(),
            scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).vertices.detach().numpy()[0]  # first frame
        faces = smpl.faces
        mesh = trimesh.Trimesh(vertices, faces)
        mesh.visual.face_colors = [200, 200, 250, 100]

        keypoints3d = AISTDataset.load_keypoint3d(aist_dataset.keypoint3d_dir,
                                                  seq_name,
                                                  use_optim=True)
        pts = vedo.Points(keypoints3d[0], r=20)  # first frame

        vedo.show(mesh, pts, interactive=True)
        exit()

    # Visualize.
    os.makedirs(FLAGS.save_dir, exist_ok=True)
    save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4')
    plot_on_video(keypoints2d, video_path, save_path, fps=60)
Beispiel #16
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # ========= Define VIBE model ========= #
    model = VIBE_Demo(
        seqlen=16,
        device=device,
        n_layers=2,
        hidden_size=1024,
        add_linear=True,
        use_residual=True,
    ).to(device)

    # ========= Load pretrained weights ========= #
    pretrained_file = download_ckpt(use_3dpw=False)
    ckpt = torch.load(pretrained_file, map_location=device)
    print(f'Performance of pretrained model on 3DPW: {ckpt["performance"]}')
    ckpt = ckpt['gen_state_dict']
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    print(f'Loaded pretrained weights from \"{pretrained_file}\"')

    total_time = time.time()
    # ========= Run VIBE on crops ========= #
    print(f'Running VIBE on crops...')
    vibe_time = time.time()
    image_folder = args.input_folder

    dataset = InferenceFromCrops(image_folder=image_folder)
    orig_height = orig_width = 512

    dataloader = DataLoader(dataset,
                            batch_size=args.vibe_batch_size,
                            num_workers=0)

    with torch.no_grad():

        pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

        for batch_num, batch in enumerate(dataloader):
            print("BATCH:", batch_num)
            batch = batch.unsqueeze(0)
            batch = batch.to(device)

            batch_size, seqlen = batch.shape[:2]
            output = model(batch)[-1]

            pred_cam.append(output['theta'][:, :, :3].reshape(
                batch_size * seqlen, -1))
            pred_verts.append(output['verts'].reshape(batch_size * seqlen, -1,
                                                      3))
            pred_pose.append(output['theta'][:, :, 3:75].reshape(
                batch_size * seqlen, -1))
            pred_betas.append(output['theta'][:, :, 75:].reshape(
                batch_size * seqlen, -1))
            pred_joints3d.append(output['kp_3d'].reshape(
                batch_size * seqlen, -1, 3))

        pred_cam = torch.cat(pred_cam, dim=0)
        pred_verts = torch.cat(pred_verts, dim=0)
        pred_pose = torch.cat(pred_pose, dim=0)
        pred_betas = torch.cat(pred_betas, dim=0)
        pred_joints3d = torch.cat(pred_joints3d, dim=0)

        del batch

    # ========= [Optional] run Temporal SMPLify to refine the results ========= #
    if args.run_smplify and args.tracking_method == 'pose':
        norm_joints2d = np.concatenate(norm_joints2d, axis=0)
        norm_joints2d = convert_kps(norm_joints2d, src='staf', dst='spin')
        norm_joints2d = torch.from_numpy(norm_joints2d).float().to(device)

        # Run Temporal SMPLify
        update, new_opt_vertices, new_opt_cam, new_opt_pose, new_opt_betas, \
        new_opt_joints3d, new_opt_joint_loss, opt_joint_loss = smplify_runner(
            pred_rotmat=pred_pose,
            pred_betas=pred_betas,
            pred_cam=pred_cam,
            j2d=norm_joints2d,
            device=device,
            batch_size=norm_joints2d.shape[0],
            pose2aa=False,
        )

        # update the parameters after refinement
        print(
            f'Update ratio after Temporal SMPLify: {update.sum()} / {norm_joints2d.shape[0]}'
        )
        pred_verts = pred_verts.cpu()
        pred_cam = pred_cam.cpu()
        pred_pose = pred_pose.cpu()
        pred_betas = pred_betas.cpu()
        pred_joints3d = pred_joints3d.cpu()
        pred_verts[update] = new_opt_vertices[update]
        pred_cam[update] = new_opt_cam[update]
        pred_pose[update] = new_opt_pose[update]
        pred_betas[update] = new_opt_betas[update]
        pred_joints3d[update] = new_opt_joints3d[update]

    elif args.run_smplify and args.tracking_method == 'bbox':
        print(
            '[WARNING] You need to enable pose tracking to run Temporal SMPLify algorithm!'
        )
        print('[WARNING] Continuing without running Temporal SMPLify!..')

    # ========= Save results to a pickle file ========= #
    output_path = image_folder.replace('cropped_frames', 'vibe_results')
    os.makedirs(output_path, exist_ok=True)

    pred_cam = pred_cam.cpu().numpy()
    pred_verts = pred_verts.cpu().numpy()
    pred_pose = pred_pose.cpu().numpy()
    pred_betas = pred_betas.cpu().numpy()
    pred_joints3d = pred_joints3d.cpu().numpy()

    vibe_results = {
        'pred_cam': pred_cam,
        'verts': pred_verts,
        'pose': pred_pose,
        'betas': pred_betas,
        'joints3d': pred_joints3d,
    }

    del model
    end = time.time()
    fps = len(dataset) / (end - vibe_time)

    print(f'VIBE FPS: {fps:.2f}')
    total_time = time.time() - total_time
    print(
        f'Total time spent: {total_time:.2f} seconds (including model loading time).'
    )
    print(
        f'Total FPS (including model loading time): {len(dataset) / total_time:.2f}.'
    )

    print(
        f'Saving vibe results to \"{os.path.join(output_path, "vibe_results.pkl")}\".'
    )

    with open(os.path.join(output_path, "vibe_results.pkl"), 'wb') as f_save:
        pickle.dump(vibe_results, f_save)

    if not args.no_render:
        # ========= Render results as a single video ========= #
        renderer = Renderer(resolution=(orig_width, orig_height),
                            orig_img=True,
                            wireframe=args.wireframe)

        output_img_folder = os.path.join(output_path, 'vibe_images')
        os.makedirs(output_img_folder, exist_ok=True)

        print(f'Rendering output video, writing frames to {output_img_folder}')

        image_file_names = sorted([
            os.path.join(image_folder, x) for x in os.listdir(image_folder)
            if x.endswith('.png') or x.endswith('.jpg')
        ])

        for frame_idx in tqdm(range(len(image_file_names))):
            img_fname = image_file_names[frame_idx]
            img = cv2.imread(img_fname)

            frame_verts = vibe_results['verts'][frame_idx]
            frame_cam = vibe_results['pred_cam'][frame_idx]

            mesh_filename = None

            if args.save_obj:
                mesh_folder = os.path.join(output_path, 'vibe_meshes')
                os.makedirs(mesh_folder, exist_ok=True)
                mesh_filename = os.path.join(mesh_folder,
                                             f'{frame_idx:06d}.obj')

            rend_img = renderer.render(
                img,
                frame_verts,
                cam=frame_cam,
                mesh_filename=mesh_filename,
            )

            whole_img = rend_img

            if args.sideview:
                side_img_bg = np.zeros_like(img)
                side_rend_img90 = renderer.render(
                    side_img_bg,
                    frame_verts,
                    cam=frame_cam,
                    angle=90,
                    axis=[0, 1, 0],
                )
                side_rend_img270 = renderer.render(
                    side_img_bg,
                    frame_verts,
                    cam=frame_cam,
                    angle=270,
                    axis=[0, 1, 0],
                )
                if args.reposed_render:
                    smpl = SMPL('data/vibe_data', batch_size=1)
                    zero_pose = torch.from_numpy(
                        np.zeros((1, pred_pose.shape[-1]))).float()
                    zero_pose[:, 0] = np.pi
                    pred_frame_betas = torch.from_numpy(
                        pred_betas[frame_idx][None, :]).float()
                    with torch.no_grad():
                        reposed_smpl_output = smpl(
                            betas=pred_frame_betas,
                            body_pose=zero_pose[:, 3:],
                            global_orient=zero_pose[:, :3])
                        reposed_verts = reposed_smpl_output.vertices
                        reposed_verts = reposed_verts.cpu().detach().numpy()

                    reposed_cam = np.array([0.9, 0, 0])
                    reposed_rend_img = renderer.render(side_img_bg,
                                                       reposed_verts[0],
                                                       cam=reposed_cam)
                    reposed_rend_img90 = renderer.render(side_img_bg,
                                                         reposed_verts[0],
                                                         cam=reposed_cam,
                                                         angle=90,
                                                         axis=[0, 1, 0])

                    top_row = np.concatenate(
                        [img, reposed_rend_img, reposed_rend_img90], axis=1)
                    bot_row = np.concatenate(
                        [rend_img, side_rend_img90, side_rend_img270], axis=1)
                    whole_img = np.concatenate([top_row, bot_row], axis=0)

                else:
                    top_row = np.concatenate([img, side_img_bg, side_img_bg],
                                             axis=1)
                    bot_row = np.concatenate(
                        [rend_img, side_rend_img90, side_rend_img270], axis=1)
                    whole_img = np.concatenate([top_row, bot_row], axis=0)

            # cv2.imwrite(os.path.join(output_img_folder, f'{frame_idx:06d}.png'), whole_img)
            cv2.imwrite(
                os.path.join(output_img_folder, os.path.basename(img_fname)),
                whole_img)

        # ========= Save rendered video ========= #
        save_vid_path = os.path.join(output_path, 'vibe_video.mp4')
        print(f'Saving result video to {save_vid_path}')
        images_to_video(img_folder=output_img_folder,
                        output_vid_file=save_vid_path)

    print('================= END =================')