Exemplo n.º 1
0
    def cal_loss(self, xhr, cam_ext):
        
        ### reconstruction loss
        loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec)
        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        ### vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact score and non-coll score
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], 
                                           output_type='aa').view(self.batch_size, -1)
 
        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]

        smplx_output = self.body_mesh_model(return_verts=True, 
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)


        s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), 
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        loss_contact = torch.tensor(1.0, dtype=torch.float32, device=self.device)
        if body_sdf_batch.lt(0).sum().item() < 1: # if the number of negative sdf entries is less than one
            loss_sdf_pene = torch.tensor(10475, dtype=torch.float32, device=self.device)
            loss_contact = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = (body_sdf_batch > 0).sum()

        loss_collision = loss_sdf_pene.float() / 10475.0

        return loss_rec, loss_vposer, loss_contact, loss_collision
Exemplo n.º 2
0
    def cal_loss(self, xs, xh, eps_g, eps_l, cam_ext, cam_int, max_d,
                 scene_verts, scene_face,
                 s_grid_min_batch, s_grid_max_batch, s_grid_sdf_batch,
                 ep):



        # normalize global trans
        xhn = GeometryTransformer.normalize_global_T(xh, cam_int, max_d)

        # convert global rotation
        xhnr = GeometryTransformer.convert_to_6D_rot(xhn)
        [xhnr_rec, mu_g, logsigma2_g, 
            mu_l, logsigma2_l] = self.model_h(xhnr, eps_g, eps_l, xs)
        xhn_rec = GeometryTransformer.convert_to_3D_rot(xhnr_rec)


        # recover global trans
        xh_rec = GeometryTransformer.recover_global_T(xhn_rec, cam_int, max_d)

        loss_rec_t = self.weight_loss_rec_h*( 0.5*F.l1_loss(xhnr_rec[:,:3], xhnr[:,:3])
                                             +0.5*F.l1_loss(xh_rec[:,:3], xh[:,:3]))
        loss_rec_p = self.weight_loss_rec_h*F.l1_loss(xhnr_rec[:,3:], xhnr[:,3:])


        ### kl divergence loss
        fca = 1.0
        if self.loss_weight_anealing:
            fca = min(1.0, max(float(ep) / (self.epoch*0.75),0) )

        loss_KL_g = fca**2  *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_g) +mu_g**2 -1.0 -logsigma2_g)
        loss_KL_l = fca**2  *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_l) +mu_l**2 -1.0 -logsigma2_l)



        ### Vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact loss
        ## (1) get the reconstructed body mesh
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'],
                                output_type='aa').view(self.batch_size, -1)

        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]


        smplx_output = self.body_mesh_model(return_verts=True,
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)

        ## (2) select contact vertex according to prox annotation
        vid, fid = GeometryTransformer.get_contact_id(body_segments_folder=self.contact_id_folder, 
                                  contact_body_parts=self.contact_part)

        body_verts_contact_batch = body_verts_batch[:, vid, :]


        ## (3) compute chamfer loss between pcd_batch and body_verts_batch
        dist_chamfer_contact = ext.chamferDist()
        contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(),
                                                        scene_verts.contiguous())

        fcc = 0.0
        if ep > 0.75*self.epoch:
            fcc = 1.0

        loss_contact = fcc *self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0)  )


        ### SDF scene penetration loss
        s_grid_min_batch = s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(s_grid_sdf_batch.unsqueeze(1),
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        # if there are no penetrating vertices then set sdf_penetration_loss = 0
        if body_sdf_batch.lt(0).sum().item() < 1:
            loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean()

        fsp = 0.0
        if ep > 0.75*self.epoch:
            fsp = 1.0

        loss_sdf_pene = fsp*self.weight_collision*loss_sdf_pene

        return loss_rec_t, loss_rec_p, loss_KL_g, loss_KL_l, loss_contact, loss_vposer, loss_sdf_pene
Exemplo n.º 3
0
    def cal_loss(self, xhr, cam_ext):
        
        ### reconstruction loss
        loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec)
        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        ### vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact loss
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], 
                                           output_type='aa').view(self.batch_size, -1)
 
        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]

        smplx_output = self.body_mesh_model(return_verts=True, 
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)

        vid, fid = GeometryTransformer.get_contact_id(
                                body_segments_folder=self.contact_id_folder,
                                contact_body_parts=self.contact_part)
        body_verts_contact_batch = body_verts_batch[:, vid, :]

        dist_chamfer_contact = ext.chamferDist()
        contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(), 
                                                self.s_verts_batch.contiguous())

        loss_contact = self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0))  


        ### sdf collision loss
        s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), 
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        # if there are no penetrating vertices then set sdf_penetration_loss = 0
        if body_sdf_batch.lt(0).sum().item() < 1:
            loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean()

        loss_collision = self.weight_collision*loss_sdf_pene


        return loss_rec, loss_vposer, loss_contact, loss_collision