def get_model(args): opt = parser.parse(args) # load checkpoints state_dict_path = None if opt.load_netMR_checkpoint_path is not None: state_dict_path = opt.load_netMR_checkpoint_path state_dict = None if state_dict_path is not None and os.path.exists(state_dict_path): print('Resuming from ', state_dict_path) state_dict = torch.load(state_dict_path) opt = state_dict['opt'] opt.resolution = 256 opt.loadSize = 1024 cuda = torch.device('cuda') opt_netG = state_dict['opt_netG'] netG = HGPIFuNetwNML(opt_netG).to(device=cuda) netMR = HGPIFuMRNet(opt, netG).to(device=cuda) netG.eval() # load checkpoints netMR.load_state_dict(state_dict['model_state_dict']) return netMR
def recon(opt, use_rect=False): # load checkpoints state_dict_path = None if opt.load_netMR_checkpoint_path is not None: state_dict_path = opt.load_netMR_checkpoint_path elif opt.resume_epoch < 0: state_dict_path = '%s/%s_train_latest' % (opt.checkpoints_path, opt.name) opt.resume_epoch = 0 else: state_dict_path = '%s/%s_train_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch) start_id = opt.start_id end_id = opt.end_id cuda = torch.device('cuda:%d' % opt.gpu_id if torch.cuda.is_available() else 'cpu') state_dict = None if state_dict_path is not None and os.path.exists(state_dict_path): print('Resuming from ', state_dict_path) state_dict = torch.load(state_dict_path, map_location=cuda) print('Warning: opt is overwritten.') dataroot = opt.dataroot resolution = opt.resolution results_path = opt.results_path loadSize = opt.loadSize opt = state_dict['opt'] opt.dataroot = dataroot opt.resolution = resolution opt.results_path = results_path opt.loadSize = loadSize else: raise Exception('failed loading state dict!', state_dict_path) # parser.print_options(opt) if use_rect: test_dataset = EvalDataset(opt) else: test_dataset = EvalWPoseDataset(opt) print('test data size: ', len(test_dataset)) projection_mode = test_dataset.projection_mode opt_netG = state_dict['opt_netG'] netG = HGPIFuNetwNML(opt_netG, projection_mode).to(device=cuda) netMR = HGPIFuMRNet(opt, netG, projection_mode).to(device=cuda) def set_eval(): netG.eval() # load checkpoints netMR.load_state_dict(state_dict['model_state_dict']) os.makedirs(opt.checkpoints_path, exist_ok=True) os.makedirs(opt.results_path, exist_ok=True) os.makedirs('%s/%s/recon' % (opt.results_path, opt.name), exist_ok=True) if start_id < 0: start_id = 0 if end_id < 0: end_id = len(test_dataset) ## test with torch.no_grad(): set_eval() print('generate mesh (test) ...') for i in tqdm(range(start_id, end_id)): if i >= len(test_dataset): break # for multi-person processing, set it to False if True: test_data = test_dataset[i] save_path = '%s/%s/recon/result_%s_%d.obj' % ( opt.results_path, opt.name, test_data['name'], opt.resolution) print(save_path) gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose) else: for j in range(test_dataset.get_n_person(i)): test_dataset.person_id = j test_data = test_dataset[i] save_path = '%s/%s/recon/result_%s_%d.obj' % ( opt.results_path, opt.name, test_data['name'], j) gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose)