Exemplo n.º 1
0
def report_result(args, model, imgs, poses, hwf, bound):
    """
    report view-synthesis result on heldout views
    """
    ray_origins, ray_directions = get_rays_shapenet(hwf, poses)

    view_psnrs = []
    for img, rays_o, rays_d in zip(imgs, ray_origins, ray_directions):
        rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
        t_vals, xyz = sample_points(rays_o,
                                    rays_d,
                                    bound[0],
                                    bound[1],
                                    args.num_samples,
                                    perturb=False)

        synth = []
        num_rays = rays_d.shape[0]
        with torch.no_grad():
            for i in range(0, num_rays, args.test_batchsize):
                rgbs_batch, sigmas_batch = model(xyz[i:i +
                                                     args.test_batchsize])
                color_batch = volume_render(rgbs_batch,
                                            sigmas_batch,
                                            t_vals[i:i + args.test_batchsize],
                                            white_bkgd=True)
                synth.append(color_batch)
            synth = torch.cat(synth, dim=0).reshape_as(img)
            error = F.mse_loss(img, synth)
            psnr = -10 * torch.log10(error)
            view_psnrs.append(psnr)

    scene_psnr = torch.stack(view_psnrs).mean()
    return scene_psnr
Exemplo n.º 2
0
def test_time_optimize(args, model, optim, imgs, poses, hwf, bound):
    """
    test-time-optimize the meta trained model on available views
    """
    pixels = imgs.reshape(-1, 3)

    rays_o, rays_d = get_rays_shapenet(hwf, poses)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    num_rays = rays_d.shape[0]
    for step in range(args.tto_steps):
        indices = torch.randint(num_rays, size=[args.tto_batchsize])
        raybatch_o, raybatch_d = rays_o[indices], rays_d[indices]
        pixelbatch = pixels[indices]
        t_vals, xyz = sample_points(raybatch_o,
                                    raybatch_d,
                                    bound[0],
                                    bound[1],
                                    args.num_samples,
                                    perturb=True)

        optim.zero_grad()
        rgbs, sigmas = model(xyz)
        colors = volume_render(rgbs, sigmas, t_vals, white_bkgd=True)
        loss = F.mse_loss(colors, pixelbatch)
        loss.backward()
        optim.step()
def inner_loop(model, optim, imgs, poses, hwf, bound, num_samples,
               raybatch_size, inner_steps):
    """
    train the inner model for a specified number of iterations
    """
    pixels = imgs.reshape(-1, 3)

    rays_o, rays_d = get_rays_shapenet(hwf, poses)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    num_rays = rays_d.shape[0]
    for step in range(inner_steps):
        indices = torch.randint(num_rays, size=[raybatch_size])
        raybatch_o, raybatch_d = rays_o[indices], rays_d[indices]
        pixelbatch = pixels[indices]
        t_vals, xyz = sample_points(raybatch_o,
                                    raybatch_d,
                                    bound[0],
                                    bound[1],
                                    num_samples,
                                    perturb=True)

        optim.zero_grad()
        rgbs, sigmas = model(xyz)
        colors = volume_render(rgbs, sigmas, t_vals, white_bkgd=True)
        loss = F.mse_loss(colors, pixelbatch)
        loss.backward()
        optim.step()
Exemplo n.º 4
0
def test_time_optimize(args, model, meta_state_dict, tto_view):
    """
    quicky optimize the meta trained model to a target appearance
    and return the corresponding network weights
    """

    model.load_state_dict(meta_state_dict)
    optim = torch.optim.SGD(model.parameters(), args.tto_lr)

    pixels = tto_view['img'].reshape(-1, 3)
    rays_o, rays_d = get_rays_tourism(tto_view['H'], tto_view['W'],
                                      tto_view['kinv'], tto_view['pose'])
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    num_rays = rays_d.shape[0]
    for step in range(args.tto_steps):
        indices = torch.randint(num_rays, size=[args.tto_batchsize])
        raybatch_o, raybatch_d = rays_o[indices], rays_d[indices]
        pixelbatch = pixels[indices]
        t_vals, xyz = sample_points(raybatch_o,
                                    raybatch_d,
                                    tto_view['bound'][0],
                                    tto_view['bound'][1],
                                    args.num_samples,
                                    perturb=True)

        optim.zero_grad()
        rgbs, sigmas = model(xyz)
        colors = volume_render(rgbs, sigmas, t_vals)
        loss = F.mse_loss(colors, pixelbatch)
        loss.backward()
        optim.step()

    state_dict = copy.deepcopy(model.state_dict())
    return state_dict
Exemplo n.º 5
0
def create_360_video(args, model, hwf, bound, device, scene_id, savedir):
    """
    create 360 video of a specific shape
    """
    video_frames = []
    poses_360 = get_360_poses(args.radius).to(device)
    ray_origins, ray_directions = get_rays_shapenet(hwf, poses_360)

    for rays_o, rays_d in zip(ray_origins, ray_directions):
        rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
        t_vals, xyz = sample_points(rays_o,
                                    rays_d,
                                    bound[0],
                                    bound[1],
                                    args.num_samples,
                                    perturb=False)

        synth = []
        num_rays = rays_d.shape[0]
        with torch.no_grad():
            for i in range(0, num_rays, args.test_batchsize):
                rgbs_batch, sigmas_batch = model(xyz[i:i +
                                                     args.test_batchsize])
                color_batch = volume_render(rgbs_batch,
                                            sigmas_batch,
                                            t_vals[i:i + args.test_batchsize],
                                            white_bkgd=True)
                synth.append(color_batch)
            synth = torch.cat(synth, dim=0).reshape(int(hwf[0]), int(hwf[1]),
                                                    3)
            synth = torch.clip(synth, min=0, max=1)
            synth = (255 * synth).to(torch.uint8)
            video_frames.append(synth)
    video_frames = torch.stack(video_frames, dim=0)
    video_frames = video_frames.cpu().numpy()

    video_path = savedir.joinpath(f"{scene_id}.mp4")
    imageio.mimwrite(video_path, video_frames, fps=30)

    return None
Exemplo n.º 6
0
def report_result(model, img, rays_o, rays_d, bound, num_samples, raybatch_size):
    """
    report synthesis result on heldout view
    """
    pixels = img.reshape(-1, 3)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    t_vals, xyz = sample_points(rays_o, rays_d, bound[0], bound[1],
                                num_samples, perturb=False)
    
    synth = []
    num_rays = rays_d.shape[0]
    with torch.no_grad():
        for i in range(0, num_rays, raybatch_size):
            rgbs_batch, sigmas_batch = model(xyz[i:i+raybatch_size])
            color_batch = volume_render(rgbs_batch, sigmas_batch, t_vals[i:i+raybatch_size])
            synth.append(color_batch)
        synth = torch.cat(synth, dim=0)
        error = F.mse_loss(synth, pixels)
        psnr = -10*torch.log10(error)
    
    return psnr
Exemplo n.º 7
0
def synthesize_view(args, model, H, W, kinv, pose, bound):
    """
    given camera intrinsics and camera pose, synthesize a novel view
    """
    rays_o, rays_d = get_rays_tourism(H, W, kinv, pose)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

    t_vals, xyz = sample_points(rays_o,
                                rays_d,
                                bound[0],
                                bound[1],
                                args.num_samples,
                                perturb=False)

    synth = []
    num_rays = rays_d.shape[0]
    with torch.no_grad():
        for i in range(0, num_rays, args.test_batchsize):
            rgbs_batch, sigmas_batch = model(xyz[i:i + args.test_batchsize])
            color_batch = volume_render(rgbs_batch, sigmas_batch,
                                        t_vals[i:i + args.test_batchsize])
            synth.append(color_batch)
        synth = torch.cat(synth, dim=0).reshape(H, W, 3)
    return synth