Esempio n. 1
0
            target_view_mask = target_view_mask_init.clone()
            if not args.include_src:
                target_view_mask *= ~src_view_mask

            novel_view_idxs = target_view_mask.nonzero(as_tuple=False).reshape(-1)

            poses = poses[target_view_mask]  # (NV[-NS], 4, 4)

            all_rays = (
                util.gen_rays(
                    poses.reshape(-1, 4, 4),
                    W,
                    H,
                    focal * args.scale,
                    z_near,
                    z_far,
                    c=c * args.scale if c is not None else None,
                )
                .reshape(-1, 8)
                .to(device=device)
            )  # ((NV[-NS])*H*W, 8)

            poses = None
            focal = focal.to(device=device)

        rays_spl = torch.split(all_rays, args.ray_batch_size, dim=0)  # Creates views

        n_gen_views = len(novel_view_idxs)

        net.encode(
Esempio n. 2
0
        radius = args.radius

    # Use 360 pose sequence from NeRF
    render_poses = torch.stack(
        [
            util.pose_spherical(angle, args.elevation, radius)
            for angle in np.linspace(-180, 180, args.num_views + 1)[:-1]
        ],
        0,
    )  # (NV, 4, 4)

render_rays = util.gen_rays(
    render_poses,
    W,
    H,
    focal * args.scale,
    z_near,
    z_far,
    c=c * args.scale if c is not None else None,
).to(device=device)
# (NV, H, W, 8)

focal = focal.to(device=device)

source = torch.tensor(list(map(int, args.source.split())), dtype=torch.long)
NS = len(source)
random_source = NS == 1 and source[0] == -1
assert not (source >= NV).any()

if renderer.n_coarse < 64:
    # Ensure decent sampling resolution
Esempio n. 3
0
    def vis_step(self, data, global_step, idx=None):
        if "images" not in data:
            return {}
        if idx is None:
            batch_idx = np.random.randint(0, data["images"].shape[0])
        else:
            print(idx)
            batch_idx = idx
        images = data["images"][batch_idx].to(device=device)  # (NV, 3, H, W)
        poses = data["poses"][batch_idx].to(device=device)  # (NV, 4, 4)
        focal = data["focal"][batch_idx:batch_idx + 1]  # (1)
        c = data.get("c")
        if c is not None:
            c = c[batch_idx:batch_idx + 1]  # (1)
        NV, _, H, W = images.shape
        cam_rays = util.gen_rays(poses,
                                 W,
                                 H,
                                 focal,
                                 self.z_near,
                                 self.z_far,
                                 c=c)  # (NV, H, W, 8)
        images_0to1 = images * 0.5 + 0.5  # (NV, 3, H, W)

        curr_nviews = nviews[torch.randint(0, len(nviews), (1, )).item()]
        views_src = np.sort(np.random.choice(NV, curr_nviews, replace=False))
        view_dest = np.random.randint(0, NV - curr_nviews)
        for vs in range(curr_nviews):
            view_dest += view_dest >= views_src[vs]
        views_src = torch.from_numpy(views_src)

        # set renderer net to eval mode
        renderer.eval()
        source_views = (images_0to1[views_src].permute(
            0, 2, 3, 1).cpu().numpy().reshape(-1, H, W, 3))

        gt = images_0to1[view_dest].permute(1, 2,
                                            0).cpu().numpy().reshape(H, W, 3)
        with torch.no_grad():
            test_rays = cam_rays[view_dest]  # (H, W, 8)
            test_images = images[views_src]  # (NS, 3, H, W)
            net.encode(
                test_images.unsqueeze(0),
                poses[views_src].unsqueeze(0),
                focal.to(device=device),
                c=c.to(device=device) if c is not None else None,
            )
            test_rays = test_rays.reshape(1, H * W, -1)
            render_dict = DotMap(render_par(test_rays, want_weights=True))
            coarse = render_dict.coarse
            fine = render_dict.fine

            using_fine = len(fine) > 0

            alpha_coarse_np = coarse.weights[0].sum(
                dim=-1).cpu().numpy().reshape(H, W)
            rgb_coarse_np = coarse.rgb[0].cpu().numpy().reshape(H, W, 3)
            depth_coarse_np = coarse.depth[0].cpu().numpy().reshape(H, W)

            if using_fine:
                alpha_fine_np = fine.weights[0].sum(
                    dim=1).cpu().numpy().reshape(H, W)
                depth_fine_np = fine.depth[0].cpu().numpy().reshape(H, W)
                rgb_fine_np = fine.rgb[0].cpu().numpy().reshape(H, W, 3)

        print("c rgb min {} max {}".format(rgb_coarse_np.min(),
                                           rgb_coarse_np.max()))
        print("c alpha min {}, max {}".format(alpha_coarse_np.min(),
                                              alpha_coarse_np.max()))
        alpha_coarse_cmap = util.cmap(alpha_coarse_np) / 255
        depth_coarse_cmap = util.cmap(depth_coarse_np) / 255
        vis_list = [
            *source_views,
            gt,
            depth_coarse_cmap,
            rgb_coarse_np,
            alpha_coarse_cmap,
        ]

        vis_coarse = np.hstack(vis_list)
        vis = vis_coarse

        if using_fine:
            print("f rgb min {} max {}".format(rgb_fine_np.min(),
                                               rgb_fine_np.max()))
            print("f alpha min {}, max {}".format(alpha_fine_np.min(),
                                                  alpha_fine_np.max()))
            depth_fine_cmap = util.cmap(depth_fine_np) / 255
            alpha_fine_cmap = util.cmap(alpha_fine_np) / 255
            vis_list = [
                *source_views,
                gt,
                depth_fine_cmap,
                rgb_fine_np,
                alpha_fine_cmap,
            ]

            vis_fine = np.hstack(vis_list)
            vis = np.vstack((vis_coarse, vis_fine))
            rgb_psnr = rgb_fine_np
        else:
            rgb_psnr = rgb_coarse_np

        psnr = util.psnr(rgb_psnr, gt)
        vals = {"psnr": psnr}
        print("psnr", psnr)

        # set the renderer network back to train mode
        renderer.train()
        return vis, vals
Esempio n. 4
0
    def calc_losses(self, data, is_train=True, global_step=0):
        if "images" not in data:
            return {}
        all_images = data["images"].to(device=device)  # (SB, NV, 3, H, W)

        SB, NV, _, H, W = all_images.shape
        all_poses = data["poses"].to(device=device)  # (SB, NV, 4, 4)
        all_bboxes = data.get("bbox")  # (SB, NV, 4)  cmin rmin cmax rmax
        all_focals = data["focal"]  # (SB)
        all_c = data.get("c")  # (SB)

        if self.use_bbox and global_step >= args.no_bbox_step:
            self.use_bbox = False
            print(">>> Stopped using bbox sampling @ iter", global_step)

        if not is_train or not self.use_bbox:
            all_bboxes = None

        all_rgb_gt = []
        all_rays = []

        curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()]
        if curr_nviews == 1:
            image_ord = torch.randint(0, NV, (SB, 1))
        else:
            image_ord = torch.empty((SB, curr_nviews), dtype=torch.long)
        for obj_idx in range(SB):
            if all_bboxes is not None:
                bboxes = all_bboxes[obj_idx]
            images = all_images[obj_idx]  # (NV, 3, H, W)
            poses = all_poses[obj_idx]  # (NV, 4, 4)
            focal = all_focals[obj_idx]
            c = None
            if "c" in data:
                c = data["c"][obj_idx]
            if curr_nviews > 1:
                # Somewhat inefficient, don't know better way
                image_ord[obj_idx] = torch.from_numpy(
                    np.random.choice(NV, curr_nviews, replace=False))
            images_0to1 = images * 0.5 + 0.5

            cam_rays = util.gen_rays(poses,
                                     W,
                                     H,
                                     focal,
                                     self.z_near,
                                     self.z_far,
                                     c=c)  # (NV, H, W, 8)
            rgb_gt_all = images_0to1
            rgb_gt_all = (rgb_gt_all.permute(0, 2, 3,
                                             1).contiguous().reshape(-1, 3)
                          )  # (NV, H, W, 3)

            if all_bboxes is not None:
                pix = util.bbox_sample(bboxes, args.ray_batch_size)
                pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2]
            else:
                pix_inds = torch.randint(0, NV * H * W,
                                         (args.ray_batch_size, ))

            rgb_gt = rgb_gt_all[pix_inds]  # (ray_batch_size, 3)
            rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to(
                device=device)  # (ray_batch_size, 8)

            all_rgb_gt.append(rgb_gt)
            all_rays.append(rays)

        all_rgb_gt = torch.stack(all_rgb_gt)  # (SB, ray_batch_size, 3)
        all_rays = torch.stack(all_rays)  # (SB, ray_batch_size, 8)

        image_ord = image_ord.to(device)
        src_images = util.batched_index_select_nd(
            all_images, image_ord)  # (SB, NS, 3, H, W)
        src_poses = util.batched_index_select_nd(all_poses,
                                                 image_ord)  # (SB, NS, 4, 4)

        all_bboxes = all_poses = all_images = None

        net.encode(
            src_images,
            src_poses,
            all_focals.to(device=device),
            c=all_c.to(device=device) if all_c is not None else None,
        )

        render_dict = DotMap(render_par(
            all_rays,
            want_weights=True,
        ))
        coarse = render_dict.coarse
        fine = render_dict.fine
        using_fine = len(fine) > 0

        loss_dict = {}

        rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt)
        loss_dict["rc"] = rgb_loss.item() * self.lambda_coarse
        if using_fine:
            fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt)
            rgb_loss = rgb_loss * self.lambda_coarse + fine_loss * self.lambda_fine
            loss_dict["rf"] = fine_loss.item() * self.lambda_fine

        loss = rgb_loss
        if is_train:
            loss.backward()
        loss_dict["t"] = loss.item()

        return loss_dict
Esempio n. 5
0
        sigma = self.img[tuple(xyz)]
        # Anything out of bounds set back to 0
        sigma[mask] = 0.0
        sigma = sigma.reshape(1, -1, 1)
        rgb = torch.ones(sigma.size(0), sigma.size(1), 3).to(device)
        return torch.cat((rgb, sigma), dim=-1).to(device)


image = CTImage(torch.tensor(arr).to(device))
renderer = NeRFRenderer(
    n_coarse=64, n_fine=32, n_fine_depth=16, depth_std=0.01, sched=[], white_bkgd=False, eval_batch_size=50000
).to(device=device)
render_par = renderer.bind_parallel(image, [0], simple_output=True).eval()

render_rays = util.gen_rays(render_poses, W, H, focal, z_near, z_far).to(device=device)

all_rgb_fine = []
for rays in tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)):
    rgb, _depth = render_par(rays[None])
    all_rgb_fine.append(rgb[0])
_depth = None
rgb_fine = torch.cat(all_rgb_fine)
frames = (rgb_fine.view(num_views, H, W, 3).cpu().numpy() * 255).astype(
    np.uint8
)

im_name = "raw_data"

frames_dir_name = os.path.join(output, im_name + "_frames")
os.makedirs(frames_dir_name, exist_ok=True)
Esempio n. 6
0
            poses = data["poses"][0]  # (NV, 4, 4)
            pri_poses = poses[src_view_mask]  # (NS, 4, 4)
            pri_poses = pri_poses.to(device=device)
            if not args.include_src:
                target_view_mask *= ~src_view_mask

            novel_view_idxs = target_view_mask.nonzero(
                as_tuple=False).reshape(-1)
            poses = poses[target_view_mask]  # (NV-1, 4, 4)

            all_rays = (util.gen_rays(
                poses.reshape(-1, 4, 4),
                W,
                H,
                focal * args.scale,
                z_near,
                z_far,
                c=c * args.scale if c is not None else None).reshape(
                    -1, 8).to(device=device))  # ((NV-1)*H*W, 8)
            rays_spl = torch.split(all_rays, args.ray_batch_size,
                                   dim=0)  # Creates views
            poses = _poses_a = _poses_b = None
            focal = focal.to(device=device)

        n_gen_views = len(novel_view_idxs)

        util.get_module(renderer).net.encode(
            images[src_view_mask].to(device=device).unsqueeze(0),
            pri_poses.unsqueeze(0),
            focal, (z_near, z_far),
Esempio n. 7
0
        images_0to1 = images * 0.5 + 0.5  # (B, 3, H, W)

        SB, NV, _, H, W = images.shape

        if random_source:
            src_view = torch.randint(0, NV, (SB, 1))
        else:
            src_view = source.unsqueeze(0).expand(SB, -1)

        dest_view = torch.randint(0, NV - NS, (SB, 1))
        for i in range(NS):
            dest_view += dest_view >= src_view[:, i:i + 1]

        dest_poses = util.batched_index_select_nd(poses, dest_view)
        all_rays = util.gen_rays(dest_poses.reshape(-1, 4, 4), W, H, focal,
                                 z_near, z_far).reshape(SB, -1, 8)

        pri_images = util.batched_index_select_nd(
            images, src_view)  # (SB, NS, 3, H, W)
        pri_poses = util.batched_index_select_nd(poses,
                                                 src_view)  # (SB, NS, 4, 4)

        net.encode(
            pri_images.to(device=device),
            pri_poses.to(device=device),
            focal.to(device=device),
        )

        rgb_fine, _depth = render_par(all_rays.to(device=device))
        _depth = None
        rgb_fine = rgb_fine.reshape(SB, H, W, 3).cpu().numpy()