Ejemplo n.º 1
0
    def forward_init(self,
                     x,
                     init_pose=None,
                     init_shape=None,
                     init_cam=None,
                     n_iter=1,
                     J_regressor=None):
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose.contiguous()).view(
            batch_size, 24, 3, 3)

        pred_output = self.smpl(betas=pred_shape,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints
        pred_smpl_joints = pred_output.smpl_joints
        pred_keypoints_2d = projection(pred_joints, pred_cam)
        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3,
                                                                 3)).reshape(
                                                                     -1, 72)

        if J_regressor is not None:
            pred_joints = torch.matmul(J_regressor, pred_vertices)
            pred_pelvis = pred_joints[:, [0], :].clone()
            pred_joints = pred_joints[:, H36M_TO_J14, :]
            pred_joints = pred_joints - pred_pelvis

        output = {
            'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'smpl_kp_3d': pred_smpl_joints,
            'rotmat': pred_rotmat,
            'pred_cam': pred_cam,
            'pred_shape': pred_shape,
            'pred_pose': pred_pose,
        }
        return output
Ejemplo n.º 2
0
 def ssl(self, init_pose, init_shape, init_cam, xf, n_iter, batch_size):
     pred_pose = init_pose
     pred_shape = init_shape
     pred_cam = init_cam
     for i in range(n_iter):
         xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
         xc = self.ssl_head(xc)
         pred_pose = self.ssl_decpose(xc) + pred_pose
         pred_shape = self.ssl_decshape(xc) + pred_shape
         pred_cam = self.ssl_deccam(xc) + pred_cam
     pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
     return pred_rotmat, pred_shape, pred_cam
Ejemplo n.º 3
0
    def __init__(self,
                 options,
                 orig_size=224,
                 feat_in_dim=None,
                 smpl_mean_params=None,
                 pretrained=True):
        super(SMPL_Regressor, self).__init__()

        self.mapping_to_detectron = None
        self.orphans_in_detectron = None

        self.focal_length = 5000.
        self.options = options

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.orig_size = orig_size

        mean_params = np.load(smpl_mean_params)
        init_pose_6d = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
        if cfg.DANET.USE_6D_ROT:
            init_pose = init_pose_6d
        else:
            init_pose_rotmat = rot6d_to_rotmat(init_pose_6d)
            init_pose = init_pose_rotmat.reshape(-1).unsqueeze(0)
        init_shape = torch.from_numpy(
            mean_params['shape'][:].astype('float32')).unsqueeze(0)
        init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)

        init_params = (init_cam, init_shape, init_pose)

        self.smpl = SMPL(path_config.SMPL_MODEL_DIR,
                         batch_size=self.options.batch_size,
                         create_transl=False)

        if cfg.DANET.DECOMPOSED:
            print('using decomposed predictor.')
            self.smpl_para_Outs = DecomposedPredictor(feat_in_dim, init_params,
                                                      pretrained)
        else:
            print('using global predictor.')
            self.smpl_para_Outs = GlobalPredictor(feat_in_dim, pretrained)

        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
Ejemplo n.º 4
0
    def postprocess(self, p_est, b_est):
        """
        Convert 6d rotation to 9d rotation
        Input:
            p_est: pose_tran from forward()
            b_est: bone from forward()
        """
        pose_6d = p_est[:, :-3].contiguous()
        p_est_rot = rot6d_to_rotmat(pose_6d).view(-1, 25 * 9)

        pose = p_est_rot
        tran = p_est[:, -3:]
        bone = b_est

        return pose, tran, bone
Ejemplo n.º 5
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        # ouput the prediected rotation, shape, and cam parameters
        return pred_rotmat, pred_shape, pred_cam
Ejemplo n.º 6
0
Archivo: hmr.py Proyecto: xjwxjw/BOA
    def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
        # mean_params = np.load(smpl_mean_params)
        # self.init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
        # self.init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
        # self.init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)
        x = self.conv1(x)
        # y2 = self.adapter_conv1x1_prev(x)
        # x = y + y2
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam],1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam
        
        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        return pred_rotmat, pred_shape, pred_cam
Ejemplo n.º 7
0
    def forward(self, x, scale, init_pose=None, init_shape=None, init_cam=None, n_iter=3):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x, scale)
        x2 = self.layer2(x1, scale)
        x3 = self.layer3(x2, scale)
        x4 = self.layer4(x3, scale)
        feat_layer4 = x4.view(batch_size, x4.size(1), -1).mean(dim=-1)
        feat_layer4 = self.layer4_mlp(feat_layer4)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.relu(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.relu(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
        feat_list = [feat_layer4]

        return pred_rotmat, pred_shape, pred_cam, feat_list
Ejemplo n.º 8
0
    def forward(self, body_iuv, limb_iuv):

        return_dict = {}
        return_dict['visualization'] = {}
        return_dict['losses'] = {}

        if cfg.DANET.INPUT_MODE == 'rgb':
            map_channels = body_iuv.size(1) / 3
            body_u, body_v, body_i = body_iuv[:, :
                                              map_channels], body_iuv[:,
                                                                      map_channels:
                                                                      2 *
                                                                      map_channels], body_iuv[:,
                                                                                              2
                                                                                              *
                                                                                              map_channels:
                                                                                              3
                                                                                              *
                                                                                              map_channels]
            global_para, global_feat = self.body_net(
                iuv_map2img(body_u, body_v, body_i))
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            global_para, global_feat = self.body_net(body_iuv)
        elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            global_para, global_feat = self.body_net(
                torch.cat([body_iuv['iuv'], body_iuv['feat']], dim=1))
        elif cfg.DANET.INPUT_MODE == 'feat':
            global_para, global_feat = self.body_net(body_iuv['feat'])
        elif cfg.DANET.INPUT_MODE == 'seg':
            global_para, global_feat = self.body_net(body_iuv['index'])

        global_para += self.mean_cam_shape

        if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']:
            nbs = limb_iuv['pfeat'].size(0)
            limb_mapsize = limb_iuv['pfeat'].size(-1)
        elif cfg.DANET.INPUT_MODE in ['seg']:
            nbs = limb_iuv['pindex'].size(0)
            limb_mapsize = limb_iuv['pindex'].size(-1)
        else:
            nbs = limb_iuv.size(0)
            limb_mapsize = limb_iuv.size(-1)

        if cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            limb_iuv_stacked = limb_iuv['piuv'].view(nbs * 24, -1,
                                                     limb_mapsize,
                                                     limb_mapsize)
            limb_feat_stacked = limb_iuv['pfeat'].view(nbs * 24, -1,
                                                       limb_mapsize,
                                                       limb_mapsize)
            _, limb_feat = self.limb_net(
                torch.cat([limb_iuv_stacked, limb_feat_stacked], dim=1))
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            limb_iuv_stacked = limb_iuv.view(nbs * 24, -1, limb_mapsize,
                                             limb_mapsize)
            _, limb_feat = self.limb_net(limb_iuv_stacked)
        if cfg.DANET.INPUT_MODE == 'feat':
            limb_feat_stacked = limb_iuv['pfeat'].view(nbs * 24, -1,
                                                       limb_mapsize,
                                                       limb_mapsize)
            _, limb_feat = self.limb_net(limb_feat_stacked)
        if cfg.DANET.INPUT_MODE == 'seg':
            limb_feat_stacked = limb_iuv['pindex'].view(
                nbs * 24, -1, limb_mapsize, limb_mapsize)
            _, limb_feat = self.limb_net(limb_feat_stacked)

        limb_feat = limb_feat['x4']
        limb_feat = self.limb_reslayer(
            limb_feat.view(nbs, -1, limb_feat.size(-2), limb_feat.size(-1)))

        rot_feats = limb_feat.view(nbs, 24, -1, limb_feat.size(-2),
                                   limb_feat.size(-1))

        if cfg.DANET.REFINE_STRATEGY == 'lstm_direct':

            return_dict['joint_rotation'] = []

            local_para = self.pose_regressors[0](rot_feats.view(
                rot_feats.size(0), 24 * rot_feats.size(2), 1,
                1)).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)
            return_dict['joint_rotation'].append(smpl_pose)

            for s_i in range(cfg.DANET.REFINEMENT.STACK_NUM):
                pos_feats = {}
                for i in range(24):
                    pos_feats[i] = rot_feats[:, i]

                pos_feats_refined = {}
                for br in range(len(self.limb_branch_lstm)):
                    pos_feat_in = torch.stack(
                        [pos_feats[ind] for ind in self.limb_branch_lstm[br]],
                        dim=1)
                    pos_feat_in = pos_feat_in.squeeze(-1).squeeze(-1)
                    if br == 0:
                        lstm_out, hidden_feat = self.limb_lstm[s_i][0](
                            pos_feat_in)
                    elif br == 1:
                        lstm_out, _ = self.limb_lstm[s_i][0](pos_feat_in,
                                                             hidden_feat)
                    elif br in [2, 3]:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in,
                                                                  hidden_feat)
                    else:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in)
                    for i, ind in enumerate(self.limb_branch_lstm[br]):
                        if ind == 0 and br != 0:
                            continue
                        pos_feats_refined[ind] = lstm_out[:, i].unsqueeze(
                            -1).unsqueeze(-1)

                # update
                for i in range(24):
                    pos_feats[i] = pos_feats[i].repeat(
                        1, 2, 1, 1) + pos_feats_refined[i]

                refined_feat = torch.stack([pos_feats[i] for i in range(24)],
                                           dim=1)
                part_feats = refined_feat.view(refined_feat.size(0),
                                               24 * refined_feat.size(2), 1, 1)

                local_para = self.pose_regressors[s_i + 1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'lstm':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            if self.training:
                local_para = self.pose_regressors[0](rot_feats.view(
                    rot_feats.size(0), 24 * rot_feats.size(2), 1,
                    1)).view(nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)
                return_dict['joint_rotation'].append(smpl_pose)

            rot_feats_before = rot_feats

            for s_i in range(cfg.DANET.REFINEMENT.STACK_NUM):
                pos_feats = {}

                pos_feats[0] = rot_feats_before[:, 0]
                for br in range(len(self.limb_branch)):
                    for ind in self.limb_branch[br]:
                        p_ind = self.smpl_parents[0][ind]
                        pos_rot_feat_cat = torch.cat(
                            [pos_feats[p_ind], rot_feats_before[:, p_ind]],
                            dim=1)
                        pos_feats[ind] = self.rot2pos[s_i][ind](
                            pos_rot_feat_cat)

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats = torch.cat(
                            [pos_feats[i] for i in range(24)], dim=1)
                        smpl_coord = self.coord_regressors[s_i](
                            coord_feats).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord)

                pos_feats_refined = {}
                for br in range(len(self.limb_branch_lstm)):
                    pos_feat_in = torch.stack(
                        [pos_feats[ind] for ind in self.limb_branch_lstm[br]],
                        dim=1)
                    pos_feat_in = pos_feat_in.squeeze(-1).squeeze(-1)
                    if br == 0:
                        lstm_out, hidden_feat = self.limb_lstm[s_i][0](
                            pos_feat_in)
                    elif br == 1:
                        lstm_out, _ = self.limb_lstm[s_i][0](pos_feat_in,
                                                             hidden_feat)
                    elif br in [2, 3]:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in,
                                                                  hidden_feat)
                    else:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in)
                    for i, ind in enumerate(self.limb_branch_lstm[br]):
                        if ind == 0 and br != 0:
                            continue
                        pos_feats_refined[ind] = lstm_out[:, i].unsqueeze(
                            -1).unsqueeze(-1)

                # update
                for i in range(24):
                    pos_feats[i] = pos_feats[i].repeat(
                        1, 2, 1, 1) + pos_feats_refined[i]

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats = torch.cat(
                            [pos_feats[i] for i in range(24)], dim=1)
                        smpl_coord = self.coord_regressors[s_i + 1](
                            coord_feats).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord)

                tri_pos_feats = [
                    torch.cat([
                        pos_feats[self.smpl_parents[0][i]], pos_feats[i],
                        pos_feats[self.smpl_children[1][i]]
                    ],
                              dim=1) for i in range(24)
                ]
                tri_pos_feats = torch.cat(tri_pos_feats, dim=0)
                tran_rot_feats = self.pos2rot[s_i](tri_pos_feats)
                tran_rot_feats = tran_rot_feats.view(24, nbs, -1,
                                                     tran_rot_feats.size(-2),
                                                     tran_rot_feats.size(-1))
                tran_rot_feats = tran_rot_feats.transpose(0, 1)

                part_feats = tran_rot_feats.contiguous().view(
                    tran_rot_feats.size(0), 24 * tran_rot_feats.size(2), 1, 1)

                local_para = self.pose_regressors[s_i + 1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'gcn':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            if self.training:
                local_para = self.pose_regressors[0](rot_feats.view(
                    rot_feats.size(0), 24 * rot_feats.size(2), 1,
                    1)).view(nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)
                return_dict['joint_rotation'].append(smpl_pose)

            rot_feats_before = rot_feats

            rot_feats_init = rot_feats_before.squeeze(-1).squeeze(-1)
            pos_feats_init = self.r2p_gcn(rot_feats_init, self.r2p_A[0])

            if self.training:
                if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                    coord_feats0 = pos_feats_init.unsqueeze(2).view(
                        pos_feats_init.size(0),
                        pos_feats_init.size(-1) * 24, 1, 1)
                    smpl_coord0 = self.coord_regressors[0](coord_feats0).view(
                        nbs, 24, -1)
                    return_dict['joint_position'].append(smpl_coord0)

            if cfg.DANET.REFINEMENT.REFINE_ON:
                graph_A = self.A_mask * self.edge_act(self.edge_importance)
                norm_graph_A = normalize_undigraph(self.I_n[0] + graph_A)[0]
                l_pos_feat = self.refine_gcn(pos_feats_init, norm_graph_A)
                l_pos_feat = pos_feats_init + l_pos_feat

                pos_feats_refined = l_pos_feat

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats1 = pos_feats_refined.unsqueeze(2).view(
                            pos_feats_refined.size(0),
                            pos_feats_refined.size(-1) * 24, 1, 1)
                        smpl_coord1 = self.coord_regressors[1](
                            coord_feats1).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord1)
            else:
                pos_feats_refined = pos_feats_init

            rot_feats_refined = self.p2r_gcn(pos_feats_refined, self.p2r_A[0])

            tran_rot_feats = rot_feats_refined.unsqueeze(-1).unsqueeze(-1)
            part_feats = tran_rot_feats.view(tran_rot_feats.size(0),
                                             24 * tran_rot_feats.size(2), 1, 1)

            local_para = self.pose_regressors[-1](part_feats).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)
            smpl_pose += self.mean_pose
            if cfg.DANET.USE_6D_ROT:
                smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                    local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'gcn_direct':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            local_para = self.pose_regressors[0](rot_feats.view(
                rot_feats.size(0), 24 * rot_feats.size(2), 1,
                1)).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)

            if cfg.DANET.REFINEMENT.REFINE_ON:
                return_dict['joint_rotation'].append(smpl_pose)

                pos_feats_init = rot_feats.squeeze(-1).squeeze(-1)

                graph_A = self.A_mask * self.edge_act(self.edge_importance)
                norm_graph_A = normalize_undigraph(self.I_n[0] + graph_A)[0]
                l_pos_feat = self.refine_gcn(pos_feats_init, norm_graph_A)
                l_pos_feat = pos_feats_init + l_pos_feat

                pos_feats_refined = l_pos_feat
                tran_rot_feats = pos_feats_refined.unsqueeze(-1).unsqueeze(-1)

                part_feats = tran_rot_feats.view(tran_rot_feats.size(0),
                                                 24 * tran_rot_feats.size(2),
                                                 1, 1)

                local_para = self.pose_regressors[-1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)

        para = torch.cat([global_para, smpl_pose], dim=1)

        return_dict['para'] = para

        return return_dict
Ejemplo n.º 9
0
    dtype = torch.float32
    batch_size = args.batch_size

    fitting_evaluator = build_evaluator(args, device)
    body_model = build_body_model(args, device)
    smplify = build_smplify3d(args)
    train_dataloader = setup_human36m_dloader(args)

    mean_params = np.load(args.mean_params)
    init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0).expand(
        batch_size, -1)
    init_betas = torch.from_numpy(
        mean_params['shape'][:].astype('float32')).unsqueeze(0)
    init_pose = init_pose.to(device=device).expand(batch_size, -1)
    init_betas = init_betas.to(device=device).expand(batch_size, -1)
    init_pose = rot6d_to_rotmat(init_pose).view(batch_size, 24, 3, 3)
    init_orient = init_pose[:, 0].unsqueeze(1)
    init_pose = init_pose[:, 1:]

    J_regressor_single_batch = torch.from_numpy(np.load(
        args.regressor)).float()

    if os.path.isfile(args.outfile):
        output = np.load(args.outfile, allow_pickle=True)
        opt_pose_output = output['pose']
        opt_betas_output = output['betas']
        opt_orient_output = output['orient']
    else:
        opt_pose_output, opt_betas_output, opt_orient_output = None, None, None

    iterator = enumerate(train_dataloader)
Ejemplo n.º 10
0
    def forward(self,
                x,
                init_pose=None,
                init_shape=None,
                init_cam=None,
                n_iter=3,
                J_regressor=None):
        # 这里的 batch_size =  n * T
        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size,
                                              -1)  # batch_size X 24*6
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size,
                                                -1)  # batch_size X 10
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)  # batch_size X 3
        '''
          预测的pose,shape以及cam 初始化
        '''
        pred_pose = init_pose  # batch_size X 24*6
        pred_shape = init_shape  # batch_size X 10
        pred_cam = init_cam  # batch_size X 3
        ''' n_iter 迭代次数 '''
        for i in range(n_iter):
            xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose  # batch_size X 24*6
            pred_shape = self.decshape(xc) + pred_shape  # batch_size X 10
            pred_cam = self.deccam(xc) + pred_cam  # batch_size X 3
        ''' 经过 rot6d_to_rotmat后 数据形状 为 batch_size * 24 X 3 X 3 '''
        pred_rotmat = rot6d_to_rotmat(pred_pose).view(
            batch_size, 24, 3, 3)  # 改变形状后为 (batch_size, 24, 3, 3)
        ''' 通过smpl Regressor模型生成预测smpl模型参数'''
        pred_output = self.smpl(
            betas=pred_shape,
            body_pose=pred_rotmat[:, 1:],  # batch_size X 23 X 3 X 3
            global_orient=pred_rotmat[:, 0].unsqueeze(
                1),  # batch_size X 1 X 3 X 3
            pose2rot=False  #
        )

        pred_vertices = pred_output.vertices  #预测定点
        pred_joints = pred_output.joints  #关节

        ###J_regressor == None
        if J_regressor is not None:
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
            pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
            pred_joints = pred_joints[:, H36M_TO_J14, :]

        pred_keypoints_2d = projection(pred_joints, pred_cam)

        pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(
            -1, 3, 3)).reshape(-1, 72)  # pose --> (batxh_size,72)

        output = [{
            'theta': torch.cat([pred_cam, pose, pred_shape],
                               dim=1),  # --->(batxh_size, 85)
            'verts': pred_vertices,
            'kp_2d': pred_keypoints_2d,
            'kp_3d': pred_joints,
            'rotmat': pred_rotmat
        }]
        return output