def forward(self,
                sample,
                scaletrans=None,
                scale=None,
                trans=None,
                rotaxisang=None):
        """
        Args:
            scaletrans: torch.Tensor of shape [batch_size, channels] with channels == 6
                with in first position the predicted scale values and in 2,3 the 
                predicted translation values, and global rotation encoded as axis-angles
                in channel positions 4,5,6
        """
        if scaletrans is None:
            batch_size = scale.shape[0]
        else:
            batch_size = scaletrans.shape[0]
        if scale is None:
            scale = scaletrans[:, :1]
        if trans is None:
            trans = scaletrans[:, 1:3]
        if rotaxisang is None:
            rotaxisang = scaletrans[:, 3:]
        # Get rotation matrixes from axis-angles
        rotmat = rodrigues_layer.batch_rodrigues(rotaxisang).view(
            rotaxisang.shape[0], 3, 3)
        canobjverts = sample[BaseQueries.OBJCANVERTS].cuda()
        rotobjverts = rotmat.bmm(canobjverts.float().transpose(1,
                                                               2)).transpose(
                                                                   1, 2)

        final_trans = trans.unsqueeze(1) * self.trans_factor
        final_scale = scale.view(batch_size, 1, 1) * self.scale_factor
        height, width = tuple(sample[TransQueries.IMAGE].shape[2:])
        camintr = sample[TransQueries.CAMINTR].cuda()
        objverts3d, center3d = project.recover_3d_proj(rotobjverts,
                                                       camintr,
                                                       final_scale,
                                                       final_trans,
                                                       input_res=(width,
                                                                  height))
        # Recover 2D positions given camera intrinsic parameters and object vertex
        # coordinates in camera coordinate reference
        pred_objverts2d = camproject.batch_proj2d(objverts3d, camintr)
        if BaseQueries.OBJCORNERS3D in sample:
            canobjcorners = sample[BaseQueries.OBJCANCORNERS].cuda()
            rotobjcorners = rotmat.bmm(canobjcorners.float().transpose(
                1, 2)).transpose(1, 2)
            recov_objcorners3d = rotobjcorners + center3d
            pred_objcorners2d = camproject.batch_proj2d(
                rotobjcorners + center3d, camintr)
        else:
            pred_objcorners2d = None
            recov_objcorners3d = None
            rotobjcorners = None
        return {
            "obj_verts2d": pred_objverts2d,
            "obj_verts3d": rotobjverts,
            "recov_objverts3d": objverts3d,
            "recov_objcorners3d": recov_objcorners3d,
            "obj_scale": final_scale,
            "obj_prescale": scale,
            "obj_prerot": rotaxisang,
            "obj_trans": final_trans,
            "obj_pretrans": trans,
            "obj_corners2d": pred_objcorners2d,
            "obj_corners3d": rotobjcorners,
        }
示例#2
0
    def recover_mano(
        self,
        sample,
        encoder_output=None,
        pose=None,
        shape=None,
        no_loss=False,
        total_loss=None,
        scale=None,
        trans=None,
    ):
        # Get hand projection, centered
        mano_results = self.mano_branch(encoder_output,
                                        sides=sample[BaseQueries.SIDE],
                                        pose=pose,
                                        shape=shape)
        if self.adaptor:
            adapt_joints, _ = self.adaptor(mano_results["verts3d"])
            adapt_joints = adapt_joints.transpose(1, 2)
            mano_results[
                "joints3d"] = adapt_joints - adapt_joints[:, self.
                                                          mano_center_idx].unsqueeze(
                                                              1)
            mano_results["verts3d"] = mano_results[
                "verts3d"] - adapt_joints[:, self.mano_center_idx].unsqueeze(1)
        if not no_loss:
            mano_total_loss, mano_losses = self.mano_loss.compute_loss(
                mano_results, sample)
            if total_loss is None:
                total_loss = mano_total_loss
            else:
                total_loss += mano_total_loss
            mano_losses["mano_total_loss"] = mano_total_loss.clone()

        # Recover hand position in camera coordinates
        if (self.mano_lambda_joints2d or self.mano_lambda_verts2d
                or self.mano_lambda_recov_joints3d
                or self.mano_lambda_recov_verts3d):
            if scale is None and trans is None:
                scaletrans = self.scaletrans_branch(encoder_output)
                if trans is None:
                    trans = scaletrans[:, 1:]
                if scale is None:
                    scale = scaletrans[:, :1]
            final_trans = trans.unsqueeze(1) * self.obj_trans_factor
            final_scale = scale.view(scale.shape[0], 1,
                                     1) * self.obj_scale_factor
            height, width = tuple(sample[TransQueries.IMAGE].shape[2:])
            camintr = sample[TransQueries.CAMINTR].cuda()
            recov_joints3d, center3d = project.recover_3d_proj(
                mano_results["joints3d"],
                camintr,
                final_scale,
                final_trans,
                input_res=(width, height))
            recov_hand_verts3d = mano_results["verts3d"] + center3d
            proj_joints2d = camproject.batch_proj2d(recov_joints3d, camintr)
            proj_verts2d = camproject.batch_proj2d(recov_hand_verts3d, camintr)

            mano_results["joints2d"] = proj_joints2d
            mano_results["recov_joints3d"] = recov_joints3d
            mano_results["recov_handverts3d"] = recov_hand_verts3d
            mano_results["hand_center3d"] = center3d
            mano_results["verts2d"] = proj_verts2d
            mano_results["hand_pretrans"] = trans
            mano_results["hand_prescale"] = scale
            mano_results["hand_trans"] = final_trans
            mano_results["hand_scale"] = final_scale
            if not no_loss:
                # Compute hand losses in pixel space and camera coordinates
                if self.mano_lambda_joints2d is not None and TransQueries.JOINTS2D in sample:
                    gt_joints2d = sample[TransQueries.JOINTS2D].cuda().float()
                    if self.criterion2d == "l2":
                        # Normalize predictions in pixel space so that results are roughly centered
                        # and have magnitude ~1
                        norm_joints2d_pred = normalize_pixel_out(proj_joints2d)
                        norm_joints2d_gt = normalize_pixel_out(gt_joints2d)
                        joints2d_loss = torch_f.mse_loss(
                            norm_joints2d_pred, norm_joints2d_gt)
                    elif self.criterion2d == "l1":
                        joints2d_loss = torch_f.l1_loss(
                            proj_joints2d, gt_joints2d)
                    elif self.criterion2d == "smoothl1":
                        joints2d_loss = torch_f.smooth_l1_loss(
                            proj_joints2d, gt_joints2d)
                    total_loss += self.mano_lambda_joints2d * joints2d_loss
                    mano_losses["joints2d"] = joints2d_loss
                if self.mano_lambda_verts2d is not None and TransQueries.HANDVERTS2D in sample:
                    gt_verts2d = sample[
                        TransQueries.HANDVERTS2D].cuda().float()
                    verts2d_loss = torch_f.mse_loss(
                        normalize_pixel_out(proj_verts2d, self.inp_res),
                        normalize_pixel_out(gt_verts2d, self.inp_res),
                    )
                    total_loss += self.mano_lambda_verts2d * verts2d_loss
                    mano_losses["verts2d"] = verts2d_loss
                if self.mano_lambda_recov_joints3d is not None and BaseQueries.JOINTS3D in sample:
                    joints3d_gt = sample[BaseQueries.JOINTS3D].cuda()
                    recov_loss = torch_f.mse_loss(recov_joints3d, joints3d_gt)
                    total_loss += self.mano_lambda_recov_joints3d * recov_loss
                    mano_losses["recov_joint3d"] = recov_loss
                if self.mano_lambda_recov_verts3d is not None and BaseQueries.HANDVERTS3D in sample:
                    hand_verts3d_gt = sample[BaseQueries.HANDVERTS3D].cuda()
                    recov_loss = torch_f.mse_loss(recov_hand_verts3d,
                                                  hand_verts3d_gt)
                    total_loss += self.mano_lambda_recov_verts3d * recov_loss
        return mano_results, total_loss, mano_losses