mask = (xyz[0,:]<0) | (xyz[1,:]<0) | (xyz[2,:]<0) | (xyz[0,:]>x_lim) | (xyz[1,:]>y_lim) | (xyz[2,:]>z_lim) xyz[:,mask] = 0 sigma = self.img[tuple(xyz)] # Anything out of bounds set back to 0 sigma[mask] = 0.0 sigma = sigma.reshape(1, -1, 1) rgb = torch.ones(sigma.size(0), sigma.size(1), 3).to(device) return torch.cat((rgb, sigma), dim=-1).to(device) image = CTImage(torch.tensor(arr).to(device)) renderer = NeRFRenderer( n_coarse=64, n_fine=32, n_fine_depth=16, depth_std=0.01, sched=[], white_bkgd=False, eval_batch_size=50000 ).to(device=device) render_par = renderer.bind_parallel(image, [0], simple_output=True).eval() render_rays = util.gen_rays(render_poses, W, H, focal, z_near, z_far).to(device=device) all_rgb_fine = [] for rays in tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)): rgb, _depth = render_par(rays[None]) all_rgb_fine.append(rgb[0]) _depth = None rgb_fine = torch.cat(all_rgb_fine) frames = (rgb_fine.view(num_views, H, W, 3).cpu().numpy() * 255).astype( np.uint8 ) im_name = "raw_data"
focal = torch.tensor(focal, dtype=torch.float32, device=device) # Render training data or load in if already rendered # if os.path.exists(os.path.join(output, f'training_ct_{H}.pkl')): # ct_gt = torch.load(os.path.join(output, f'training_ct_{H}.pkl')) # else: image = CTImage(torch.tensor(arr).to(device)) renderer = NeRFRenderer(n_coarse=512, depth_std=0.01, sched=[], white_bkgd=False, composite_x_ray=False, eval_batch_size=50000, lindisp=True).to(device=device) render_par = renderer.bind_parallel(image, [0], simple_output=True).eval() render_rays = util.gen_rays_variable_sensor(render_poses, width_pixels, height_pixels, width, height, focal, z_near, z_far).to(device) all_rgb_fine = [] for rays in tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)): rgb, _depth = render_par(rays[None]) all_rgb_fine.append(rgb[0]) _depth = None rgb_fine = torch.cat(all_rgb_fine) # rgb_fine = torch.clamp(1 - rgb_fine, 0, 1) ct_gt_min = rgb_fine.min() ct_gt_max = rgb_fine.max()