示例#1
0
def main(cfg: DictConfig):
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    obj_path = cfg.data.obj_path
    texture_path = cfg.data.texture_path
    views_folder = cfg.data.views_folder
    params_file = os.path.join(views_folder, "params.json")
    dataset = CowMultiViews(obj_path,
                            views_folder,
                            texture_path,
                            params_file=params_file)

    train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset(
        dataset, train_fraction=0.7, validation_fraction=0.2)

    del dataset
    train_dataset.unit_normalize()
    validation_dataset.unit_normalize()
    test_dataset.unit_normalize()

    mesh_verts = test_dataset.get_verts()
    mesh_edges = test_dataset.get_edges()
    mesh_vert_normals = test_dataset.get_vert_normals()
    mesh_texture = test_dataset.get_texture()
    pytorch_mesh = test_dataset.pytorch_mesh.cuda()
    face_attrs = test_dataset.get_faces_as_vertex_matrices()

    feature_size = test_dataset.param_vectors.shape[1]

    torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda()
    torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda()
    torch_normals = torch.from_numpy(
        np.array(mesh_vert_normals)).float().cuda()
    torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda()
    torch_texture = torch.unsqueeze(torch_texture.permute(2, 0, 1), 0)
    torch_face_attrs = torch.from_numpy(np.array(face_attrs)).float().cuda()

    subset_indices = [82]  #random.sample(list(range(len(test_dataset))),1)
    test_dataloader = Subset(test_dataset, subset_indices)
    print(subset_indices, len(test_dataloader))

    image_translator = ImageTranslator(input_dim=6,
                                       output_dim=3,
                                       image_size=tuple(
                                           cfg.data.image_size)).cuda()

    mse_loss = torch.nn.MSELoss()

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        image_translator.parameters(),
        lr=cfg.optimizer.lr,
    )

    stats = None
    start_epoch = 0
    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(),
                                   cfg.checkpoint_path)

    # Init the stats object.
    if stats is None:
        stats = Stats(["mse_loss", "sec/it"], )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma**(
            epoch / cfg.optimizer.lr_scheduler_step_size)

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda,
                                                     last_epoch=start_epoch -
                                                     1,
                                                     verbose=False)

    # Initialize the cache for storing variables needed for visulization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    loaded_data = torch.load(checkpoint_path)

    image_translator.load_state_dict(loaded_data["model"], strict=False)
    image_translator.eval()
    stats.new_epoch()

    image_list = []
    for iteration, data in enumerate(test_dataloader):
        print(iteration)
        optimizer.zero_grad()

        views, param_vectors = data
        views = torch.unsqueeze(torch.from_numpy(views), 0)
        param_vectors = torch.unsqueeze(torch.from_numpy(param_vectors), 0)
        views = views.float().cuda()
        param_vectors = param_vectors.float().cuda()
        camera_instance = Camera()
        camera_instance.lookAt(param_vectors[0][0],
                               math.degrees(param_vectors[0][1]),
                               math.degrees(param_vectors[0][2]))

        rasterizer_instance = Rasterizer()
        rasterizer_instance.init_rasterizer(camera_instance.camera)
        fragments = rasterizer_instance.rasterizer(pytorch_mesh)
        pix_to_face = fragments.pix_to_face
        bary_coords = fragments.bary_coords

        pix_features = torch.squeeze(
            interpolate_face_attributes(pix_to_face, bary_coords,
                                        torch_face_attrs), 3)
        param_matrix = torch.zeros(pix_features.size()[0],
                                   pix_features.size()[1],
                                   pix_features.size()[2],
                                   param_vectors.size()[1]).float().cuda()
        param_matrix[:, :, :, :] = param_vectors
        image_features = pix_features  # torch.cat([pix_features,param_matrix],3)
        predicted_render = image_translator(image_features, torch_texture)

        image_list = [
            views[0].permute(2, 0, 1), predicted_render[0].permute(2, 0, 1)
        ]

    if viz is not None:
        visualize_image_outputs(validation_images=image_list,
                                viz=viz,
                                visdom_env=cfg.visualization.visdom_env)
示例#2
0
def main(cfg: DictConfig):
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    obj_path = cfg.data.obj_path
    texture_path = cfg.data.texture_path
    views_folder = cfg.data.views_folder
    params_file = os.path.join(views_folder,"params.json")
    dataset = CowMultiViews(obj_path,views_folder,texture_path,params_file=params_file)

    train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset(dataset,
                                                                                         train_fraction=0.7,
                                                                                         validation_fraction=0.2)

    del dataset
    train_dataset.unit_normalize()
    validation_dataset.unit_normalize()

    mesh_verts = train_dataset.get_verts()
    mesh_edges = train_dataset.get_edges()
    mesh_vert_normals = train_dataset.get_vert_normals()
    mesh_texture = train_dataset.get_texture()
    pytorch_mesh = train_dataset.pytorch_mesh.cuda()

    random_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['random'],num_random_dims=cfg.training.feature_dim)
    coord_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['coord'],num_random_dims=cfg.training.feature_dim)
    normal_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['normal'],num_random_dims=cfg.training.feature_dim)

    torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda()
    torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda()
    torch_normals = torch.from_numpy(np.array(mesh_vert_normals)).float().cuda()
    torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda()
    torch_texture = torch.unsqueeze(torch_texture,0)
    torch_random_face_attrs = torch.tensor(np.array(random_face_attrs),requires_grad=True).float().cuda()
    torch_random_face_attrs = torch.nn.Parameter(torch_random_face_attrs)
    torch_coord_face_attrs = torch.tensor(np.array(coord_face_attrs)).float().cuda()
    torch_normal_face_attrs = torch.tensor(np.array(normal_face_attrs)).float().cuda()

    train_dataloader = DataLoader(train_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4)
    validation_dataloader = DataLoader(validation_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4)

    image_translator = ImageTranslator(input_dim=cfg.training.feature_dim+9,output_dim=3,
                                   image_size=tuple(cfg.data.image_size)).cuda()

    mse_loss = torch.nn.MSELoss()

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        list(image_translator.parameters())+[torch_random_face_attrs],
        lr=cfg.optimizer.lr,
    )

    stats = None
    start_epoch = 0
    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)

    # Init the stats object.
    if stats is None:
        stats = Stats(
            ["mse_loss", "sec/it"],
        )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma ** (
                epoch / cfg.optimizer.lr_scheduler_step_size
        )

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
    )

    # Initialize the cache for storing variables needed for visulization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    for epoch in range(cfg.optimizer.max_epochs):
        image_translator.train()
        stats.new_epoch()
        for iteration,data in enumerate(train_dataloader):
            optimizer.zero_grad()

            views,param_vectors = data
            views = views.float().cuda()
            param_vectors = param_vectors.float().cuda()
            camera_instance = Camera()
            camera_instance.lookAt(param_vectors[0][0],math.degrees(param_vectors[0][1]),math.degrees(param_vectors[0][2]))
            camera_location = param_vectors[0,3:6]
            light_location = param_vectors[0,6:9]
            torch_camera_face_attrs = torch_coord_face_attrs - camera_location
            torch_light_face_attrs = torch_coord_face_attrs - light_location
            torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2)

            rasterizer_instance = Rasterizer()
            rasterizer_instance.init_rasterizer(camera_instance.camera)
            fragments = rasterizer_instance.rasterizer(pytorch_mesh)
            pix_to_face = fragments.pix_to_face
            bary_coords = fragments.bary_coords

            pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face,bary_coords,torch_face_attrs),3)
            predicted_render = image_translator(pix_features,torch_texture)

            loss = 1000*mse_loss(predicted_render,views)
            loss.backward()
            optimizer.step()

            # Update stats with the current metrics.
            stats.update(
                {"mse_loss": float(loss)},
                stat_set="train",
            )

            if iteration % cfg.stats_print_interval == 0:
                stats.print(stat_set="train")

        # Adjust the learning rate.
        #lr_scheduler.step()

        # Validation
        if epoch % cfg.validation_epoch_interval == 0: # and epoch > 0:

            # Sample a validation camera/image.
            val_batch = next(validation_dataloader.__iter__())
            views, param_vectors= val_batch
            views = views.float().cuda()
            param_vectors = param_vectors.float().cuda()

            # Activate eval mode of the model (allows to do a full rendering pass).
            image_translator.eval()
            with torch.no_grad():
                camera_instance = Camera()
                camera_instance.lookAt(param_vectors[0][0], math.degrees(param_vectors[0][1]), math.degrees(param_vectors[0][2]))
                camera_location = param_vectors[0,3:6]
                light_location = param_vectors[0,6:9]
                torch_camera_face_attrs = torch_coord_face_attrs - camera_location
                torch_light_face_attrs = torch_coord_face_attrs - light_location
                torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2)

                rasterizer_instance = Rasterizer()
                rasterizer_instance.init_rasterizer(camera_instance.camera)
                fragments = rasterizer_instance.rasterizer(pytorch_mesh)
                pix_to_face = fragments.pix_to_face
                bary_coords = fragments.bary_coords

                pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face, bary_coords, torch_face_attrs), 3)
                #pix_features = pix_features.permute(0, 3, 1, 2)
                predicted_render = image_translator(pix_features,torch_texture)
                loss = 1000*mse_loss(predicted_render,views)


            # Update stats with the validation metrics.
            stats.update({"mse_loss":loss}, stat_set="val")
            stats.print(stat_set="val")

            if viz is not None:
                # Plot that loss curves into visdom.
                stats.plot_stats(
                    viz=viz,
                    visdom_env=cfg.visualization.visdom_env,
                    plot_file=None,
                )
                # Visualize the intermediate results.
                render_max = torch.max(predicted_render)
                visualize_image_outputs(
                    validation_images = [views[0].permute(2,0,1),predicted_render[0].permute(2,0,1)],viz=viz,visdom_env=cfg.visualization.visdom_env
                )

            # Set the model back to train mode.
            image_translator.train()

        # Checkpoint.
        if (
                epoch % cfg.checkpoint_epoch_interval == 0
                and len(cfg.checkpoint_path) > 0
                and epoch > 0
        ):
            print(f"Storing checkpoint {checkpoint_path}.")
            data_to_store = {
                "model": image_translator.state_dict(),
                "features" : torch_face_attrs,
                "optimizer": optimizer.state_dict(),
                "stats": pickle.dumps(stats),
            }
            torch.save(data_to_store, checkpoint_path)
示例#3
0
def main(cfg: DictConfig):

    # Set the relevant seeds for reproducibility.
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        warnings.warn(
            "Please note that although executing on CPU is supported," +
            "the training is unlikely to finish in resonable time.")
        device = "cpu"

    # Initialize the Radiance Field model.
    model = RadianceFieldRenderer(
        image_size=cfg.data.image_size,
        n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
        n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
        n_rays_per_image=cfg.raysampler.n_rays_per_image,
        min_depth=cfg.raysampler.min_depth,
        max_depth=cfg.raysampler.max_depth,
        stratified=cfg.raysampler.stratified,
        stratified_test=cfg.raysampler.stratified_test,
        chunk_size_test=cfg.raysampler.chunk_size_test,
        n_harmonic_functions_xyz=cfg.implicit_function.
        n_harmonic_functions_xyz,
        n_harmonic_functions_dir=cfg.implicit_function.
        n_harmonic_functions_dir,
        n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
        n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
        n_layers_xyz=cfg.implicit_function.n_layers_xyz,
        density_noise_std=cfg.implicit_function.density_noise_std,
    )
    # Move the model to the relevant device.
    model.to(device)

    # Init stats to None before loading.
    stats = None
    optimizer_state_dict = None
    start_epoch = 0

    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(),
                                   cfg.checkpoint_path)
    if len(cfg.checkpoint_path) > 0:
        # Make the root of the experiment directory.
        checkpoint_dir = os.path.split(checkpoint_path)[0]
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Resume training if requested.
        if cfg.resume and os.path.isfile(checkpoint_path):
            print(f"Resuming from checkpoint {checkpoint_path}.")
            loaded_data = torch.load(checkpoint_path)
            model.load_state_dict(loaded_data["model"])
            stats = pickle.loads(loaded_data["stats"])
            print(f"   => resuming from epoch {stats.epoch}.")
            optimizer_state_dict = loaded_data["optimizer"]
            start_epoch = stats.epoch

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg.optimizer.lr,
    )

    # Load the optimizer state dict in case we are resuming.
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)
        optimizer.last_epoch = start_epoch

    # Init the stats object.
    if stats is None:
        stats = Stats([
            "loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine",
            "sec/it"
        ], )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma**(
            epoch / cfg.optimizer.lr_scheduler_step_size)

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda,
                                                     last_epoch=start_epoch -
                                                     1,
                                                     verbose=False)

    # Initialize the cache for storing variables needed for visulization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    # Load the training/validation data.
    train_dataset, val_dataset, _ = get_nerf_datasets(
        dataset_name=cfg.data.dataset_name,
        image_size=cfg.data.image_size,
    )

    if cfg.data.precache_rays:
        # Precache the projection rays.
        model.eval()
        with torch.no_grad():
            for dataset in (train_dataset, val_dataset):
                cache_cameras = [e["camera"].to(device) for e in dataset]
                cache_camera_hashes = [e["camera_idx"] for e in dataset]
                model.precache_rays(cache_cameras, cache_camera_hashes)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        collate_fn=trivial_collate,
    )

    # The validation dataloader is just an endless stream of random samples.
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        num_workers=0,
        collate_fn=trivial_collate,
        sampler=torch.utils.data.RandomSampler(
            val_dataset,
            replacement=True,
            num_samples=cfg.optimizer.max_epochs,
        ),
    )

    # Set the model to the training mode.
    model.train()

    # Run the main training loop.
    for epoch in range(start_epoch, cfg.optimizer.max_epochs):
        stats.new_epoch()  # Init a new epoch.
        for iteration, batch in enumerate(train_dataloader):
            image, camera, camera_idx = batch[0].values()
            image = image.to(device)
            camera = camera.to(device)

            optimizer.zero_grad()

            # Run the forward pass of the model.
            nerf_out, metrics = model(
                camera_idx if cfg.data.precache_rays else None,
                camera,
                image,
            )

            # The loss is a sum of coarse and fine MSEs
            loss = metrics["mse_coarse"] + metrics["mse_fine"]

            # Take the training step.
            loss.backward()
            optimizer.step()

            # Update stats with the current metrics.
            stats.update(
                {
                    "loss": float(loss),
                    **metrics
                },
                stat_set="train",
            )

            if iteration % cfg.stats_print_interval == 0:
                stats.print(stat_set="train")

            # Update the visualisatioon cache.
            visuals_cache.append({
                "camera":
                camera.cpu(),
                "camera_idx":
                camera_idx,
                "image":
                image.cpu().detach(),
                "rgb_fine":
                nerf_out["rgb_fine"].cpu().detach(),
                "rgb_coarse":
                nerf_out["rgb_coarse"].cpu().detach(),
                "rgb_gt":
                nerf_out["rgb_gt"].cpu().detach(),
                "coarse_ray_bundle":
                nerf_out["coarse_ray_bundle"],
            })

        # Adjust the learning rate.
        lr_scheduler.step()

        print(cfg.validation_epoch_interval)
        # Validation
        if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:

            # Sample a validation camera/image.
            val_batch = next(val_dataloader.__iter__())
            val_image, val_camera, camera_idx = val_batch[0].values()
            val_image = val_image.to(device)
            val_camera = val_camera.to(device)

            # Activate eval mode of the model (allows to do a full rendering pass).
            model.eval()
            with torch.no_grad():
                val_nerf_out, val_metrics = model(
                    camera_idx if cfg.data.precache_rays else None,
                    val_camera,
                    val_image,
                )

            # Update stats with the validation metrics.
            stats.update(val_metrics, stat_set="val")
            stats.print(stat_set="val")

            if viz is not None:
                # Plot that loss curves into visdom.
                stats.plot_stats(
                    viz=viz,
                    visdom_env=cfg.visualization.visdom_env,
                    plot_file=None,
                )
                # Visualize the intermediate results.
                visualize_nerf_outputs(val_nerf_out, visuals_cache, viz,
                                       cfg.visualization.visdom_env)

            # Set the model back to train mode.
            model.train()

        # Checkpoint.
        if (epoch % cfg.checkpoint_epoch_interval == 0
                and len(cfg.checkpoint_path) > 0 and epoch > 0):
            print(f"Storing checkpoint {checkpoint_path}.")
            data_to_store = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "stats": pickle.dumps(stats),
            }
            torch.save(data_to_store, checkpoint_path)
示例#4
0
def main(cfg: DictConfig):

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        warnings.warn(
            "Please note that although executing on CPU is supported," +
            "the testing is unlikely to finish in reasonable time.")
        device = "cpu"

    # Initialize the Radiance Field model.
    model = RadianceFieldRenderer(
        image_size=cfg.data.image_size,
        n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
        n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
        n_rays_per_image=cfg.raysampler.n_rays_per_image,
        min_depth=cfg.raysampler.min_depth,
        max_depth=cfg.raysampler.max_depth,
        stratified=cfg.raysampler.stratified,
        stratified_test=cfg.raysampler.stratified_test,
        chunk_size_test=cfg.raysampler.chunk_size_test,
        n_harmonic_functions_xyz=cfg.implicit_function.
        n_harmonic_functions_xyz,
        n_harmonic_functions_dir=cfg.implicit_function.
        n_harmonic_functions_dir,
        n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
        n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
        n_layers_xyz=cfg.implicit_function.n_layers_xyz,
        density_noise_std=cfg.implicit_function.density_noise_std,
    )

    # Move the model to the relevant device.
    model.to(device)

    # Resume from the checkpoint.
    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(),
                                   cfg.checkpoint_path)
    if not os.path.isfile(checkpoint_path):
        raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")

    print(f"Loading checkpoint {checkpoint_path}.")
    loaded_data = torch.load(checkpoint_path)
    # Do not load the cached xy grid.
    # - this allows to set an arbitrary evaluation image size.
    state_dict = {
        k: v
        for k, v in loaded_data["model"].items()
        if "_grid_raysampler._xy_grid" not in k
    }
    model.load_state_dict(state_dict, strict=False)

    # Load the test data.
    if cfg.test.mode == "evaluation":
        _, _, test_dataset = get_nerf_datasets(
            dataset_name=cfg.data.dataset_name,
            image_size=cfg.data.image_size,
        )
    elif cfg.test.mode == "export_video":
        train_dataset, _, _ = get_nerf_datasets(
            dataset_name=cfg.data.dataset_name,
            image_size=cfg.data.image_size,
        )
        test_dataset = generate_eval_video_cameras(
            train_dataset,
            trajectory_type=cfg.test.trajectory_type,
            up=cfg.test.up,
            scene_center=cfg.test.scene_center,
            n_eval_cams=cfg.test.n_frames,
            trajectory_scale=cfg.test.trajectory_scale,
        )
        # store the video in directory (checkpoint_file - extension + '_video')
        export_dir = os.path.splitext(checkpoint_path)[0] + "_video"
        os.makedirs(export_dir, exist_ok=True)
    else:
        raise ValueError(f"Unknown test mode {cfg.test_mode}.")

    # Init the test dataloader.
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=trivial_collate,
    )

    if cfg.test.mode == "evaluation":
        # Init the test stats object.
        eval_stats = [
            "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"
        ]
        stats = Stats(eval_stats)
        stats.new_epoch()
    elif cfg.test.mode == "export_video":
        # Init the frame buffer.
        frame_paths = []

    # Set the model to the eval mode.
    model.eval()

    # Run the main testing loop.
    for batch_idx, test_batch in enumerate(test_dataloader):
        print(batch_idx, len(test_dataloader))
        test_image, test_camera, camera_idx = test_batch[0].values()
        if test_image is not None:
            test_image = test_image.to(device)
        test_camera = test_camera.to(device)

        # Activate eval mode of the model (allows to do a full rendering pass).
        model.eval()
        with torch.no_grad():
            test_nerf_out, test_metrics = model(
                None,  # we do not use pre-cached cameras
                test_camera,
                test_image,
            )

        if cfg.test.mode == "evaluation":
            # Update stats with the validation metrics.
            stats.update(test_metrics, stat_set="test")
            stats.print(stat_set="test")

        elif cfg.test.mode == "export_video":
            # Store the video frame.
            frame = test_nerf_out["rgb_fine"][0].detach().cpu()
            frame_path = os.path.join(export_dir, f"frame_{batch_idx:05d}.png")
            print(f"Writing {frame_path}.")
            Image.fromarray(
                (frame.numpy() * 255.0).astype(np.uint8)).save(frame_path)
            frame_paths.append(frame_path)

    if cfg.test.mode == "evaluation":
        print(f"Final evaluation metrics on '{cfg.data.dataset_name}':")
        for stat in eval_stats:
            stat_value = stats.stats["test"][stat].get_epoch_averages()[0]
            print(f"{stat:15s}: {stat_value:1.4f}")

    elif cfg.test.mode == "export_video":
        # Convert the exported frames to a video.
        video_path = os.path.join(export_dir, "video.mp4")
        ffmpeg_bin = "ffmpeg"
        frame_regexp = os.path.join(export_dir, "frame_%05d.png")
        ffmcmd = (
            "%s -r %d -i %s -vcodec h264 -f mp4 -y -b 2000k -pix_fmt yuv420p %s"
            % (ffmpeg_bin, cfg.test.fps, frame_regexp, video_path))
        ret = os.system(ffmcmd)
        if ret != 0:
            raise RuntimeError("ffmpeg failed!")