def test(): with torch.no_grad(): for i, (c2w, lp) in enumerate(zip(tqdm(cam_to_worlds), light_locs)): exp = exp_imgs[i].clamp(min=0, max=1) cameras = NeRFCamera(cam_to_world=c2w.unsqueeze(0), focal=focal, device=device) lights = PointLights(intensity=[1,1,1], location=lp[None,...], scale=100, device=device) if isinstance(bsdf, ComposeSpatialVarying): got = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=BasisBRDF(bsdf), cameras=cameras, lights=lights, device=device, silent=True, )[0].clamp(min=0, max=1) f, axes = plt.subplots(r, c) f.set_figheight(10) f.set_figwidth(10) got = got.unsqueeze(-1).expand(got.shape + (3,)) wm_0 = None wm_1 = None for k, img in enumerate(got.split(1, dim=-2)): img = img.squeeze(-2).cpu().numpy() axes[unroll(k, c)].imshow(img) axes[unroll(k, c)].axis('off') if k == 0: wm_0 = img if k == 1: wm_1 = img plt.subplots_adjust(wspace=0, hspace=0) plt.savefig(f"outputs/nerv_weights_{i:04}.png", bbox_inches="tight") plt.clf() plt.close(f) # render first two and normalize for easy figure f, axes = plt.subplots(2) f.set_figheight(10) f.set_figwidth(10) total = wm_0 + wm_1 wm_0 = wm_0/total wm_1 = wm_1/total axes[0].imshow(wm_0) axes[0].axis('off') axes[1].imshow(wm_1) axes[1].axis('off') plt.subplots_adjust(wspace=0, hspace=0) plt.savefig(f"outputs/nerv_wm01_{i:04}.png", bbox_inches="tight") plt.clf() plt.close(f) normals = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=Debug(), cameras=cameras, lights=lights, device=device, silent=True, )[0] save_image(f"outputs/nerv_normals_{i:04}.png", normals) if (integrator is not None) and False: got = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=integrator, cameras=cameras, lights=lights, device=device, silent=True, )[0].clamp(min=0, max=1) save_image(f"outputs/got_{i:04}.png", got)
def run_mitsuba(min_elev, max_elev, min_azim, max_azim, radius): for k in kinds: for i, elev in enumerate(np.linspace(min_elev, max_elev, PER)): for j, azim in enumerate(np.linspace(min_azim, max_azim, PER)): ox, oy, oz = elaz_to_xyz(elev, azim, radius) lx, ly, lz = elaz_to_xyz(elev, azim, 1.05 * radius) scene = get_scene(ox, oy, oz, lx, ly, lz, k) out = render_torch(scene, spp=32) for _ in range(N-1): out += render_torch(scene, spp=32) out /= N out = torch.cat([ out[..., :3], out[..., 3].unsqueeze(-1) > 0, ], dim=-1) save_image(f"{k}_{i:03}_{j:03}.png", out)
def test(): with torch.no_grad(): for i, (pose, intrinsic) in enumerate(zip(tqdm(poses), intrinsics)): cameras = DTUCamera(pose=pose[None, ...], intrinsic=intrinsic[None, ...], device=device) if isinstance(bsdf, ComposeSpatialVarying): got, _ = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=BasisBRDF(bsdf), cameras=cameras, lights=lights, device=device, silent=True, ) f, axes = plt.subplots(r, c) f.set_figheight(10) f.set_figwidth(10) got = got.unsqueeze(-1).expand(got.shape + (3, )) for k, img in enumerate(got.split(1, dim=-2)): img = img.squeeze(-2).cpu().numpy() axes[unroll(k, c)].imshow(img) axes[unroll(k, c)].axis('off') plt.subplots_adjust(wspace=0, hspace=0) plt.savefig(f"outputs/weights_{i:04}.png", bbox_inches="tight") plt.clf() plt.close(f) normals, _ = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=Debug(), cameras=cameras, lights=lights, device=device, silent=True, background=1, ) save_image(f"outputs/normals_{i:04}.png", normals) if (integrator is not None): got = pt.pathtrace( shape, size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=integrator, cameras=cameras, lights=lights, device=device, silent=True, background=1, )[0].clamp(min=0, max=1) save_image(f"outputs/got_{i:04}.png", got)
#if ((i % ckpt_freq) == 0) and (i != 0): save_fn(i) if (i % valid_freq) == 0: with torch.no_grad(): R = R[0].unsqueeze(0) T = T[0].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) light_update(cameras, lights) validate, _ = pt.pathtrace( shape, size=size, chunk_size=min(size, max_valid_size), bundle_size=1, bsdf=bsdf, integrator=integrator, cameras=cameras, lights=lights, device=device, silent=True, ) save_image(valid_name_fn(i), validate) return losses losses = train_sample( nerfle, integrator=integrator, lights=lights, Rs=Rs, Ts=Ts, exp_imgs=exp_imgs, opt=opt, size=SIZE, crop_size=16, save_freq=20000, valid_freq=10000, max_valid_size=64, iters=300_000,
def train_gan( nerf, nerf_optim, disc, disc_optim, dataloader, batch_size=3, iters=80, device="cuda", valid_freq=250, ): integrator = NeRFReproduce() with trange(iters * len(dataloader)) as t: for j in range(iters): for i, (data, _tgt) in enumerate(dataloader): if data.shape[0] != batch_size: t.update() continue data = data.to(device) # train discriminator # real data: disc.zero_grad() pred = disc(data) label = torch.ones(batch_size, device=device) real_loss = F.binary_cross_entropy_with_logits(pred, label) real_loss.backward() real_loss = real_loss.item() # fake data: nerf.assign_latent( torch.randn(batch_size, latent_size, device=device)) v = random.sample(range(Rs.shape[0]), batch_size) R, T = Rs[v], Ts[v] cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) fake = pt.pathtrace(nerf, size=64, chunk_size=8, bundle_size=1, integrator=integrator, cameras=cameras, background=1, bsdf=None, lights=None, silent=True, with_noise=False, device=device)[0].permute(0, 3, 1, 2) pred = disc(fake.detach().clone()) label = torch.zeros(batch_size, device=device) fake_loss = F.binary_cross_entropy_with_logits(pred, label) fake_loss.backward() fake_loss = fake_loss.item() disc_optim.step() # train generator/nerf nerf.zero_grad() pred = disc(fake) gen_loss = F.binary_cross_entropy_with_logits( pred, torch.ones_like(label)) gen_loss = gen_loss gen_loss.backward() gen_loss = gen_loss.item() nerf_optim.step() t.set_postfix(Dreal=f"{real_loss:.05}", Dfake=f"{fake_loss:.05}", G=f"{gen_loss:.05}") t.update() ij = i + j * len(dataloader) if ij % valid_freq == 0: save_image(f"outputs/gan_valid_{ij:05}.png", fake[0].permute(1, 2, 0)) #save_image(f"outputs/ref_{ij:05}.png", data[0].permute(1,2,0)) ... ... ... ...