def test_time_optimize(args, model, meta_state_dict, tto_view): """ quicky optimize the meta trained model to a target appearance and return the corresponding network weights """ model.load_state_dict(meta_state_dict) optim = torch.optim.SGD(model.parameters(), args.tto_lr) pixels = tto_view['img'].reshape(-1, 3) rays_o, rays_d = get_rays_tourism(tto_view['H'], tto_view['W'], tto_view['kinv'], tto_view['pose']) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) num_rays = rays_d.shape[0] for step in range(args.tto_steps): indices = torch.randint(num_rays, size=[args.tto_batchsize]) raybatch_o, raybatch_d = rays_o[indices], rays_d[indices] pixelbatch = pixels[indices] t_vals, xyz = sample_points(raybatch_o, raybatch_d, tto_view['bound'][0], tto_view['bound'][1], args.num_samples, perturb=True) optim.zero_grad() rgbs, sigmas = model(xyz) colors = volume_render(rgbs, sigmas, t_vals) loss = F.mse_loss(colors, pixelbatch) loss.backward() optim.step() state_dict = copy.deepcopy(model.state_dict()) return state_dict
def val_meta(args, model, val_loader, device): """ validate the meta trained model for phototourism """ meta_trained_state = model.state_dict() val_model = copy.deepcopy(model) val_psnrs = [] for img, pose, kinv, bound in val_loader: img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device) img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze() rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose) # optimize on the left half, test on the right half left_width = img.shape[1]//2 right_width = img.shape[1] - left_width tto_img, test_img = torch.split(img, [left_width, right_width], dim=1) tto_rays_o, test_rays_o = torch.split(rays_o, [left_width, right_width], dim=1) tto_rays_d, test_rays_d = torch.split(rays_d, [left_width, right_width], dim=1) val_model.load_state_dict(meta_trained_state) val_optim = torch.optim.SGD(val_model.parameters(), args.inner_lr) inner_loop(val_model, val_optim, tto_img, tto_rays_o, tto_rays_d, bound, args.num_samples, args.train_batchsize, args.inner_steps) psnr = report_result(val_model, test_img, test_rays_o, test_rays_d, bound, args.num_samples, args.test_batchsize) val_psnrs.append(psnr) val_psnr = torch.stack(val_psnrs).mean() return val_psnr
def test(): parser = argparse.ArgumentParser(description='phototourism with meta-learning') parser.add_argument('--config', type=str, required=True, help='config file for the scene') parser.add_argument('--weight-path', type=str, required=True, help='path to the meta-trained weight file') args = parser.parse_args() with open(args.config) as config: info = json.load(config) for key, value in info.items(): args.__dict__[key] = value device = torch.device("cuda" if torch.cuda.is_available() else "cpu") test_set = build_tourism(image_set="test", args=args) test_loader = DataLoader(test_set, batch_size=1, shuffle=False) model = build_nerf(args) model.to(device) checkpoint = torch.load(args.weight_path, map_location=device) meta_state_dict = checkpoint['meta_model_state_dict'] test_psnrs = [] for idx, (img, pose, kinv, bound) in enumerate(test_loader): img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device) img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze() rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose) # optimize on the left half, test on the right half left_width = img.shape[1]//2 right_width = img.shape[1] - left_width tto_img, test_img = torch.split(img, [left_width, right_width], dim=1) tto_rays_o, test_rays_o = torch.split(rays_o, [left_width, right_width], dim=1) tto_rays_d, test_rays_d = torch.split(rays_d, [left_width, right_width], dim=1) model.load_state_dict(meta_state_dict) optim = torch.optim.SGD(model.parameters(), args.tto_lr) inner_loop(model, optim, tto_img, tto_rays_o, tto_rays_d, bound, args.num_samples, args.tto_batchsize, args.tto_steps) psnr = report_result(model, test_img, test_rays_o, test_rays_d, bound, args.num_samples, args.test_batchsize) print(f"test view {idx+1}, psnr:{psnr:.3f}") test_psnrs.append(psnr) test_psnrs = torch.stack(test_psnrs) print("----------------------------------") print(f"test dataset mean psnr: {test_psnrs.mean():.3f}") print("\ncreating interpolation video ...\n") create_interpolation_video(args, model, meta_state_dict, test_set, device) print("\ninterpolation video created!")
def train_meta(args, meta_model, meta_optim, data_loader, device): """ train the meta_model for one epoch using reptile meta learning https://arxiv.org/abs/1803.02999 """ for img, pose, kinv, bound in data_loader: img, pose, kinv, bound = img.to(device), pose.to(device), kinv.to(device), bound.to(device) img, pose, kinv, bound = img.squeeze(), pose.squeeze(), kinv.squeeze(), bound.squeeze() rays_o, rays_d = get_rays_tourism(img.shape[0], img.shape[1], kinv, pose) meta_optim.zero_grad() inner_model = copy.deepcopy(meta_model) inner_optim = torch.optim.SGD(inner_model.parameters(), args.inner_lr) inner_loop(inner_model, inner_optim, img, rays_o, rays_d, bound, args.num_samples, args.train_batchsize, args.inner_steps) with torch.no_grad(): for meta_param, inner_param in zip(meta_model.parameters(), inner_model.parameters()): meta_param.grad = meta_param - inner_param meta_optim.step()
def synthesize_view(args, model, H, W, kinv, pose, bound): """ given camera intrinsics and camera pose, synthesize a novel view """ rays_o, rays_d = get_rays_tourism(H, W, kinv, pose) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1], args.num_samples, perturb=False) synth = [] num_rays = rays_d.shape[0] with torch.no_grad(): for i in range(0, num_rays, args.test_batchsize): rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize]) color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i + args.test_batchsize]) synth.append(color_batch) synth = torch.cat(synth, dim=0).reshape(H, W, 3) return synth