示例#1
0
    def parse(self, config_path = None, log_path = None, run_name = LOG_RUN_NAME, checkpoint_name = CHECKPOINT_NAME_LAST, create_logger = False):
        assert ((config_path is not None) != (log_path is not None)), \
            "Either config or log with checkpoints must be provided, append option --help for more information."

        if log_path is not None:
            # relative path segments
            segments = os.path.normpath(log_path).split(os.path.sep)

            # path segments
            self.exp_name, self.log_name, self.log_version = segments[-3:]

            # Logger log dir
            self.log_dir = Path(log_path)

            # Config path
            self.config_path = str(self.log_dir / TensorBoardLogger.NAME_HPARAMS_FILE)
        else:
            self.config_path = config_path

        # Read config file.
        with open(self.config_path, "r") as file:
            cfg_dict = yaml.load(file, Loader=yaml.FullLoader)
            cfg = CfgNode(nest_dict(cfg_dict, sep="."))

        self.root_path = Path(cfg.experiment.logdir)
        if log_path is None:
            self.exp_name = cfg.experiment.id
            self.log_name = run_name

        # Log root experiment path
        self.log_root_dir = str(self.root_path / self.exp_name)

        # Train logger
        logger = None
        if create_logger:
            os.makedirs(Path(self.log_root_dir) / self.log_name, exist_ok=True)
            # Create logger instance
            logger = TensorBoardLogger(self.log_root_dir, self.log_name, version=self.log_version)
            print("Logger initiated...")

            # Logger log dir, if conflicting log_version
            self.log_dir = Path(logger.log_dir)

        print(f"Current log dir {self.log_dir}")
        # Checkpoint dir
        self.checkpoint_dir = self.log_dir / "checkpoints"
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        if log_path is not None:
            # latest best checkpoint path
            self.checkpoint_path = str(self.checkpoint_dir / checkpoint_name)

        return cfg, logger
示例#2
0
    def __init__(self, cfg, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cfg = CfgNode(nest_dict(cfg, sep="."))
        self.hparams = flatten_dict(cfg, sep=".")

        # Criterions
        self.loss = torch.nn.MSELoss()
        self.criterion_psnr = mse2psnr

        # Custom modules
        self.volume_renderer = VolumeRenderer(
            self.cfg.nerf.train.radiance_field_noise_std,
            self.cfg.nerf.validation.radiance_field_noise_std,
            self.cfg.dataset.white_background,
            attenuation_threshold=1e-5)

        # Dataset types
        self.train_dataset, self.val_dataset = None, None
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--base-dir",
        type=str,
        required=False,
        help="Override the default base dir.",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--save-dir",
                        type=str,
                        help="Save mesh to this directory, if specified.")

    parser.add_argument("--iso-level",
                        type=float,
                        help="Iso-Level to be queried",
                        default=32)

    parser.add_argument('--cache-mesh', dest='cache_mesh', action='store_true')
    parser.add_argument('--no-cache-mesh',
                        dest='cache_mesh',
                        action='store_false')
    parser.set_defaults(cache_mesh=True)

    config_args = parser.parse_args()

    # Read config file.
    cfg = None
    with open(config_args.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    checkpoint = torch.load(config_args.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")

    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    export_ray_trace(model_coarse, model_fine, config_args, cfg,
                     encode_position_fn, encode_direction_fn, device)
示例#4
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # If a pre-cached dataset is available, skip the dataloader.
    USE_CACHED_DATASET = False
    train_paths, validation_paths = None, None
    images, poses, render_poses, hwf, i_split = None, None, None, None, None
    H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None
    if hasattr(cfg.dataset, "cachedir") and os.path.exists(
            cfg.dataset.cachedir):
        train_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "train", "*.data"))
        validation_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "val", "*.data"))
        USE_CACHED_DATASET = True
    else:
        # Load dataset
        images, poses, render_poses, hwf = None, None, None, None
        if cfg.dataset.type.lower() == "blender":
            images, poses, render_poses, hwf, i_split = load_blender_data(
                cfg.dataset.basedir,
                half_res=cfg.dataset.half_res,
                testskip=cfg.dataset.testskip,
            )
            i_train, i_val, i_test = i_split
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            if cfg.nerf.train.white_background:
                images = images[..., :3] * images[..., -1:] + (
                    1. - images[..., -1:])
        elif cfg.dataset.type.lower() == "llff":
            images, poses, bds, render_poses, i_test = load_llff_data(
                cfg.dataset.basedir, factor=cfg.dataset.downsample_factor)
            hwf = poses[0, :3, -1]
            poses = poses[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold]
            i_val = i_test
            i_train = np.array([
                i for i in np.arange(images.shape[0])
                if (i not in i_test and i not in i_val)
            ])
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            images = torch.from_numpy(images)
            poses = torch.from_numpy(poses)

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse-resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)
    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters())
    if model_fine is not None:
        trainable_parameters += list(model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters,
                                                         lr=cfg.optimizer.lr)

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        if checkpoint["model_fine_state_dict"]:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        if model_fine:
            model_coarse.train()

        rgb_coarse, rgb_fine = None, None
        target_ray_values = None
        if USE_CACHED_DATASET:
            datafile = np.random.choice(train_paths)
            cache_dict = torch.load(datafile)
            ray_bundle = cache_dict["ray_bundle"].to(device)
            ray_origins, ray_directions = (
                ray_bundle[0].reshape((-1, 3)),
                ray_bundle[1].reshape((-1, 3)),
            )
            target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3))
            select_inds = np.random.choice(
                ray_origins.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
            )
            ray_origins, ray_directions = (
                ray_origins[select_inds],
                ray_directions[select_inds],
            )
            target_ray_values = target_ray_values[select_inds].to(device)
            # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device)

            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                cache_dict["height"],
                cache_dict["width"],
                cache_dict["focal_length"],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
        else:
            img_idx = np.random.choice(i_train)
            img_target = images[img_idx].to(device)
            pose_target = poses[img_idx, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            coords = torch.stack(
                meshgrid_xy(
                    torch.arange(H).to(device),
                    torch.arange(W).to(device)),
                dim=-1,
            )
            coords = coords.reshape((-1, 2))
            select_inds = np.random.choice(
                coords.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False)
            select_inds = coords[select_inds]
            ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
            ray_directions = ray_directions[select_inds[:, 0],
                                            select_inds[:, 1], :]
            # batch_rays = torch.stack([ray_origins, ray_directions], dim=0)
            target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

            then = time.time()
            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
            target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = None
        if rgb_fine is not None:
            fine_loss = torch.nn.functional.mse_loss(
                rgb_fine[..., :3], target_ray_values[..., :3])
        # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
        loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0)
        loss.backward()
        psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " PSNR: " + str(psnr))
        writer.add_scalar("train/loss", loss.item(), i)
        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        if rgb_fine is not None:
            writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1):
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            if model_fine:
                model_coarse.eval()

            start = time.time()
            with torch.no_grad():
                rgb_coarse, rgb_fine = None, None
                target_ray_values = None
                if USE_CACHED_DATASET:
                    datafile = np.random.choice(validation_paths)
                    cache_dict = torch.load(datafile)
                    rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                        cache_dict["height"],
                        cache_dict["width"],
                        cache_dict["focal_length"],
                        model_coarse,
                        model_fine,
                        cache_dict["ray_origins"].to(device),
                        cache_dict["ray_directions"].to(device),
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                    )
                    target_ray_values = cache_dict["target"].to(device)
                else:
                    img_idx = np.random.choice(i_val)
                    img_target = images[img_idx].to(device)
                    pose_target = poses[img_idx, :3, :4].to(device)
                    ray_origins, ray_directions = get_ray_bundle(
                        H, W, focal, pose_target)
                    rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                        H,
                        W,
                        focal,
                        model_coarse,
                        model_fine,
                        ray_origins,
                        ray_directions,
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                    )
                    target_ray_values = img_target
                coarse_loss = img2mse(rgb_coarse[..., :3],
                                      target_ray_values[..., :3])
                fine_loss = 0.0
                if rgb_fine is not None:
                    fine_loss = img2mse(rgb_fine[..., :3],
                                        target_ray_values[..., :3])
                loss = coarse_loss + fine_loss
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validataion/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                if rgb_fine is not None:
                    writer.add_image("validation/rgb_fine",
                                     cast_to_image(rgb_fine[..., :3]), i)
                    writer.add_scalar("validation/fine_loss", fine_loss.item(),
                                      i)
                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter":
                i,
                "model_coarse_state_dict":
                model_coarse.state_dict(),
                "model_fine_state_dict":
                None if not model_fine else model_fine.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                "loss":
                loss,
                "psnr":
                psnr,
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
示例#5
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--savedir",
                        type=str,
                        help="Save images to this directory, if specified.")
    parser.add_argument("--save-disparity-image",
                        action="store_true",
                        help="Save disparity images too.")
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        # Load blender dataset
        images, poses, render_poses, hwf, i_split = load_blender_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
        )
        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
    elif cfg.dataset.type.lower() == "llff":
        # Load LLFF dataset
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz)

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    checkpoint = torch.load(configargs.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]

    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    render_poses = render_poses.float().to(device)

    # Create directory to save images to.
    os.makedirs(configargs.savedir, exist_ok=True)
    if configargs.save_disparity_image:
        os.makedirs(os.path.join(configargs.savedir, "disparity"),
                    exist_ok=True)

    # Evaluation loop
    times_per_image = []
    for i, pose in enumerate(tqdm(render_poses)):
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = pose[:3, :4]
            ray_origins, ray_directions = get_ray_bundle(
                hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
            )
            rgb = rgb_fine if rgb_fine is not None else rgb_coarse
            if configargs.save_disparity_image:
                disp = disp_fine if disp_fine is not None else disp_coarse
        times_per_image.append(time.time() - start)
        if configargs.savedir:
            savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
            imageio.imwrite(
                savefile, cast_to_image(rgb[..., :3],
                                        cfg.dataset.type.lower()))
            if configargs.save_disparity_image:
                savefile = os.path.join(configargs.savedir, "disparity",
                                        f"{i:04d}.png")
                imageio.imwrite(savefile, cast_to_disparity_image(disp))
        tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
示例#6
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True, help="Path to (.yml) config file."
    )
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # If a pre-cached dataset is available, skip the dataloader.
    USE_CACHED_DATASET = False
    train_paths, validation_paths = None, None
    images, poses, render_poses, hwf, i_split = None, None, None, None, None
    H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None
    if hasattr(cfg.dataset, "cachedir") and os.path.exists(cfg.dataset.cachedir):
        train_paths = glob.glob(os.path.join(cfg.dataset.cachedir, "train", "*.data"))
        validation_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "val", "*.data")
        )
        USE_CACHED_DATASET = True
    else:
        # Load dataset
        images, poses, render_poses, hwf = None, None, None, None
        if cfg.dataset.type.lower() == "blender":
            images, poses, render_poses, hwf, i_split = load_blender_data(
                cfg.dataset.basedir,
                half_res=cfg.dataset.half_res,
                testskip=cfg.dataset.testskip,
            )
            i_train, i_val, i_test = i_split
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            if cfg.nerf.train.white_background:
                images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
        elif cfg.dataset.type.lower() == "llff":
            images, poses, bds, render_poses, i_test = load_llff_data(
                cfg.dataset.basedir, factor=cfg.dataset.downsample_factor
            )
            hwf = poses[0, :3, -1]
            poses = poses[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images.shape[0])[:: cfg.dataset.llffhold]
            i_val = i_test
            i_train = np.array(
                [
                    i
                    for i in np.arange(images.shape[0])
                    if (i not in i_test and i not in i_val)
                ]
            )
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            images = torch.from_numpy(images)
            poses = torch.from_numpy(poses)

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse-resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)
    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
        )
        model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters())
    if model_fine is not None:
        trainable_parameters += list(model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(
        trainable_parameters, lr=cfg.optimizer.lr
    )

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(os.path.abspath(configargs.load_checkpoint)):
        device = torch.device('cuda:0')
        #device = torch.device('cpu')

        checkpoint = torch.load(configargs.load_checkpoint, map_location=device)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        if checkpoint["model_fine_state_dict"]:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        dim_xyz = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_xyz
        dim_dir = 3 + 2 * 3 * cfg.models.coarse.num_encoding_fn_dir
        dummy_input = torch.zeros((1, dim_dir + dim_xyz), dtype=torch.float).to(device)

        out_folder, _ = ntpath.split(configargs.config)
        torch.onnx.export(model_coarse, dummy_input, os.path.join(out_folder, "coarse_model.onnx"),
                          verbose=False, input_names=["input"])
        torch.onnx.export(model_fine, dummy_input, os.path.join(out_folder, "fine_model.onnx"), input_names=["input"])

    else:
        print("Couldn't find the checkpoint file at {}".format(os.path.abspath(configargs.load_checkpoint)))

    print("Done!")
示例#7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # If a pre-cached dataset is available, skip the dataloader.
    USE_CACHED_DATASET = False
    train_paths, validation_paths = None, None
    images, poses, render_poses, hwf, i_split, expressions = None, None, None, None, None, None
    H, W, focal, i_train, i_val, i_test = None, None, None, None, None, None
    if hasattr(cfg.dataset, "cachedir") and os.path.exists(
            cfg.dataset.cachedir):
        train_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "train", "*.data"))
        validation_paths = glob.glob(
            os.path.join(cfg.dataset.cachedir, "val", "*.data"))
        USE_CACHED_DATASET = True
    else:
        # Load dataset
        images, poses, render_poses, hwf, expressions = None, None, None, None, None
        if cfg.dataset.type.lower() == "blender":
            images, poses, render_poses, hwf, i_split, expressions, _, bboxs = load_flame_data(
                cfg.dataset.basedir,
                half_res=cfg.dataset.half_res,
                testskip=cfg.dataset.testskip,
            )
            i_train, i_val, i_test = i_split
            H, W, focal = hwf
            H, W = int(H), int(W)
            hwf = [H, W, focal]
            if cfg.nerf.train.white_background:
                images = images[..., :3] * images[..., -1:] + (
                    1.0 - images[..., -1:])
    print("done loading data")
    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"  #+ ":" + str(cfg.experiment.device)
    else:
        device = "cpu"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse-resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True)
    model_coarse.to(device)
    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
            num_layers=cfg.models.coarse.num_layers,
            hidden_size=cfg.models.coarse.hidden_size,
            include_expression=True)
        model_fine.to(device)

    ###################################
    ###################################
    train_background = False
    supervised_train_background = False
    blur_background = False

    train_latent_codes = True
    disable_expressions = False  # True to disable expressions
    disable_latent_codes = False  # True to disable latent codes
    fixed_background = True  # Do False to disable BG
    regularize_latent_codes = True  # True to add latent code LOSS, false for most experiments
    ###################################
    ###################################

    supervised_train_background = train_background and supervised_train_background
    # Avg background
    #images[i_train]
    if train_background:
        with torch.no_grad():
            avg_img = torch.mean(images[i_train], axis=0)
            # Blur Background:
            if blur_background:
                avg_img = avg_img.permute(2, 0, 1)
                avg_img = avg_img.unsqueeze(0)
                smoother = GaussianSmoothing(channels=3,
                                             kernel_size=11,
                                             sigma=11)
                print("smoothed background initialization. shape ",
                      avg_img.shape)
                avg_img = smoother(avg_img).squeeze(0).permute(1, 2, 0)
            #avg_img = torch.zeros(H,W,3)
            #avg_img = torch.rand(H,W,3)
            #avg_img = 0.5*(torch.rand(H,W,3) + torch.mean(images[i_train],axis=0))
            background = torch.tensor(avg_img, device=device)
        background.requires_grad = True

    if fixed_background:  # load GT background
        print("loading GT background to condition on")
        from PIL import Image
        background = Image.open(
            os.path.join(cfg.dataset.basedir, 'bg', '00050.png'))
        background.thumbnail((H, W))
        background = torch.from_numpy(np.array(background).astype(
            np.float32)).to(device)
        background = background / 255
        print("bg shape", background.shape)
        print("should be ", images[i_train][0].shape)
        assert background.shape == images[i_train][0].shape
    else:
        background = None

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters())
    if model_fine is not None:
        trainable_parameters += list(model_fine.parameters())
    if train_background:
        #background.requires_grad = True
        #trainable_parameters.append(background) # add it later when init optimizer for different lr
        print("background.is_leaf ", background.is_leaf, background.device)

    if train_latent_codes:
        latent_codes = torch.zeros(len(i_train), 32, device=device)
        print("initialized latent codes with shape %d X %d" %
              (latent_codes.shape[0], latent_codes.shape[1]))
        if not disable_latent_codes:
            trainable_parameters.append(latent_codes)
            latent_codes.requires_grad = True

    if train_background:
        optimizer = getattr(torch.optim,
                            cfg.optimizer.type)([{
                                'params': trainable_parameters
                            }, {
                                'params': background,
                                'lr': cfg.optimizer.lr
                            }],
                                                lr=cfg.optimizer.lr)
    else:
        optimizer = getattr(torch.optim, cfg.optimizer.type)(
            [{
                'params': trainable_parameters
            }, {
                'params': background,
                'lr': cfg.optimizer.lr
            }],  # this is obsolete but need for continuing training
            lr=cfg.optimizer.lr)
    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        if checkpoint["model_fine_state_dict"]:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        if checkpoint["background"] is not None:
            print("loaded bg from checkpoint")
            background = torch.nn.Parameter(
                checkpoint['background'].to(device))
        if checkpoint["latent_codes"] is not None:
            print("loaded latent codes from checkpoint")
            latent_codes = torch.nn.Parameter(
                checkpoint['latent_codes'].to(device))

        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    # Prepare importance sampling maps
    ray_importance_sampling_maps = []
    p = 0.9
    print("computing boundix boxes probability maps")
    for i in i_train:
        bbox = bboxs[i]
        probs = np.zeros((H, W))
        probs.fill(1 - p)
        probs[bbox[0]:bbox[1], bbox[2]:bbox[3]] = p
        probs = (1 / probs.sum()) * probs
        ray_importance_sampling_maps.append(probs.reshape(-1))

    print("Starting loop")
    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        if model_fine:
            model_coarse.train()

        rgb_coarse, rgb_fine = None, None
        target_ray_values = None
        background_ray_values = None
        if USE_CACHED_DATASET:
            datafile = np.random.choice(train_paths)
            cache_dict = torch.load(datafile)
            ray_bundle = cache_dict["ray_bundle"].to(device)
            ray_origins, ray_directions = (
                ray_bundle[0].reshape((-1, 3)),
                ray_bundle[1].reshape((-1, 3)),
            )
            target_ray_values = cache_dict["target"][..., :3].reshape((-1, 3))
            select_inds = np.random.choice(
                ray_origins.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
            )
            ray_origins, ray_directions = (
                ray_origins[select_inds],
                ray_directions[select_inds],
            )
            target_ray_values = target_ray_values[select_inds].to(device)
            #target_ray_values = target_ray_values[select_inds].to(device)
            # ray_bundle = torch.stack([ray_origins, ray_directions], dim=0).to(device)

            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                cache_dict["height"],
                cache_dict["width"],
                cache_dict["focal_length"],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expressions)
        else:
            img_idx = np.random.choice(i_train)
            img_target = images[img_idx].to(device)
            pose_target = poses[img_idx, :3, :4].to(device)
            if not disable_expressions:
                expression_target = expressions[img_idx].to(device)  # vector
            else:  # zero expr
                expression_target = torch.zeros(76, device=device)
            #bbox = bboxs[img_idx]
            if not disable_latent_codes:
                latent_code = latent_codes[img_idx].to(
                    device) if train_latent_codes else None
            else:
                latent_codes = torch.zeros(32, device=device)
            #latent_code = torch.zeros(32).to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            coords = torch.stack(
                meshgrid_xy(
                    torch.arange(H).to(device),
                    torch.arange(W).to(device)),
                dim=-1,
            )

            # Only randomly choose rays that are in the bounding box !
            # coords = torch.stack(
            #     meshgrid_xy(torch.arange(bbox[0],bbox[1]).to(device), torch.arange(bbox[2],bbox[3]).to(device)),
            #     dim=-1,
            # )

            coords = coords.reshape((-1, 2))
            # select_inds = np.random.choice(
            #     coords.shape[0], size=(cfg.nerf.train.num_random_rays), replace=False
            # )

            # Use importance sampling to sample mainly in the bbox with prob p
            select_inds = np.random.choice(
                coords.shape[0],
                size=(cfg.nerf.train.num_random_rays),
                replace=False,
                p=ray_importance_sampling_maps[img_idx])

            select_inds = coords[select_inds]
            ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
            ray_directions = ray_directions[select_inds[:, 0],
                                            select_inds[:, 1], :]
            #dump_rays(ray_origins, ray_directions)

            # batch_rays = torch.stack([ray_origins, ray_directions], dim=0)
            target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]
            background_ray_values = background[select_inds[:, 0],
                                               select_inds[:, 1], :] if (
                                                   train_background or
                                                   fixed_background) else None
            #if i<10000:
            #   background_ray_values = None
            #background_ray_values = None
            then = time.time()
            rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="train",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expression_target,
                background_prior=background_ray_values,
                latent_code=latent_code if not disable_latent_codes else
                torch.zeros(32, device=device))
            target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = None
        if rgb_fine is not None:
            fine_loss = torch.nn.functional.mse_loss(
                rgb_fine[..., :3], target_ray_values[..., :3])
        # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
        loss = 0.0
        # if fine_loss is not None:
        #     loss = fine_loss
        # else:
        #     loss = coarse_loss

        latent_code_loss = torch.zeros(1, device=device)
        if train_latent_codes and not disable_latent_codes:
            latent_code_loss = torch.norm(latent_code) * 0.0005
            #latent_code_loss = torch.zeros(1)

        background_loss = torch.zeros(1, device=device)
        if supervised_train_background:
            background_loss = torch.nn.functional.mse_loss(
                background_ray_values[..., :3],
                target_ray_values[..., :3],
                reduction='none').sum(1)
            background_loss = torch.mean(background_loss * weights) * 0.001

        loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0)
        psnr = mse2psnr(loss.item())

        #loss_total = loss #+ (latent_code_loss if latent_code_loss is not None else 0.0)
        loss = loss + (latent_code_loss *
                       10 if regularize_latent_codes else 0.0)
        loss_total = loss + (background_loss if supervised_train_background
                             is not None else 0.0)
        #loss.backward()
        loss_total.backward()
        #psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " BG Loss: " +
                       str(background_loss.item()) + " PSNR: " + str(psnr) +
                       " LatentReg: " + str(latent_code_loss.item()))
        #writer.add_scalar("train/loss", loss.item(), i)
        if train_latent_codes:
            writer.add_scalar("train/code_loss", latent_code_loss.item(), i)
        if supervised_train_background:
            writer.add_scalar("train/bg_loss", background_loss.item(), i)

        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        if rgb_fine is not None:
            writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1 and False):
            #torch.cuda.empty_cache()
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            if model_fine:
                model_coarse.eval()

            start = time.time()
            with torch.no_grad():
                rgb_coarse, rgb_fine = None, None
                target_ray_values = None
                if USE_CACHED_DATASET:
                    datafile = np.random.choice(validation_paths)
                    cache_dict = torch.load(datafile)
                    rgb_coarse, _, _, rgb_fine, _, weights = run_one_iter_of_nerf(
                        cache_dict["height"],
                        cache_dict["width"],
                        cache_dict["focal_length"],
                        model_coarse,
                        model_fine,
                        cache_dict["ray_origins"].to(device),
                        cache_dict["ray_directions"].to(device),
                        cfg,
                        mode="validation",
                        encode_position_fn=encode_position_fn,
                        encode_direction_fn=encode_direction_fn,
                        expressions=expression_target,
                        latent_code=torch.zeros(32, device=device))
                    target_ray_values = cache_dict["target"].to(device)
                else:
                    # Do all validation set...
                    loss = 0
                    for img_idx in i_val[:2]:
                        img_target = images[img_idx].to(device)
                        #tqdm.set_description('val im %d' % img_idx)
                        #tqdm.refresh()  # to show immediately the update

                        # # save val image for debug ### DEBUG ####
                        # #GT = target_ray_values[..., :3]
                        # import PIL.Image
                        # #img = GT.permute(2, 0, 1)
                        # # Conver to PIL Image and then np.array (output shape: (H, W, 3))
                        # #im_numpy = img_target.detach().cpu().numpy()
                        # #im_numpy = np.array(torchvision.transforms.ToPILImage()(img_target.detach().cpu()))
                        #
                        # #                   im = PIL.Image.fromarray(im_numpy)
                        # im = img_target
                        # im = im.permute(2, 0, 1)
                        # img = np.array(torchvision.transforms.ToPILImage()(im.detach().cpu()))
                        # im = PIL.Image.fromarray(img)
                        # im.save('val_im_target_debug.png')
                        # ### DEBUG #### END

                        pose_target = poses[img_idx, :3, :4].to(device)
                        ray_origins, ray_directions = get_ray_bundle(
                            H, W, focal, pose_target)
                        rgb_coarse, _, _, rgb_fine, _, _, weights = run_one_iter_of_nerf(
                            H,
                            W,
                            focal,
                            model_coarse,
                            model_fine,
                            ray_origins,
                            ray_directions,
                            cfg,
                            mode="validation",
                            encode_position_fn=encode_position_fn,
                            encode_direction_fn=encode_direction_fn,
                            expressions=expression_target,
                            background_prior=background.view(-1, 3) if
                            (train_background or fixed_background) else None,
                            latent_code=torch.zeros(32).to(device)
                            if train_latent_codes or disable_latent_codes else
                            None,
                        )
                        #print("did one val")
                        target_ray_values = img_target
                        coarse_loss = img2mse(rgb_coarse[..., :3],
                                              target_ray_values[..., :3])
                        curr_loss, curr_fine_loss = 0.0, 0.0
                        if rgb_fine is not None:
                            curr_fine_loss = img2mse(
                                rgb_fine[..., :3], target_ray_values[..., :3])
                            curr_loss = curr_fine_loss
                        else:
                            curr_loss = coarse_loss
                        loss += curr_loss + curr_fine_loss

                loss /= len(i_val)
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validation/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                if rgb_fine is not None:
                    writer.add_image("validation/rgb_fine",
                                     cast_to_image(rgb_fine[..., :3]), i)
                    writer.add_scalar("validation/fine_loss", fine_loss.item(),
                                      i)

                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                if train_background or fixed_background:
                    writer.add_image("validation/background",
                                     cast_to_image(background[..., :3]), i)
                    writer.add_image("validation/weights",
                                     (weights.detach().cpu().numpy()),
                                     i,
                                     dataformats='HW')
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        #gpu_profile(frame=sys._getframe(), event='line', arg=None)

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter":
                i,
                "model_coarse_state_dict":
                model_coarse.state_dict(),
                "model_fine_state_dict":
                None if not model_fine else model_fine.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                "loss":
                loss,
                "psnr":
                psnr,
                "background":
                None if not (train_background or fixed_background) else
                background.data,
                "latent_codes":
                None if not train_latent_codes else latent_codes.data
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint / pre-trained model to evaluate.",
    )
    parser.add_argument("--savedir",
                        type=str,
                        default='./renders/',
                        help="Save images to this directory, if specified.")
    parser.add_argument("--save-disparity-image",
                        action="store_true",
                        help="Save disparity images too.")
    parser.add_argument("--save-error-image",
                        action="store_true",
                        help="Save photometric error visualization")
    configargs = parser.parse_args()

    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        # Load blender dataset
        images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
            test=True)
        #i_train, i_val, i_test = i_split
        i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
    elif cfg.dataset.type.lower() == "llff":
        # Load LLFF dataset
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)

    # Device on which to run.
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True)
    model_coarse.to(device)

    # If a fine-resolution model is specified, initialize it.
    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs,
            num_layers=cfg.models.coarse.num_layers,
            hidden_size=cfg.models.coarse.hidden_size,
            include_expression=True)
        model_fine.to(device)

    checkpoint = torch.load(configargs.checkpoint)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]
    if "background" in checkpoint.keys():
        background = checkpoint["background"]
        if background is not None:
            print("loaded background with shape ", background.shape)
            background.to(device)
    if "latent_codes" in checkpoint.keys():
        latent_codes = checkpoint["latent_codes"]
        use_latent_code = False
        if latent_codes is not None:
            use_latent_code = True
            latent_codes.to(device)
            print("loading index map for latent codes...")
            idx_map = np.load(cfg.dataset.basedir +
                              "/index_map.npy").astype(int)
            print("loaded latent codes from checkpoint, with shape ",
                  latent_codes.shape)
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    replace_background = True
    if replace_background:
        from PIL import Image
        #background = Image.open('./view.png')
        background = Image.open(cfg.dataset.basedir + '/bg/00050.png')
        #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png')
        background.thumbnail((H, W))
        background = torch.from_numpy(
            np.array(background).astype(float)).to(device)
        background = background / 255
        print('loaded custom background of shape', background.shape)

        #background = torch.ones_like(background)
        #background.permute(2,0,1)

    render_poses = render_poses.float().to(device)

    # Create directory to save images to.
    os.makedirs(configargs.savedir, exist_ok=True)
    if configargs.save_disparity_image:
        os.makedirs(os.path.join(configargs.savedir, "disparity"),
                    exist_ok=True)
    if configargs.save_error_image:
        os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True)
    os.makedirs(os.path.join(configargs.savedir, "normals"), exist_ok=True)
    # Evaluation loop
    times_per_image = []

    #render_poses = render_poses.float().to(device)
    render_poses = poses[i_test].float().to(device)
    #expressions = torch.arange(-6,6,0.5).float().to(device)
    render_expressions = expressions[i_test].float().to(device)
    #avg_img = torch.mean(images[i_train],axis=0)
    #avg_img = torch.ones_like(avg_img)

    #pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
    #for i, pose in enumerate(tqdm(render_poses)):
    index_of_image_after_train_shuffle = 0
    # render_expressions = render_expressions[[300]] ### TODO render specific expression

    #######################
    no_background = False
    no_expressions = False
    no_lcode = False
    nerf = False
    frontalize = False
    interpolate_mouth = False

    #######################
    if nerf:
        no_background = True
        no_expressions = True
        no_lcode = True
    if no_background: background = None
    if no_expressions:
        render_expressions = torch.zeros_like(render_expressions,
                                              device=render_expressions.device)
    if no_lcode:
        use_latent_code = True
        latent_codes = torch.zeros(5000, 32, device=device)

    for i, expression in enumerate(tqdm(render_expressions)):
        #for i in range(75,151):

        #if i%25 != 0: ### TODO generate only every 25th im
        #if i != 511: ### TODO generate only every 25th im
        #    continue
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = render_poses[i]

            if interpolate_mouth:
                frame_id = 241
                num_images = 150
                pose = render_poses[241]
                expression = render_expressions[241].clone()
                expression[68] = torch.arange(-1, 1, 2 / 150, device=device)[i]

            if frontalize:
                pose = render_poses[0]
            #pose = render_poses[300] ### TODO fixes pose
            #expression = render_expressions[0] ### TODO fixes expr
            #expression = torch.zeros_like(expression).to(device)

            ablate = 'view_dir'

            if ablate == 'expression':
                pose = render_poses[100]
            elif ablate == 'latent_code':
                pose = render_poses[100]
                expression = render_expressions[100]
                if idx_map[100 + i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[100 + i, 1]
            elif ablate == 'view_dir':
                pose = render_poses[100]
                expression = render_expressions[100]
                _, ray_directions_ablation = get_ray_bundle(
                    hwf[0], hwf[1], hwf[2], render_poses[240 + i][:3, :4])

            pose = pose[:3, :4]

            #pose = torch.from_numpy(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]))
            if use_latent_code:
                if idx_map[i, 1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[i, 1]
            #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code
            #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei
            index_of_image_after_train_shuffle = idx_map[
                10, 1]  ## TODO Fixes latent code - USE THIS if not ablating!

            latent_code = latent_codes[index_of_image_after_train_shuffle].to(
                device) if use_latent_code else None

            #latent_code = torch.mean(latent_codes)
            ray_origins, ray_directions = get_ray_bundle(
                hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions=expression,
                background_prior=background.view(-1, 3) if
                (background is not None) else None,
                #background_prior = torch.ones_like(background).view(-1,3),  # White background
                latent_code=latent_code,
                ray_directions_ablation=ray_directions_ablation)
            rgb = rgb_fine if rgb_fine is not None else rgb_coarse
            normals = torch_normal_map(disp_fine, focal, weights, clean=True)
            #normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
            save_plt_image(
                normals.cpu().numpy().astype('uint8'),
                os.path.join(configargs.savedir, 'normals', f"{i:04d}.png"))
            #if configargs.save_disparity_image:
            if False:
                disp = disp_fine if disp_fine is not None else disp_coarse
                #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy())
                normals = normal_map_from_depth_map_backproject(
                    disp_fine.cpu().numpy())
                save_plt_image(
                    normals.astype('uint8'),
                    os.path.join(configargs.savedir, 'normals',
                                 f"{i:04d}.png"))

            #if configargs.save_normal_image:
            #    normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
        #rgb[torch.where(weights>0.25)]=1.0
        #rgb[torch.where(weights>0.1)] = (rgb * weights + (torch.ones_like(weights)-weights)*torch.ones_like(weights))
        times_per_image.append(time.time() - start)
        if configargs.savedir:
            savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
            imageio.imwrite(
                savefile, cast_to_image(rgb[..., :3],
                                        cfg.dataset.type.lower()))
            if configargs.save_disparity_image:
                savefile = os.path.join(configargs.savedir, "disparity",
                                        f"{i:04d}.png")
                imageio.imwrite(savefile, cast_to_disparity_image(disp_fine))
            if configargs.save_error_image:
                savefile = os.path.join(configargs.savedir, "error",
                                        f"{i:04d}.png")
                GT = images[i_test][i]
                fig = error_image(GT, rgb.cpu().numpy())
                #imageio.imwrite(savefile, cast_to_disparity_image(disp))
                plt.savefig(savefile,
                            pad_inches=0,
                            bbox_inches='tight',
                            dpi=54)
        tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")
示例#9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()

    # Read config file.
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # # (Optional:) enable this to track autograd issues when debugging
    # torch.autograd.set_detect_anomaly(True)

    # Load dataset
    if cfg.dataset.type.lower() == "llff":
        images, poses, _, _, i_test = load_llff_data(
            cfg.dataset.basedir, factor=cfg.dataset.downsample_factor)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        if not isinstance(i_test, list):
            i_test = [i_test]
        if cfg.dataset.llffhold > 0:
            i_test = np.arange(images.shape[0])[::cfg.dataset.llffhold]
        i_val = i_test
        i_train = np.array([
            i for i in np.arange(images.shape[0])
            if (i not in i_test and i not in i_val)
        ])
        H, W, focal = hwf
        H, W = int(H), int(W)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)
        USE_HR_LR = False

        # Load LR images
        if hasattr(cfg.dataset, "relative_lr_factor"):
            assert hasattr(cfg.dataset, "hr_fps")
            assert hasattr(cfg.dataset, "hr_frequency")
            USE_HR_LR = True
            images_lr, poses_lr, _, _, i_test = load_llff_data(
                cfg.dataset.basedir,
                factor=cfg.dataset.downsample_factor *
                cfg.dataset.relative_lr_factor)
            hwf = poses_lr[0, :3, -1]
            poses_lr = poses_lr[:, :3, :4]
            if not isinstance(i_test, list):
                i_test = [i_test]
            if cfg.dataset.llffhold > 0:
                i_test = np.arange(images_lr.shape[0])[::cfg.dataset.llffhold]
            i_train_lr = np.array([
                i for i in np.arange(images_lr.shape[0]) if (i not in i_test)
            ])
            H_lr, W_lr, focal_lr = hwf
            H_lr, W_lr = int(H_lr), int(W_lr)
            images_lr = torch.from_numpy(images_lr)
            poses_lr = torch.from_numpy(poses_lr)

            # Expose only some HR images
            i_train = i_train[::cfg.dataset.hr_fps]

            print(f'LR summary: N={i_train_lr.shape[0]}, '
                  f'resolution={H_lr}x{W_lr}, focal-length={focal_lr}')
            print(f'HR summary: N={i_train.shape[0]}, '
                  f'resolution={H}x{W}, focal-length={focal}')

    # Seed experiment for repeatability
    seed = cfg.experiment.randomseed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Device on which to run.
    device = f"cuda:{cfg.experiment.gpu_id}"

    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz,
    )

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir,
        )

    # Initialize a coarse and fine resolution model.
    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs,
    )
    model_coarse.to(device)

    model_fine = getattr(models, cfg.models.fine.type)(
        num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
        include_input_xyz=cfg.models.fine.include_input_xyz,
        include_input_dir=cfg.models.fine.include_input_dir,
        use_viewdirs=cfg.models.fine.use_viewdirs,
    )
    model_fine.to(device)

    # Initialize optimizer.
    trainable_parameters = list(model_coarse.parameters()) + list(
        model_fine.parameters())
    optimizer = getattr(torch.optim, cfg.optimizer.type)(trainable_parameters,
                                                         lr=cfg.optimizer.lr)

    # Setup logging.
    logdir = os.path.join(cfg.experiment.logdir, cfg.experiment.id)
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)
    # Write out config parameters.
    with open(os.path.join(logdir, "config.yml"), "w") as f:
        f.write(cfg.dump())  # cfg, f, default_flow_style=False)

    # By default, start at iteration 0 (unless a checkpoint is specified).
    start_iter = 0

    # Load an existing checkpoint, if a path is specified.
    if os.path.exists(configargs.load_checkpoint):
        checkpoint = torch.load(configargs.load_checkpoint)
        model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
        model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_iter = checkpoint["iter"]

    # # TODO: Prepare raybatch tensor if batching random rays

    for i in trange(start_iter, cfg.experiment.train_iters):

        model_coarse.train()
        model_fine.train()

        if USE_HR_LR and i % cfg.dataset.hr_frequency != 0:
            H_iter, W_iter, focal_iter = H_lr, W_lr, focal_lr
            i_train_iter = i_train_lr
            images_iter, poses_iter = images_lr, poses_lr
        else:
            H_iter, W_iter, focal_iter = H, W, focal
            i_train_iter = i_train
            images_iter, poses_iter = images, poses

        img_idx = np.random.choice(i_train_iter)
        img_target = images_iter[img_idx].to(device)
        pose_target = poses_iter[img_idx, :3, :4].to(device)
        ray_origins, ray_directions = get_ray_bundle(H_iter, W_iter,
                                                     focal_iter, pose_target)
        coords = torch.stack(
            meshgrid_xy(
                torch.arange(H_iter).to(device),
                torch.arange(W_iter).to(device)),
            dim=-1,
        )
        coords = coords.reshape((-1, 2))
        select_inds = np.random.choice(coords.shape[0],
                                       size=(cfg.nerf.train.num_random_rays),
                                       replace=False)
        select_inds = coords[select_inds]
        ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
        ray_directions = ray_directions[select_inds[:, 0], select_inds[:,
                                                                       1], :]
        target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

        rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
            H_iter,
            W_iter,
            focal_iter,
            model_coarse,
            model_fine,
            ray_origins,
            ray_directions,
            cfg,
            mode="train",
            encode_position_fn=encode_position_fn,
            encode_direction_fn=encode_direction_fn,
        )
        target_ray_values = target_s

        coarse_loss = torch.nn.functional.mse_loss(rgb_coarse[..., :3],
                                                   target_ray_values[..., :3])
        fine_loss = torch.nn.functional.mse_loss(rgb_fine[..., :3],
                                                 target_ray_values[..., :3])
        loss = coarse_loss + fine_loss
        loss.backward()
        psnr = mse2psnr(loss.item())
        optimizer.step()
        optimizer.zero_grad()

        # Learning rate updates
        num_decay_steps = cfg.scheduler.lr_decay * 1000
        lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor
                                     **(i / num_decay_steps))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_new

        if i % cfg.experiment.print_every == 0 or i == cfg.experiment.train_iters - 1:
            tqdm.write("[TRAIN] Iter: " + str(i) + " Loss: " +
                       str(loss.item()) + " PSNR: " + str(psnr))
        writer.add_scalar("train/loss", loss.item(), i)
        writer.add_scalar("train/coarse_loss", coarse_loss.item(), i)
        writer.add_scalar("train/fine_loss", fine_loss.item(), i)
        writer.add_scalar("train/psnr", psnr, i)

        # Validation
        if (i % cfg.experiment.validate_every == 0
                or i == cfg.experiment.train_iters - 1):
            tqdm.write("[VAL] =======> Iter: " + str(i))
            model_coarse.eval()
            model_fine.eval()

            start = time.time()
            with torch.no_grad():
                img_idx = np.random.choice(i_val)
                img_target = images[img_idx].to(device)
                pose_target = poses[img_idx, :3, :4].to(device)
                ray_origins, ray_directions = get_ray_bundle(
                    H, W, focal, pose_target)
                rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                    H,
                    W,
                    focal,
                    model_coarse,
                    model_fine,
                    ray_origins,
                    ray_directions,
                    cfg,
                    mode="validation",
                    encode_position_fn=encode_position_fn,
                    encode_direction_fn=encode_direction_fn,
                )
                target_ray_values = img_target
                coarse_loss = img2mse(rgb_coarse[..., :3],
                                      target_ray_values[..., :3])
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = coarse_loss + fine_loss
                psnr = mse2psnr(loss.item())
                writer.add_scalar("validation/loss", loss.item(), i)
                writer.add_scalar("validation/coarse_loss", coarse_loss.item(),
                                  i)
                writer.add_scalar("validation/fine_loss", fine_loss.item(), i)
                writer.add_scalar("validataion/psnr", psnr, i)
                writer.add_image("validation/rgb_coarse",
                                 cast_to_image(rgb_coarse[..., :3]), i)
                writer.add_image("validation/rgb_fine",
                                 cast_to_image(rgb_fine[..., :3]), i)
                writer.add_image(
                    "validation/img_target",
                    cast_to_image(target_ray_values[..., :3]),
                    i,
                )
                tqdm.write("Validation loss: " + str(loss.item()) +
                           " Validation PSNR: " + str(psnr) + " Time: " +
                           str(time.time() - start))

        if i % cfg.experiment.save_every == 0 or i == cfg.experiment.train_iters - 1:
            checkpoint_dict = {
                "iter": i,
                "model_coarse_state_dict": model_coarse.state_dict(),
                "model_fine_state_dict": model_fine.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "psnr": psnr,
            }
            torch.save(
                checkpoint_dict,
                os.path.join(logdir, "checkpoint" + str(i).zfill(5) + ".ckpt"),
            )
            tqdm.write("================== Saved Checkpoint =================")

    print("Done!")
示例#10
0
def main():

    # Config options:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="Checkpoint / pre-trained model to evaluate.")
    parser.add_argument("--dual-render", action='store_true', default=False)
    parser.add_argument('--gpu-id',
                        type=int,
                        default=0,
                        help="id of the CUDA GPU to use (default: 0)")
    configargs = parser.parse_args()

    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = CfgNode(cfg_dict)

    # Dataset:
    images, poses, render_poses, hwf = None, None, None, None
    i_train, i_val, i_test = None, None, None
    if cfg.dataset.type.lower() == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(
            cfg.dataset.basedir,
            half_res=cfg.dataset.half_res,
            testskip=cfg.dataset.testskip,
        )
        i_train, i_val, i_test = i_split
        H, W, focal = hwf
        H, W = int(H), int(W)
        hwf = [H, W, focal]
        if cfg.nerf.train.white_background:
            images = images[..., :3] * images[..., -1:] + (1.0 -
                                                           images[..., -1:])
    elif cfg.dataset.type.lower() == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            cfg.dataset.basedir,
            factor=cfg.dataset.downsample_factor,
        )
        hwf = poses[0, :3, -1]
        H, W, focal = hwf
        hwf = [int(H), int(W), focal]
        render_poses = torch.from_numpy(render_poses)
        images = torch.from_numpy(images)
        poses = torch.from_numpy(poses)

    # Hardware
    device = "cpu"
    if torch.cuda.is_available():
        torch.cuda.set_device(configargs.gpu_id)
        device = "cuda"

    # Model
    encode_position_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
        include_input=cfg.models.coarse.include_input_xyz,
        log_sampling=cfg.models.coarse.log_sampling_xyz)

    encode_direction_fn = None
    if cfg.models.coarse.use_viewdirs:
        encode_direction_fn = get_embedding_function(
            num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
            include_input=cfg.models.coarse.include_input_dir,
            log_sampling=cfg.models.coarse.log_sampling_dir)

    model_coarse = getattr(models, cfg.models.coarse.type)(
        num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
        include_input_xyz=cfg.models.coarse.include_input_xyz,
        include_input_dir=cfg.models.coarse.include_input_dir,
        use_viewdirs=cfg.models.coarse.use_viewdirs)
    model_coarse.to(device)

    model_fine = None
    if hasattr(cfg.models, "fine"):
        model_fine = getattr(models, cfg.models.fine.type)(
            num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
            num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
            include_input_xyz=cfg.models.fine.include_input_xyz,
            include_input_dir=cfg.models.fine.include_input_dir,
            use_viewdirs=cfg.models.fine.use_viewdirs)
        model_fine.to(device)

    # Load checkpoint
    checkpoint = torch.load(configargs.checkpoint, map_location=device)
    model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
    if checkpoint["model_fine_state_dict"]:
        try:
            model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
        except:
            print("The checkpoint has a fine-level model, but it could "
                  "not be loaded (possibly due to a mismatched config file.")
    if "height" in checkpoint.keys():
        hwf[0] = checkpoint["height"]
    if "width" in checkpoint.keys():
        hwf[1] = checkpoint["width"]
    if "focal_length" in checkpoint.keys():
        hwf[2] = checkpoint["focal_length"]

    # Prepare model and data
    model_coarse.eval()
    if model_fine:
        model_fine.eval()

    render_poses = render_poses.float().to(device)

    print("Dual render?", configargs.dual_render)

    # Evaluation loop
    with torch.no_grad():
        fine_psnrs = []
        if type(i_test) != list:
            i_test = [i_test]
        for i in i_test:
            print(
                f"Test sample {i + 1 - i_test[0]}/{i_test[-1] - i_test[0]}...")

            img_target = images[i].to(device)
            pose_target = poses[i, :3, :4].to(device)
            ray_origins, ray_directions = get_ray_bundle(
                H, W, focal, pose_target)
            rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
                H,
                W,
                focal,
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                dual_render=configargs.dual_render)
            target_ray_values = img_target
            coarse_loss = img2mse(rgb_coarse[..., :3],
                                  target_ray_values[..., :3])
            loss, fine_loss = 0.0, 0.0
            if rgb_fine is not None:
                fine_loss = img2mse(rgb_fine[..., :3],
                                    target_ray_values[..., :3])
                loss = fine_loss
            else:
                loss = coarse_loss
            loss = coarse_loss + fine_loss
            psnr = mse2psnr(loss.item())
            psnr_coarse = mse2psnr(coarse_loss)
            psnr_fine = mse2psnr(fine_loss)
            print(
                f"\t Loss at sample: {psnr} (f:{psnr_fine}, c:{psnr_coarse})")

            fine_psnrs.append(psnr_fine)

        print(f"Validation PSNR: {sum(fine_psnrs) / len(fine_psnrs)}.")