def save_result(self, xh_rec, output_data_file): dirname = os.path.dirname(output_data_file) if not os.path.exists(dirname): os.makedirs(dirname) body_param_list = BodyParamParser.body_params_encapsulate(xh_rec) print('[INFO] save results to: ' + output_data_file) for _, body_param in enumerate(body_param_list): body_param['cam_ext'] = self.cam_ext.detach().cpu().numpy() body_param['cam_int'] = self.cam_int.detach().cpu().numpy() outfile = open(output_data_file, 'wb') pickle.dump(body_param, outfile) outfile.close()
def test(self): self.model_h.eval() self.model_h.to(self.device) self.vposer.to(self.device) self.body_mesh_model.to(self.device) ## load checkpoints ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir, 'epoch-*.ckp')), key=os.path.getmtime) ckp_path = ckp_list[-1] checkpoint = torch.load(ckp_path) print('[INFO] load checkpoints: ' + ckp_path) self.model_h.load_state_dict(checkpoint['model_h_state_dict']) ## read data from here! cam_file_list = sorted(glob.glob(self.test_data_path + '/cam_*')) for ii, cam_file in enumerate(cam_file_list): cam_params = np.load(cam_file, allow_pickle=True, encoding='latin1').item() depth0 = torch.tensor(np.load(cam_file.replace('cam', 'depth')), dtype=torch.float32, device=self.device) seg0 = torch.tensor(np.load(cam_file.replace('cam', 'seg')), dtype=torch.float32, device=self.device) cam_ext = torch.tensor(cam_params['cam_ext'], dtype=torch.float32, device=self.device).unsqueeze(0) #[1,4,4] cam_int = torch.tensor(cam_params['cam_int'], dtype=torch.float32, device=self.device).unsqueeze(0) # [1,3,3] depth, _, max_d = self.data_preprocessing( depth0, 'depth', target_domain_size=[128, 128]) #[1,1,128,128] max_d = max_d.view(1) seg, _, _ = self.data_preprocessing(seg0, 'depth', target_domain_size=[128, 128]) xs = torch.cat([depth, seg], dim=1) xs_batch = xs.repeat(self.n_samples, 1, 1, 1) max_d_batch = max_d.repeat(self.n_samples) cam_int_batch = cam_int.repeat(self.n_samples, 1, 1) cam_ext_batch = cam_ext.repeat(self.n_samples, 1, 1) xhnr_gen = self.model_h.sample(xs_batch) xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen) xh_gen = GeometryTransformer.recover_global_T( xhn_gen, cam_int_batch, max_d_batch) body_param_list = BodyParamParser.body_params_encapsulate(xh_gen) if not os.path.exists(self.outdir): os.makedirs(self.outdir) print('[INFO] save results to: ' + self.outdir) for jj, body_param in enumerate(body_param_list): body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy() body_param['cam_int'] = cam_int_batch.detach().cpu().numpy() outfilename = os.path.join( self.outdir, 'body_gen_{:06d}.pkl'.format(self.n_samples * ii + jj)) outfile = open(outfilename, 'wb') pickle.dump(body_param, outfile) outfile.close()
def test(self, batch_gen): self.model_h.eval() self.model_h.to(self.device) self.vposer.to(self.device) self.body_mesh_model.to(self.device) ## load checkpoints ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir, 'epoch-*.ckp')), key=os.path.getmtime) ckp_path = ckp_list[-1] checkpoint = torch.load(ckp_path) print('[INFO] load checkpoints: ' + ckp_path) self.model_h.load_state_dict(checkpoint['model_h_state_dict']) ## get a batch of data for testing batch_gen.reset() test_data = batch_gen.next_batch(batch_size=1) depth_batch = test_data[0] seg_batch = test_data[1] max_d_batch = test_data[2] cam_int_batch = test_data[3] cam_ext_batch = test_data[4] ## pass data to network xs = torch.cat([depth_batch, seg_batch], dim=1) xs_n = xs.repeat(self.n_samples, 1, 1, 1) noise_batch_g = torch.randn([self.n_samples, self.model_h_latentD], dtype=torch.float32, device=self.device) noise_batch_l = torch.randn([self.n_samples, self.model_h_latentD], dtype=torch.float32, device=self.device) if self.use_cont_rot: xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l) xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen) else: xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l) xh_gen = GeometryTransformer.recover_global_T(xhn_gen, cam_int_batch, max_d_batch) body_param_list = BodyParamParser.body_params_encapsulate(xh_gen) scene_name = os.path.abspath( self.scene_file_path).split("/")[-2].split("_")[0] outdir = os.path.join(self.output_dir, scene_name) if not os.path.exists(outdir): os.makedirs(outdir) print('[INFO] save results to: ' + outdir) for ii, body_param in enumerate(body_param_list): body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy() body_param['cam_int'] = cam_int_batch.detach().cpu().numpy() outfilename = os.path.join(outdir, 'body_gen_{:06d}.pkl'.format(ii + 900)) outfile = open(outfilename, 'wb') pickle.dump(body_param, outfile) outfile.close()