def report_result(args, model, imgs, poses, hwf, bound): """ report view-synthesis result on heldout views """ ray_origins, ray_directions = get_rays_shapenet(hwf, poses) view_psnrs = [] for img, rays_o, rays_d in zip(imgs, ray_origins, ray_directions): rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1], args.num_samples, perturb=False) synth = [] num_rays = rays_d.shape[0] with torch.no_grad(): for i in range(0, num_rays, args.test_batchsize): rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize]) color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i + args.test_batchsize], white_bkgd=True) synth.append(color_batch) synth = torch.cat(synth, dim=0).reshape_as(img) error = F.mse_loss(img, synth) psnr = -10 * torch.log10(error) view_psnrs.append(psnr) scene_psnr = torch.stack(view_psnrs).mean() return scene_psnr
def test_time_optimize(args, model, optim, imgs, poses, hwf, bound): """ test-time-optimize the meta trained model on available views """ pixels = imgs.reshape(-1, 3) rays_o, rays_d = get_rays_shapenet(hwf, poses) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) num_rays = rays_d.shape[0] for step in range(args.tto_steps): indices = torch.randint(num_rays, size=[args.tto_batchsize]) raybatch_o, raybatch_d = rays_o[indices], rays_d[indices] pixelbatch = pixels[indices] t_vals, xyz = sample_points(raybatch_o, raybatch_d, bound[0], bound[1], args.num_samples, perturb=True) optim.zero_grad() rgbs, sigmas = model(xyz) colors = volume_render(rgbs, sigmas, t_vals, white_bkgd=True) loss = F.mse_loss(colors, pixelbatch) loss.backward() optim.step()
def inner_loop(model, optim, imgs, poses, hwf, bound, num_samples, raybatch_size, inner_steps): """ train the inner model for a specified number of iterations """ pixels = imgs.reshape(-1, 3) rays_o, rays_d = get_rays_shapenet(hwf, poses) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) num_rays = rays_d.shape[0] for step in range(inner_steps): indices = torch.randint(num_rays, size=[raybatch_size]) raybatch_o, raybatch_d = rays_o[indices], rays_d[indices] pixelbatch = pixels[indices] t_vals, xyz = sample_points(raybatch_o, raybatch_d, bound[0], bound[1], num_samples, perturb=True) optim.zero_grad() rgbs, sigmas = model(xyz) colors = volume_render(rgbs, sigmas, t_vals, white_bkgd=True) loss = F.mse_loss(colors, pixelbatch) loss.backward() optim.step()
def test_time_optimize(args, model, meta_state_dict, tto_view): """ quicky optimize the meta trained model to a target appearance and return the corresponding network weights """ model.load_state_dict(meta_state_dict) optim = torch.optim.SGD(model.parameters(), args.tto_lr) pixels = tto_view['img'].reshape(-1, 3) rays_o, rays_d = get_rays_tourism(tto_view['H'], tto_view['W'], tto_view['kinv'], tto_view['pose']) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) num_rays = rays_d.shape[0] for step in range(args.tto_steps): indices = torch.randint(num_rays, size=[args.tto_batchsize]) raybatch_o, raybatch_d = rays_o[indices], rays_d[indices] pixelbatch = pixels[indices] t_vals, xyz = sample_points(raybatch_o, raybatch_d, tto_view['bound'][0], tto_view['bound'][1], args.num_samples, perturb=True) optim.zero_grad() rgbs, sigmas = model(xyz) colors = volume_render(rgbs, sigmas, t_vals) loss = F.mse_loss(colors, pixelbatch) loss.backward() optim.step() state_dict = copy.deepcopy(model.state_dict()) return state_dict
def create_360_video(args, model, hwf, bound, device, scene_id, savedir): """ create 360 video of a specific shape """ video_frames = [] poses_360 = get_360_poses(args.radius).to(device) ray_origins, ray_directions = get_rays_shapenet(hwf, poses_360) for rays_o, rays_d in zip(ray_origins, ray_directions): rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1], args.num_samples, perturb=False) synth = [] num_rays = rays_d.shape[0] with torch.no_grad(): for i in range(0, num_rays, args.test_batchsize): rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize]) color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i + args.test_batchsize], white_bkgd=True) synth.append(color_batch) synth = torch.cat(synth, dim=0).reshape(int(hwf[0]), int(hwf[1]), 3) synth = torch.clip(synth, min=0, max=1) synth = (255 * synth).to(torch.uint8) video_frames.append(synth) video_frames = torch.stack(video_frames, dim=0) video_frames = video_frames.cpu().numpy() video_path = savedir.joinpath(f"{scene_id}.mp4") imageio.mimwrite(video_path, video_frames, fps=30) return None
def report_result(model, img, rays_o, rays_d, bound, num_samples, raybatch_size): """ report synthesis result on heldout view """ pixels = img.reshape(-1, 3) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1], num_samples, perturb=False) synth = [] num_rays = rays_d.shape[0] with torch.no_grad(): for i in range(0, num_rays, raybatch_size): rgbs_batch, sigmas_batch = model(xyz[i:i+raybatch_size]) color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i+raybatch_size]) synth.append(color_batch) synth = torch.cat(synth, dim=0) error = F.mse_loss(synth, pixels) psnr = -10*torch.log10(error) return psnr
def synthesize_view(args, model, H, W, kinv, pose, bound): """ given camera intrinsics and camera pose, synthesize a novel view """ rays_o, rays_d = get_rays_tourism(H, W, kinv, pose) rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1], args.num_samples, perturb=False) synth = [] num_rays = rays_d.shape[0] with torch.no_grad(): for i in range(0, num_rays, args.test_batchsize): rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize]) color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i + args.test_batchsize]) synth.append(color_batch) synth = torch.cat(synth, dim=0).reshape(H, W, 3) return synth