示例#1
0
def train_on_kind(k, envmap=True):
  Rs = []; Ts = []
  exp_imgs = []
  exp_masks = []
  for i, elev in enumerate(torch.linspace(0, 45, N_VIEWS, device=device)):
    for j, azim in enumerate(torch.linspace(-90, 90, N_VIEWS, device=device)):
      R, T = look_at_view_transform(dist=DIST, elev=elev, azim=azim, device=device)
      Rs.append(R); Ts.append(T)
      img = load_image(f"mitsuba_scenes/cbox_relight/{k}_{i:03}_{j:03}.png", (SIZE, SIZE)).to(device)
      exp_imgs.append(img[..., :3])

  nerfle = NeRFLE(envmap=envmap, device=device)

  integrator = NeRFReproduce()

  lights = PointLights(device=device, scale=10)

  surface_lr = 8e-5
  print(f"NeRF({k}, envmap={envmap}) LR is {surface_lr}")
  opt = torch.optim.AdamW([
    { 'params': nerfle.parameters(), 'lr':surface_lr, },
  ], lr=surface_lr, weight_decay=0)

  def light_update(cam, light):
    light.location = cam.get_camera_center()* 1.05

  def train_sample(
    shape,
    integrator,
    lights,
    Rs, Ts,
    exp_imgs,
    opt, size, crop_size,
    light_update,
    N=3, iters=40_000,
    num_ckpts=5, save_freq=50,
    valid_freq=250, max_valid_size=128,
    extra_loss=lambda mi, got, exp, mask: 0,
    save_fn=lambda i: None,
    name_fn=lambda i: f"outputs/train_{i:05}.png",
    valid_name_fn=lambda i: f"outputs/valid_{i:05}.png",
    uv_select=lambda crop_size: None,
    silent=False,
    bsdf=None,
  ):
示例#2
0
iters = 25_000
print(f"{dataset}, Size: {SIZE}, Iters: {iters}")

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

tfs = json.load(open(DIR + "transforms_train.json"))
exp_imgs = []
exp_masks = []
focal = 0.5 * SIZE / np.tan(0.5 * float(tfs['camera_angle_x']))
cam_to_worlds = []
with torch.no_grad():
    for frame in tfs["frames"]:
        img = load_image(os.path.join(DIR, frame['file_path'] + '.png'), resize=(SIZE, SIZE))\
          .to(device)
        exp_imgs.append(img[..., :3])
        exp_masks.append((img[..., 3] - 1e-5).ceil())
        tf_mat = torch.tensor(frame['transform_matrix'],
                              dtype=torch.float,
                              device=device)[:3, :4]
        tf_mat[:3, 3] = F.normalize(tf_mat[:3, 3], dim=-1)
        cam_to_worlds.append(tf_mat)

integrator = Direct()

shape = torch.jit.load(f"models/{dataset}_sdf_f.pt", device)
density_field = SDF(sdf=shape)
#density_field = SDF(sdf=torch.jit.script(SphereSDF(n=2<<6)))

density_field.max_steps = 64
示例#3
0
dataset = "55"
DIR = os.path.join(DIR, f"scan{dataset}")
print(f"visualize DTU, Size: {SIZE}, Scan: {dataset}")

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

num_imgs = 0
exp_imgs = []
exp_masks = []
mask_dir = os.path.join(DIR, "mask")
for f in sorted(os.listdir(mask_dir)):
    if f.startswith("._"): continue
    mask = load_image(os.path.join(mask_dir, f),
                      resize=(SIZE, SIZE)).to(device)
    num_imgs += 1
    exp_masks.append(mask.max(dim=-1)[0].ceil())

image_dir = os.path.join(DIR, "image")
for f in sorted(os.listdir(image_dir)):
    if f.startswith("._"): continue
    img = load_image(os.path.join(image_dir, f),
                     resize=(SIZE, SIZE)).to(device)
    exp_imgs.append(img)

assert (len(exp_imgs) == len(exp_masks))

tfs = np.load(os.path.join(DIR, "cameras.npz"))
Ps = [tfs[f"world_mat_{i}"] @ tfs[f"scale_mat_{i}"] for i in range(num_imgs)]
示例#4
0
def train_on_kind(k):
    Rs = []
    Ts = []
    exp_imgs = []
    exp_masks = []
    for i, elev in enumerate(torch.linspace(0, 45, N_VIEWS, device=device)):
        for j, azim in enumerate(
                torch.linspace(-90, 90, N_VIEWS, device=device)):
            R, T = look_at_view_transform(dist=DIST,
                                          elev=elev,
                                          azim=azim,
                                          device=device)
            Rs.append(R)
            Ts.append(T)
            img = load_image(
                f"mitsuba_scenes/cbox_relight/{k}_{i:03}_{j:03}.png",
                (SIZE, SIZE)).to(device)
            exp_imgs.append(img[..., :3])
            exp_masks.append(img[..., 3])

    if False:
        density_field = SDF(sdf=torch.jit.script(SphereSDF(n=2 << 5)))
    else:
        sdf = torch.jit.load(f"models/col_{k}_sdf.pt")
        density_field = SDF(sdf=sdf)
        density_field.max_steps = 64

    if True:
        learned_bsdf = ComposeSpatialVarying([
            *[NeuralBSDF() for _ in range(2)],
            Diffuse(preprocess=nn.Softplus()).random(),
            Conductor(activation=nn.Softplus(), device=device).random(),
        ])
    else:
        learned_bsdf = torch.load(f"models/col_{k}_bsdf.pt")

    integrator = Direct()

    lights = PointLights(device=device, scale=5)

    occ_mlp = SkipConnMLP(
        in_size=5,
        out=1,
        device=device,
    ).to(device)

    surface_lr = 8e-5
    bsdf_lr = 8e-5
    light_lr = 8e-5
    print(
        f"Surface LR for {k} is {surface_lr}, BSDF LR is {bsdf_lr}, L LR is {light_lr}"
    )
    opt = torch.optim.AdamW([
        {
            'params': density_field.parameters(),
            'lr': surface_lr,
        },
        {
            'params': learned_bsdf.parameters(),
            'lr': bsdf_lr,
        },
        {
            'params': lights.intensity_parameters(),
            'lr': light_lr,
        },
        {
            'params': occ_mlp.parameters(),
            'lr': 8e-5,
        },
    ],
                            lr=surface_lr,
                            weight_decay=0)

    def extra_loss(mi, got, exp, mask):
        # might need to add in something for eikonal loss over all space
        raw_n = getattr(mi, "raw_normals", None)
        loss = 0
        if raw_n is not None: loss = loss + eikonal_loss(raw_n)

        raw_w = getattr(mi, 'normalized_weights', None)
        if raw_w is not None: loss = loss + 1e-2 * raw_w.std(dim=-1).mean()

        return loss

    def light_update(cam, light):
        light.location = cam.get_camera_center() * 1.05

    losses = train_sample(
        density_field,
        bsdf=learned_bsdf,
        integrator=integrator,
        lights=lights,
        Rs=Rs,
        Ts=Ts,
        exp_imgs=exp_imgs,
        exp_masks=exp_masks,
        opt=opt,
        size=SIZE,
        crop_size=128,
        save_freq=7500,
        valid_freq=4000,
        max_valid_size=128,
        iters=iters,
        N=4,
        extra_loss=extra_loss,
        uv_select=lambda _, crop_size: rand_uv(SIZE, SIZE, crop_size),
        light_update=light_update,
        name_fn=lambda i: f"outputs/train_{k}_{i:06}.png",
        valid_name_fn=lambda i: f"outputs/valid_{k}_{i:06}.png",
        silent=True,
        really_silent=True,
        w_isect=occ_mlp,
    )

    if iters > 0:
        torch.jit.save(density_field.sdf, f"models/col_{k}_sdf.pt")
        torch.save(learned_bsdf, f"models/col_{k}_bsdf.pt")

    print("Checking train set")

    # Training set
    test(
        density_field,
        integrator=integrator,
        bsdf=learned_bsdf,
        lights=lights,
        Rs=Rs,
        Ts=Ts,
        exp_imgs=exp_imgs,
        size=SIZE,
        light_update=light_update,
        name_fn=lambda i: f"outputs/col_final_{k}_{i:03}.png",
        w_isect=True,
    )

    Rs, Ts, exp_imgs, exp_masks, xyzs = test_colocate_resources(k,
                                                                SIZE,
                                                                dist=DIST,
                                                                device=device)

    xyzs_iter = iter(xyzs)

    def light_update(_, light):
        light.location = next(xyzs_iter).unsqueeze(0)

    print("Starting test set")

    # Test set
    test(
        density_field,
        integrator=integrator,
        bsdf=learned_bsdf,
        lights=lights,
        Rs=Rs,
        Ts=Ts,
        exp_imgs=exp_imgs,
        size=SIZE,
        light_update=light_update,
        name_fn=lambda i: f"outputs/col_test_{k}_{i:03}.png",
    )