Esempio n. 1
0
    def loss_fn(f_rng, cp, fp, image_id):
        H, W, focal = (
            intrinsics["train"].height,
            intrinsics["train"].width,
            intrinsics["train"].focal_length,
        )

        ray_origins, ray_directions, target_s = sampler(
            images["train"][image_id],
            poses["train"][image_id],
            intrinsics["train"],
            f_rng[0],
            config.dataset.sampler,
        )

        _, rendered_images = run_one_iter_of_nerf(
            H,
            W,
            focal,
            functools.partial(model_coarse.apply, cp),
            functools.partial(model_fine.apply, fp),
            ray_origins,
            ray_directions,
            config.nerf.train,
            config.nerf.model,
            config.dataset.projection,
            f_rng[1],
            False,
        )

        rgb_coarse, _, _, rgb_fine, _, _ = (
            rendered_images[..., :3],
            rendered_images[..., 3:4],
            rendered_images[..., 4:5],
            rendered_images[..., 5:8],
            rendered_images[..., 8:9],
            rendered_images[..., 9:10],
        )

        coarse_loss = jnp.mean(
            ((target_s[..., :3] - rgb_coarse)**2.0).flatten())
        loss = coarse_loss
        if config.nerf.train.num_fine > 0:
            fine_loss = jnp.mean(
                ((target_s[..., :3] - rgb_fine)**2.0).flatten())
            loss = loss + fine_loss
        return loss, Losses(coarse_loss=coarse_loss, fine_loss=fine_loss)
Esempio n. 2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # Load dataset
    if cfg.dataset.type.lower() == "llff":
        images, poses, _, _, i_test = load_llff_data(
            cfg.dataset.basedir, factor=cfg.dataset.downsample_factor)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        if not isinstance(i_test, list):
            i_test = [i_test]
        if cfg.dataset.llffhold > 0:
            i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold]
        i_val = i_test
        i_train = np.array([
            i for i in np.arange(images.shape[0])
            if (i not in i_test and i not in i_val)
        ])
        H, W, focal = hwf
        H, W = int(H), int(W)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)
        USE_HR_LR = False

        # Load LR images
        if hasattr(cfg.dataset, "relative_lr_factor"):
            assert hasattr(cfg.dataset, "hr_fps")
            assert hasattr(cfg.dataset, "hr_frequency")
            USE_HR_LR = True
            images_lr, poses_lr, _, _, i_test = load_llff_data(
                cfg.dataset.basedir,
                factor=cfg.dataset.downsample_factor *
                cfg.dataset.relative_lr_factor)
            hwf = poses_lr[0, :3, -1]
            poses_lr = poses_lr[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold]
            i_train_lr = np.array([
                i for i in np.arange(images_lr.shape[0]) if (i not in i_test)
            ])
            H_lr, W_lr, focal_lr = hwf
            H_lr, W_lr = int(H_lr), int(W_lr)
            images_lr = torch.from_numpy(images_lr)
            poses_lr = torch.from_numpy(poses_lr)

            # Expose only some HR images
            i_train = i_train[::cfg.dataset.hr_fps]

            print(f'LR summary: N={i_train_lr.shape[0]}, '
                  f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}')
            print(f'HR summary: N={i_train.shape[0]}, '
                  f'resolution={H}x{W}, focal-length={focal}')

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    device = f"cuda:{cfg.experiment.gpu_id}"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse and fine resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    model_fine = getattr(models, cfg.models.fine.type)(
        num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
        include_input_xyz=cfg.models.fine.include_input_xyz,
        include_input_dir=cfg.models.fine.include_input_dir,
        use_viewdirs=cfg.models.fine.use_viewdirs,
    )
    model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters()) + list(
        model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters,
                                                         lr=cfg.optimizer.lr)

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        model_fine.train()

        if USE_HR_LR and i % cfg.dataset.hr_frequency != 0:
            H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr
            i_train_iter = i_train_lr
            images_iter, poses_iter = images_lr, poses_lr
        else:
            H_iter, W_iter, focal_iter = H, W, focal
            i_train_iter = i_train
            images_iter, poses_iter = images, poses

        img_idx = np.random.choice(i_train_iter)
        img_target = images_iter[img_idx].to(device)
        pose_target = poses_iter[img_idx, :3, :4].to(device)
        ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter,
                                                     focal_iter, pose_target)
        coords = torch.stack(
            meshgrid_xy(
                torch.arange(H_iter).to(device),
                torch.arange(W_iter).to(device)),
            dim=-1,
        )
        coords = coords.reshape((-1, 2))
        select_inds = np.random.choice(coords.shape[0],
                                       size=(cfg.nerf.train.num_random_rays),
                                       replace=False)
        select_inds = coords[select_inds]
        ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
        ray_directions = ray_directions[select_inds[:, 0], select_inds[:,
                                                                       1], :]
        target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

        rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
            H_iter,
            W_iter,
            focal_iter,
            model_coarse,
            model_fine,
            ray_origins,
            ray_directions,
            cfg,
            mode="train",
            encode_position_fn=encode_position_fn,
            encode_direction_fn=encode_direction_fn,
        )
        target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3],
                                                 target_ray_values[..., :3])
        loss = coarse_loss + fine_loss
        loss.backward()
        psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " PSNR: " + str(psnr))
        writer.add_scalar("train/loss", loss.item(), i)
        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1):
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            model_fine.eval()

            start = time.time()
            with torch.no_grad():
                img_idx = np.random.choice(i_val)
                img_target = images[img_idx].to(device)
                pose_target = poses[img_idx, :3, :4].to(device)
                ray_origins, ray_directions = get_ray_bundle(
                    H, W, focal, pose_target)
                rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                    H,
                    W,
                    focal,
                    model_coarse,
                    model_fine,
                    ray_origins,
                    ray_directions,
                    cfg,
                    mode="validation",
                    encode_position_fn=encode_position_fn,
                    encode_direction_fn=encode_direction_fn,
                )
                target_ray_values = img_target
                coarse_loss = img2mse(rgb_coarse[..., :3],
                                      target_ray_values[..., :3])
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = coarse_loss + fine_loss
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validation/fine_loss", fine_loss.item(), i)
                writer.add_scalar("validataion/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                writer.add_image("validation/rgb_fine",
                                 cast_to_image(rgb_fine[..., :3]), i)
                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter": i,
                "model_coarse_state_dict": model_coarse.state_dict(),
                "model_fine_state_dict": model_fine.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "psnr": psnr,
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--savedir",
                        type=str,
                        default='./renders/',
                        help="Save images to this directory, if specified.")
    parser.add_argument("--save-disparity-image",
                        action="store_true",
                        help="Save disparity images too.")
    parser.add_argument("--save-error-image",
                        action="store_true",
                        help="Save photometric error visualization")
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        # Load blender dataset
        images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
            test=True)
        #i_train, i_val, i_test = i_split
        i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
    elif cfg.dataset.type.lower() == "llff":
        # Load LLFF dataset
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True)
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
            num_layers=cfg.models.coarse.num_layers,
            hidden_size=cfg.models.coarse.hidden_size,
            include_expression=True)
        model_fine.to(device)

    checkpoint = torch.load(configargs.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]
    if "background" in checkpoint.keys():
        background = checkpoint["background"]
        if background is not None:
            print("loaded background with shape ", background.shape)
            background.to(device)
    if "latent_codes" in checkpoint.keys():
        latent_codes = checkpoint["latent_codes"]
        use_latent_code = False
        if latent_codes is not None:
            use_latent_code = True
            latent_codes.to(device)
            print("loading index map for latent codes...")
            idx_map = np.load(cfg.dataset.basedir +
                              "/index_map.npy").astype(int)
            print("loaded latent codes from checkpoint, with shape ",
                  latent_codes.shape)
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    replace_background = True
    if replace_background:
        from PIL import Image
        #background = Image.open('./view.png')
        background = Image.open(cfg.dataset.basedir + '/bg/00050.png')
        #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png')
        background.thumbnail((H, W))
        background = torch.from_numpy(
            np.array(background).astype(float)).to(device)
        background = background / 255
        print('loaded custom background of shape', background.shape)

        #background = torch.ones_like(background)
        #background.permute(2,0,1)

    render_poses = render_poses.float().to(device)

    # Create directory to save images to.
    os.makedirs(configargs.savedir, exist_ok=True)
    if configargs.save_disparity_image:
        os.makedirs(os.path.join(configargs.savedir, "disparity"),
                    exist_ok=True)
    if configargs.save_error_image:
        os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True)
    os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True)
    # Evaluation loop
    times_per_image = []

    #render_poses = render_poses.float().to(device)
    render_poses = poses[i_test].float().to(device)
    #expressions = torch.arange(-6,6,0.5).float().to(device)
    render_expressions = expressions[i_test].float().to(device)
    #avg_img = torch.mean(images[i_train],axis=0)
    #avg_img = torch.ones_like(avg_img)

    #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
    #for i, pose in enumerate(tqdm(render_poses)):
    index_of_image_after_train_shuffle = 0
    # render_expressions = render_expressions[[300]] ### TODO render specific expression

    #######################
    no_background = False
    no_expressions = False
    no_lcode = False
    nerf = False
    frontalize = False
    interpolate_mouth = False

    #######################
    if nerf:
        no_background = True
        no_expressions = True
        no_lcode = True
    if no_background: background = None
    if no_expressions:
        render_expressions = torch.zeros_like(render_expressions,
                                              device=render_expressions.device)
    if no_lcode:
        use_latent_code = True
        latent_codes = torch.zeros(5000, 32, device=device)

    for i, expression in enumerate(tqdm(render_expressions)):
        #for i in range(75,151):

        #if i%25 != 0: ### TODO generate only every 25th im
        #if i != 511: ### TODO generate only every 25th im
        #    continue
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = render_poses[i]

            if interpolate_mouth:
                frame_id = 241
                num_images = 150
                pose = render_poses[241]
                expression = render_expressions[241].clone()
                expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i]

            if frontalize:
                pose = render_poses[0]
            #pose = render_poses[300] ### TODO fixes pose
            #expression = render_expressions[0] ### TODO fixes expr
            #expression = torch.zeros_like(expression).to(device)

            ablate = 'view_dir'

            if ablate == 'expression':
                pose = render_poses[100]
            elif ablate == 'latent_code':
                pose = render_poses[100]
                expression = render_expressions[100]
                if idx_map[100 + i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[100 + i, 1]
            elif ablate == 'view_dir':
                pose = render_poses[100]
                expression = render_expressions[100]
                _, ray_directions_ablation = get_ray_bundle(
                    hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4])

            pose = pose[:3, :4]

            #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]))
            if use_latent_code:
                if idx_map[i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[i, 1]
            #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code
            #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei
            index_of_image_after_train_shuffle = idx_map[
                10, 1]  ## TODO Fixes latent code - USE THIS if not ablating!

            latent_code = latent_codes[index_of_image_after_train_shuffle].to(
                device) if use_latent_code else None

            #latent_code = torch.mean(latent_codes)
            ray_origins, ray_directions = get_ray_bundle(
                hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expression,
                background_prior=background.view(-1, 3) if
                (background is not None) else None,
                #background_prior = torch.ones_like(background).view(-1,3),  # White background
                latent_code=latent_code,
                ray_directions_ablation=ray_directions_ablation)
            rgb = rgb_fine if rgb_fine is not None else rgb_coarse
            normals = torch_normal_map(disp_fine, focal, weights, clean=True)
            #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
            save_plt_image(
                normals.cpu().numpy().astype('uint8'),
                os.path.join(configargs.savedir, 'normals', f"{i:04d}.png"))
            #if configargs.save_disparity_image:
            if False:
                disp = disp_fine if disp_fine is not None else disp_coarse
                #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy())
                normals = normal_map_from_depth_map_backproject(
                    disp_fine.cpu().numpy())
                save_plt_image(
                    normals.astype('uint8'),
                    os.path.join(configargs.savedir, 'normals',
                                 f"{i:04d}.png"))

            #if configargs.save_normal_image:
            #    normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
        #rgb[torch.where(weights>0.25)]=1.0
        #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights))
        times_per_image.append(time.time() - start)
        if configargs.savedir:
            savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
            imageio.imwrite(
                savefile, cast_to_image(rgb[..., :3],
                                        cfg.dataset.type.lower()))
            if configargs.save_disparity_image:
                savefile = os.path.join(configargs.savedir, "disparity",
                                        f"{i:04d}.png")
                imageio.imwrite(savefile, cast_to_disparity_image(disp_fine))
            if configargs.save_error_image:
                savefile = os.path.join(configargs.savedir, "error",
                                        f"{i:04d}.png")
                GT = images[i_test][i]
                fig = error_image(GT, rgb.cpu().numpy())
                #imageio.imwrite(savefile, cast_to_disparity_image(disp))
                plt.savefig(savefile,
                            pad_inches=0,
                            bbox_inches='tight',
                            dpi=54)
        tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
Esempio n. 4
0
def main():

    # Config options:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="Checkpoint / pre-trained model to evaluate.")
    parser.add_argument("--dual-render", action='store_true', default=False)
    parser.add_argument('--gpu-id',
                        type=int,
                        default=0,
                        help="id of the CUDA GPU to use (default: 0)")
    configargs = parser.parse_args()

    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # Dataset:
    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
        )
        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
        hwf = [H, W, focal]
        if cfg.nerf.train.white_background:
            images = images[..., :3] * images[..., -1:] + (1.0 -
                                                           images[..., -1:])
    elif cfg.dataset.type.lower() == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)

    # Hardware
    device = "cpu"
    if torch.cuda.is_available():
        torch.cuda.set_device(configargs.gpu_id)
        device = "cuda"

    # Model
    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz)

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir)

    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs)
    model_coarse.to(device)

    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs)
        model_fine.to(device)

    # Load checkpoint
    checkpoint = torch.load(configargs.checkpoint, map_location=device)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]

    # Prepare model and data
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    render_poses = render_poses.float().to(device)

    print("Dual render?", configargs.dual_render)

    # Evaluation loop
    with torch.no_grad():
        fine_psnrs = []
        if type(i_test) != list:
            i_test = [i_test]
        for i in i_test:
            print(
                f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...")

            img_target = images[i].to(device)
            pose_target = poses[i, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                dual_render=configargs.dual_render)
            target_ray_values = img_target
            coarse_loss = img2mse(rgb_coarse[..., :3],
                                  target_ray_values[..., :3])
            loss, fine_loss = 0.0, 0.0
            if rgb_fine is not None:
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = fine_loss
            else:
                loss = coarse_loss
            loss = coarse_loss + fine_loss
            psnr = mse2psnr(loss.item())
            psnr_coarse = mse2psnr(coarse_loss)
            psnr_fine = mse2psnr(fine_loss)
            print(
                f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})")

            fine_psnrs.append(psnr_fine)

        print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")