コード例 #1
0
 def convert_disparity_to_saveable(disparity, normalize=True):
     # takes in a single disparity map of shape height x width.
     # can be saved via: imageio.imwrite(filename, convert_disparity_to_saveable(disparity))
     converted_disparity = (disparity / np.max(disparity)
                            if normalize else disparity.copy())
     converted_disparity = to8b(
         converted_disparity)  # height x width. int values in [0,255].
     return converted_disparity
コード例 #2
0
ファイル: train.py プロジェクト: mackopes/nerf-pytorch
def save_depthmap(filename, depth, H, W):
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    disp_map_label = 1 - 1. / torch.max(
        torch.ones_like(depth).cuda() * 1e-10, depth)
    disp_map_label = torch.reshape(disp_map_label, (H, W))

    rgb8_label = to8b(disp_map_label.cpu().detach().numpy())
    imageio.imwrite(filename, rgb8_label)
コード例 #3
0
    def update_canvas(self, images, disps=None):
        for i, image in enumerate(images):
            resized_rgb = F.interpolate(image.unsqueeze(dim=0),
                                        size=(self.size,
                                              self.size)).squeeze(dim=0)
            self.real_images_array[i].src = renormalize.as_url(resized_rgb)
            self.real_canvas_array[i].image = renormalize.as_url(image)
            self.real_canvas_array[i].resized_image = renormalize.as_url(
                resized_rgb)
            if disps is not None:
                disp_img = torch.from_numpy(to8b(to_disp_img(
                    disps[i]))).unsqueeze(dim=0) / 255.
                resized_disp = F.interpolate(disp_img.unsqueeze(dim=0),
                                             size=(self.size,
                                                   self.size)).squeeze(dim=0)
                self.real_canvas_array[i].resized_disp = renormalize.as_url(
                    resized_disp)
                self.real_canvas_array[i].disp = renormalize.as_url(disp_img)

        if self.editing_canvas.index >= 0:
            self.editing_canvas.image = self.real_canvas_array[
                self.editing_canvas.index].image
コード例 #4
0
 def convert_disparity_to_phong(disparity, normalize=True):
     converted_disparity = (disparity / np.max(disparity)
                            if normalize else disparity.copy())
     converted_disparity = to8b(
         visualize_disparity_with_blinn_phong(converted_disparity))
     return converted_disparity  # height x width x 3. int values in [0,255].
コード例 #5
0
 def convert_disparity_to_jet(disparity, normalize=True):
     converted_disparity = (disparity / np.max(disparity)
                            if normalize else disparity.copy())
     converted_disparity = to8b(
         visualize_disparity_with_jet_color_scheme(converted_disparity))
     return converted_disparity  # height x width x 3. int values in [0,255].
コード例 #6
0
 def convert_rgb_to_saveable(rgb):
     # input: float values in [0,1]
     # output: int values in [0,255]
     return to8b(rgb)
コード例 #7
0
def free_viewpoint_rendering(args):

    # memory vs. speed and quality
    frames_at_a_time = 10  # set to 1 to reduce memory requirements
    only_rgb = False  # set to True to reduce memory requirements. Needs to be False for some scene editing to work.

    # determine output name
    if args.camera_path == "spiral":
        output_name = args.deformations + "_" + args.camera_path
    elif args.camera_path == "fixed":
        output_name = (args.deformations + "_" + args.camera_path + "_" +
                       str(args.fixed_view))
    elif args.camera_path == "input_reconstruction":
        output_name = args.deformations + "_" + args.camera_path
    else:
        raise RuntimeError("invalid --camera_path argument")

    if args.forced_background_stabilization is not None:
        output_name += "_fbs_" + str(args.forced_background_stabilization)
    if args.motion_factor is not None:
        output_name += "_exaggeration_" + str(args.motion_factor)
    if args.foreground_removal is not None:
        output_name += "_removal_" + str(args.foreground_removal)
    if args.render_canonical:
        output_name += "_canonical"

    output_folder = os.path.join(args.input, "output", output_name)
    create_folder(output_folder)

    # load Nerf network
    (
        render_kwargs_train,
        render_kwargs_test,
        start,
        grad_vars,
        load_weights_into_network,
        checkpoint_dict,
        get_training_ray_bending_latents,
        load_llff_dataset,
        raw_render_path,
        render_convenient,
        convert_rgb_to_saveable,
        convert_disparity_to_saveable,
        convert_disparity_to_jet,
        convert_disparity_to_phong,
        store_ray_bending_mesh_visualization,
        to8b,
    ) = _setup_nonrigid_nerf_network(args.input)
    print("sucessfully loaded nerf network", flush=True)

    # load dataset
    ray_bending_latents = (get_training_ray_bending_latents()
                           )  # shape: frames x latent_size
    (
        images,
        poses,
        all_rotations,
        all_translations,
        bds,
        render_poses,
        render_rotations,
        render_translations,
        i_train,
        i_val,
        i_test,
        near,
        far,
        dataset_extras,
    ) = load_llff_dataset(render_kwargs_train_=render_kwargs_train,
                          render_kwargs_test_=render_kwargs_test
                          )  # load dataset that this nerf was trained on
    print("sucessfully loaded dataset", flush=True)

    # determine subset
    if args.deformations == "train":
        indices = i_train
        poses = poses[i_train]
        ray_bending_latents = ray_bending_latents[i_train]
        images = images[i_train]
        print("rendering training set")
    elif args.deformations == "test":
        indices = i_test
        poses = poses[i_test]
        ray_bending_latents = ray_bending_latents[i_test]
        images = images[i_test]
        print("rendering test set")
    elif args.deformations == "all":
        print("rendering training and test set")
    else:
        raise RuntimeError("invalid --deformations argument")

    copy_over_groundtruth_images = False
    if copy_over_groundtruth_images:
        groundtruth_images_folder = os.path.join(output_folder, "groundtruth")
        create_folder(groundtruth_images_folder)
        for i, rgb in enumerate(images):
            rgb = convert_rgb_to_saveable(rgb)
            file_prefix = os.path.join(groundtruth_images_folder,
                                       str(i).zfill(6))
            imageio.imwrite(file_prefix + ".png", rgb)

    # determine camera poses and latent codes
    num_poses = poses.shape[0]
    intrinsics = dataset_extras["intrinsics"]
    if args.camera_path == "input_reconstruction":
        poses = poses
        intrinsics = [
            intrinsics[dataset_extras["imageid_to_viewid"][i]]
            for i in range(num_poses)
        ]
    elif args.camera_path == "fixed":
        poses = torch.stack(
            [torch.Tensor(poses[args.fixed_view]) for _ in range(num_poses)],
            0)  # N x 3 x 4
        intrinsics = [
            intrinsics[dataset_extras["imageid_to_viewid"][args.fixed_view]]
            for _ in range(num_poses)
        ]
    elif args.camera_path == "spiral":
        # poses = np.stack(_spiral_poses(poses, bds, num_poses), axis=0)
        poses = []
        while len(poses) < num_poses:
            poses += [render_pose for render_pose in render_poses]
        poses = np.stack(poses, axis=0)[:num_poses]
        intrinsics = [
            intrinsics[dataset_extras["imageid_to_viewid"][0]]
            for _ in range(num_poses)
        ]
    else:
        # poses has shape N x ... and ray_bending_latents has shape N x ...
        # Can design custom camera paths here.
        # poses is indexed with imageid
        # ray_bending_latents is indexed with timestepid
        # intrinsics is indexed with viewid
        # images is indexed with imageid
        raise RuntimeError

        # example with time interpolation from a fixed camera view
        num_target_frames = 500
        latent_indices = np.linspace(0,
                                     ray_bending_latents.shape[0] - 1,
                                     num=num_target_frames)
        start_indices = np.floor(latent_indices).astype(np.int)
        end_indices = np.ceil(latent_indices).astype(np.int)
        start_latents = ray_bending_latents[
            start_indices]  # num_target_frames x latent_size
        end_latents = ray_bending_latents[
            end_indices]  # num_target_frames x latent_size
        interpolation_factors = latent_indices - start_indices  # shape: num_target_frames. should be in [0,1]
        interpolation_factors = torch.Tensor(interpolation_factors).reshape(
            -1, 1)  # num_target_frames x 1
        ray_bending_latents = end_latents * interpolation_factors + start_latents * (
            1. - interpolation_factors)

        fixed_camera = 0
        poses = torch.stack([
            torch.Tensor(poses[fixed_camera]) for _ in range(num_target_frames)
        ], 0)  # N x 3 x 4
        intrinsics = [
            intrinsics[dataset_extras["imageid_to_viewid"][fixed_camera]]
            for _ in range(num_target_frames)
        ]

    latents = ray_bending_latents

    latents = latents.detach().cuda()

    # rendering
    correspondence_rgbs = []
    rigidities = []
    rgbs = []
    disps = []

    num_output_frames = poses.shape[0]
    for start_index in range(0, num_output_frames, frames_at_a_time):

        end_index = np.min([start_index + frames_at_a_time, num_output_frames])
        print(
            "rendering " + str(start_index) + " to " + str(end_index) +
            " out of " + str(num_output_frames),
            flush=True,
        )

        subposes = poses[start_index:end_index]
        sublatents = [latents[i] for i in range(start_index, end_index)]

        # render
        returned = render_convenient(
            poses=subposes,
            ray_bending_latents=sublatents,
            intrinsics=intrinsics,
            with_ray_bending=not args.render_canonical,
            detailed_output=not only_rgb,
            rigidity_test_time_cutoff=args.forced_background_stabilization,
            motion_factor=args.motion_factor,
            foreground_removal=args.foreground_removal)
        if only_rgb:
            subrgbs, subdisps = returned
        else:
            subrgbs, subdisps, details_and_rest = returned
        print("finished rendering", flush=True)

        rgbs += [image for image in subrgbs]
        disps += [image for image in subdisps]
        if only_rgb:
            correspondence_rgbs += [None for _ in subrgbs]
            rigidities += [None for _ in subrgbs]
            continue

        # determine correspondences
        # details_and_rest: list, one entry per image. each image has first two dimensions height x width.
        min_point = np.array(
            checkpoint_dict["scripts_dict"]["min_nerf_volume_point"]).reshape(
                1, 1, 3)
        max_point = np.array(
            checkpoint_dict["scripts_dict"]["max_nerf_volume_point"]).reshape(
                1, 1, 3)
        for i, image_details in enumerate(details_and_rest):
            # visibility_weight is the weight of the influence that each sample has on the final rgb value. so they sum to at most 1.
            accumulated_visibility = torch.cumsum(
                torch.Tensor(image_details["fine_visibility_weights"]).cuda(),
                dim=-1)  # height x width x point samples
            median_indices = torch.min(
                torch.abs(accumulated_visibility - 0.5), dim=-1
            )[1]  # height x width. visibility goes from 0 to 1. 0.5 is the median, so treat it as "most likely to be on the actually visible surface"
            # visualize canonical correspondences as RGB
            height, width = median_indices.shape
            surface_pixels = (
                image_details["fine_input_pts"].reshape(
                    height * width, -1,
                    3)[np.arange(height * width),
                       median_indices.cpu().reshape(-1), :].reshape(
                           height, width, 3)
            )  # height x width x 3. median_indices contains the index of one ray sample for each pixel. this ray sample is selected in this line of code.
            correspondence_rgb = (surface_pixels - min_point) / (max_point -
                                                                 min_point)
            number_of_small_rgb_voxels = 100  # break the canonical space into smaller voxels. each voxel covers the entire RGB space [0,1]^3. makes it easier to visualize small changes. leads to a 3D checkerboard pattern
            if number_of_small_rgb_voxels > 1:
                correspondence_rgb *= number_of_small_rgb_voxels
                correspondence_rgb = correspondence_rgb - correspondence_rgb.astype(
                    int)
            correspondence_rgbs.append(correspondence_rgb)

            # visualize rigidity
            if "fine_rigidity_mask" in image_details:
                rigidity = (image_details["fine_rigidity_mask"].reshape(
                    height * width,
                    -1)[np.arange(height * width),
                        median_indices.cpu().reshape(-1)].reshape(
                            height, width))  # height x width. values in [0,1]
                rigidities.append(rigidity)
            else:
                rigidities.append(None)

    rgbs = np.stack(rgbs, axis=0)
    disps = np.stack(disps, axis=0)
    correspondence_rgbs = np.stack(correspondence_rgbs, axis=0)
    use_rigidity = rigidities[0] is not None

    # store results
    # for i, (rgb, disp, correspondence_rgb, rigidity) in zip(indices, (zip(rgbs, disps, correspondence_rgbs, rigidities))):
    for i, (rgb, disp, correspondence_rgb, rigidity) in enumerate(
            zip(rgbs, disps, correspondence_rgbs, rigidities)):
        print("storing image " + str(i) + " / " + str(rgbs.shape[0]),
              flush=True)
        rgb = convert_rgb_to_saveable(rgb)
        disp_saveable = convert_disparity_to_saveable(disp)
        disp_jet = convert_disparity_to_jet(disp)
        disp_phong = convert_disparity_to_phong(disp)
        if not only_rgb:
            correspondence_rgb = convert_rgb_to_saveable(correspondence_rgb)
        if use_rigidity:
            rigidity_saveable = convert_disparity_to_saveable(rigidity,
                                                              normalize=False)
            rigidity_jet = convert_disparity_to_jet(rigidity, normalize=False)

        file_postfix = "_" + str(i).zfill(6) + ".png"
        imageio.imwrite(os.path.join(output_folder, "rgb" + file_postfix), rgb)
        if not only_rgb:
            imageio.imwrite(
                os.path.join(output_folder, "correspondences" + file_postfix),
                correspondence_rgb,
            )
        if use_rigidity:
            imageio.imwrite(
                os.path.join(output_folder, "rigidity" + file_postfix),
                rigidity_saveable,
            )
            imageio.imwrite(
                os.path.join(output_folder, "rigidity_jet" + file_postfix),
                rigidity_jet)
        imageio.imwrite(os.path.join(output_folder, "disp" + file_postfix),
                        disp_saveable)
        imageio.imwrite(os.path.join(output_folder, "disp_jet" + file_postfix),
                        disp_jet)
        imageio.imwrite(
            os.path.join(output_folder, "disp_phong" + file_postfix),
            disp_phong)

    # movies
    file_prefix = os.path.join(output_folder, "video_")
    try:
        print("storing RGB video...", flush=True)
        imageio.mimwrite(
            file_prefix + "rgb.mp4",
            convert_rgb_to_saveable(rgbs),
            fps=args.output_video_fps,
            quality=9,
        )
        if not only_rgb:
            print("storing correspondence RGB video...", flush=True)
            imageio.mimwrite(
                file_prefix + "correspondences.mp4",
                convert_rgb_to_saveable(correspondence_rgbs),
                fps=args.output_video_fps,
                quality=9,
            )
        print("storing disparity video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp.mp4",
            convert_disparity_to_saveable(disps),
            fps=args.output_video_fps,
            quality=9,
        )
        print("storing disparity jet video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp_jet.mp4",
            np.stack([convert_disparity_to_jet(disp) for disp in disps],
                     axis=0),
            fps=args.output_video_fps,
            quality=9,
        )
        print("storing disparity phong video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp_phong.mp4",
            np.stack([convert_disparity_to_phong(disp) for disp in disps],
                     axis=0),
            fps=args.output_video_fps,
            quality=9,
        )
        if use_rigidity:
            rigidities = np.stack(rigidities, axis=0)
            print("storing rigidity video...", flush=True)
            imageio.mimwrite(
                file_prefix + "rigidity.mp4",
                convert_disparity_to_saveable(rigidities, normalize=False),
                fps=args.output_video_fps,
                quality=9,
            )
            print("storing rigidity jet video...", flush=True)
            imageio.mimwrite(
                file_prefix + "rigidity_jet.mp4",
                np.stack(
                    [
                        convert_disparity_to_jet(rigidity, normalize=False)
                        for rigidity in rigidities
                    ],
                    axis=0,
                ),
                fps=args.output_video_fps,
                quality=9,
            )
    except:
        print(
            "imageio.mimwrite() failed. maybe ffmpeg is not installed properly?"
        )

    # evaluation of background stability
    if args.camera_path == "fixed":
        standard_deviations = np.std(rgbs, axis=0)
        averaged_standard_devations = 10 * np.mean(standard_deviations,
                                                   axis=-1)

        from matplotlib import cm
        color_mapping = np.array([cm.jet(i)[:3] for i in range(256)])
        max_value = 1
        min_value = 0
        averaged_standard_devations = np.clip(
            averaged_standard_devations, a_max=max_value, a_min=min_value
        ) / max_value  # cut off above max_value. result is normalized to [0,1]
        averaged_standard_devations = (
            255. * averaged_standard_devations).astype(
                'uint8')  # now contains int in [0,255]
        original_shape = averaged_standard_devations.shape
        averaged_standard_devations = color_mapping[
            averaged_standard_devations.flatten()]
        averaged_standard_devations = averaged_standard_devations.reshape(
            original_shape + (3, ))

        imageio.imwrite(os.path.join(output_folder, "standard_deviations.png"),
                        averaged_standard_devations)

    # quantitative evaluation
    if args.camera_path == "input_reconstruction":
        try:
            from PerceptualSimilarity import lpips
            perceptual_metric = lpips.LPIPS(net='alex')
        except:
            print(
                "Perceptual LPIPS metric not found. Please see the README for installation instructions"
            )
            perceptual_metric = None

        create_error_maps = True  # whether to write out error images instead of just computing scores

        naive_error_folder = os.path.join(output_folder, "naive_errors/")
        create_folder(naive_error_folder)
        ssim_error_folder = os.path.join(output_folder, "ssim_errors/")
        create_folder(ssim_error_folder)

        to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)

        def visualize_with_jet_color_scheme(image):
            from matplotlib import cm
            color_mapping = np.array([cm.jet(i)[:3] for i in range(256)])
            max_value = 1.0
            min_value = 0.0
            intermediate = np.clip(
                image, a_max=max_value, a_min=min_value
            ) / max_value  # cut off above max_value. result is normalized to [0,1]
            intermediate = (255. * intermediate).astype(
                'uint8')  # now contains int in [0,255]
            original_shape = intermediate.shape
            intermediate = color_mapping[intermediate.flatten()]
            intermediate = intermediate.reshape(original_shape + (3, ))
            return intermediate

        mask = None
        scores = {}
        from skimage.metrics import structural_similarity as ssim
        for i, (groundtruth, generated) in enumerate(zip(images, rgbs)):

            if mask is None:  # undistortion leads to masked-out black pixels in groundtruth
                mask = (np.sum(groundtruth, axis=-1) == 0.)
            groundtruth[mask] = 0.
            generated[mask] = 0.

            # PSNR
            mse = np.mean((groundtruth - generated)**2)
            psnr = -10. * np.log10(mse)

            # SSIM
            # https://scikit-image.org/docs/dev/api/skimage.metrics.html#skimage.metrics.structural_similarity
            returned = ssim(groundtruth,
                            generated,
                            data_range=1.0,
                            multichannel=True,
                            gaussian_weights=True,
                            sigma=1.5,
                            use_sample_covariance=False,
                            full=create_error_maps)
            if create_error_maps:
                ssim_error, ssim_error_image = returned
            else:
                ssim_error = returned

            # perceptual metric
            if perceptual_metric is None:
                lpips = 1.
            else:

                def numpy_to_pytorch(np_image):
                    torch_image = 2 * torch.from_numpy(
                        np_image) - 1  # height x width x 3. must be in [-1,+1]
                    torch_image = torch_image.permute(2, 0,
                                                      1)  # 3 x height x width
                    return torch_image.unsqueeze(0)  # 1 x 3 x height x width

                lpips = perceptual_metric.forward(
                    numpy_to_pytorch(groundtruth), numpy_to_pytorch(generated))
                lpips = float(lpips.detach().reshape(1).numpy()[0])

            scores[i] = {"psnr": psnr, "ssim": ssim_error, "lpips": lpips}

            if create_error_maps:
                # MSE-style
                error = np.linalg.norm(groundtruth - generated,
                                       axis=-1) / np.sqrt(
                                           1 + 1 + 1)  # height x width
                error *= 10.  # exaggarate error
                error = np.clip(error, 0.0, 1.0)
                error = to8b(visualize_with_jet_color_scheme(
                    error))  # height x width x 3. int values in [0,255]
                filename = os.path.join(naive_error_folder,
                                        'error_{:03d}.png'.format(i))
                imageio.imwrite(filename, error)

                # SSIM
                filename = os.path.join(ssim_error_folder,
                                        'error_{:03d}.png'.format(i))
                ssim_error_image = to8b(
                    visualize_with_jet_color_scheme(
                        1. - np.mean(ssim_error_image, axis=-1)))
                imageio.imwrite(filename, ssim_error_image)

        averaged_scores = {}
        averaged_scores["average_psnr"] = np.mean(
            [score["psnr"] for score in scores.values()])
        averaged_scores["average_ssim"] = np.mean(
            [score["ssim"] for score in scores.values()])
        averaged_scores["average_lpips"] = np.mean(
            [score["lpips"] for score in scores.values()])

        print(averaged_scores, flush=True)

        scores.update(averaged_scores)

        import json
        with open(os.path.join(output_folder, "scores.json"),
                  "w") as json_file:
            json.dump(scores, json_file, indent=4)
コード例 #8
0
def render_path(render_poses,
                styles,
                hwfs,
                chunk,
                render_kwargs,
                nfs=None,
                gt_imgs=None,
                alpha_cache=None,
                feature_cache=None,
                weights_cache=None,
                savedir=None,
                maximum=1000,
                get_cached=None,
                get_rgbs=False,
                verbose=True,
                cb=None,
                update_cb=None):
    render_kwargs['network_fine'].get_cached = get_cached
    rgbs = []
    disps = []
    alphas = []
    features = []
    weights = []

    total_psnr = 0.
    total_psnr0 = 0

    N = len(render_poses)
    s = N // maximum if len(render_poses) > maximum else 1
    if gt_imgs is not None:
        gt_imgs = gt_imgs[::s].cuda()
    render_poses = render_poses[::s].cuda()
    styles = styles[::s].cuda()
    hwfs = hwfs[::s].cuda()

    iterator = zip(render_poses, styles)
    if verbose:
        iterator = tqdm(iterator, total=len(styles))

    for i, (c2w, style) in enumerate(iterator):
        if cb is not None:
            cb(i)
        H, W, focal = hwfs[i]
        H, W = int(H), int(W)
        if nfs is not None:
            near, far = nfs[i]
            render_kwargs.update({'near': near, 'far': far})

        alpha = alpha_cache[i] if alpha_cache is not None else None
        feature = feature_cache[i] if feature_cache is not None else None
        weight = weights_cache[i] if weights_cache is not None else None
        rgb, disp, acc, additional = render(H,
                                            W,
                                            focal,
                                            style=style,
                                            chunk=chunk,
                                            weights=weight,
                                            c2w=c2w[:3, :4],
                                            alpha=alpha,
                                            feature=feature,
                                            **render_kwargs)

        if 'rgb0' in additional:
            rgb0 = additional['rgb0']

        if gt_imgs is not None:
            gt_img = gt_imgs[i]
            mse_loss = img2mse(rgb, gt_img)
            psnr = mse2psnr(mse_loss)
            total_psnr += psnr.item()
            if 'rgb0' in additional:
                mse_loss0 = img2mse(rgb0, gt_img)
                psnr0 = mse2psnr(mse_loss0)
                total_psnr0 += psnr0.item()

        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())

        if update_cb:
            update_cb(i, rgbs[-1])

        if get_cached:
            alphas.append(additional['raw_alpha'])
            features.append(additional['features'])
            weights.append(additional['weights0'])

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            disp8 = to8b(to_disp_img(disps[-1]))
            imageio.imwrite(os.path.join(savedir, '{:04d}_rgb.png'.format(i)),
                            rgb8)
            imageio.imwrite(os.path.join(savedir, '{:04d}_disp.png'.format(i)),
                            disp8)
            if gt_imgs is not None:
                gt_img = to8b((gt_imgs[i]).cpu().numpy())
                imageio.imwrite(
                    os.path.join(savedir, '{:04d}_gt.png'.format(i)), gt_img)

    if gt_imgs is not None:
        with open(os.path.join(savedir, 'log.txt'), 'a+') as f:
            torchvision.utils.save_image(
                torch.tensor(rgbs).cpu().permute(0, 3, 1, 2), 'rgbs.png')
            torchvision.utils.save_image(
                torch.tensor(gt_imgs).cpu().permute(0, 3, 1, 2), 'gt.png')
            msg = f'psnr0, psnr1, {total_psnr0/len(render_poses)}, {total_psnr/len(render_poses)}'
            f.write(msg + '\n')
            print(msg)

    total_psnr = total_psnr / len(rgbs)
    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    if get_cached:
        render_kwargs['network_fine'].get_cached = None
        return rgbs, disps, total_psnr, alphas, features, weights
    else:
        return rgbs, disps, total_psnr
コード例 #9
0
ファイル: nerf_reptile.py プロジェクト: huy-ha/nerf
def log_qualitative_results(writer,
                            metalearning_iter,
                            scene_id,
                            save_dir,
                            render_poses,
                            poses,
                            i_split,
                            hwf,
                            chunk,
                            render_kwargs_test,
                            images,
                            N_importance,
                            timesteps,
                            use_viewdirs=True,
                            render_test_set=False):
    H, W, focal = hwf
    testsavedir = os.path.join(
        save_dir, 'testset_iter/{:06d}/scene{}'.format(metalearning_iter,
                                                       scene_id))
    os.makedirs(testsavedir, exist_ok=True)
    split_timesteps = []
    for i in i_split:
        split_timesteps.extend([timesteps[i]] * H * W)
    split_timesteps = np.asarray(split_timesteps)

    render_path(poses[i_split],
                hwf,
                split_timesteps,
                chunk,
                render_kwargs_test,
                gt_imgs=images[i_split],
                savedir=testsavedir)
    # Log a rendered validation view to Tensorboard
    img_i = np.random.choice(i_split)
    target = images[img_i]
    pose = poses[img_i, :3, :4]
    split_timestep = [timesteps[img_i]] * (H * W)

    rgb, disp, acc, extras = render(H,
                                    W,
                                    focal,
                                    split_timestep,
                                    chunk=chunk,
                                    c2w=pose,
                                    **render_kwargs_test)

    psnr = mse2psnr(img2mse(rgb, target))

    # Save out the validation image for Tensorboard-free monitoring
    testimgdir = os.path.join(save_dir, 'tboard_val_imgs')
    if not os.path.exists(testimgdir):
        os.makedirs(testimgdir, exist_ok=True)
    imageio.imwrite(
        os.path.join(testimgdir,
                     '{:06d}_{}.png'.format(metalearning_iter, scene_id)),
        to8b(rgb))

    writer.add_image(f'rgb/{scene_id}',
                     np.squeeze(to8b(rgb)[tf.newaxis], axis=0),
                     metalearning_iter,
                     dataformats='HWC')
    writer.add_image(f'disp/{scene_id}',
                     np.squeeze(disp[tf.newaxis, ..., tf.newaxis], axis=0),
                     metalearning_iter,
                     dataformats='HWC')
    writer.add_image(f'acc/{scene_id}',
                     np.squeeze(acc[tf.newaxis, ..., tf.newaxis], axis=0),
                     metalearning_iter,
                     dataformats='HWC')
    writer.add_image(f'rgb_holdout/{scene_id}',
                     np.squeeze(target[tf.newaxis], axis=0),
                     metalearning_iter,
                     dataformats='HWC')
    if N_importance > 0:
        writer.add_image(f'rgb0/{scene_id}',
                         np.squeeze(to8b(extras['rgb0'])[tf.newaxis], axis=0),
                         metalearning_iter,
                         dataformats='HWC')
        writer.add_image(f'disp0/{scene_id}',
                         np.squeeze(extras['disp0'][tf.newaxis, ...,
                                                    tf.newaxis],
                                    axis=0),
                         metalearning_iter,
                         dataformats='HWC')
        writer.add_image(f'z_std/{scene_id}',
                         np.squeeze(extras['z_std'][tf.newaxis, ...,
                                                    tf.newaxis],
                                    axis=0),
                         metalearning_iter,
                         dataformats='HWC')

    # Save videos
    if render_test_set:
        sorted_timesteps = sorted(list(set(timesteps)))
        rgbs, disps = render_timesteps(poses[i_split[0]],
                                       hwf,
                                       sorted_timesteps,
                                       chunk,
                                       render_kwargs_test,
                                       savedir=testsavedir)
        moviebase = os.path.join(
            save_dir, '{}_temporal_{:06d}_'.format(scene_id,
                                                   metalearning_iter))
        imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
        imageio.mimwrite(moviebase + 'disp.mp4',
                         to8b(disps / np.max(disps)),
                         fps=30,
                         quality=8)

        if use_viewdirs:
            render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
            rgbs_still, _ = render_path(render_poses, hwf, sorted_timesteps,
                                        chunk, render_kwargs_test)
            render_kwargs_test['c2w_staticcam'] = None
            imageio.mimwrite(moviebase + 'rgb_still.mp4',
                             to8b(rgbs_still),
                             fps=30,
                             quality=8)
コード例 #10
0
    def __init__(self,
                 instance,
                 config,
                 use_cached=True,
                 expname=None,
                 edit_type=None,
                 num_canvases=9,
                 shape_params='fusion_shape_branch',
                 color_params='color_branch',
                 randneg=8192,
                 device='cuda:0'):
        super().__init__(style=dict(
            border="3px solid gray", padding="8px", display="inline-block"))
        torch.set_default_tensor_type('torch.cuda.FloatTensor' if device ==
                                      'cuda:0' else 'cpu')
        self.edit_type = edit_type
        self.instance = instance
        self.num_canvases = num_canvases
        self.shape_params = shape_params
        self.color_params = color_params
        self.size = IMG_SIZE
        self.randneg = randneg
        self.device = device
        self.msg_out = labwidget.Div()
        self.editing_canvas = paintwidget.PaintWidget(image='',
                                                      width=self.size * 3,
                                                      height=self.size * 3).on(
                                                          'mask',
                                                          self.change_mask)
        self.editing_canvas.index = -1
        self.copy_canvas = paintwidget.PaintWidget(image='',
                                                   width=self.size * 2,
                                                   height=self.size * 2).on(
                                                       'mask', self.copy)
        self.copy_mask = None
        inline = dict(display='inline', border="2px solid gray")

        self.toggle_rgbs_disps_btn = labwidget.Button(
            'show depth', style=inline).on('click', self.toggle_rgb_disps)
        self.positive_mask_btn = labwidget.Button(self.pad('edit color'),
                                                  style=inline).on(
                                                      'click',
                                                      self.positive_mask)
        self.addition_mask_btn = labwidget.Button(self.pad('add shape'),
                                                  style=inline).on(
                                                      'click', self.add)
        self.sigma_mask_btn = labwidget.Button(self.pad('remove shape'),
                                               style=inline).on(
                                                   'click', self.sigma_mask)
        self.color_from_btn = labwidget.Button(self.pad('transfer color'),
                                               style=inline).on(
                                                   'click', self.color_from)
        self.shape_from_btn = labwidget.Button(self.pad('transfer shape'),
                                               style=inline).on(
                                                   'click', self.shape_from)
        self.execute_btn = labwidget.Button(self.pad('execute'),
                                            style=inline).on(
                                                'click', self.execute_edit)
        self.brushsize_textbox = labwidget.Textbox(5,
                                                   desc='brushsize: ',
                                                   size=3).on(
                                                       'value',
                                                       self.change_brushsize)

        self.target = None
        self.use_color_cache = True

        self.color_style = dict(display='inline', border="2px solid white")
        trn = transforms.Compose(
            [transforms.Resize(32),
             transforms.ToTensor()])
        bg_img = trn(Image.open('bg.png').convert('RGB'))
        bg_img = renormalize.as_url(bg_img * 2 - 1)
        self.color_pallete = [
            labwidget.Image(src=bg_img,
                            style=self.color_style).on('click', self.set_color)
        ]
        self.color_pallete[-1].index = 0
        self.color_pallete[-1].color_type = 'bg'

        for color in mean_colors.colors.values():
            image = torch.zeros(3, 32, 32)
            image[0, :, :] = color[0]
            image[1, :, :] = color[1]
            image[2, :, :] = color[2]
            image = image / 255. * 2 - 1
            self.color_pallete.append(
                labwidget.Image(src=renormalize.as_url(image),
                                style=self.color_style).on(
                                    'click', self.set_color))
            self.color_pallete[-1].index = len(self.color_pallete) - 1
            self.color_pallete[-1].color_type = 'color'
            # TODO: Highlight the white box with black for clarity

        self.color = None
        self.mask_type = None
        self.real_canvas_array = []
        self.real_images_array = []
        self.positive_masks = []

        train, test, optimizer, styles = load_model(instance,
                                                    config,
                                                    expname=expname)
        poses, hwfs, cache, args = load_dataset(instance,
                                                config,
                                                num_canvases=num_canvases,
                                                N_instances=styles.shape[0],
                                                expname=expname,
                                                use_cached=use_cached)
        self.parentdir = load_config(config).expname
        self.expname = expname if expname else self.parentdir
        self.savedir = os.path.join(self.expname, str(instance))
        os.makedirs(self.savedir, exist_ok=True)
        self.poses = poses.to(device)
        self.cache = cache
        self.chunk = args.chunk
        self.near = args.blender_near
        self.far = args.blender_far
        self.nfs = [[self.near, self.far]] * self.poses.shape[0]
        self.hwfs = hwfs.to(device)
        self.old_fine_network = dict(
            copy.deepcopy(test['network_fine']).named_parameters())
        self.train_kwargs = train
        self.test_kwargs = test
        self.optimizer = None
        self.all_instance_styles = styles
        self.instance_style = styles[instance].unsqueeze(dim=0).to(device)

        if cache is not None:
            self.weights = cache['weights']
            self.alphas = cache['alphas']
            self.features = cache['features']
        else:
            self.weights = None
            self.alphas = None
            self.features = None

        self.trn = transforms.Compose(
            [transforms.Resize(128),
             transforms.ToTensor()])
        self.transfer_instances_array = [
            labwidget.Image(src='').on('click', self.change_target)
            for _ in range(12)
        ]
        self.addition_instances_array = [
            labwidget.Image(src='').on('click', self.change_target)
            for _ in range(12)
        ]
        images, disps = self.render(self.poses,
                                    self.instance_style,
                                    verbose=False,
                                    get_disps=True)
        for i, image in enumerate(images):
            resized = F.interpolate(image.unsqueeze(dim=0),
                                    size=(self.size, self.size)).squeeze(dim=0)
            disp_img = torch.from_numpy(to8b(to_disp_img(
                disps[i]))).unsqueeze(dim=0) / 255.
            resized_disp = F.interpolate(disp_img.unsqueeze(dim=0),
                                         size=(self.size,
                                               self.size)).squeeze(dim=0)
            self.real_images_array.append(
                labwidget.Image(src=renormalize.as_url(resized)).on(
                    'click', self.set_editing_canvas))
            self.real_images_array[-1].index = i
            self.real_canvas_array.append(
                paintwidget.PaintWidget(image=renormalize.as_url(image),
                                        width=self.size * 3,
                                        height=self.size * 3).on(
                                            'mask', self.change_mask))
            self.real_canvas_array[-1].index = i
            self.real_canvas_array[-1].negative_mask = ''
            self.real_canvas_array[-1].resized_image = renormalize.as_url(
                resized)
            self.real_canvas_array[-1].resized_disp = renormalize.as_url(
                resized_disp)
            self.real_canvas_array[-1].disp = renormalize.as_url(disp_img)
            self.real_canvas_array[-1].orig = renormalize.as_url(image)
            self.positive_masks.append(torch.zeros(image.shape).cpu())
        self.show_rgbs = True

        self.change_brushsize()
        self.editname_textbox = labwidget.Datalist(choices=self.saved_names(),
                                                   style=inline)
        self.save_btn = labwidget.Button('save',
                                         style=inline).on('click', self.save)
        self.load_btn = labwidget.Button('load',
                                         style=inline).on('click', self.load)
コード例 #11
0
                                                  1) * N_per_instance
        all_poses = all_poses[ps:pe]

        if args.video:
            all_poses, all_hwfs = generate_flythrough(all_poses[0].cpu(),
                                                      all_hwfs[0],
                                                      num_poses=100)

        nfs = [[writer.near, writer.far]] * all_poses.shape[0]
        styles = writer.instance_style.repeat((all_poses.shape[0], 1))

        with torch.no_grad():
            print(f'Saving samples in {savedir}')
            rgbs, disps, psnr = render_path(all_poses,
                                            styles,
                                            all_hwfs,
                                            writer.chunk,
                                            writer.test_kwargs,
                                            nfs=nfs,
                                            savedir=savedir,
                                            verbose=True)
            if args.video:
                imageio.mimwrite(os.path.join(savedir, 'video.mp4'),
                                 to8b(rgbs),
                                 fps=30,
                                 quality=8)
                imageio.mimwrite(os.path.join(savedir, 'disps.mp4'),
                                 to8b(disps / np.max(disps)),
                                 fps=30,
                                 quality=8)