Beispiel #1
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                J_regressor=None):
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output
Beispiel #2
0
 def forward(self,
             x,
             init_pose=None,
             init_shape=None,
             init_cam=None,
             n_iter=3,
             J_regressor=None):
     batch_size = x.shape[0]
     pred_pose, pred_shape, pred_cam = self.iter_refine(
         x,
         init_pose=init_pose,
         init_shape=init_shape,
         init_cam=init_cam,
         n_iter=n_iter,
         J_regressor=J_regressor)
     pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3)
     # pred_rotmat = convert_orth_6d_to_mat(pred_pose).view(batch_size , 24, 3, 3)
     return self.smpl_to_kpts(pred_rotmat, pred_shape, pred_cam,
                              J_regressor)
Beispiel #3
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                return_features=False):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
        }]

        if return_features:
            return xf, output
        else:
            return output
Beispiel #4
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                J_regressor=None):
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        # print('inside SPIN model, pred pose shape',pred_pose.shape)

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        # print("Inputto spin model pose--",pred_pose.shape)

        # print("Inputto spin model ---- ",pred_shape.shape,pred_shape,pred_rotmat[:, 1:].shape,pred_rotmat[:, 0].unsqueeze(1).shape)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints
        #print('inside spin model,pred_vertices shape ',pred_joints.shape)

        #H36M_TO_J17 = [8,5,45,46,4,7,21,19,17,16,18,20,47,48,51,50,24]

        #H36M_TO_J17 = [3,2,1,6,7,8,27,26,25,17,18,19,14,15,12,13,14]

        #H36M_TO_J17 = [8,5,2,1,4,7,21,19,17,16,18,20,12,15,3,9,15]

        H36M_TO_J17 = [
            25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 40, 44
        ]
        #print("@@@@@ USED!!!!")

        #print('predicted joints shape',pred_joints.shape)

        pred_joints = pred_joints[:, H36M_TO_J17, :]
        #print('after predicted joints shape',pred_joints.shape)

        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]
            #print('after predicted joints shape',pred_joints.shape)

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output