示例#1
0
文件: main.py 项目: danielgolf/nerf
def evaluation(cfg, mlp, data):
    expid = cfg.experiment.id
    logdir = cfg.experiment.logdir
    save_dir = os.path.join(logdir, expid, 'rendered')
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for i in trange(data.render_poses.shape[0]):
            pose = data.render_poses[i]
            _, ray_ori, ray_dir = get_ray_bundle(data.height, data.width,
                                                 data.focal, pose)
            img_shape = ray_ori.shape
            ray_ori = ray_ori.reshape((-1, ray_ori.shape[-1]))
            ray_dir = ray_dir.reshape((-1, ray_dir.shape[-1]))

            rgb_coarse, rgb_fine = nerf_iteration(mlp,
                                                  cfg,
                                                  torch.cat([ray_ori, ray_dir],
                                                            dim=-1),
                                                  data.near,
                                                  data.far,
                                                  mode='validation')

            if rgb_fine is not None:
                rgb = rgb_fine.reshape(img_shape)
            else:
                rgb = rgb_coarse.reshape(img_shape)

            save_file = os.path.join(save_dir, f"{i:04d}.png")
            img = np.array(to_pil_image(rgb.permute(2, 0, 1).cpu()))
            imageio.imwrite(save_file, img)
示例#2
0
文件: loader.py 项目: danielgolf/nerf
    def get_valid_img(self):
        idx = np.random.choice(self.i_val)
        img = self.images[idx]
        pose = self.poses[idx]

        _, ray_ori, ray_dir = get_ray_bundle(self.height, self.width,
                                             self.focal, pose)

        ray_ori = ray_ori.reshape((-1, ray_ori.shape[-1]))
        ray_dir = ray_dir.reshape((-1, ray_dir.shape[-1]))

        rays = torch.cat([ray_ori, ray_dir], dim=-1)
        return rays, img
示例#3
0
def convert_poses_to_rays(poses, H, W, focal):
    ray_origins = []
    ray_directions = []
    for pose in poses:
        chunk_ray_origins, chunk_ray_directions = get_ray_bundle(
            H, W, focal, pose)

        ray_origins.append(chunk_ray_origins)
        ray_directions.append(chunk_ray_directions)

    ray_origins = torch.stack(ray_origins, 0)
    ray_directions = torch.stack(ray_directions, 0)

    return ray_origins, ray_directions
def run_one_iter_of_tinynerf(
    height,
    width,
    focal_length,
    tform_cam2world,
    near_thresh,
    far_thresh,
    depth_samples_per_ray,
    encoding_function,
    get_minibatches_function,
    chunksize,
    model,
    encoding_function_args,
):

    # Get the "bundle" of rays through all image pixels.
    ray_origins, ray_directions = get_ray_bundle(
        height, width, focal_length, tform_cam2world
    )

    # Sample query points along each ray
    query_points, depth_values = compute_query_points_from_rays(
        ray_origins, ray_directions, near_thresh, far_thresh, depth_samples_per_ray
    )

    # "Flatten" the query points.
    flattened_query_points = query_points.reshape((-1, 3))

    # Encode the query points (default: positional encoding).
    encoded_points = encoding_function(flattened_query_points, encoding_function_args)

    # Split the encoded points into "chunks", run the model on all chunks, and
    # concatenate the results (to avoid out-of-memory issues).
    batches = get_minibatches_function(encoded_points, chunksize=chunksize)
    predictions = []
    for batch in batches:
        predictions.append(model(batch))
    radiance_field_flattened = torch.cat(predictions, dim=0)

    # "Unflatten" to obtain the radiance field.
    unflattened_shape = list(query_points.shape[:-1]) + [4]
    radiance_field = torch.reshape(radiance_field_flattened, unflattened_shape)

    # Perform differentiable volume rendering to re-synthesize the RGB image.
    rgb_predicted, _, _ = render_volume_density(
        radiance_field, ray_origins, depth_values
    )

    return rgb_predicted
示例#5
0
def test_get_ray_bundle():
    tfrom_cam2world = np.eye(4, 4, dtype=np.float32)
    tfrom_cam2world[0, 3] = 2.0
    tfrom_cam2world[1, 3] = -3.0
    tfrom_cam2world[2, 3] = 5.0

    for i in range(0, 2):
        jax_fn = lambda x: get_ray_bundle(10, 10, 0.3, x)[i]
        torch_fn = lambda x: get_ray_bundle_torch(10, 10, 0.3, x)[i]

        jo, to, djos, dtos = run_and_grad(jax_fn, torch_fn, (0, ),
                                          tfrom_cam2world)

        assert np.allclose(jo, to, rtol=1e-3, atol=1e-5)
        assert all(
            np.allclose(djo, dto, rtol=1e-3, atol=1e-5)
            for djo, dto in zip(djos, dtos))
示例#6
0
文件: loader.py 项目: danielgolf/nerf
    def get_train_batch(self, size=1024):
        idx = np.random.choice(self.i_train)
        img = self.images[idx]
        pose = self.poses[idx]

        coords, ray_ori, ray_dir = get_ray_bundle(self.height, self.width,
                                                  self.focal, pose)

        ray_idx = np.random.choice(  # take a subset of all rays
            coords.shape[0],
            size=size,
            replace=False)

        coords = coords[ray_idx]
        ray_ori = ray_ori[coords[:, 0], coords[:, 1]]
        ray_dir = ray_dir[coords[:, 0], coords[:, 1]]
        img = img[coords[:, 0], coords[:, 1]]

        rays = torch.cat([ray_ori, ray_dir], dim=-1)
        return rays, img
示例#7
0
    def cache_dataset(self):
        # TODO(0) testskip = args.blender_stride, offset for a small dataset
        # Unpacking data
        bundle = self.load_dataset().to(self.device)

        # Coordinates to sample from
        self.init_sampling(bundle.hwf)

        for img_idx in trange(bundle.size):
            # Create data chunk bundle
            sample = bundle[img_idx]
            sample.ray_origins, sample.ray_directions = get_ray_bundle(
                *sample.hwf, sample.poses)
            if self.cfg.dataset.use_ndc:
                # Use normalized device coordinates
                sample.ndc()

            if self.cfg.dataset.caching.sample_all or self.type == DatasetType.VALIDATION:
                self.save_dataset(sample, img_idx)
            else:
                raise NotImplementedError
示例#8
0
def cache_nerf_dataset(args):

    images, poses, render_poses, hwf = (
        None,
        None,
        None,
        None,
    )
    i_train, i_val, i_test = None, None, None

    if args.type == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datapath,
            half_res=args.blender_half_res,
            testskip=args.blender_stride)

        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
        hwf = [H, W, focal]
    elif args.type == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datapath, factor=args.llff_downsample_factor)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        if not isinstance(i_test, list):
            i_test = [i_test]
        if args.llffhold > 0:
            i_test = np.arange(images.shape[0])[::args.llffhold]
        i_val = i_test
        i_train = np.array([
            i for i in np.arange(images.shape[0])
            if (i not in i_test and i not in i_val)
        ])
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    os.makedirs(os.path.join(args.savedir, "train"), exist_ok=True)
    os.makedirs(os.path.join(args.savedir, "val"), exist_ok=True)
    os.makedirs(os.path.join(args.savedir, "test"), exist_ok=True)
    np.random.seed(args.randomseed)

    for img_idx in tqdm(i_train):
        for j in range(args.num_variations):
            img_target = images[img_idx].to(device)
            pose_target = poses[img_idx, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            coords = torch.stack(
                meshgrid_xy(
                    torch.arange(H).to(device),
                    torch.arange(W).to(device)),
                dim=-1,
            )
            coords = coords.reshape((-1, 2))
            target_s = None
            save_path = None
            if args.sample_all is False:
                select_inds = np.random.choice(coords.shape[0],
                                               size=(args.num_random_rays),
                                               replace=False)
                select_inds = coords[select_inds]
                ray_origins = ray_origins[select_inds[:, 0], select_inds[:,
                                                                         1], :]
                ray_directions = ray_directions[select_inds[:, 0],
                                                select_inds[:, 1], :]
                target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]
                save_path = os.path.join(
                    args.savedir,
                    "train",
                    str(img_idx).zfill(4),
                    str(j).zfill(4),
                    ".data",
                )
            else:
                target_s = img_target
                save_path = os.path.join(args.savedir, "train",
                                         str(img_idx).zfill(4) + ".data")

            batch_rays = torch.stack([ray_origins, ray_directions], dim=0)

            cache_dict = {
                "height": H,
                "width": W,
                "focal_length": focal,
                "ray_bundle": batch_rays.detach().cpu(),
                "target": target_s.detach().cpu(),
            }

            save_path = os.path.join(args.savedir, "train",
                                     str(img_idx).zfill(4) + ".data")
            torch.save(cache_dict, save_path)

            if args.sample_all is True:
                break

    for img_idx in tqdm(i_val):
        img_target = images[img_idx].to(device)
        pose_target = poses[img_idx, :3, :4].to(device)
        ray_origins, ray_directions = get_ray_bundle(H, W, focal, pose_target)

        cache_dict = {
            "height": H,
            "width": W,
            "focal_length": focal,
            "ray_origins": ray_origins.detach().cpu(),
            "ray_directions": ray_directions.detach().cpu(),
            "target": img_target.detach().cpu(),
        }

        save_path = os.path.join(args.savedir, "val",
                                 str(img_idx).zfill(4) + ".data")
        torch.save(cache_dict, save_path)
示例#9
0
def export_ray_trace(model_coarse, model_fine, config_args, cfg,
                     encode_position_fn, encode_direction_fn, device):

    # Mesh Extraction
    samples_dimen_y = 8
    samples_dimen_x = 4
    plane_near = 0
    plane_far = 4.0
    img_size = 800
    step_size = 2
    dist_threshold = 0.002
    prob_threshold = 0.6

    # Data
    vertices, triangles, normals, diffuse = [], [], [], []
    render_poses = torch.stack([
        torch.from_numpy(pose_spherical(angleY, angleX, plane_far)).float()
        for angleY in np.linspace(-180, 180, samples_dimen_y, endpoint=False)
        for angleX in np.linspace(-90, 90, samples_dimen_x, endpoint=True)
    ],
                               dim=0)

    hwf = [img_size, img_size, 1111.1111]

    grid = get_grid(img_size)
    for i, pose in enumerate(tqdm(render_poses)):
        pose = pose[:3, :4].to(device)

        # Ray origins & directions
        ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2],
                                                     pose)

        # cfg.nerf['validation']['num_coarse'] = 64
        # cfg.nerf['validation']['num_fine'] = 64
        with torch.no_grad():
            _, _, _, rgb_fine, _, depth_fine = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )

            # Apply nn mask
            surface_points = ray_origins + ray_directions * depth_fine[...,
                                                                       None]

            acc = []
            initial_values = surface_points[grid.T[0], grid.T[1]].view(
                img_size, img_size, -1)
            size = step_size * 2 + 1
            size_samples = size**2 - 1
            for a in range(-step_size, step_size + 1):
                for b in range(-step_size, step_size + 1):
                    offset = torch.tensor([a, b])

                    new_grid = grid + offset
                    new_grid = new_grid.clamp(0, img_size - 1)

                    new_grid = surface_points[new_grid.T[0],
                                              new_grid.T[1]].view(
                                                  img_size, img_size, -1)
                    new_grid_s = ((new_grid - initial_values)**
                                  2).sum(-1) < dist_threshold

                    acc.append(new_grid_s)

            new_mask = torch.stack(
                acc, -1).sum(-1).squeeze(-1) > size_samples * prob_threshold

            dep_mask = (depth_fine > 0)
            mask = new_mask * dep_mask

            ray_origins, ray_directions, depth_fine = ray_origins[
                mask], ray_directions[mask], depth_fine[mask]
            rgb_fine = rgb_fine[mask]

            surface_points = ray_origins + ray_directions * depth_fine[...,
                                                                       None]

            vertices.append(surface_points.view(-1, 3).cpu().detach())
            normals.append((-ray_directions).view(-1, 3).cpu().detach())
            diffuse.append(rgb_fine.view(-1, 3).cpu().detach())

    # Query the whole diffuse map
    diffuse_fine = torch.cat(diffuse, dim=0).numpy()
    vertices_fine = torch.cat(vertices, dim=0).numpy()
    normals_fine = torch.cat(normals, dim=0).numpy()

    # Export model
    # export_obj(vertices_fine, [], diffuse_fine, normals_fine, "lego-sampling.obj")
    export_ply(vertices_fine, diffuse_fine, normals_fine, "lego-sampling.ply")
示例#10
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # If a pre-cached dataset is available, skip the dataloader.
    USE_CACHED_DATASET = False
    train_paths, validation_paths = None, None
    images, poses, render_poses, hwf, i_split = None, None, None, None, None
    H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None
    if hasattr(cfg.dataset, "cachedir") and os.path.exists(
            cfg.dataset.cachedir):
        train_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "train", "*.data"))
        validation_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "val", "*.data"))
        USE_CACHED_DATASET = True
    else:
        # Load dataset
        images, poses, render_poses, hwf = None, None, None, None
        if cfg.dataset.type.lower() == "blender":
            images, poses, render_poses, hwf, i_split = load_blender_data(
                cfg.dataset.basedir,
                half_res=cfg.dataset.half_res,
                testskip=cfg.dataset.testskip,
            )
            i_train, i_val, i_test = i_split
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            if cfg.nerf.train.white_background:
                images = images[..., :3] * images[..., -1:] + (
                    1. - images[..., -1:])
        elif cfg.dataset.type.lower() == "llff":
            images, poses, bds, render_poses, i_test = load_llff_data(
                cfg.dataset.basedir, factor=cfg.dataset.downsample_factor)
            hwf = poses[0, :3, -1]
            poses = poses[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold]
            i_val = i_test
            i_train = np.array([
                i for i in np.arange(images.shape[0])
                if (i not in i_test and i not in i_val)
            ])
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            images = torch.from_numpy(images)
            poses = torch.from_numpy(poses)

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse-resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)
    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters())
    if model_fine is not None:
        trainable_parameters += list(model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters,
                                                         lr=cfg.optimizer.lr)

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        if checkpoint["model_fine_state_dict"]:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        if model_fine:
            model_coarse.train()

        rgb_coarse, rgb_fine = None, None
        target_ray_values = None
        if USE_CACHED_DATASET:
            datafile = np.random.choice(train_paths)
            cache_dict = torch.load(datafile)
            ray_bundle = cache_dict["ray_bundle"].to(device)
            ray_origins, ray_directions = (
                ray_bundle[0].reshape((-1, 3)),
                ray_bundle[1].reshape((-1, 3)),
            )
            target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3))
            select_inds = np.random.choice(
                ray_origins.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
            )
            ray_origins, ray_directions = (
                ray_origins[select_inds],
                ray_directions[select_inds],
            )
            target_ray_values = target_ray_values[select_inds].to(device)
            # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device)

            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                cache_dict["height"],
                cache_dict["width"],
                cache_dict["focal_length"],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
        else:
            img_idx = np.random.choice(i_train)
            img_target = images[img_idx].to(device)
            pose_target = poses[img_idx, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            coords = torch.stack(
                meshgrid_xy(
                    torch.arange(H).to(device),
                    torch.arange(W).to(device)),
                dim=-1,
            )
            coords = coords.reshape((-1, 2))
            select_inds = np.random.choice(
                coords.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False)
            select_inds = coords[select_inds]
            ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
            ray_directions = ray_directions[select_inds[:, 0],
                                            select_inds[:, 1], :]
            # batch_rays = torch.stack([ray_origins, ray_directions], dim=0)
            target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

            then = time.time()
            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
            target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = None
        if rgb_fine is not None:
            fine_loss = torch.nn.functional.mse_loss(
                rgb_fine[..., :3], target_ray_values[..., :3])
        # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
        loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0)
        loss.backward()
        psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " PSNR: " + str(psnr))
        writer.add_scalar("train/loss", loss.item(), i)
        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        if rgb_fine is not None:
            writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1):
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            if model_fine:
                model_coarse.eval()

            start = time.time()
            with torch.no_grad():
                rgb_coarse, rgb_fine = None, None
                target_ray_values = None
                if USE_CACHED_DATASET:
                    datafile = np.random.choice(validation_paths)
                    cache_dict = torch.load(datafile)
                    rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                        cache_dict["height"],
                        cache_dict["width"],
                        cache_dict["focal_length"],
                        model_coarse,
                        model_fine,
                        cache_dict["ray_origins"].to(device),
                        cache_dict["ray_directions"].to(device),
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                    )
                    target_ray_values = cache_dict["target"].to(device)
                else:
                    img_idx = np.random.choice(i_val)
                    img_target = images[img_idx].to(device)
                    pose_target = poses[img_idx, :3, :4].to(device)
                    ray_origins, ray_directions = get_ray_bundle(
                        H, W, focal, pose_target)
                    rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                        H,
                        W,
                        focal,
                        model_coarse,
                        model_fine,
                        ray_origins,
                        ray_directions,
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                    )
                    target_ray_values = img_target
                coarse_loss = img2mse(rgb_coarse[..., :3],
                                      target_ray_values[..., :3])
                fine_loss = 0.0
                if rgb_fine is not None:
                    fine_loss = img2mse(rgb_fine[..., :3],
                                        target_ray_values[..., :3])
                loss = coarse_loss + fine_loss
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validataion/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                if rgb_fine is not None:
                    writer.add_image("validation/rgb_fine",
                                     cast_to_image(rgb_fine[..., :3]), i)
                    writer.add_scalar("validation/fine_loss", fine_loss.item(),
                                      i)
                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter":
                i,
                "model_coarse_state_dict":
                model_coarse.state_dict(),
                "model_fine_state_dict":
                None if not model_fine else model_fine.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                "loss":
                loss,
                "psnr":
                psnr,
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
示例#11
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--savedir",
                        type=str,
                        help="Save images to this directory, if specified.")
    parser.add_argument("--save-disparity-image",
                        action="store_true",
                        help="Save disparity images too.")
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        # Load blender dataset
        images, poses, render_poses, hwf, i_split = load_blender_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
        )
        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
    elif cfg.dataset.type.lower() == "llff":
        # Load LLFF dataset
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz)

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    checkpoint = torch.load(configargs.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]

    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    render_poses = render_poses.float().to(device)

    # Create directory to save images to.
    os.makedirs(configargs.savedir, exist_ok=True)
    if configargs.save_disparity_image:
        os.makedirs(os.path.join(configargs.savedir, "disparity"),
                    exist_ok=True)

    # Evaluation loop
    times_per_image = []
    for i, pose in enumerate(tqdm(render_poses)):
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = pose[:3, :4]
            ray_origins, ray_directions = get_ray_bundle(
                hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
            rgb = rgb_fine if rgb_fine is not None else rgb_coarse
            if configargs.save_disparity_image:
                disp = disp_fine if disp_fine is not None else disp_coarse
        times_per_image.append(time.time() - start)
        if configargs.savedir:
            savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
            imageio.imwrite(
                savefile, cast_to_image(rgb[..., :3],
                                        cfg.dataset.type.lower()))
            if configargs.save_disparity_image:
                savefile = os.path.join(configargs.savedir, "disparity",
                                        f"{i:04d}.png")
                imageio.imwrite(savefile, cast_to_disparity_image(disp))
        tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
示例#12
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # If a pre-cached dataset is available, skip the dataloader.
    USE_CACHED_DATASET = False
    train_paths, validation_paths = None, None
    images, poses, render_poses, hwf, i_split, expressions = None, None, None, None, None, None
    H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None
    if hasattr(cfg.dataset, "cachedir") and os.path.exists(
            cfg.dataset.cachedir):
        train_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "train", "*.data"))
        validation_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "val", "*.data"))
        USE_CACHED_DATASET = True
    else:
        # Load dataset
        images, poses, render_poses, hwf, expressions = None, None, None, None, None
        if cfg.dataset.type.lower() == "blender":
            images, poses, render_poses, hwf, i_split, expressions, _, bboxs = load_flame_data(
                cfg.dataset.basedir,
                half_res=cfg.dataset.half_res,
                testskip=cfg.dataset.testskip,
            )
            i_train, i_val, i_test = i_split
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            if cfg.nerf.train.white_background:
                images = images[..., :3] * images[..., -1:] + (
                    1.0 - images[..., -1:])
    print("done loading data")
    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"  #+ ":" + str(cfg.experiment.device)
    else:
        device = "cpu"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse-resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True)
    model_coarse.to(device)
    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
            num_layers=cfg.models.coarse.num_layers,
            hidden_size=cfg.models.coarse.hidden_size,
            include_expression=True)
        model_fine.to(device)

    ###################################
    ###################################
    train_background = False
    supervised_train_background = False
    blur_background = False

    train_latent_codes = True
    disable_expressions = False  # True to disable expressions
    disable_latent_codes = False  # True to disable latent codes
    fixed_background = True  # Do False to disable BG
    regularize_latent_codes = True  # True to add latent code LOSS, false for most experiments
    ###################################
    ###################################

    supervised_train_background = train_background and supervised_train_background
    # Avg background
    #images[i_train]
    if train_background:
        with torch.no_grad():
            avg_img = torch.mean(images[i_train], axis=0)
            # Blur Background:
            if blur_background:
                avg_img = avg_img.permute(2, 0, 1)
                avg_img = avg_img.unsqueeze(0)
                smoother = GaussianSmoothing(channels=3,
                                             kernel_size=11,
                                             sigma=11)
                print("smoothed background initialization. shape ",
                      avg_img.shape)
                avg_img = smoother(avg_img).squeeze(0).permute(1, 2, 0)
            #avg_img = torch.zeros(H,W,3)
            #avg_img = torch.rand(H,W,3)
            #avg_img = 0.5*(torch.rand(H,W,3) + torch.mean(images[i_train],axis=0))
            background = torch.tensor(avg_img, device=device)
        background.requires_grad = True

    if fixed_background:  # load GT background
        print("loading GT background to condition on")
        from PIL import Image
        background = Image.open(
            os.path.join(cfg.dataset.basedir, 'bg', '00050.png'))
        background.thumbnail((H, W))
        background = torch.from_numpy(np.array(background).astype(
            np.float32)).to(device)
        background = background / 255
        print("bg shape", background.shape)
        print("should be ", images[i_train][0].shape)
        assert background.shape == images[i_train][0].shape
    else:
        background = None

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters())
    if model_fine is not None:
        trainable_parameters += list(model_fine.parameters())
    if train_background:
        #background.requires_grad = True
        #trainable_parameters.append(background) # add it later when init optimizer for different lr
        print("background.is_leaf ", background.is_leaf, background.device)

    if train_latent_codes:
        latent_codes = torch.zeros(len(i_train), 32, device=device)
        print("initialized latent codes with shape %d X %d" %
              (latent_codes.shape[0], latent_codes.shape[1]))
        if not disable_latent_codes:
            trainable_parameters.append(latent_codes)
            latent_codes.requires_grad = True

    if train_background:
        optimizer = getattr(torch.optim,
                            cfg.optimizer.type)([{
                                'params': trainable_parameters
                            }, {
                                'params': background,
                                'lr': cfg.optimizer.lr
                            }],
                                                lr=cfg.optimizer.lr)
    else:
        optimizer = getattr(torch.optim, cfg.optimizer.type)(
            [{
                'params': trainable_parameters
            }, {
                'params': background,
                'lr': cfg.optimizer.lr
            }],  # this is obsolete but need for continuing training
            lr=cfg.optimizer.lr)
    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        if checkpoint["model_fine_state_dict"]:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        if checkpoint["background"] is not None:
            print("loaded bg from checkpoint")
            background = torch.nn.Parameter(
                checkpoint['background'].to(device))
        if checkpoint["latent_codes"] is not None:
            print("loaded latent codes from checkpoint")
            latent_codes = torch.nn.Parameter(
                checkpoint['latent_codes'].to(device))

        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    # Prepare importance sampling maps
    ray_importance_sampling_maps = []
    p = 0.9
    print("computing boundix boxes probability maps")
    for i in i_train:
        bbox = bboxs[i]
        probs = np.zeros((H, W))
        probs.fill(1 - p)
        probs[bbox[0]:bbox[1], bbox[2]:bbox[3]] = p
        probs = (1 / probs.sum()) * probs
        ray_importance_sampling_maps.append(probs.reshape(-1))

    print("Starting loop")
    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        if model_fine:
            model_coarse.train()

        rgb_coarse, rgb_fine = None, None
        target_ray_values = None
        background_ray_values = None
        if USE_CACHED_DATASET:
            datafile = np.random.choice(train_paths)
            cache_dict = torch.load(datafile)
            ray_bundle = cache_dict["ray_bundle"].to(device)
            ray_origins, ray_directions = (
                ray_bundle[0].reshape((-1, 3)),
                ray_bundle[1].reshape((-1, 3)),
            )
            target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3))
            select_inds = np.random.choice(
                ray_origins.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
            )
            ray_origins, ray_directions = (
                ray_origins[select_inds],
                ray_directions[select_inds],
            )
            target_ray_values = target_ray_values[select_inds].to(device)
            #target_ray_values = target_ray_values[select_inds].to(device)
            # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device)

            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                cache_dict["height"],
                cache_dict["width"],
                cache_dict["focal_length"],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expressions)
        else:
            img_idx = np.random.choice(i_train)
            img_target = images[img_idx].to(device)
            pose_target = poses[img_idx, :3, :4].to(device)
            if not disable_expressions:
                expression_target = expressions[img_idx].to(device)  # vector
            else:  # zero expr
                expression_target = torch.zeros(76, device=device)
            #bbox = bboxs[img_idx]
            if not disable_latent_codes:
                latent_code = latent_codes[img_idx].to(
                    device) if train_latent_codes else None
            else:
                latent_codes = torch.zeros(32, device=device)
            #latent_code = torch.zeros(32).to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            coords = torch.stack(
                meshgrid_xy(
                    torch.arange(H).to(device),
                    torch.arange(W).to(device)),
                dim=-1,
            )

            # Only randomly choose rays that are in the bounding box !
            # coords = torch.stack(
            #     meshgrid_xy(torch.arange(bbox[0],bbox[1]).to(device), torch.arange(bbox[2],bbox[3]).to(device)),
            #     dim=-1,
            # )

            coords = coords.reshape((-1, 2))
            # select_inds = np.random.choice(
            #     coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False
            # )

            # Use importance sampling to sample mainly in the bbox with prob p
            select_inds = np.random.choice(
                coords.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
                p=ray_importance_sampling_maps[img_idx])

            select_inds = coords[select_inds]
            ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
            ray_directions = ray_directions[select_inds[:, 0],
                                            select_inds[:, 1], :]
            #dump_rays(ray_origins, ray_directions)

            # batch_rays = torch.stack([ray_origins, ray_directions], dim=0)
            target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]
            background_ray_values = background[select_inds[:, 0],
                                               select_inds[:, 1], :] if (
                                                   train_background or
                                                   fixed_background) else None
            #if i<10000:
            #   background_ray_values = None
            #background_ray_values = None
            then = time.time()
            rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expression_target,
                background_prior=background_ray_values,
                latent_code=latent_code if not disable_latent_codes else
                torch.zeros(32, device=device))
            target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = None
        if rgb_fine is not None:
            fine_loss = torch.nn.functional.mse_loss(
                rgb_fine[..., :3], target_ray_values[..., :3])
        # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
        loss = 0.0
        # if fine_loss is not None:
        #     loss = fine_loss
        # else:
        #     loss = coarse_loss

        latent_code_loss = torch.zeros(1, device=device)
        if train_latent_codes and not disable_latent_codes:
            latent_code_loss = torch.norm(latent_code) * 0.0005
            #latent_code_loss = torch.zeros(1)

        background_loss = torch.zeros(1, device=device)
        if supervised_train_background:
            background_loss = torch.nn.functional.mse_loss(
                background_ray_values[..., :3],
                target_ray_values[..., :3],
                reduction='none').sum(1)
            background_loss = torch.mean(background_loss * weights) * 0.001

        loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0)
        psnr = mse2psnr(loss.item())

        #loss_total = loss #+ (latent_code_loss if latent_code_loss is not None else 0.0)
        loss = loss + (latent_code_loss *
                       10 if regularize_latent_codes else 0.0)
        loss_total = loss + (background_loss if supervised_train_background
                             is not None else 0.0)
        #loss.backward()
        loss_total.backward()
        #psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " BG Loss: " +
                       str(background_loss.item()) + " PSNR: " + str(psnr) +
                       " LatentReg: " + str(latent_code_loss.item()))
        #writer.add_scalar("train/loss", loss.item(), i)
        if train_latent_codes:
            writer.add_scalar("train/code_loss", latent_code_loss.item(), i)
        if supervised_train_background:
            writer.add_scalar("train/bg_loss", background_loss.item(), i)

        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        if rgb_fine is not None:
            writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1 and False):
            #torch.cuda.empty_cache()
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            if model_fine:
                model_coarse.eval()

            start = time.time()
            with torch.no_grad():
                rgb_coarse, rgb_fine = None, None
                target_ray_values = None
                if USE_CACHED_DATASET:
                    datafile = np.random.choice(validation_paths)
                    cache_dict = torch.load(datafile)
                    rgb_coarse, _, _, rgb_fine, _, weights = run_one_iter_of_nerf(
                        cache_dict["height"],
                        cache_dict["width"],
                        cache_dict["focal_length"],
                        model_coarse,
                        model_fine,
                        cache_dict["ray_origins"].to(device),
                        cache_dict["ray_directions"].to(device),
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                        expressions=expression_target,
                        latent_code=torch.zeros(32, device=device))
                    target_ray_values = cache_dict["target"].to(device)
                else:
                    # Do all validation set...
                    loss = 0
                    for img_idx in i_val[:2]:
                        img_target = images[img_idx].to(device)
                        #tqdm.set_description('val im %d' % img_idx)
                        #tqdm.refresh()  # to show immediately the update

                        # # save val image for debug ### DEBUG ####
                        # #GT = target_ray_values[..., :3]
                        # import PIL.Image
                        # #img = GT.permute(2, 0, 1)
                        # # Conver to PIL Image and then np.array (output shape: (H, W, 3))
                        # #im_numpy = img_target.detach().cpu().numpy()
                        # #im_numpy = np.array(torchvision.transforms.ToPILImage()(img_target.detach().cpu()))
                        #
                        # #                   im = PIL.Image.fromarray(im_numpy)
                        # im = img_target
                        # im = im.permute(2, 0, 1)
                        # img = np.array(torchvision.transforms.ToPILImage()(im.detach().cpu()))
                        # im = PIL.Image.fromarray(img)
                        # im.save('val_im_target_debug.png')
                        # ### DEBUG #### END

                        pose_target = poses[img_idx, :3, :4].to(device)
                        ray_origins, ray_directions = get_ray_bundle(
                            H, W, focal, pose_target)
                        rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf(
                            H,
                            W,
                            focal,
                            model_coarse,
                            model_fine,
                            ray_origins,
                            ray_directions,
                            cfg,
                            mode="validation",
                            encode_position_fn=encode_position_fn,
                            encode_direction_fn=encode_direction_fn,
                            expressions=expression_target,
                            background_prior=background.view(-1, 3) if
                            (train_background or fixed_background) else None,
                            latent_code=torch.zeros(32).to(device)
                            if train_latent_codes or disable_latent_codes else
                            None,
                        )
                        #print("did one val")
                        target_ray_values = img_target
                        coarse_loss = img2mse(rgb_coarse[..., :3],
                                              target_ray_values[..., :3])
                        curr_loss, curr_fine_loss = 0.0, 0.0
                        if rgb_fine is not None:
                            curr_fine_loss = img2mse(
                                rgb_fine[..., :3], target_ray_values[..., :3])
                            curr_loss = curr_fine_loss
                        else:
                            curr_loss = coarse_loss
                        loss += curr_loss + curr_fine_loss

                loss /= len(i_val)
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validation/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                if rgb_fine is not None:
                    writer.add_image("validation/rgb_fine",
                                     cast_to_image(rgb_fine[..., :3]), i)
                    writer.add_scalar("validation/fine_loss", fine_loss.item(),
                                      i)

                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                if train_background or fixed_background:
                    writer.add_image("validation/background",
                                     cast_to_image(background[..., :3]), i)
                    writer.add_image("validation/weights",
                                     (weights.detach().cpu().numpy()),
                                     i,
                                     dataformats='HW')
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        #gpu_profile(frame=sys._getframe(), event='line', arg=None)

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter":
                i,
                "model_coarse_state_dict":
                model_coarse.state_dict(),
                "model_fine_state_dict":
                None if not model_fine else model_fine.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                "loss":
                loss,
                "psnr":
                psnr,
                "background":
                None if not (train_background or fixed_background) else
                background.data,
                "latent_codes":
                None if not train_latent_codes else latent_codes.data
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--savedir",
                        type=str,
                        default='./renders/',
                        help="Save images to this directory, if specified.")
    parser.add_argument("--save-disparity-image",
                        action="store_true",
                        help="Save disparity images too.")
    parser.add_argument("--save-error-image",
                        action="store_true",
                        help="Save photometric error visualization")
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        # Load blender dataset
        images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
            test=True)
        #i_train, i_val, i_test = i_split
        i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
    elif cfg.dataset.type.lower() == "llff":
        # Load LLFF dataset
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True)
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
            num_layers=cfg.models.coarse.num_layers,
            hidden_size=cfg.models.coarse.hidden_size,
            include_expression=True)
        model_fine.to(device)

    checkpoint = torch.load(configargs.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]
    if "background" in checkpoint.keys():
        background = checkpoint["background"]
        if background is not None:
            print("loaded background with shape ", background.shape)
            background.to(device)
    if "latent_codes" in checkpoint.keys():
        latent_codes = checkpoint["latent_codes"]
        use_latent_code = False
        if latent_codes is not None:
            use_latent_code = True
            latent_codes.to(device)
            print("loading index map for latent codes...")
            idx_map = np.load(cfg.dataset.basedir +
                              "/index_map.npy").astype(int)
            print("loaded latent codes from checkpoint, with shape ",
                  latent_codes.shape)
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    replace_background = True
    if replace_background:
        from PIL import Image
        #background = Image.open('./view.png')
        background = Image.open(cfg.dataset.basedir + '/bg/00050.png')
        #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png')
        background.thumbnail((H, W))
        background = torch.from_numpy(
            np.array(background).astype(float)).to(device)
        background = background / 255
        print('loaded custom background of shape', background.shape)

        #background = torch.ones_like(background)
        #background.permute(2,0,1)

    render_poses = render_poses.float().to(device)

    # Create directory to save images to.
    os.makedirs(configargs.savedir, exist_ok=True)
    if configargs.save_disparity_image:
        os.makedirs(os.path.join(configargs.savedir, "disparity"),
                    exist_ok=True)
    if configargs.save_error_image:
        os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True)
    os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True)
    # Evaluation loop
    times_per_image = []

    #render_poses = render_poses.float().to(device)
    render_poses = poses[i_test].float().to(device)
    #expressions = torch.arange(-6,6,0.5).float().to(device)
    render_expressions = expressions[i_test].float().to(device)
    #avg_img = torch.mean(images[i_train],axis=0)
    #avg_img = torch.ones_like(avg_img)

    #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
    #for i, pose in enumerate(tqdm(render_poses)):
    index_of_image_after_train_shuffle = 0
    # render_expressions = render_expressions[[300]] ### TODO render specific expression

    #######################
    no_background = False
    no_expressions = False
    no_lcode = False
    nerf = False
    frontalize = False
    interpolate_mouth = False

    #######################
    if nerf:
        no_background = True
        no_expressions = True
        no_lcode = True
    if no_background: background = None
    if no_expressions:
        render_expressions = torch.zeros_like(render_expressions,
                                              device=render_expressions.device)
    if no_lcode:
        use_latent_code = True
        latent_codes = torch.zeros(5000, 32, device=device)

    for i, expression in enumerate(tqdm(render_expressions)):
        #for i in range(75,151):

        #if i%25 != 0: ### TODO generate only every 25th im
        #if i != 511: ### TODO generate only every 25th im
        #    continue
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = render_poses[i]

            if interpolate_mouth:
                frame_id = 241
                num_images = 150
                pose = render_poses[241]
                expression = render_expressions[241].clone()
                expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i]

            if frontalize:
                pose = render_poses[0]
            #pose = render_poses[300] ### TODO fixes pose
            #expression = render_expressions[0] ### TODO fixes expr
            #expression = torch.zeros_like(expression).to(device)

            ablate = 'view_dir'

            if ablate == 'expression':
                pose = render_poses[100]
            elif ablate == 'latent_code':
                pose = render_poses[100]
                expression = render_expressions[100]
                if idx_map[100 + i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[100 + i, 1]
            elif ablate == 'view_dir':
                pose = render_poses[100]
                expression = render_expressions[100]
                _, ray_directions_ablation = get_ray_bundle(
                    hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4])

            pose = pose[:3, :4]

            #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]))
            if use_latent_code:
                if idx_map[i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[i, 1]
            #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code
            #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei
            index_of_image_after_train_shuffle = idx_map[
                10, 1]  ## TODO Fixes latent code - USE THIS if not ablating!

            latent_code = latent_codes[index_of_image_after_train_shuffle].to(
                device) if use_latent_code else None

            #latent_code = torch.mean(latent_codes)
            ray_origins, ray_directions = get_ray_bundle(
                hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expression,
                background_prior=background.view(-1, 3) if
                (background is not None) else None,
                #background_prior = torch.ones_like(background).view(-1,3),  # White background
                latent_code=latent_code,
                ray_directions_ablation=ray_directions_ablation)
            rgb = rgb_fine if rgb_fine is not None else rgb_coarse
            normals = torch_normal_map(disp_fine, focal, weights, clean=True)
            #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
            save_plt_image(
                normals.cpu().numpy().astype('uint8'),
                os.path.join(configargs.savedir, 'normals', f"{i:04d}.png"))
            #if configargs.save_disparity_image:
            if False:
                disp = disp_fine if disp_fine is not None else disp_coarse
                #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy())
                normals = normal_map_from_depth_map_backproject(
                    disp_fine.cpu().numpy())
                save_plt_image(
                    normals.astype('uint8'),
                    os.path.join(configargs.savedir, 'normals',
                                 f"{i:04d}.png"))

            #if configargs.save_normal_image:
            #    normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
        #rgb[torch.where(weights>0.25)]=1.0
        #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights))
        times_per_image.append(time.time() - start)
        if configargs.savedir:
            savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
            imageio.imwrite(
                savefile, cast_to_image(rgb[..., :3],
                                        cfg.dataset.type.lower()))
            if configargs.save_disparity_image:
                savefile = os.path.join(configargs.savedir, "disparity",
                                        f"{i:04d}.png")
                imageio.imwrite(savefile, cast_to_disparity_image(disp_fine))
            if configargs.save_error_image:
                savefile = os.path.join(configargs.savedir, "error",
                                        f"{i:04d}.png")
                GT = images[i_test][i]
                fig = error_image(GT, rgb.cpu().numpy())
                #imageio.imwrite(savefile, cast_to_disparity_image(disp))
                plt.savefig(savefile,
                            pad_inches=0,
                            bbox_inches='tight',
                            dpi=54)
        tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
示例#14
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # Load dataset
    if cfg.dataset.type.lower() == "llff":
        images, poses, _, _, i_test = load_llff_data(
            cfg.dataset.basedir, factor=cfg.dataset.downsample_factor)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        if not isinstance(i_test, list):
            i_test = [i_test]
        if cfg.dataset.llffhold > 0:
            i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold]
        i_val = i_test
        i_train = np.array([
            i for i in np.arange(images.shape[0])
            if (i not in i_test and i not in i_val)
        ])
        H, W, focal = hwf
        H, W = int(H), int(W)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)
        USE_HR_LR = False

        # Load LR images
        if hasattr(cfg.dataset, "relative_lr_factor"):
            assert hasattr(cfg.dataset, "hr_fps")
            assert hasattr(cfg.dataset, "hr_frequency")
            USE_HR_LR = True
            images_lr, poses_lr, _, _, i_test = load_llff_data(
                cfg.dataset.basedir,
                factor=cfg.dataset.downsample_factor *
                cfg.dataset.relative_lr_factor)
            hwf = poses_lr[0, :3, -1]
            poses_lr = poses_lr[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold]
            i_train_lr = np.array([
                i for i in np.arange(images_lr.shape[0]) if (i not in i_test)
            ])
            H_lr, W_lr, focal_lr = hwf
            H_lr, W_lr = int(H_lr), int(W_lr)
            images_lr = torch.from_numpy(images_lr)
            poses_lr = torch.from_numpy(poses_lr)

            # Expose only some HR images
            i_train = i_train[::cfg.dataset.hr_fps]

            print(f'LR summary: N={i_train_lr.shape[0]}, '
                  f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}')
            print(f'HR summary: N={i_train.shape[0]}, '
                  f'resolution={H}x{W}, focal-length={focal}')

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    device = f"cuda:{cfg.experiment.gpu_id}"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse and fine resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    model_fine = getattr(models, cfg.models.fine.type)(
        num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
        include_input_xyz=cfg.models.fine.include_input_xyz,
        include_input_dir=cfg.models.fine.include_input_dir,
        use_viewdirs=cfg.models.fine.use_viewdirs,
    )
    model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters()) + list(
        model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters,
                                                         lr=cfg.optimizer.lr)

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        model_fine.train()

        if USE_HR_LR and i % cfg.dataset.hr_frequency != 0:
            H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr
            i_train_iter = i_train_lr
            images_iter, poses_iter = images_lr, poses_lr
        else:
            H_iter, W_iter, focal_iter = H, W, focal
            i_train_iter = i_train
            images_iter, poses_iter = images, poses

        img_idx = np.random.choice(i_train_iter)
        img_target = images_iter[img_idx].to(device)
        pose_target = poses_iter[img_idx, :3, :4].to(device)
        ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter,
                                                     focal_iter, pose_target)
        coords = torch.stack(
            meshgrid_xy(
                torch.arange(H_iter).to(device),
                torch.arange(W_iter).to(device)),
            dim=-1,
        )
        coords = coords.reshape((-1, 2))
        select_inds = np.random.choice(coords.shape[0],
                                       size=(cfg.nerf.train.num_random_rays),
                                       replace=False)
        select_inds = coords[select_inds]
        ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
        ray_directions = ray_directions[select_inds[:, 0], select_inds[:,
                                                                       1], :]
        target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

        rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
            H_iter,
            W_iter,
            focal_iter,
            model_coarse,
            model_fine,
            ray_origins,
            ray_directions,
            cfg,
            mode="train",
            encode_position_fn=encode_position_fn,
            encode_direction_fn=encode_direction_fn,
        )
        target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3],
                                                 target_ray_values[..., :3])
        loss = coarse_loss + fine_loss
        loss.backward()
        psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " PSNR: " + str(psnr))
        writer.add_scalar("train/loss", loss.item(), i)
        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1):
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            model_fine.eval()

            start = time.time()
            with torch.no_grad():
                img_idx = np.random.choice(i_val)
                img_target = images[img_idx].to(device)
                pose_target = poses[img_idx, :3, :4].to(device)
                ray_origins, ray_directions = get_ray_bundle(
                    H, W, focal, pose_target)
                rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                    H,
                    W,
                    focal,
                    model_coarse,
                    model_fine,
                    ray_origins,
                    ray_directions,
                    cfg,
                    mode="validation",
                    encode_position_fn=encode_position_fn,
                    encode_direction_fn=encode_direction_fn,
                )
                target_ray_values = img_target
                coarse_loss = img2mse(rgb_coarse[..., :3],
                                      target_ray_values[..., :3])
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = coarse_loss + fine_loss
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validation/fine_loss", fine_loss.item(), i)
                writer.add_scalar("validataion/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                writer.add_image("validation/rgb_fine",
                                 cast_to_image(rgb_fine[..., :3]), i)
                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter": i,
                "model_coarse_state_dict": model_coarse.state_dict(),
                "model_fine_state_dict": model_fine.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "psnr": psnr,
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
示例#15
0
def main():

    # Config options:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="Checkpoint / pre-trained model to evaluate.")
    parser.add_argument("--dual-render", action='store_true', default=False)
    parser.add_argument('--gpu-id',
                        type=int,
                        default=0,
                        help="id of the CUDA GPU to use (default: 0)")
    configargs = parser.parse_args()

    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # Dataset:
    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
        )
        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
        hwf = [H, W, focal]
        if cfg.nerf.train.white_background:
            images = images[..., :3] * images[..., -1:] + (1.0 -
                                                           images[..., -1:])
    elif cfg.dataset.type.lower() == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)

    # Hardware
    device = "cpu"
    if torch.cuda.is_available():
        torch.cuda.set_device(configargs.gpu_id)
        device = "cuda"

    # Model
    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz)

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir)

    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs)
    model_coarse.to(device)

    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs)
        model_fine.to(device)

    # Load checkpoint
    checkpoint = torch.load(configargs.checkpoint, map_location=device)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]

    # Prepare model and data
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    render_poses = render_poses.float().to(device)

    print("Dual render?", configargs.dual_render)

    # Evaluation loop
    with torch.no_grad():
        fine_psnrs = []
        if type(i_test) != list:
            i_test = [i_test]
        for i in i_test:
            print(
                f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...")

            img_target = images[i].to(device)
            pose_target = poses[i, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                dual_render=configargs.dual_render)
            target_ray_values = img_target
            coarse_loss = img2mse(rgb_coarse[..., :3],
                                  target_ray_values[..., :3])
            loss, fine_loss = 0.0, 0.0
            if rgb_fine is not None:
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = fine_loss
            else:
                loss = coarse_loss
            loss = coarse_loss + fine_loss
            psnr = mse2psnr(loss.item())
            psnr_coarse = mse2psnr(coarse_loss)
            psnr_fine = mse2psnr(fine_loss)
            print(
                f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})")

            fine_psnrs.append(psnr_fine)

        print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")