Beispiel #1
0
    def fitting(self, input_data_file):

        with open(input_data_file, 'rb') as f:
            body_param_input = pickle.load(f)

        xh, self.cam_ext, self.cam_int = BodyParamParser.body_params_parse_fitting(
            body_param_input)
        xhr = GeometryTransformer.convert_to_6D_rot(xh)
        self.xhr_rec.data = xhr.clone()

        for ii in range(self.num_iter):

            self.optimizer.zero_grad()

            loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss(
                xhr, self.cam_ext)
            loss = loss_rec + loss_vposer + loss_contact + loss_collision
            if self.verbose:
                print(
                    '[INFO][fitting] iter={:d}, l_rec={:f}, l_vposer={:f}, l_contact={:f}, l_collision={:f}'
                    .format(ii, loss_rec.item(), loss_vposer.item(),
                            loss_contact.item(), loss_collision.item()))

            loss.backward(retain_graph=True)
            self.optimizer.step()

        print('[INFO][fitting] fitting finish, returning optimal value')

        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        return xh_rec
Beispiel #2
0
    def cal_loss(self, xhr, cam_ext):
        
        ### reconstruction loss
        loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec)
        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        ### vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact score and non-coll score
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], 
                                           output_type='aa').view(self.batch_size, -1)
 
        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]

        smplx_output = self.body_mesh_model(return_verts=True, 
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)


        s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), 
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        loss_contact = torch.tensor(1.0, dtype=torch.float32, device=self.device)
        if body_sdf_batch.lt(0).sum().item() < 1: # if the number of negative sdf entries is less than one
            loss_sdf_pene = torch.tensor(10475, dtype=torch.float32, device=self.device)
            loss_contact = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = (body_sdf_batch > 0).sum()

        loss_collision = loss_sdf_pene.float() / 10475.0

        return loss_rec, loss_vposer, loss_contact, loss_collision
Beispiel #3
0
    def eval_colllision(self):

        coll_list = []
        cont_list = []
        for ii in range(8000):
            filename = os.path.join(self.input_data_file, 'body_gen_{:06d}.pkl'.format(ii))
            
            if not os.path.exists(filename):
                continue

            with open(filename, 'rb') as f:
                body_param_input = pickle.load(f)


            xh, self.cam_ext, self.cam_int= BodyParamParser.body_params_parse_fitting(body_param_input)
            xhr = GeometryTransformer.convert_to_6D_rot(xh)
            self.xhr_rec = xhr

            T_mat = np.eye(4)
            T_mat[1,:] = np.array([0,-1,0,0])
            T_mat[2,:] = np.array([0,0,-1,0])
            T_mat = torch.tensor(T_mat, dtype=torch.float32, device=self.device)
            T_mat = T_mat.unsqueeze(0)
            trans = torch.matmul(self.cam_ext[:1],T_mat)

            loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss(xhr, trans)

            coll_list.append(loss_collision.item())
            cont_list.append(loss_contact.item())

        return coll_list, cont_list
Beispiel #4
0
    def fitting(self, input_data_file):


        with open(input_data_file, 'rb') as f:
            body_param_input = pickle.load(f)

        xh, self.cam_ext, self.cam_int= BodyParamParser.body_params_parse_fitting(body_param_input)
        xhr = GeometryTransformer.convert_to_6D_rot(xh)
        self.xhr_rec.data = xhr.clone()

        T_mat = np.eye(4)
        T_mat[1,:] = np.array([0,-1,0,0])
        T_mat[2,:] = np.array([0,0,-1,0])
        T_mat = torch.tensor(T_mat, dtype=torch.float32, device=self.device)
        T_mat = T_mat.unsqueeze(0)
        trans = torch.matmul(self.cam_ext[:1],T_mat)
        


        for ii in range(self.num_iter):

            self.optimizer.zero_grad()

            loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss(xhr, trans)
            loss = loss_rec + loss_vposer + loss_contact + loss_collision
            if self.verbose:
                print('[INFO][fitting] iter={:d}, l_rec={:f}, l_vposer={:f}, l_contact={:f}, l_collision={:f}'.format(
                                        ii, loss_rec.item(), loss_vposer.item(), 
                                        loss_contact.item(), loss_collision.item()) )

            loss.backward(retain_graph=True)
            self.optimizer.step()


        print('[INFO][fitting] fitting finish, returning optimal value')


        xh_rec =  GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        return xh_rec
Beispiel #5
0
    def cal_loss(self, xs, xh, eps_g, eps_l, cam_ext, cam_int, max_d,
                 scene_verts, scene_face,
                 s_grid_min_batch, s_grid_max_batch, s_grid_sdf_batch,
                 ep):



        # normalize global trans
        xhn = GeometryTransformer.normalize_global_T(xh, cam_int, max_d)

        # convert global rotation
        xhnr = GeometryTransformer.convert_to_6D_rot(xhn)
        [xhnr_rec, mu_g, logsigma2_g, 
            mu_l, logsigma2_l] = self.model_h(xhnr, eps_g, eps_l, xs)
        xhn_rec = GeometryTransformer.convert_to_3D_rot(xhnr_rec)


        # recover global trans
        xh_rec = GeometryTransformer.recover_global_T(xhn_rec, cam_int, max_d)

        loss_rec_t = self.weight_loss_rec_h*( 0.5*F.l1_loss(xhnr_rec[:,:3], xhnr[:,:3])
                                             +0.5*F.l1_loss(xh_rec[:,:3], xh[:,:3]))
        loss_rec_p = self.weight_loss_rec_h*F.l1_loss(xhnr_rec[:,3:], xhnr[:,3:])


        ### kl divergence loss
        fca = 1.0
        if self.loss_weight_anealing:
            fca = min(1.0, max(float(ep) / (self.epoch*0.75),0) )

        loss_KL_g = fca**2  *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_g) +mu_g**2 -1.0 -logsigma2_g)
        loss_KL_l = fca**2  *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_l) +mu_l**2 -1.0 -logsigma2_l)



        ### Vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact loss
        ## (1) get the reconstructed body mesh
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'],
                                output_type='aa').view(self.batch_size, -1)

        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]


        smplx_output = self.body_mesh_model(return_verts=True,
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)

        ## (2) select contact vertex according to prox annotation
        vid, fid = GeometryTransformer.get_contact_id(body_segments_folder=self.contact_id_folder, 
                                  contact_body_parts=self.contact_part)

        body_verts_contact_batch = body_verts_batch[:, vid, :]


        ## (3) compute chamfer loss between pcd_batch and body_verts_batch
        dist_chamfer_contact = ext.chamferDist()
        contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(),
                                                        scene_verts.contiguous())

        fcc = 0.0
        if ep > 0.75*self.epoch:
            fcc = 1.0

        loss_contact = fcc *self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0)  )


        ### SDF scene penetration loss
        s_grid_min_batch = s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(s_grid_sdf_batch.unsqueeze(1),
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        # if there are no penetrating vertices then set sdf_penetration_loss = 0
        if body_sdf_batch.lt(0).sum().item() < 1:
            loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean()

        fsp = 0.0
        if ep > 0.75*self.epoch:
            fsp = 1.0

        loss_sdf_pene = fsp*self.weight_collision*loss_sdf_pene

        return loss_rec_t, loss_rec_p, loss_KL_g, loss_KL_l, loss_contact, loss_vposer, loss_sdf_pene
Beispiel #6
0
    def test(self):

        self.model_h.eval()
        self.model_h.to(self.device)
        self.vposer.to(self.device)
        self.body_mesh_model.to(self.device)

        ## load checkpoints
        ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir,
                                                 'epoch-*.ckp')),
                          key=os.path.getmtime)
        ckp_path = ckp_list[-1]
        checkpoint = torch.load(ckp_path)
        print('[INFO] load checkpoints: ' + ckp_path)
        self.model_h.load_state_dict(checkpoint['model_h_state_dict'])

        ## read data from here!
        cam_file_list = sorted(glob.glob(self.test_data_path + '/cam_*'))

        for ii, cam_file in enumerate(cam_file_list):

            cam_params = np.load(cam_file,
                                 allow_pickle=True,
                                 encoding='latin1').item()

            depth0 = torch.tensor(np.load(cam_file.replace('cam', 'depth')),
                                  dtype=torch.float32,
                                  device=self.device)
            seg0 = torch.tensor(np.load(cam_file.replace('cam', 'seg')),
                                dtype=torch.float32,
                                device=self.device)

            cam_ext = torch.tensor(cam_params['cam_ext'],
                                   dtype=torch.float32,
                                   device=self.device).unsqueeze(0)  #[1,4,4]
            cam_int = torch.tensor(cam_params['cam_int'],
                                   dtype=torch.float32,
                                   device=self.device).unsqueeze(0)  # [1,3,3]

            depth, _, max_d = self.data_preprocessing(
                depth0, 'depth', target_domain_size=[128, 128])  #[1,1,128,128]
            max_d = max_d.view(1)
            seg, _, _ = self.data_preprocessing(seg0,
                                                'depth',
                                                target_domain_size=[128, 128])

            xs = torch.cat([depth, seg], dim=1)
            xs_batch = xs.repeat(self.n_samples, 1, 1, 1)
            max_d_batch = max_d.repeat(self.n_samples)
            cam_int_batch = cam_int.repeat(self.n_samples, 1, 1)
            cam_ext_batch = cam_ext.repeat(self.n_samples, 1, 1)

            xhnr_gen = self.model_h.sample(xs_batch)
            xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen)
            xh_gen = GeometryTransformer.recover_global_T(
                xhn_gen, cam_int_batch, max_d_batch)

            body_param_list = BodyParamParser.body_params_encapsulate(xh_gen)

            if not os.path.exists(self.outdir):
                os.makedirs(self.outdir)

            print('[INFO] save results to: ' + self.outdir)
            for jj, body_param in enumerate(body_param_list):
                body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy()
                body_param['cam_int'] = cam_int_batch.detach().cpu().numpy()
                outfilename = os.path.join(
                    self.outdir,
                    'body_gen_{:06d}.pkl'.format(self.n_samples * ii + jj))
                outfile = open(outfilename, 'wb')
                pickle.dump(body_param, outfile)
                outfile.close()
Beispiel #7
0
    def test(self, batch_gen):

        self.model_h.eval()
        self.model_h.to(self.device)

        self.vposer.to(self.device)
        self.body_mesh_model.to(self.device)

        ## load checkpoints
        ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir,
                                                 'epoch-*.ckp')),
                          key=os.path.getmtime)
        ckp_path = ckp_list[-1]
        checkpoint = torch.load(ckp_path)
        print('[INFO] load checkpoints: ' + ckp_path)
        self.model_h.load_state_dict(checkpoint['model_h_state_dict'])

        ## get a batch of data for testing
        batch_gen.reset()

        test_data = batch_gen.next_batch(batch_size=1)

        depth_batch = test_data[0]
        seg_batch = test_data[1]
        max_d_batch = test_data[2]
        cam_int_batch = test_data[3]
        cam_ext_batch = test_data[4]

        ## pass data to network
        xs = torch.cat([depth_batch, seg_batch], dim=1)

        xs_n = xs.repeat(self.n_samples, 1, 1, 1)

        noise_batch_g = torch.randn([self.n_samples, self.model_h_latentD],
                                    dtype=torch.float32,
                                    device=self.device)
        noise_batch_l = torch.randn([self.n_samples, self.model_h_latentD],
                                    dtype=torch.float32,
                                    device=self.device)

        if self.use_cont_rot:
            xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l)
            xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen)
        else:
            xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l)

        xh_gen = GeometryTransformer.recover_global_T(xhn_gen, cam_int_batch,
                                                      max_d_batch)

        body_param_list = BodyParamParser.body_params_encapsulate(xh_gen)

        scene_name = os.path.abspath(
            self.scene_file_path).split("/")[-2].split("_")[0]
        outdir = os.path.join(self.output_dir, scene_name)
        if not os.path.exists(outdir):
            os.makedirs(outdir)

        print('[INFO] save results to: ' + outdir)
        for ii, body_param in enumerate(body_param_list):
            body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy()
            body_param['cam_int'] = cam_int_batch.detach().cpu().numpy()
            outfilename = os.path.join(outdir,
                                       'body_gen_{:06d}.pkl'.format(ii + 900))
            outfile = open(outfilename, 'wb')
            pickle.dump(body_param, outfile)
            outfile.close()
Beispiel #8
0
    def cal_loss(self, xhr, cam_ext):
        
        ### reconstruction loss
        loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec)
        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        ### vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact loss
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], 
                                           output_type='aa').view(self.batch_size, -1)
 
        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]

        smplx_output = self.body_mesh_model(return_verts=True, 
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)

        vid, fid = GeometryTransformer.get_contact_id(
                                body_segments_folder=self.contact_id_folder,
                                contact_body_parts=self.contact_part)
        body_verts_contact_batch = body_verts_batch[:, vid, :]

        dist_chamfer_contact = ext.chamferDist()
        contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(), 
                                                self.s_verts_batch.contiguous())

        loss_contact = self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0))  


        ### sdf collision loss
        s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), 
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        # if there are no penetrating vertices then set sdf_penetration_loss = 0
        if body_sdf_batch.lt(0).sum().item() < 1:
            loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean()

        loss_collision = self.weight_collision*loss_sdf_pene


        return loss_rec, loss_vposer, loss_contact, loss_collision