Exemplo n.º 1
0
def test():
  with torch.no_grad():
    for i, (c2w, lp) in enumerate(zip(tqdm(cam_to_worlds), light_locs)):
      exp = exp_imgs[i].clamp(min=0, max=1)
      cameras = NeRFCamera(cam_to_world=c2w.unsqueeze(0), focal=focal, device=device)
      lights = PointLights(intensity=[1,1,1], location=lp[None,...], scale=100, device=device)

      if isinstance(bsdf, ComposeSpatialVarying):
        got = pt.pathtrace(
          shape,
          size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=BasisBRDF(bsdf),
          cameras=cameras, lights=lights, device=device, silent=True,
        )[0].clamp(min=0, max=1)
        f, axes = plt.subplots(r, c)
        f.set_figheight(10)
        f.set_figwidth(10)
        got = got.unsqueeze(-1).expand(got.shape + (3,))
        wm_0 = None
        wm_1 = None
        for k, img in enumerate(got.split(1, dim=-2)):
          img = img.squeeze(-2).cpu().numpy()
          axes[unroll(k, c)].imshow(img)
          axes[unroll(k, c)].axis('off')
          if k == 0: wm_0 = img
          if k == 1: wm_1 = img
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.savefig(f"outputs/nerv_weights_{i:04}.png", bbox_inches="tight")
        plt.clf()
        plt.close(f)

        # render first two and normalize for easy figure
        f, axes = plt.subplots(2)
        f.set_figheight(10)
        f.set_figwidth(10)
        total = wm_0 + wm_1
        wm_0 = wm_0/total
        wm_1 = wm_1/total
        axes[0].imshow(wm_0)
        axes[0].axis('off')
        axes[1].imshow(wm_1)
        axes[1].axis('off')
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.savefig(f"outputs/nerv_wm01_{i:04}.png", bbox_inches="tight")
        plt.clf()
        plt.close(f)
      normals = pt.pathtrace(
        shape,
        size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=Debug(),
        cameras=cameras, lights=lights, device=device, silent=True,
      )[0]
      save_image(f"outputs/nerv_normals_{i:04}.png", normals)

      if (integrator is not None) and False:
        got = pt.pathtrace(
          shape,
          size=SIZE, chunk_size=SIZE, bundle_size=1, bsdf=bsdf, integrator=integrator,
          cameras=cameras, lights=lights, device=device, silent=True,
        )[0].clamp(min=0, max=1)
        save_image(f"outputs/got_{i:04}.png", got)
Exemplo n.º 2
0
def run_mitsuba(min_elev, max_elev, min_azim, max_azim, radius):
  for k in kinds:
    for i, elev in enumerate(np.linspace(min_elev, max_elev, PER)):
      for j, azim in enumerate(np.linspace(min_azim, max_azim, PER)):
        ox, oy, oz = elaz_to_xyz(elev, azim, radius)
        lx, ly, lz = elaz_to_xyz(elev, azim, 1.05 * radius)
        scene = get_scene(ox, oy, oz, lx, ly, lz, k)
        out = render_torch(scene, spp=32)
        for _ in range(N-1):
          out += render_torch(scene, spp=32)
        out /= N
        out = torch.cat([
          out[..., :3],
          out[..., 3].unsqueeze(-1) > 0,
        ], dim=-1)
        save_image(f"{k}_{i:03}_{j:03}.png", out)
Exemplo n.º 3
0
def test():
    with torch.no_grad():
        for i, (pose, intrinsic) in enumerate(zip(tqdm(poses), intrinsics)):
            cameras = DTUCamera(pose=pose[None, ...],
                                intrinsic=intrinsic[None, ...],
                                device=device)
            if isinstance(bsdf, ComposeSpatialVarying):
                got, _ = pt.pathtrace(
                    shape,
                    size=SIZE,
                    chunk_size=SIZE,
                    bundle_size=1,
                    bsdf=bsdf,
                    integrator=BasisBRDF(bsdf),
                    cameras=cameras,
                    lights=lights,
                    device=device,
                    silent=True,
                )
                f, axes = plt.subplots(r, c)
                f.set_figheight(10)
                f.set_figwidth(10)
                got = got.unsqueeze(-1).expand(got.shape + (3, ))
                for k, img in enumerate(got.split(1, dim=-2)):
                    img = img.squeeze(-2).cpu().numpy()
                    axes[unroll(k, c)].imshow(img)
                    axes[unroll(k, c)].axis('off')
                plt.subplots_adjust(wspace=0, hspace=0)
                plt.savefig(f"outputs/weights_{i:04}.png", bbox_inches="tight")
                plt.clf()
                plt.close(f)
            normals, _ = pt.pathtrace(
                shape,
                size=SIZE,
                chunk_size=SIZE,
                bundle_size=1,
                bsdf=bsdf,
                integrator=Debug(),
                cameras=cameras,
                lights=lights,
                device=device,
                silent=True,
                background=1,
            )
            save_image(f"outputs/normals_{i:04}.png", normals)

            if (integrator is not None):
                got = pt.pathtrace(
                    shape,
                    size=SIZE,
                    chunk_size=SIZE,
                    bundle_size=1,
                    bsdf=bsdf,
                    integrator=integrator,
                    cameras=cameras,
                    lights=lights,
                    device=device,
                    silent=True,
                    background=1,
                )[0].clamp(min=0, max=1)
                save_image(f"outputs/got_{i:04}.png", got)
Exemplo n.º 4
0
      #if ((i % ckpt_freq) == 0) and (i != 0): save_fn(i)

      if (i % valid_freq) == 0:
        with torch.no_grad():
          R = R[0].unsqueeze(0)
          T = T[0].unsqueeze(0)
          cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
          light_update(cameras, lights)
          validate, _ = pt.pathtrace(
            shape, size=size, chunk_size=min(size, max_valid_size),
            bundle_size=1,
            bsdf=bsdf, integrator=integrator,
            cameras=cameras, lights=lights, device=device, silent=True,
          )
          save_image(valid_name_fn(i), validate)
    return losses

  losses = train_sample(
    nerfle,
    integrator=integrator,
    lights=lights,
    Rs=Rs, Ts=Ts,
    exp_imgs=exp_imgs,
    opt=opt,
    size=SIZE,
    crop_size=16,
    save_freq=20000,
    valid_freq=10000,
    max_valid_size=64,
    iters=300_000,
Exemplo n.º 5
0
def train_gan(
    nerf,
    nerf_optim,
    disc,
    disc_optim,
    dataloader,
    batch_size=3,
    iters=80,
    device="cuda",
    valid_freq=250,
):
    integrator = NeRFReproduce()
    with trange(iters * len(dataloader)) as t:
        for j in range(iters):
            for i, (data, _tgt) in enumerate(dataloader):
                if data.shape[0] != batch_size:
                    t.update()
                    continue
                data = data.to(device)

                # train discriminator
                # real data:
                disc.zero_grad()

                pred = disc(data)
                label = torch.ones(batch_size, device=device)
                real_loss = F.binary_cross_entropy_with_logits(pred, label)
                real_loss.backward()
                real_loss = real_loss.item()
                # fake data:
                nerf.assign_latent(
                    torch.randn(batch_size, latent_size, device=device))
                v = random.sample(range(Rs.shape[0]), batch_size)
                R, T = Rs[v], Ts[v]
                cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
                fake = pt.pathtrace(nerf,
                                    size=64,
                                    chunk_size=8,
                                    bundle_size=1,
                                    integrator=integrator,
                                    cameras=cameras,
                                    background=1,
                                    bsdf=None,
                                    lights=None,
                                    silent=True,
                                    with_noise=False,
                                    device=device)[0].permute(0, 3, 1, 2)

                pred = disc(fake.detach().clone())
                label = torch.zeros(batch_size, device=device)
                fake_loss = F.binary_cross_entropy_with_logits(pred, label)
                fake_loss.backward()
                fake_loss = fake_loss.item()

                disc_optim.step()

                # train generator/nerf
                nerf.zero_grad()
                pred = disc(fake)
                gen_loss = F.binary_cross_entropy_with_logits(
                    pred, torch.ones_like(label))
                gen_loss = gen_loss
                gen_loss.backward()
                gen_loss = gen_loss.item()
                nerf_optim.step()

                t.set_postfix(Dreal=f"{real_loss:.05}",
                              Dfake=f"{fake_loss:.05}",
                              G=f"{gen_loss:.05}")
                t.update()

                ij = i + j * len(dataloader)
                if ij % valid_freq == 0:
                    save_image(f"outputs/gan_valid_{ij:05}.png",
                               fake[0].permute(1, 2, 0))
                    #save_image(f"outputs/ref_{ij:05}.png", data[0].permute(1,2,0))
                ...
            ...
        ...
    ...