Пример #1
0
    def smpl_to_kpts(self, pred_rotmat, pred_shape, pred_cam, J_regressor):
        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output
Пример #2
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                J_regressor=None):
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output
Пример #3
0
    def forward(self, input, J_regressor=None):
        batch_size, seqlen, nc, h, w = input.shape

        feature = self.hmr.feature_extractor(input.reshape(-1, nc, h, w))

        input = feature.reshape(batch_size, seqlen, -1)

        # input size NTF
        batch_size, seqlen = input.shape[:2]

        vibe_output = self.vibe(input)[-1]

        vibe_theta = vibe_output['theta']
        vibe_pose = vibe_theta[:,:,3:75]
        pred_cam = vibe_theta[:,:,:3].reshape(batch_size * seqlen, 3)
        pred_shape = vibe_theta[:,:,75:].reshape(batch_size * seqlen, 10)
        vibe_pose_6d = convert_aa_to_orth6d(vibe_pose).reshape(vibe_pose.shape[0], vibe_pose.shape[1], -1)

        X_r= self.refiner_model(vibe_pose_6d.permute(1, 0, 2), input.permute(1, 0, 2))
        X_r = X_r.permute(1, 0, 2).contiguous()[:,:seqlen, :]
        
        pred_rotmat = convert_orth_6d_to_mat(X_r).reshape(batch_size * seqlen, 24, 3, 3)
        # pred_rotmat = convert_orth_6d_to_mat(vibe_pose_6d).reshape(batch_size * seqlen, 24, 3, 3)

        pred_output = self.smpl(
            betas=pred_shape,
            body_pose=pred_rotmat[:, 1:],
            global_orient=pred_rotmat[:, 0:1],
            pose2rot=False
        )

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)

        smpl_output = [{
            'theta'  : torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts'  : pred_vertices,
            'kp_2d'  : pred_keypoints_2d,
            'kp_3d'  : pred_joints,
            'rotmat' : pred_rotmat
        }]

        for s in smpl_output:
            s['theta'] = s['theta'].reshape(batch_size, seqlen, -1)
            s['verts'] = s['verts'].reshape(batch_size, seqlen, -1, 3)
            s['kp_2d'] = s['kp_2d'].reshape(batch_size, seqlen, -1, 2)
            s['kp_3d'] = s['kp_3d'].reshape(batch_size, seqlen, -1, 3)
            s['rotmat'] = s['rotmat'].reshape(batch_size, seqlen, -1, 3, 3)

        return smpl_output
Пример #4
0
    def forward(self, input, J_regressor=None):
        # input size NTF
        batch_size, seqlen, nc, h, w = input.shape
        feature = self.hmr.feature_extractor(input.reshape(-1, nc, h, w))
        input = feature.reshape(batch_size, seqlen, -1)
        batch_size, seqlen = input.shape[:2]

        X_r, pred_cam, pred_shape = self.meva_model(input.permute(1, 0, 2))

        X_r, pred_cam, pred_shape = X_r.permute(1, 0, 2)[:,:seqlen,:].reshape(batch_size * seqlen, -1), \
            pred_cam.permute(1, 0, 2)[:,:seqlen,:].reshape(batch_size * seqlen, -1), \
                pred_shape.permute(1, 0, 2)[:,:seqlen,:].reshape(batch_size * seqlen, -1)
        # pred_shape = pred_shape[:, None, :].expand(batch_size, seqlen, -1).reshape(batch_size * seqlen, -1)

        pred_rotmat = convert_orth_6d_to_mat(X_r).view(batch_size * seqlen, 24,
                                                       3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        smpl_output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]

        for s in smpl_output:
            s['theta'] = s['theta'].reshape(batch_size, seqlen, -1)
            s['verts'] = s['verts'].reshape(batch_size, seqlen, -1, 3)
            s['kp_2d'] = s['kp_2d'].reshape(batch_size, seqlen, -1, 2)
            s['kp_3d'] = s['kp_3d'].reshape(batch_size, seqlen, -1, 3)
            s['rotmat'] = s['rotmat'].reshape(batch_size, seqlen, -1, 3, 3)

        return smpl_output
Пример #5
0
def read_data(folder, set, debug=False):

    dataset = {
        'vid_name': [],
        'frame_id': [],
        'joints3D': [],
        'joints2D': [],
        'shape': [],
        'pose': [],
        'bbox': [],
        'img_name': [],
        'features': [],
        'valid': [],
    }

    model = spin.get_pretrained_hmr()

    if set == 'val': set = 'test'
    sequences = [
        x.split('.')[0]
        for x in os.listdir(osp.join(folder, 'sequenceFiles', set))
    ]

    J_regressor = None

    smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
    if set == 'test':
        J_regressor = torch.from_numpy(
            np.load(osp.join(VIBE_DATA_DIR, 'J_regressor_h36m.npy'))).float()

    for i, seq in tqdm(enumerate(sequences)):

        data_file = osp.join(folder, 'sequenceFiles', set, seq + '.pkl')

        data = pkl.load(open(data_file, 'rb'), encoding='latin1')

        img_dir = osp.join(folder, 'imageFiles', seq)

        num_people = len(data['poses'])
        num_frames = len(data['img_frame_ids'])
        assert (data['poses2d'][0].shape[0] == num_frames)

        for p_id in range(num_people):
            pose = torch.from_numpy(data['poses'][p_id]).float()
            shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(
                pose.size(0), 1)
            trans = torch.from_numpy(data['trans'][p_id]).float()
            j2d = data['poses2d'][p_id].transpose(0, 2, 1)
            cam_pose = data['cam_poses']
            campose_valid = data['campose_valid'][p_id]

            # ======== Align the mesh params ======== #
            rot = pose[:, :3]
            rot_mat = batch_rodrigues(rot)

            Rc = torch.from_numpy(cam_pose[:, :3, :3]).float()
            Rs = torch.bmm(Rc, rot_mat.reshape(-1, 3, 3))
            rot = rotation_matrix_to_angle_axis(Rs)
            pose[:, :3] = rot
            # ======== Align the mesh params ======== #

            output = smpl(betas=shape,
                          body_pose=pose[:, 3:],
                          global_orient=pose[:, :3],
                          transl=trans)
            # verts = output.vertices
            j3d = output.joints

            if J_regressor is not None:
                vertices = output.vertices
                J_regressor_batch = J_regressor[None, :].expand(
                    vertices.shape[0], -1, -1).to(vertices.device)
                j3d = torch.matmul(J_regressor_batch, vertices)
                j3d = j3d[:, H36M_TO_J14, :]

            img_paths = []
            for i_frame in range(num_frames):
                img_path = os.path.join(img_dir +
                                        '/image_{:05d}.jpg'.format(i_frame))
                img_paths.append(img_path)

            bbox_params, time_pt1, time_pt2 = get_smooth_bbox_params(
                j2d, vis_thresh=VIS_THRESH, sigma=8)

            # process bbox_params
            c_x = bbox_params[:, 0]
            c_y = bbox_params[:, 1]
            scale = bbox_params[:, 2]
            w = h = 150. / scale
            w = h = h * 1.1
            bbox = np.vstack([c_x, c_y, w, h]).T

            # process keypoints
            j2d[:, :, 2] = j2d[:, :, 2] > 0.3  # set the visibility flags
            # Convert to common 2d keypoint format
            perm_idxs = get_perm_idxs('3dpw', 'common')
            perm_idxs += [0, 0]  # no neck, top head
            j2d = j2d[:, perm_idxs]
            j2d[:, 12:, 2] = 0.0

            # print('j2d', j2d[time_pt1:time_pt2].shape)
            # print('campose', campose_valid[time_pt1:time_pt2].shape)

            img_paths_array = np.array(img_paths)[time_pt1:time_pt2]
            dataset['vid_name'].append(
                np.array([f'{seq}_{p_id}'] * num_frames)[time_pt1:time_pt2])
            dataset['frame_id'].append(
                np.arange(0, num_frames)[time_pt1:time_pt2])
            dataset['img_name'].append(img_paths_array)
            dataset['joints3D'].append(j3d.numpy()[time_pt1:time_pt2])
            dataset['joints2D'].append(j2d[time_pt1:time_pt2])
            dataset['shape'].append(shape.numpy()[time_pt1:time_pt2])
            dataset['pose'].append(pose.numpy()[time_pt1:time_pt2])
            dataset['bbox'].append(bbox)
            dataset['valid'].append(campose_valid[time_pt1:time_pt2])

            features = extract_features(model,
                                        img_paths_array,
                                        bbox,
                                        kp_2d=j2d[time_pt1:time_pt2],
                                        debug=debug,
                                        dataset='3dpw',
                                        scale=1.2)
            dataset['features'].append(features)

    for k in dataset.keys():
        dataset[k] = np.concatenate(dataset[k])
        print(k, dataset[k].shape)

    # Filter out keypoints
    indices_to_use = np.where(
        (dataset['joints2D'][:, :, 2] > VIS_THRESH).sum(-1) > MIN_KP)[0]
    for k in dataset.keys():
        dataset[k] = dataset[k][indices_to_use]

    return dataset
Пример #6
0
def smplify_runner(pred_rotmat,
                   pred_betas,
                   pred_cam,
                   j2d,
                   device,
                   batch_size,
                   lr=1.0,
                   opt_steps=1,
                   use_lbfgs=True,
                   pose2aa=True):
    smplify = TemporalSMPLify(
        step_size=lr,
        batch_size=batch_size,
        num_iters=opt_steps,
        focal_length=5000.,
        use_lbfgs=use_lbfgs,
        device=device,
        # max_iter=10,
    )
    # Convert predicted rotation matrices to axis-angle
    if pose2aa:
        pred_pose = rotation_matrix_to_angle_axis(
            pred_rotmat.detach()).reshape(batch_size, -1)
    else:
        pred_pose = pred_rotmat

    # Calculate camera parameters for smplify
    pred_cam_t = torch.stack([
        pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 /
        (224 * pred_cam[:, 0] + 1e-9)
    ],
                             dim=-1)

    gt_keypoints_2d_orig = j2d
    # Before running compute reprojection error of the network
    opt_joint_loss = smplify.get_fitting_loss(
        pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(),
        0.5 * 224 * torch.ones(batch_size, 2, device=device),
        gt_keypoints_2d_orig).mean(dim=-1)

    best_prediction_id = torch.argmin(opt_joint_loss).item()
    pred_betas = pred_betas[best_prediction_id].unsqueeze(0)
    # pred_betas = pred_betas[best_prediction_id:best_prediction_id+2] # .unsqueeze(0)
    # top5_best_idxs = torch.topk(opt_joint_loss, 5, largest=False)[1]
    # breakpoint()

    start = time.time()
    # Run SMPLify optimization initialized from the network prediction
    # new_opt_vertices, new_opt_joints, \
    # new_opt_pose, new_opt_betas, \
    # new_opt_cam_t, \
    output, new_opt_joint_loss = smplify(
        pred_pose.detach(),
        pred_betas.detach(),
        pred_cam_t.detach(),
        0.5 * 224 * torch.ones(batch_size, 2, device=device),
        gt_keypoints_2d_orig,
    )
    new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)
    # smplify_time = time.time() - start
    # print(f'Smplify time: {smplify_time}')
    # Will update the dictionary for the examples where the new loss is less than the current one
    update = (new_opt_joint_loss < opt_joint_loss)

    new_opt_vertices = output['verts']
    new_opt_cam_t = output['theta'][:, :3]
    new_opt_pose = output['theta'][:, 3:75]
    new_opt_betas = output['theta'][:, 75:]

    return_val = [
        update,
        new_opt_vertices.cpu(),
        new_opt_cam_t.cpu(),
        new_opt_pose.cpu(),
        new_opt_betas.cpu(),
        new_opt_joint_loss,
        opt_joint_loss,
    ]

    return return_val
Пример #7
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                return_features=False):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
        }]

        if return_features:
            return xf, output
        else:
            return output
Пример #8
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                J_regressor=None):
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        # print('inside SPIN model, pred pose shape',pred_pose.shape)

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        # print("Inputto spin model pose--",pred_pose.shape)

        # print("Inputto spin model ---- ",pred_shape.shape,pred_shape,pred_rotmat[:, 1:].shape,pred_rotmat[:, 0].unsqueeze(1).shape)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints
        #print('inside spin model,pred_vertices shape ',pred_joints.shape)

        #H36M_TO_J17 = [8,5,45,46,4,7,21,19,17,16,18,20,47,48,51,50,24]

        #H36M_TO_J17 = [3,2,1,6,7,8,27,26,25,17,18,19,14,15,12,13,14]

        #H36M_TO_J17 = [8,5,2,1,4,7,21,19,17,16,18,20,12,15,3,9,15]

        H36M_TO_J17 = [
            25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 40, 44
        ]
        #print("@@@@@ USED!!!!")

        #print('predicted joints shape',pred_joints.shape)

        pred_joints = pred_joints[:, H36M_TO_J17, :]
        #print('after predicted joints shape',pred_joints.shape)

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]
            #print('after predicted joints shape',pred_joints.shape)

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output