def z_to_pcl_CC(z, camera): viewport = np.array(camera['viewport']) W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1]) aspect_ratio = W / H fovy = camera['fovy'] focal_length = camera['focal_length'] h = np.tan(fovy / 2) * 2 * focal_length w = h * aspect_ratio ##### Find (X, Y) in the Camera's view frustum # Force the caller to set the z coordinate with the correct sign Z = -torch.nn.functional.relu(-z) x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H)) x *= w / 2 y *= h / 2 x = tch_var_f(x.ravel()) y = tch_var_f(y.ravel()) X = -Z * x / focal_length Y = -Z * y / focal_length return torch.stack((X, Y, Z), dim=1)
def lf_renderer(pos, normal, lfnet, num_samples=20): """This is a simpler version of lf_renderer_v0 where the same direction samples are used for all surfels. The samples are on a uniform sphere and so this renderer also supports transmissive medium. Args: pos: normal: lfnet: num_samples: Returns: """ pos_all = pos.reshape((-1, 3)) normal_all = tch_var_f(normal.reshape((-1, 3))) spherical_samples = uniform_sample_sphere(radius=1.0, num_samples=num_samples) inp = tch_var_f( np.concatenate((np.tile(pos_all[:, np.newaxis, :], (1, num_samples, 1)), np.tile(spherical_samples[np.newaxis, :, :], (pos_all.shape[0], 1, 1))), axis=-1)) Li = lfnet(inp) cos_theta = torch.sum(inp[:, :, 3:6] * normal_all[:, np.newaxis, :], dim=-1) nonzero_mask = (cos_theta > 0).float() pos_cos_theta = cos_theta * nonzero_mask im = torch.sum(pos_cos_theta[..., np.newaxis] * Li, dim=1).reshape(pos.shape) return im
def projection_renderer_differentiable(surfels, rgb, camera, rotated_image=None, blur_size=0.15): """Project surfels given in world coordinate to the camera's projection plane in a way that is differentiable w.r.t depth. This is achieved by interpolating the surfel values using a Gaussian filter. Args: surfels: [batch_size, num_surfels, pos] rgb: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data] camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...], 'viewport': [0, 0, W, H], 'fovy': <radians>}] rotated_image: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data] Image to mix in with the result of the rotation. sigma: Std of the Gaussian used for filtering. As a rule of thumb, surfels in a radius of 3*sigma around a pixel will have a contribution on that pixel in the final image. Returns: RGB image of dimensions [batch_size, H, W, 3] from projected surfels """ px_idx, px_coord = project_image_coordinates(surfels, camera) viewport = make_list2np(camera['viewport']) W = int(viewport[2] - viewport[0]) H = int(viewport[3] - viewport[1]) rgb_reshaped = rgb.view(rgb.size(0), -1, rgb.size(-1)) # Perform a weighted average of points surrounding a pixel using a Gaussian filter # Very similar to the idea in this paper: https://arxiv.org/pdf/1810.09381.pdf x, y = np.meshgrid( np.linspace(0, W - 1, W) + 0.5, np.linspace(0, H - 1, H) + 0.5) x, y = tch_var_f(x.ravel()).repeat(surfels.size(0), 1), tch_var_f(y.ravel()).repeat( surfels.size(0), 1) x, y = x.unsqueeze(-1), y.unsqueeze(-1) xp, yp = px_coord[..., 0].unsqueeze(-2), px_coord[..., 1].unsqueeze(-2) sigma = blur_size * rgb.size(-2) / 6 scale = torch.exp((-(xp - x)**2 - (yp - y)**2) / (2 * sigma**2)) mask = scale.sum(-1) if rotated_image is not None: rotated_image = rotated_image.view(*rgb_reshaped.size()) # out = (rotated_image_weight * rotated_image + torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3), -2)) / (scale.sum(-1) + rotated_image_weight + 1e-10).unsqueeze(-1) out = torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3), -2) + rotated_image * (1 - mask) else: out = torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3), -2) / (mask + 1e-10).unsqueeze(-1) return out.view(*rgb.size()), mask.view(*rgb.size()[:-1], 1)
def perspective_RH_NO(fovy, aspect, near, far): """Right-handed camera with all coords mapped to [-1, 1] """ mat_00, mat_11, mat_22, mat_23 = perspective_NO_params( fovy, aspect, near, far) return tch_var_f([[mat_00, 0, 0, 0], [0, mat_11, 0, 0], [0, 0, -mat_22, mat_23], [0, 0, -1, 0]])
def inv_perspective_RH_NO(fovy, aspect, near, far): """Inverse perspective for right-handed camera with all coords mapped from [-1, 1] """ mat_00, mat_11, mat_22, mat_23 = perspective_NO_params( fovy, aspect, near, far) return tch_var_f([[1 / mat_00, 0, 0, 0], [0, 1 / mat_11, 0, 0], [0, 0, 0, -1], [0, 0, 1 / mat_23, -mat_22 / mat_23]])
def test_NEstNet(): import numpy as np pos = tch_var_f(list(np.random.rand(1, 3, 5, 5))) y = NEstNetV0(sph=False).cuda()(pos) print(y.shape, y.norm(dim=1)) y = NEstNetAffine(kernel_size=3).cuda()(pos) print(y.shape, y.norm(dim=1))
def test_LFNet(): from diffrend.torch.utils import tch_var_f import numpy as np pos = tch_var_f(list(np.random.rand(1, 10, 8))) y = LFNetV0(in_ch=8, out_ch=3).cuda()(pos) print(y) print(y.shape, y)
def _init_rays(self, camera): viewport = np.array(camera['viewport']) W, H = viewport[2] - viewport[0], viewport[3] - viewport[1] aspect_ratio = W / H x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H)) n_pixels = x.size fovy = np.array(camera['fovy']) focal_length = np.array(camera['focal_length']) h = np.tan(fovy / 2) * 2 * focal_length w = h * aspect_ratio x *= w / 2 y *= h / 2 x = tch_var_f(x.ravel()) y = tch_var_f(y.ravel()) eye = camera['eye'][:3] at = camera['at'][:3] up = camera['up'][:3] proj_type = camera['proj_type'] if proj_type == 'ortho' or proj_type == 'orthographic': ray_dir = normalize(at - eye)[:, np.newaxis] ray_orig = torch.stack((x, y, tch_var_f(np.zeros(n_pixels)), tch_var_f(np.ones(n_pixels))), dim=0) # inv_view_matrix = lookat_inv(eye=eye, at=at, up=up) # ray_orig = torch.mm(inv_view_matrix, ray_orig) # ray_orig = (ray_orig[:3] / ray_orig[3][np.newaxis, :]).permute(1, 0) elif proj_type == 'persp' or proj_type == 'perspective': ray_orig = eye[np.newaxis, :] ray_dir = torch.stack((x, y, tch_var_f(-np.ones(n_pixels) * focal_length)), dim=0) # inv_view_matrix = lookat_rot_inv(eye=eye, at=at, up=up) # ray_dir = torch.mm(inv_view_matrix, ray_dir) # normalize ray direction ray_dir /= torch.sqrt(torch.sum(ray_dir ** 2, dim=0)) else: raise ValueError("Invalid projection type") self.ray_orig = ray_orig self.ray_dir = ray_dir self.H = H self.W = W return ray_orig, ray_dir, H, W
def project_image_coordinates(surfels, camera): """Project surfels given in world coordinate to the camera's projection plane. Args: surfels: [batch_size, pos] camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...], 'viewport': [0, 0, W, H], 'fovy': <radians>}] Returns: Image of destination indices of dimensions [batch_size, H*W] Note that the range of possible coordinates is restricted to be between 0 and W*H (inclusive). This is inclusive because we use the last index as a "dump" for any index that falls outside of the camera's field of view """ surfels_plane = project_surfels(surfels, camera) # Rasterize viewport = make_list2np(camera['viewport']) W, H = float(viewport[2] - viewport[0]), float(viewport[3] - viewport[1]) aspect_ratio = float(W) / float(H) fovy = make_list2np(camera['fovy']) focal_length = make_list2np(camera['focal_length']) h = np.tan(fovy / 2) * 2 * focal_length w = h * aspect_ratio px_coord = torch.zeros_like(surfels_plane) px_coord[..., 2] = surfels_plane[..., 2] # Make sure to also transmit the new depth px_coord[..., :2] = surfels_plane[..., :2] * tch_var_f( [-(W - 1) / w, (H - 1) / h]).unsqueeze(-2) + tch_var_f( [W / 2., H / 2.]).unsqueeze(-2) px_coord_idx = torch.round(px_coord - 0.5).long() px_idx = px_coord_idx[..., 1] * W + px_coord_idx[..., 0] max_idx = W * H # Index used if the indices are out of bounds of the camera max_idx_tensor = tch_var_l([max_idx]) # Map out of bounds pixels to the last (extra) index mask = (px_coord_idx[..., 1] < 0) | (px_coord_idx[..., 0] < 0) | ( px_coord_idx[..., 1] >= H) | (px_coord_idx[..., 0] >= W) px_idx = torch.where(mask, max_idx_tensor, px_idx) return px_idx, px_coord
def test_render_splat_NDC_0(): fovy = np.deg2rad(45) aspect_ratio = 1 near = 0.1 far = 1000 M = perspective(fovy, aspect_ratio, near, far) Minv = inv_perspective(fovy, aspect_ratio, near, far) pos_NDC = tch_var_f([[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 0.0, 1.0]]) normals_SLC = tch_var_f([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]) num_objects = pos_NDC.size()[0] # Transform params to the Camera's view frustum if pos_NDC.size()[-1] == 3: pos_NDC = torch.cat((pos_NDC, tch_var_f(np.ones((num_objects, 1)))), dim=1) pos_CC = torch.matmul(pos_NDC, Minv.transpose(1, 0)) pos_CC = pos_CC / pos_CC[..., 3][:, np.newaxis] pixel_dist = norm_p(pos_CC[..., :3])
def lf_renderer_v0(pos, normal, lfnet, num_samples=10): pos_all = pos.reshape((-1, 3)) normal_all = normal.reshape((-1, 3)) pixel_colors = [] for idx in range(pos_all.shape[0]): dir_sample = uniform_sample_sphere(radius=1.0, num_samples=num_samples, axis=normal_all[idx], angle=np.pi / 2) inp = tch_var_f( np.concatenate((np.tile(pos_all[idx], (num_samples, 1)), dir_sample), axis=-1)) Li = lfnet(inp) cos_theta = torch.sum(inp[:, 3:6] * tch_var_f(normal_all[idx]), dim=-1) rgb = torch.sum(cos_theta[:, np.newaxis] * Li, dim=0) pixel_colors.append(rgb) im = torch.cat(pixel_colors, dim=0).reshape(pos.shape) return im
def batch_render_random_camera(filename, cam_dist, num_views, width, height, fovy, focal_length, theta_range=None, phi_range=None, axis=None, angle=None, cam_pos=None, cam_lookat=None, double_sided=False, use_quartic=False, b_shadow=True, tile_size=None, save_image_queue=None): rendering_time = [] obj = load_model(filename) # normalize the vertices v = obj['v'] axis_range = np.max(v, axis=0) - np.min(v, axis=0) v = (v - np.mean(v, axis=0)) / max(axis_range) # Normalize to make the largest spread 1 obj['v'] = v scene = copy.deepcopy(SCENE_BASIC) scene['camera']['viewport'] = [0, 0, width, height] scene['camera']['fovy'] = np.deg2rad(fovy) scene['camera']['focal_length'] = focal_length mesh = obj_to_triangle_spec(obj) faces = mesh['face'] normals = mesh['normal'] num_tri = faces.shape[0] if 'disk' in scene['objects']: del scene['objects']['disk'] scene['objects'].update({'triangle': {'face': None, 'normal': None, 'material_idx': None}}) scene['objects']['triangle']['face'] = tch_var_f(faces.tolist()) scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist()) scene['objects']['triangle']['material_idx'] = tch_var_l(np.zeros(num_tri, dtype=int).tolist()) scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]]) scene['tonemap']['gamma'] = tch_var_f([1.0]) # Linear output # generate camera positions on a sphere if cam_pos is None: cam_pos = uniform_sample_sphere(radius=cam_dist, num_samples=num_views, axis=axis, angle=angle, theta_range=theta_range, phi_range=phi_range) lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0) scene['camera']['at'] = tch_var_f(lookat) for idx in range(cam_pos.shape[0]): scene['camera']['eye'] = tch_var_f(cam_pos[idx]) # main render run start_time = time() res = render(scene, tile_size=tile_size, tiled=tile_size is not None, shadow=b_shadow, double_sided=double_sided, use_quartic=use_quartic) res['suffix'] = '_{}'.format(idx) res['camera_far'] = scene['camera']['far'] save_image_queue.put_nowait(get_data(res)) rendering_time.append(time() - start_time) # Timing statistics print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time), np.std(rendering_time)))
def test_scalability(filename, out_dir='./test_scale'): # GTX 980 8GB # 320 x 240 250 objs # 64 x 64 5000 objs # 32 x 32 20000 objs # 16 x 16 75000 objs (slow) from diffrend.model import load_model splats = load_model(filename) v = splats['v'] # normalize the vertices v = (v - np.mean(v, axis=0)) / (v.max() - v.min()) print(np.min(splats['v'], axis=0)) print(np.max(splats['v'], axis=0)) print(np.min(v, axis=0)) print(np.max(v, axis=0)) rand_idx = np.arange( v.shape[0]) #np.random.randint(0, splats['v'].shape[0], 4000) # large_scene = copy.deepcopy(SCENE_BASIC) large_scene['camera']['viewport'] = [0, 0, 64, 64] #[0, 0, 320, 240] large_scene['camera']['fovy'] = np.deg2rad(5.) large_scene['camera']['focal_length'] = 2. #large_scene['camera']['eye'] = tch_var_f([0.0, 1.0, 5.0, 1.0]), large_scene['objects']['disk']['pos'] = tch_var_f(v[rand_idx]) large_scene['objects']['disk']['normal'] = tch_var_f( splats['vn'][rand_idx]) large_scene['objects']['disk']['radius'] = tch_var_f( splats['r'][rand_idx].ravel() * 2) large_scene['objects']['disk']['material_idx'] = tch_var_l( np.zeros(rand_idx.size, dtype=int).tolist()) large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]]) render_scene(large_scene, out_dir, plot_res=True)
def forward(self, x): x = self.net(x) if self.sph_out: x = F.sigmoid(x) * tch_var_f( [2 * np.pi, np.pi / 2])[np.newaxis, :, np.newaxis, np.newaxis] x = sph2cart_unit(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) else: x = torch.cat([ x[:, 0, :, :][:, np.newaxis, ...], x[:, 1, :, :][:, np.newaxis, ...], torch.abs(x[:, 2, :, :][:, np.newaxis, ...]) ], dim=1) sum_squared = torch.sum(x**2, dim=1) x = x / torch.sqrt(sum_squared + 1e-12) return x
def rotate_cameras(camera, theta=0, phi=0): # Get the current camera rotation (relative to the 'lookat' position) camera_eye = cartesian_to_spherical( get_data(camera['eye']) - get_data(camera['at'])) # Rotate the camera new_thetas = camera_eye[..., 0] + theta new_phis = camera_eye[..., 1] + phi # Go back to cartesian coordinates and place the camera back relative to the 'lookat' position camera_eye = spherical_to_cartesian(new_thetas, new_phis, radius=np.expand_dims( camera_eye[..., 2], -1)) if camera['at'].shape[-1] == 4: zeros = np.zeros((camera_eye.shape[0], 1)) camera_eye = np.concatenate((camera_eye, zeros), axis=-1) camera['eye'] = tch_var_f(camera_eye) + camera['at']
import numpy as np from diffrend.torch.utils import tch_var_f, tch_var_l # Starter scene for rendering splats SCENE_BASIC = { 'camera': { 'proj_type': 'perspective', 'viewport': [0, 0, 320, 240], 'fovy': np.deg2rad(90.), 'focal_length': 1., 'eye': tch_var_f([0.0, 1.0, 10.0, 1.0]), 'up': tch_var_f([0.0, 1.0, 0.0, 0.0]), 'at': tch_var_f([0.0, 0.0, 0.0, 1.0]), 'near': 0.1, 'far': 1000.0, }, 'lights': { 'pos': tch_var_f([ [10., 0., 0., 1.0], [-10, 0., 0., 1.0], [0, 10., 0., 1.0], [0, -10., 0., 1.0], [0, 0., 10., 1.0], [0, 0., -10., 1.0], [20, 20, 20, 1.0], ]), 'color_idx': tch_var_l([1, 3, 4, 5, 6, 7, 1]), # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2) 'attenuation':
scene = SCENE_1 if args.ortho: scene['camera']['proj_type'] = 'ortho' if args.render: res = render_scene(scene, args.out_dir, args.norm_depth_image_only, backface_culling=args.backface_culling, plot_res=args.display) if args.opt: input_scene = copy.deepcopy(SCENE_BASIC) input_scene['materials']['albedo'] = tch_var_f([ [0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.1, 0.8, 0.9], [0.1, 0.8, 0.9], [0.9, 0.1, 0.1], ]) optimize_scene(input_scene, scene, args.out_dir, max_iter=args.max_iter, lr=args.lr, print_interval=args.print_interval) if args.test_scale: test_scalability(filename=args.model_filename, out_dir=args.out_dir) if args.opt_ndc_test: optimize_NDC_test(out_dir=args.out_dir, width=args.width,
def optimize_NDC_test(out_dir, width=32, height=32, max_iter=100, lr=1e-3, scale=10, print_interval=10, imsave_interval=10): """A demo function to check if the differentiable renderer can optimize splats in NDC. :param scene: :param out_dir: :return: """ import torch import copy from diffrend.torch.params import SCENE_SPHERE_HALFBOX if not os.path.exists(out_dir): os.mkdir(out_dir) scene = SCENE_SPHERE_HALFBOX scene['camera']['viewport'] = [0, 0, width, height] scene['camera']['fovy'] = np.deg2rad(45) scene['camera']['focal_length'] = 1 scene['camera']['eye'] = tch_var_f([2, 1, 2, 1]) scene['camera']['at'] = tch_var_f([0, 0.8, 0, 1]) target_res = render(SCENE_SPHERE_HALFBOX) target_im = target_res['image'] target_im.require_grad = False target_im_ = get_data(target_res['image']) criterion = nn.L1Loss() #nn.MSELoss() criterion = criterion.cuda() plt.ion() plt.figure() plt.imshow(target_im_) plt.title('Target Image') plt.savefig(out_dir + '/target.png') input_scene = copy.deepcopy(scene) del input_scene['objects']['sphere'] del input_scene['objects']['triangle'] num_splats = width * height x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height)) z = tch_var_f(2 * np.random.rand(num_splats) - 1) z.requires_grad = True pos = torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), z), dim=1) normals = tch_var_f(np.ones((num_splats, 4)) * np.array([0, 0, 1, 0])) normals.requires_grad = True material_idx = tch_var_l(np.ones(num_splats) * 3) input_scene['objects'] = { 'disk': { 'pos': pos, 'normal': normals, 'material_idx': material_idx } } optimizer = optim.Adam((z, normals), lr=lr) h0 = plt.figure() h1 = plt.figure() loss_per_iter = [] for iter in range(max_iter): res = render_splats_NDC(input_scene) im_out = res['image'] optimizer.zero_grad() loss = criterion(scale * im_out, scale * target_im) im_out_ = get_data(im_out) loss_ = get_data(loss) loss_per_iter.append(loss_) if iter == 0: plt.figure(h0.number) plt.imshow(im_out_) plt.title('Initial') if iter % print_interval == 0 or iter == max_iter - 1: print('%d. loss= %f' % (iter, loss_)) if iter % imsave_interval == 0 or iter == max_iter - 1: plt.figure(h1.number) plt.imshow(im_out_) plt.title('%d. loss= %f' % (iter, loss_)) plt.savefig(out_dir + '/fig_%05d.png' % iter) loss.backward() optimizer.step() plt.figure() plt.plot(loss_per_iter, linewidth=2) plt.xlabel('Iteration', fontsize=14) plt.title('Loss', fontsize=12) plt.grid(True) plt.savefig(out_dir + '/loss.png') plt.ioff() plt.show()
def train(self, ): """Train network.""" # Load pretrained model if required if self.opt.gen_model_path is not None: print("Reloading networks from") print(' > Generator', self.opt.gen_model_path) self.netG.load_state_dict( torch.load(open(self.opt.gen_model_path, 'rb'))) print(' > Generator2', self.opt.gen_model_path2) self.netG2.load_state_dict( torch.load(open(self.opt.gen_model_path2, 'rb'))) print(' > Discriminator', self.opt.dis_model_path) self.netD.load_state_dict( torch.load(open(self.opt.dis_model_path, 'rb'))) print(' > Discriminator2', self.opt.dis_model_path2) self.netD2.load_state_dict( torch.load(open(self.opt.dis_model_path2, 'rb'))) # Start training train_stream = Iterator(batch_size=self.opt.batchSize) file_name = os.path.join(self.opt.out_dir, 'L2.txt') dsize=len(train_stream) for epoch in range(self.opt.n_iter): self.critic_iter=0 # Train Discriminator critic_iters times for cnt, batch in enumerate(train_stream): # Train with real ################# #print("hii") self.iterationa_no = epoch*dsize+cnt iteration=self.iterationa_no x, cp,lp = batch real_data = tch_var_f(x) cam_pos = tch_var_f(cp) light_pos = tch_var_f(lp) # real_data = real_data.cuda() # cam_pos = cam_pos.cuda() # light_pos = light_pos.cuda() self.in_critic=1 self.netD.zero_grad() real_data = real_data.permute(0,3, 1, 2) # input_D = torch.cat([self.inputv, self.inputv_depth], 1) #import ipdb; ipdb.set_trace() real_output = self.netD(real_data, cam_pos) if self.opt.criterion == 'GAN': errD_real = self.criterion(real_output, self.labelv) errD_real.backward() elif self.opt.criterion == 'WGAN': errD_real = real_output.mean() errD_real.backward(self.mone) else: raise ValueError('Unknown GAN criterium') # Train with fake ################# self.generate_noise_vector() fake_z = self.netG(self.noisev, cam_pos) # The normal generator is dependent on z fake_n = self.generate_normals(fake_z, cam_pos, self.scene['camera']) fake = torch.cat([fake_z, fake_n], 2) fake_rendered, fd, loss = self.render_batch( fake, cam_pos,lp) # Do not bp through gen outD_fake = self.netD(fake_rendered.detach(), cam_pos.detach()) if self.opt.criterion == 'GAN': labelv = Variable(self.label.fill_(self.fake_label)) errD_fake = self.criterion(outD_fake, labelv) errD_fake.backward() errD = errD_real + errD_fake elif self.opt.criterion == 'WGAN': errD_fake = outD_fake.mean() errD_fake.backward(self.one) errD = errD_fake - errD_real else: raise ValueError('Unknown GAN criterium') # Compute gradient penalty if self.opt.gp != 'None': gradient_penalty = calc_gradient_penalty( self.netD, real_data.data, fake_rendered.data, cam_pos.data, self.opt.gp_lambda) gradient_penalty.backward() errD += gradient_penalty gnorm_D = torch.nn.utils.clip_grad_norm( self.netD.parameters(), self.opt.max_gnorm) # TODO # Update weight self.optimizerD.step() # Clamp critic weigths if not GP and if WGAN if self.opt.criterion == 'WGAN' and self.opt.gp == 'None': for p in self.netD.parameters(): p.data.clamp_(-self.opt.clamp, self.opt.clamp) self.critic_iter+=1 ############################ # (2) Update G network ########################### # To avoid computation # for p in self.netD.parameters(): # p.requires_grad = False if cnt % self.opt.critic_iters==0 and cnt >0: self.netG.zero_grad() self.in_critic=0 self.generate_noise_vector() fake_z = self.netG(self.noisev, cam_pos) if iteration % self.opt.print_interval*4 == 0: fake_z.register_hook(self.tensorboard_hook) fake_n = self.generate_normals(fake_z, cam_pos, self.scene['camera']) fake = torch.cat([fake_z, fake_n], 2) fake_rendered, fd, loss = self.render_batch( fake, cam_pos, lp) outG_fake = self.netD(fake_rendered, cam_pos) if self.opt.criterion == 'GAN': # Fake labels are real for generator cost labelv = Variable(self.label.fill_(self.real_label)) errG = self.criterion(outG_fake, labelv) errG.backward() elif self.opt.criterion == 'WGAN': errG = outG_fake.mean() + loss errG.backward(self.mone) else: raise ValueError('Unknown GAN criterium') gnorm_G = torch.nn.utils.clip_grad_norm( self.netG.parameters(), self.opt.max_gnorm) # TODO if (self.opt.alt_opt_zn_interval is not None and iteration >= self.opt.alt_opt_zn_start): # update one of the generators if (((iteration - self.opt.alt_opt_zn_start) % self.opt.alt_opt_zn_interval) == 0): # switch generator vars to optimize curr_generator_idx = (1 - curr_generator_idx) if iteration < self.opt.lr_iter: self.LR_SCHED_MAP[curr_generator_idx].step() self.OPT_MAP[curr_generator_idx].step() else: if iteration < self.opt.lr_iter: self.optG_z_lr_scheduler.step() self.optimizerG.step() mse_criterion = nn.MSELoss().cuda() # Log print if iteration % (self.opt.print_interval*5) == 0 and cnt >0 : Wassertein_D = (errD_real.data[0] - errD_fake.data[0]) self.writer.add_scalar("Loss_G", errG.data[0], self.iterationa_no) self.writer.add_scalar("Loss_D", errD.data[0], self.iterationa_no) self.writer.add_scalar("Wassertein_D", Wassertein_D, self.iterationa_no) self.writer.add_scalar("Disc_grad_norm", gnorm_D, self.iterationa_no) self.writer.add_scalar("Gen_grad_norm", gnorm_G, self.iterationa_no) print('\n[%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_D_real: %.4f' ' Loss_D_fake: %.4f Wassertein_D: %.4f ' ' L2_loss: %.4f z_lr: %.8f, n_lr: %.8f, Disc_grad_norm: %.8f, Gen_grad_norm: %.8f' % ( iteration, self.opt.n_iter, errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0], Wassertein_D, loss.data[0], self.optG_z_lr_scheduler.get_lr()[0], self.optG2_normal_lr_scheduler.get_lr()[0], gnorm_D, gnorm_G)) # Save output images if iteration % (self.opt.save_image_interval*5) == 0 and cnt >0: cs = tch_var_f(contrast_stretch_percentile( get_data(fd), 200, [fd.data.min(), fd.data.max()])) torchvision.utils.save_image( fake_rendered.data, os.path.join(self.opt.vis_images, 'output_%d.png' % (iteration)), nrow=2, normalize=True, scale_each=True) # Save input images if iteration % (self.opt.save_image_interval*5) == 0: cs = tch_var_f(contrast_stretch_percentile( get_data(fd), 200, [fd.data.min(), fd.data.max()])) torchvision.utils.save_image( real_data.data, os.path.join( self.opt.vis_images, 'input_%d.png' % (iteration)), nrow=2, normalize=True, scale_each=True) # Do checkpointing if iteration % (self.opt.save_interval*2) == 0: self.save_networks(iteration)
def render_splats_along_ray(scene, **params): """Render splats specified in the camera's coordinate system For now, assume number of splats to be the number of pixels This would be relaxed later to allow subpixel rendering. :param scene: Scene description :return: [H, W, 3] image """ # TODO (fmannan): reuse z_to_pcl_CC camera = scene['camera'] viewport = np.array(camera['viewport']) W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1]) aspect_ratio = W / H eye = camera['eye'][:3] at = camera['at'][:3] up = camera['up'][:3] Mcam = lookat(eye=eye, at=at, up=up) #M = perspective(fovy, aspect_ratio, near, far) #Minv = inv_perspective(fovy, aspect_ratio, near, far) splats = scene['objects']['disk'] pos_ray = splats['pos'] normals_CC = get_param_value('normal', splats, None) #num_objects = pos_ray.size()[0] fovy = camera['fovy'] focal_length = camera['focal_length'] h = np.tan(fovy / 2) * 2 * focal_length w = h * aspect_ratio ##### Find (X, Y) in the Camera's view frustum # Force the caller to set the z coordinate with the correct sign if pos_ray.dim() == 1: Z = -torch.nn.functional.relu(-pos_ray) # -torch.abs(pos_ray[:, 2]) else: Z = -torch.nn.functional.relu(-pos_ray[:, 2]) #-torch.abs(pos_ray[:, 2]) x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H)) x *= w / 2 y *= h / 2 x = tch_var_f(x.ravel()) y = tch_var_f(y.ravel()) #sgn = 1 if get_param_value('use_old_sign', params, False) else -1 X = -Z * x / focal_length Y = -Z * y / focal_length pos_CC = torch.stack((X, Y, Z), dim=1) if get_param_value('orient_splats', params, False) and normals_CC is not None: # TODO (fmannan): Orient splats so that [0, 0, 1] maps to the camera direction # Peform this operation only when splat normals are generated by the caller in CC # This should help with splats that are at the edge of the view-frustum and the camera has # a large fov. pass # Estimate normals from splats/point-cloud if no normals were provided if normals_CC is None: normal_est_method = get_param_value('normal_estimation_method', params, 'plane') kernel_size = get_param_value('normal_estimation_kernel_size', params, 3) normals_CC = estimate_surface_normals(pos_CC.view(H, W, 3), kernel_size, normal_est_method)[..., :3].view(-1, 3) material_idx = scene['objects']['disk']['material_idx'] light_visibility = None if 'light_vis' in scene['objects']['disk']: light_visibility = scene['objects']['disk']['light_vis'] # Samples per pixel (supersampling) samples = get_param_value('samples', params, 1) if samples > 1: """There are three variables that need to be upsampled: 1. positions, 2. normals, and 3. shadow maps (light visibility) The idea here is to generate an x-y grid in the original resolution, then shift that to find the subpixels, then find the plane parameters for the splat bounded within the pixel frustum (i.e., a frustum projected into the scene by a pixel), and then for each subpixel find the ray-plane intersection with that splat plane. The subpixel rays are generated by taking the mesh on the projection plane and shifting it by the appropriate amount to get the pixel coordinate that the ray should go through, then finding the position in the 3D camera space. The normal of the splat is copied to all those surface samples. n_x (x - x0) + n_y (y - y0) + n_z (z - z0) = 0 n_x x0 + n_y y0 + n_z z0 = d0 n_x t u_x + ... = d0 t = d0 / dot(n, ray) """ # plane parameter d = torch.sum(pos_CC * normals_CC[:, :3], dim=1) z = tch_var_f(np.ones(x.shape) * -focal_length) # # Test consistency # pos_CC_projplane = torch.stack((x, y, z), dim=1) # dot_ray_normal = torch.sum(pos_CC_projplane * normals_CC[:, :3], dim=1) # t = d / dot_ray_normal # pos_CC_test = t[:, np.newaxis] * pos_CC_projplane # diff = torch.mean(torch.abs(pos_CC_test - pos_CC)) # print(diff) # # End of consistency check # Find ray-plane intersection for the plane bounded by the frustum # The width and height of the projection plane are w and h dx = w / (samples * W - 1) # subpixel width dy = h / (samples * H - 1) # subpixel height pos_CC_supersampled = [] normals_CC_supersampled = [] material_idx_supersampled = [] if light_visibility is not None: light_visibility_supersampled = [] light_visibility = light_visibility.transpose(1, 0) for c, deltax in enumerate(np.linspace(-1, 1, samples)): # TODO (fmannan): generalize (the div by 2) for samples > 3 xx = x + deltax * dx / 2 # Shift by half of the subpixel size for r, deltay in enumerate(np.linspace(1, -1, samples)): yy = y + deltay * dy / 2 # unit ray going through sub-pixels pos_CC_projplane = normalize(torch.stack((xx, yy, z), dim=1)) dot_ray_normal = torch.sum(pos_CC_projplane * normals_CC[:, :3], dim=1) t = d / dot_ray_normal pos_CC_supersampled.append(t[:, np.newaxis] * pos_CC_projplane) normals_CC_supersampled.append(normals_CC[:, :3]) material_idx_supersampled.append(material_idx[:, np.newaxis]) if light_visibility is not None: light_visibility_supersampled.append(light_visibility) pos_CC_supersampled = torch.stack(pos_CC_supersampled, dim=2) normals_CC_supersampled = torch.stack(normals_CC_supersampled, dim=2) material_idx_supersampled = torch.stack(material_idx_supersampled, dim=2) if light_visibility is not None: light_visibility_supersampled = torch.stack(light_visibility_supersampled, dim=2) pos_CC = reshape_upsampled_data(pos_CC_supersampled, H, W, 3, samples) normals_CC = reshape_upsampled_data(normals_CC_supersampled, H, W, 3, samples) material_idx = reshape_upsampled_data(material_idx_supersampled, H, W, 1, samples).view(-1) if light_visibility is not None: light_visibility = reshape_upsampled_data(light_visibility_supersampled, H, W, light_visibility.shape[1], samples).transpose(1, 0) H *= samples W *= samples #### im_depth = norm_p(pos_CC[..., :3]).view(H, W) if get_param_value('norm_depth_image_only', params, False): min_depth = torch.min(im_depth) norm_depth_image = where(im_depth >= camera['far'], min_depth, im_depth) norm_depth_image = (norm_depth_image - min_depth) / (torch.max(im_depth) - min_depth) return { 'image': norm_depth_image, 'depth': im_depth, 'pos': pos_CC, 'normal': normals_CC } ############################## # Fragment processing # ------------------- # We can either perform the operations in the world coordinate or in the camera coordinate # Since the inputs are in NDC and converted to CC, converting to world coordinate would require more operations. # There are fewer lights than splats, so converting light positions and directions to CC is more efficient. ############################## # Lighting color_table = scene['colors'] light_pos = scene['lights']['pos'] light_clr_idx = scene['lights']['color_idx'] light_colors = color_table[light_clr_idx] light_attenuation_coeffs = scene['lights']['attenuation'] ambient_light = scene['lights']['ambient'] material_albedo = scene['materials']['albedo'] material_coeffs = scene['materials']['coeffs'] light_pos_CC = torch.mm(light_pos, Mcam.transpose(1, 0)) # Generate the fragments """ Get the normal and material for the visible objects. """ frag_normals = normals_CC[:, :3] frag_pos = pos_CC[:, :3] frag_albedo = torch.index_select(material_albedo, 0, material_idx) frag_coeffs = torch.index_select(material_coeffs, 0, material_idx) im_color = fragment_shader(frag_normals=frag_normals, light_dir=light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3], cam_dir=-normalize(frag_pos[np.newaxis, :, :3]), light_attenuation_coeffs=light_attenuation_coeffs, frag_coeffs=frag_coeffs, light_colors=light_colors, ambient_light=ambient_light, frag_albedo=frag_albedo, double_sided=False, use_quartic=get_param_value('use_quartic', params, False), light_visibility=light_visibility) im = torch.sum(im_color, dim=0).view(int(H), int(W), 3) # clip non-negative im = torch.nn.functional.relu(im) # Tonemapping #if 'tonemap' in scene: # im = tonemap(im, **scene['tonemap']) return { 'image': im, 'depth': im_depth, 'pos': pos_CC.view(H, W, 3), 'normal': normals_CC.contiguous().view(H, W, 3) }
def get_real_samples(self): """Get a real sample.""" # Define the camera poses if not self.opt.same_view: if self.opt.full_sphere_sampling: self.cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(self.opt.angle), theta_range=self.opt.theta, phi_range=self.opt.phi) else: self.cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=self.opt.angle, theta_range=np.deg2rad(self.opt.theta), phi_range=np.deg2rad(self.opt.phi)) if self.opt.full_sphere_sampling_light: self.light_pos1 = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(44), theta_range=self.opt.theta, phi_range=self.opt.phi) # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize, # axis=self.opt.axis, angle=np.deg2rad(40), # theta_range=self.opt.theta, phi_range=self.opt.phi) else: print("inbox") light_eps = 0.15 self.light_pos1 = np.random.rand(self.opt.batchSize, 3) * self.opt.cam_dist + light_eps self.light_pos2 = np.random.rand(self.opt.batchSize, 3) * self.opt.cam_dist + light_eps # TODO: deg2rad in all the angles???? # Create a splats rendering scene large_scene = create_scene(self.opt.width, self.opt.height, self.opt.fovy, self.opt.focal_length, self.opt.n_splats) lookat = self.opt.at if self.opt.at is not None else [ 0.0, 0.0, 0.0, 1.0 ] large_scene['camera']['at'] = tch_var_f(lookat) # Render scenes data, data_depth, data_normal, data_cond = [], [], [], [] inpath = self.opt.vis_images + '/' inpath2 = self.opt.vis_input + '/' for idx in range(self.opt.batchSize): # Save the splats into the rendering scene if self.opt.use_mesh: if 'sphere' in large_scene['objects']: del large_scene['objects']['sphere'] if 'disk' in large_scene['objects']: del large_scene['objects']['disk'] if 'triangle' not in large_scene['objects']: large_scene['objects'] = { 'triangle': { 'face': None, 'normal': None, 'material_idx': None } } samples = self.get_samples() large_scene['objects']['triangle']['material_idx'] = tch_var_l( np.zeros(samples['mesh']['face'][0].shape[0], dtype=int).tolist()) large_scene['objects']['triangle']['face'] = Variable( samples['mesh']['face'][0].cuda(), requires_grad=False) large_scene['objects']['triangle']['normal'] = Variable( samples['mesh']['normal'][0].cuda(), requires_grad=False) else: if 'sphere' in large_scene['objects']: del large_scene['objects']['sphere'] if 'triangle' in large_scene['objects']: del large_scene['objects']['triangle'] if 'disk' not in large_scene['objects']: large_scene['objects'] = { 'disk': { 'pos': None, 'normal': None, 'material_idx': None } } large_scene['objects']['disk']['radius'] = tch_var_f( np.ones(self.opt.n_splats) * self.opt.splats_radius) large_scene['objects']['disk']['material_idx'] = tch_var_l( np.zeros(self.opt.n_splats, dtype=int).tolist()) large_scene['objects']['disk']['pos'] = Variable( samples['splats']['pos'][idx].cuda(), requires_grad=False) large_scene['objects']['disk']['normal'] = Variable( samples['splats']['normal'][idx].cuda(), requires_grad=False) # Set camera position if not self.opt.same_view: large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx]) else: large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0]) large_scene['lights']['pos'][0, :3] = tch_var_f( self.light_pos1[idx]) #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx]) # Render scene res = render(large_scene, norm_depth_image_only=self.opt.norm_depth_image_only, double_sided=True, use_quartic=self.opt.use_quartic) # Get rendered output if self.opt.render_img_nc == 1: depth = res['depth'] im_d = depth.unsqueeze(0) else: depth = res['depth'] im_d = depth.unsqueeze(0) im = res['image'].permute(2, 0, 1) im_ = get_data(res['image']) #im_img_ = get_normalmap_image(im_) target_normal_ = get_data(res['normal']) target_normalmap_img_ = get_normalmap_image(target_normal_) im_n = tch_var_f(target_normalmap_img_).view( im.shape[1], im.shape[2], 3).permute(2, 0, 1) # Add depth image to the output structure file_name = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_{:05d}.txt'.format(idx) text_file = open(file_name, "w") text_file.write('%s\n' % (str(large_scene['camera']['eye'].data))) text_file.close() out_file_name = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_{:05d}.npy'.format(idx) np.save(out_file_name, self.cam_pos[idx]) out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_light{:05d}.npy'.format(idx) np.save(out_file_name2, self.light_pos1[idx]) out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_im{:05d}.npy'.format(idx) np.save(out_file_name3, get_data(res['image'])) out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_depth{:05d}.npy'.format(idx) np.save(out_file_name4, get_data(res['depth'])) out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_normal{:05d}.npy'.format(idx) np.save(out_file_name5, get_data(res['normal'])) if self.iterationa_no % (self.opt.save_image_interval * 5) == 0: imsave((inpath + str(self.iterationa_no) + 'real_normalmap_{:05d}.png'.format(idx)), target_normalmap_img_) imsave((inpath + str(self.iterationa_no) + 'real_depth_{:05d}.png'.format(idx)), get_data(depth)) # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d) # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_) data.append(im) data_depth.append(im_d) data_normal.append(im_n) data_cond.append(large_scene['camera']['eye']) # Stack real samples real_samples = torch.stack(data) real_samples_depth = torch.stack(data_depth) real_samples_normal = torch.stack(data_normal) real_samples_cond = torch.stack(data_cond) self.batch_size = real_samples.size(0) if not self.opt.no_cuda: real_samples = real_samples.cuda() real_samples_depth = real_samples_depth.cuda() real_samples_normal = real_samples_normal.cuda() real_samples_cond = real_samples_cond.cuda() # Set input/output variables self.input.resize_as_(real_samples.data).copy_(real_samples.data) self.input_depth.resize_as_(real_samples_depth.data).copy_( real_samples_depth.data) self.input_normal.resize_as_(real_samples_normal.data).copy_( real_samples_normal.data) self.input_cond.resize_as_(real_samples_cond.data).copy_( real_samples_cond.data) self.label.resize_(self.batch_size).fill_(self.real_label) # TODO: Remove Variables self.inputv = Variable(self.input) self.inputv_depth = Variable(self.input_depth) self.inputv_normal = Variable(self.input_normal) self.inputv_cond = Variable(self.input_cond) self.labelv = Variable(self.label)
def render_sphere_world(out_dir, cam_pos, radius, width, height, fovy, focal_length, b_display=False): """ Generate z positions on a grid fixed inside the view frustum in the world coordinate system. Place the camera and choose the camera's field of view so that the side of the square touches the frustum. """ import copy print('render sphere') sampling_time = [] rendering_time = [] num_samples = width * height r = np.ones(num_samples) * radius large_scene = copy.deepcopy(SCENE_TEST) large_scene['camera']['viewport'] = [0, 0, width, height] large_scene['camera']['fovy'] = np.deg2rad(fovy) large_scene['camera']['focal_length'] = focal_length large_scene['objects']['disk']['radius'] = tch_var_f(r) large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist()) large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]]) large_scene['tonemap']['gamma'] = tch_var_f([1.0]) # Linear output x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height)) #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1)) unit_disk_mask = (x ** 2 + y ** 2) <= 1 z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2)) # Make a hemi-sphere bulging out of the xy-plane scene z[~unit_disk_mask] = 0 pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1) # Normals outside the sphere should be [0, 0, 1] x[~unit_disk_mask] = 0 y[~unit_disk_mask] = 0 z[~unit_disk_mask] = 1 normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)) if b_display: plt.ion() plt.figure() plt.imshow(pos[..., 2].reshape((height, width))) plt.figure() plt.imshow(normals[..., 2].reshape((height, width))) large_scene['objects']['disk']['pos'] = tch_var_f(pos) large_scene['objects']['disk']['normal'] = tch_var_f(normals) large_scene['camera']['eye'] = tch_var_f(cam_pos) # main render run start_time = time() res = render(large_scene) rendering_time.append(time() - start_time) im = get_data(res['image']) im = np.uint8(255. * im) depth = get_data(res['depth']) depth[depth >= large_scene['camera']['far']] = depth.min() im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min())) if b_display: plt.figure() plt.imshow(im, interpolation='none') plt.title('Image') plt.savefig(out_dir + '/fig_img_orig.png') plt.figure() plt.imshow(im_depth, interpolation='none') plt.title('Depth Image') plt.savefig(out_dir + '/fig_depth_orig.png') imsave(out_dir + '/img_orig.png', im) imsave(out_dir + '/depth_orig.png', im_depth) # hold matplotlib figure plt.ioff() plt.show()
def test_sphere_splat_render_along_ray(out_dir, cam_pos, width, height, fovy, focal_length, use_quartic, b_display=False): """ Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system and then render using render_splats_along_ray. """ import copy print('render sphere along ray') sampling_time = [] rendering_time = [] num_samples = width * height large_scene = copy.deepcopy(SCENE_TEST) large_scene['camera']['viewport'] = [0, 0, width, height] large_scene['camera']['eye'] = tch_var_f(cam_pos) large_scene['camera']['fovy'] = np.deg2rad(fovy) large_scene['camera']['focal_length'] = focal_length large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist()) large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]]) large_scene['tonemap']['gamma'] = tch_var_f([1.0]) # Linear output x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height)) #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1)) unit_disk_mask = (x ** 2 + y ** 2) <= 1 z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2)) # Make a hemi-sphere bulging out of the xy-plane scene z[~unit_disk_mask] = 0 pos = np.stack((x.ravel(), y.ravel(), z.ravel() - 5, np.ones(num_samples)), axis=1) # Normals outside the sphere should be [0, 0, 1] x[~unit_disk_mask] = 0 y[~unit_disk_mask] = 0 z[~unit_disk_mask] = 1 normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1)) if b_display: plt.ion() plt.figure() plt.subplot(131) plt.imshow(pos[..., 0].reshape((height, width))) plt.subplot(132) plt.imshow(pos[..., 1].reshape((height, width))) plt.subplot(133) plt.imshow(pos[..., 2].reshape((height, width))) plt.figure() plt.imshow(normals[..., 2].reshape((height, width))) ## Convert to the camera's coordinate system #Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up']) pos_CC = tch_var_f(pos) #torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0)) large_scene['objects']['disk']['pos'] = pos_CC large_scene['objects']['disk']['normal'] = None # Estimate the normals tch_var_f(normals) # large_scene['camera']['eye'] = tch_var_f([-10., 0., 10.]) # large_scene['camera']['eye'] = tch_var_f([2., 0., 10.]) large_scene['camera']['eye'] = tch_var_f([-5., 0., 0.]) # main render run start_time = time() res = render_splats_along_ray(large_scene, use_quartic=use_quartic) rendering_time.append(time() - start_time) # Test cam_to_world res_world = cam_to_world(res['pos'].reshape(-1, 3), res['normal'].reshape(-1, 3), large_scene['camera']) im = get_data(res['image']) im = np.uint8(255. * im) depth = get_data(res['depth']) depth[depth >= large_scene['camera']['far']] = large_scene['camera']['far'] if b_display: plt.figure() plt.imshow(im, interpolation='none') plt.title('Image') plt.savefig(out_dir + '/fig_img_orig.png') plt.figure() plt.imshow(depth, interpolation='none') plt.title('Depth Image') #plt.savefig(out_dir + '/fig_depth_orig.png') plt.figure() pos_world = get_data(res_world['pos']) posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1])) posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1])) posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1])) plt.subplot(131) plt.imshow(posx_world) plt.title('x_world') plt.subplot(132) plt.imshow(posy_world) plt.title('y_world') plt.subplot(133) plt.imshow(posz_world) plt.title('z_world') fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(pos_world[:, 0], pos_world[:, 1], pos_world[:, 2], s=1.3) ax.set_xlabel('x') ax.set_ylabel('y') plt.figure() pos_world = get_data(res['pos'].reshape(-1, 3)) posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1])) posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1])) posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1])) plt.subplot(131) plt.imshow(posx_world) plt.title('x_CC') plt.subplot(132) plt.imshow(posy_world) plt.title('y_CC') plt.subplot(133) plt.imshow(posz_world) plt.title('z_CC') imsave(out_dir + '/img_orig.png', im) #imsave(out_dir + '/depth_orig.png', im_depth) # hold matplotlib figure plt.ioff() plt.show()
from diffrend.torch.renderer import render_splats_NDC, render, render_splats_along_ray from diffrend.torch.ops import perspective, inv_perspective from diffrend.numpy.ops import normalize as np_normalize from imageio import imsave from mpl_toolkits.mplot3d import axes3d import matplotlib.pyplot as plt from time import time SCENE_TEST = { 'camera': { 'proj_type': 'perspective', 'viewport': [0, 0, 2, 2], 'fovy': np.deg2rad(90.), 'focal_length': 1., 'eye': tch_var_f([0.0, 1.0, 10.0, 1.0]), 'up': tch_var_f([0.0, 1.0, 0.0, 0.0]), 'at': tch_var_f([0.0, 0.0, 0.0, 1.0]), 'near': 1.0, 'far': 1000.0, }, 'lights': { 'pos': tch_var_f([ [0., 0., -10., 1.0], [-15, 3, 15, 1.0], [0, 0., 10., 1.0], ]), 'color_idx': tch_var_l([2, 1, 3]), # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2) 'attenuation': tch_var_f([ [1., 0., 0.0],
from diffrend.torch.utils import get_data, tch_var_f, cam_to_world from mpl_toolkits.mplot3d import axes3d import matplotlib.pyplot as plt #scene = np.load('scene_output_twogans.npy') scene = np.load('scene_input_twogans_unnorm.npy') #scene = np.load('scene_output.npy') #pos = get_data(scene[0]['objects']['disk']['pos']) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(scene[:, 0], scene[:, 1], scene[:, 2], s=1.3) for idx in range(0, len(scene), 20): print(idx) scene[idx]['lights']['attenuation'] = tch_var_f([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) res = render_splats_along_ray(scene[idx], use_old_sign=False) im = get_data(res['image']) depth = get_data(res['depth']) plt.figure() plt.imshow(im) plt.figure() plt.imshow(depth) pos = get_data(res['pos']) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=1.3)
def render_splats_NDC(scene, **params): """Render splats specified in the camera's normalized coordinate system For now, assume number of splats to be the number of pixels This would be relaxed later to allow subpixel rendering. :param scene: Scene description :return: [H, W, 3] image """ camera = scene['camera'] viewport = np.array(camera['viewport']) W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1]) aspect_ratio = W / H fovy = camera['fovy'] near = camera['near'] far = camera['far'] eye = camera['eye'][:3] at = camera['at'][:3] up = camera['up'][:3] Mcam = lookat(eye=eye, at=at, up=up) #M = perspective(fovy, aspect_ratio, near, far) Minv = inv_perspective(fovy, aspect_ratio, near, far) splats = scene['objects']['disk'] pos_NDC = splats['pos'] normals_SLC = splats['normal'] num_objects = pos_NDC.size()[0] # Transform params to the Camera's view frustum if pos_NDC.size()[-1] == 3: pos_NDC = torch.cat((pos_NDC, tch_var_f(np.ones((num_objects, 1)))), dim=1) pos_CC = torch.matmul(pos_NDC, Minv.transpose(1, 0)) pos_CC = pos_CC / pos_CC[..., 3][:, np.newaxis] im_depth = norm_p(pos_CC[..., :3]).view(H, W) if get_param_value('norm_depth_image_only', params, False): min_depth = torch.min(im_depth) norm_depth_image = where(im_depth >= camera['far'], min_depth, im_depth) norm_depth_image = (norm_depth_image - min_depth) / (torch.max(im_depth) - min_depth) return { 'image': norm_depth_image, 'depth': im_depth, 'pos': pos_CC, 'normal': normals_SLC } ############################## # Fragment processing # ------------------- # We can either perform the operations in the world coordinate or in the camera coordinate # Since the inputs are in NDC and converted to CC, converting to world coordinate would require more operations. # There are fewer lights than splats, so converting light positions and directions to CC is more efficient. ############################## # Lighting color_table = scene['colors'] light_pos = scene['lights']['pos'] light_clr_idx = scene['lights']['color_idx'] light_colors = color_table[light_clr_idx] light_attenuation_coeffs = scene['lights']['attenuation'] ambient_light = scene['lights']['ambient'] material_albedo = scene['materials']['albedo'] material_coeffs = scene['materials']['coeffs'] material_idx = scene['objects']['disk']['material_idx'] light_pos_CC = torch.mm(light_pos, Mcam.transpose(1, 0)) # Generate the fragments """ Get the normal and material for the visible objects. """ normals_CC = normals_SLC # TODO: Transform to CC, or assume SLC is CC frag_normals = normals_CC[:, :3] frag_pos = pos_CC[:, :3] frag_albedo = torch.index_select(material_albedo, 0, material_idx) frag_coeffs = torch.index_select(material_coeffs, 0, material_idx) light_visibility = None # TODO: CHECK fragment_shader call im_color = fragment_shader(frag_normals=frag_normals, light_dir=light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3], cam_dir=-frag_pos[:, :3], light_attenuation_coeffs=light_attenuation_coeffs, frag_coeffs=frag_coeffs, light_colors=light_colors, ambient_light=ambient_light, frag_albedo=frag_albedo, double_sided=get_param_value('double_sided', params, False), use_quartic=get_param_value('use_quartic', params, False), light_visibility=light_visibility) # # Fragment shading # light_dir = light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3] # light_dir_norm = torch.sqrt(torch.sum(light_dir ** 2, dim=-1))[:, :, np.newaxis] # light_dir /= light_dir_norm # TODO: nonzero_divide # # Attenuate the lights # per_frag_att_factor = 1 / (light_attenuation_coeffs[:, 0][:, np.newaxis, np.newaxis] + # light_dir_norm * light_attenuation_coeffs[:, 1][:, np.newaxis, np.newaxis] + # (light_dir_norm ** 2) * light_attenuation_coeffs[:, 2][:, np.newaxis, np.newaxis]) # # frag_normal_dot_light = tensor_dot(frag_normals, per_frag_att_factor * light_dir, axis=-1) # frag_normal_dot_light = torch.nn.functional.relu(frag_normal_dot_light) # im_color = frag_normal_dot_light[:, :, np.newaxis] * \ # light_colors[:, np.newaxis, :] * frag_albedo[np.newaxis, :, :] im = torch.sum(im_color, dim=0).view(int(H), int(W), 3) # clip non-negative im = torch.nn.functional.relu(im) # # Tonemapping # if 'tonemap' in scene: # im = tonemap(im, **scene['tonemap']) return { 'image': im, 'depth': im_depth, 'pos': pos_CC[:, :3].view(H, W, 3), 'normal': normals_CC[:, :3].view(H, W, 3) }
def inv_perspective_LH_NO(fovy, aspect, near, far): """Left-handed camera with all coords mapped to [-1, 1] """ mat_00, mat_11, mat_22, mat_23 = perspective_NO_params( fovy, aspect, near, far) return tch_var_f([[1 / mat_00, 0, 0, 0], [0, 1 / mat_11, 0, 0], [0, 0, 0, 1], [0, 0, 1 / mat_23, -mat_22 / mat_23]])
def optimize_splats_along_ray_shadow_with_normalest_test( out_dir, width, height, max_iter=100, lr=1e-3, scale=10, shadow=True, vis_only=False, samples=1, est_normals=False, b_generate_normals=False, print_interval=10, imsave_interval=10, xyz_save_interval=100): """A demo function to check if the differentiable renderer can optimize splats rendered along ray. :param scene: :param out_dir: :return: """ import torch import copy from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0 if not os.path.exists(out_dir): os.mkdir(out_dir) scene = SCENE_SPHERE_HALFBOX_0 scene['camera']['viewport'] = [0, 0, width, height] scene['camera']['fovy'] = np.deg2rad(45) scene['camera']['focal_length'] = 1 scene['camera']['eye'] = tch_var_f( [2, 1, 2, 1]) # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) # scene['camera']['at'] = tch_var_f( [0, 0.8, 0, 1]) # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1]) # scene['lights']['attenuation'] = tch_var_f([ [0., 0.0, 0.01], [0., 0.0, 0.01], [0., 0.0, 0.01], ]) scene['materials']['coeffs'] = tch_var_f([ [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.2, 8.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], ]) target_res = render(scene, tiled=True, shadow=shadow) target_im = normalize_maxmin(target_res['image']) target_im.require_grad = False target_im_ = get_data(target_im) target_pos_ = get_data(target_res['pos']) target_normal_ = get_data(target_res['normal']) target_normalmap_img_ = get_normalmap_image(target_normal_) target_depth_ = get_data(target_res['depth']) print('[z_min, z_max] = [%f, %f]' % (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2]))) print('[depth_min, depth_max] = [%f, %f]' % (np.min(target_depth_), np.max(target_depth_))) # world -> cam -> render_splats_along_ray cc_tform = world_to_cam(target_res['pos'].view( (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera']) wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'], scene['camera']) # Check normal estimation in camera space pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape) normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape) plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None) normal_cc_normalmap = get_normalmap_image(get_data(normal_cc)) plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est)) pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] - target_res['pos'].view((-1, 3))) mean_pos_diff = torch.mean(pos_diff) normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] - target_res['normal'].view(-1, 3)) mean_normal_diff = torch.mean(normal_diff) print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff) wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape) wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal)) material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3) input_scene = copy.deepcopy(scene) del input_scene['objects']['sphere'] del input_scene['objects']['triangle'] light_vis = tch_var_f( np.ones( (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0]))) input_scene['objects'] = { 'disk': { 'pos': cc_tform['pos'], 'normal': cc_tform['normal'], 'material_idx': material_idx, 'light_vis': light_vis, } } target_res_noshadow = render(scene, tiled=True, shadow=False) res = render_splats_along_ray(input_scene) test_img_ = get_data(normalize_maxmin(res['image'])) test_depth_ = get_data(res['depth']) test_normal_ = get_data(res['normal']).reshape(test_img_.shape) test_normalmap_ = get_normalmap_image(test_normal_) im_diff = np.abs(test_img_ - get_data(normalize_maxmin(target_res_noshadow['image']))) print('mean image diff: {}'.format(np.mean(im_diff))) #### PLOT plt.ion() plt.figure() plt.imshow(test_img_, interpolation='none') plt.title('Test Image') plt.savefig(out_dir + '/test_img.png') plt.figure() plt.imshow(test_depth_, interpolation='none') plt.title('Test Depth') plt.savefig(out_dir + '/test_depth.png') plt.figure() plt.imshow(test_normalmap_, interpolation='none') plt.title('Test Normals') plt.savefig(out_dir + '/test_normal.png') #### criterion = nn.L1Loss() #nn.MSELoss() criterion = criterion.cuda() plt.ion() plt.figure() plt.imshow(target_im_, interpolation='none') plt.title('Target Image') plt.savefig(out_dir + '/target.png') plt.figure() plt.imshow(target_normalmap_img_, interpolation='none') plt.title('Normals') plt.savefig(out_dir + '/normal.png') plt.figure() plt.imshow(wc_cc_normal_img, interpolation='none') plt.title('WC_CC Normals') plt.savefig(out_dir + '/wc_cc_normal.png') plt.figure() plt.imshow(normal_cc_normalmap, interpolation='none') plt.title('Normal CC GT') plt.savefig(out_dir + '/normal_cc.png') plt.figure() plt.imshow(plane_fit_est_normalmap, interpolation='none') plt.title('Plane fit CC') plt.savefig(out_dir + '/est_normal_cc.png') plt.figure() plt.subplot(121) plt.imshow(normal_cc_normalmap, interpolation='none') plt.title('Normal CC GT') plt.subplot(122) plt.imshow(plane_fit_est_normalmap, interpolation='none') plt.title('Plane fit CC') plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png') input_scene = copy.deepcopy(scene) del input_scene['objects']['sphere'] del input_scene['objects']['triangle'] input_scene['camera']['viewport'] = [ 0, 0, int(width / samples), int(height / samples) ] num_splats = int(width * height / (samples * samples)) #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples))) z_min = scene['camera']['focal_length'] z_max = 3 z = -tch_var_f( np.ones(num_splats) * (z_min + z_max) / 2 ) # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max) z.requires_grad = True normal_angles = tch_var_f(np.random.rand(num_splats, 2)) normal_angles.requires_grad = True material_idx = tch_var_l(np.ones(num_splats) * 3) light_vis = tch_var_f( np.ones((input_scene['lights']['pos'].shape[0], num_splats))) light_vis.requires_grad = True if vis_only: assert shadow is True opt_vars = [light_vis] z = cc_tform['pos'][:, 2] # FIXME: sph2cart #normals = cc_tform['normal'] else: opt_vars = [z, normal_angles] if shadow: opt_vars += [light_vis] optimizer = optim.Adam(opt_vars, lr=lr) lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8) h0 = plt.figure() h1 = plt.figure() h2 = plt.figure() h3 = plt.figure() h4 = plt.figure() gs1 = gridspec.GridSpec(3, 3) gs1.update(wspace=0.0025, hspace=0.02) # Two options for z_norm_consistency # 1. start after N iterations # 2. start at the beginning and decay # 3. start after N iterations and decay to 0 no_decay = lambda x: x exp_decay = lambda x, scale: torch.exp(-x / scale) linear_decay = lambda x, scale: scale / (x + 1e-6) spatial_var_loss_weight = 10.0 #0.0 normal_away_from_cam_loss_weight = 0.0 grad_img_depth_loss_weight = 1.0 spatial_loss_weight = 2 z_norm_weight_init = 1 # 1e-5 z_norm_activate_iter = 0 # 1000 decay_fn = lambda x: linear_decay(x, 100) loss_per_iter = [] if b_generate_normals: est_normals = False normal_est_network = NEstNetAffine(kernel_size=3, sph=False) print(normal_est_network) normal_est_network.cuda() for iter in range(max_iter): lr_scheduler.step() zz = -F.relu(-z) - z_min # torch.clamp(z, -z_max, -z_min) if b_generate_normals: normals = generate_normals(zz, scene['camera'], normal_est_network) #if iter > 100 and iter % 10 == 0: # print(normals) elif not est_normals: phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi theta = F.sigmoid( normal_angles[:, 1] ) * np.pi / 2 # F.tanh(normal_angles[:, 1]) * np.pi / 2 normals = sph2cart_unit(torch.stack((phi, theta), dim=1)) pos = zz # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1) input_scene['objects'] = { 'disk': { 'pos': pos, 'normal': normalize(normals) if not est_normals else None, 'material_idx': material_idx, 'light_vis': torch.sigmoid(light_vis), } } res = render_splats_along_ray(input_scene, samples=samples, normal_estimation_method='plane') res_pos = res['pos'] res_normal = res['normal'] spatial_loss = spatial_3x3(res_pos) depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis]) grad_img = grad_spatial2d( torch.mean(res['image'], dim=-1)[..., np.newaxis]) grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis]) image_depth_consistency_loss = depth_rgb_gradient_consistency( res['image'], res['depth']) unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0) normal_away_from_cam_loss = away_from_camera_penalty( res_pos, res_normal) z_pos = res_pos[..., 2] z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 + (10 * F.relu(torch.abs(z_pos) - z_max))**2) z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1) spatial_var = torch.mean(res_pos[..., 0].var() + res_pos[..., 1].var() + res_pos[..., 2].var()) spatial_var_loss = (1 / (spatial_var + 1e-4)) im_out = normalize_maxmin(res['image']) res_depth_ = get_data(res['depth']) optimizer.zero_grad() z_norm_weight = z_norm_weight_init * float( iter > z_norm_activate_iter) * decay_fn(iter - z_norm_activate_iter) loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \ z_norm_weight * z_norm_loss + \ spatial_var_loss_weight * spatial_var_loss + \ grad_img_depth_loss_weight * image_depth_consistency_loss #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \ #spatial_loss_weight * spatial_loss im_out_ = get_data(im_out) im_out_normal_ = get_data(res['normal']) pos_out_ = get_data(res['pos']) loss_ = get_data(loss) z_loss_ = get_data(z_loss) z_norm_loss_ = get_data(z_norm_loss) spatial_loss_ = get_data(spatial_loss) spatial_var_loss_ = get_data(spatial_var_loss) unit_normal_loss_ = get_data(unit_normal_loss) normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss) normals_ = get_data(res_normal) image_depth_consistency_loss_ = get_data(image_depth_consistency_loss) loss_per_iter.append(loss_) if iter == 0: plt.figure(h0.number) plt.imshow(im_out_) plt.title('Initial') if iter % print_interval == 0 or iter == max_iter - 1: z_ = get_data(z) z__ = pos_out_[..., 2] print( '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,' ' spatial_var_loss: %f, normal_away_loss: %f' ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' % (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_), np.max(z_), np.min(z__), np.max(z__), z_norm_loss_, spatial_var_loss_, normal_away_from_cam_loss_, normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_, image_depth_consistency_loss_)) if iter % xyz_save_interval == 0 or iter == max_iter - 1: save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter), get_data(res_pos), get_data(res_normal)) if iter % imsave_interval == 0 or iter == max_iter - 1: z_ = get_data(z) plt.figure(h4.number) plt.clf() plt.suptitle('%d. loss= %f [%f, %f]' % (iter, loss_, np.min(z_), np.max(z_))) plt.subplot(121) #plt.axis('off') plt.imshow(im_out_, interpolation='none') plt.title('Output') plt.subplot(122) #plt.axis('off') plt.imshow(target_im_, interpolation='none') plt.title('Ground truth') # plt.subplot(223) # plt.plot(loss_per_iter, linewidth=2) # plt.xlabel('Iteration', fontsize=14) # plt.title('Loss', fontsize=12) # plt.grid(True) plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter) plt.figure(h1.number, figsize=(4, 4)) plt.clf() plt.suptitle('%d. loss= %f [%f, %f]' % (iter, loss_, np.min(z_), np.max(z_))) plt.subplot(gs1[0]) plt.axis('off') plt.imshow(im_out_, interpolation='none') plt.subplot(gs1[1]) plt.axis('off') plt.imshow(get_normalmap_image(im_out_normal_), interpolation='none') ax = plt.subplot(gs1[2]) plt.axis('off') im_tmp = ax.imshow(res_depth_, interpolation='none') # create an axes on the right side of ax. The width of cax will be 5% # of ax and the padding between cax and ax will be fixed at 0.05 inch. divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_tmp, cax=cax) plt.subplot(gs1[3]) plt.axis('off') plt.imshow(target_im_, interpolation='none') plt.subplot(gs1[4]) plt.axis('off') plt.imshow(test_normalmap_, interpolation='none') ax = plt.subplot(gs1[5]) plt.axis('off') im_tmp = ax.imshow(test_depth_, interpolation='none') divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_tmp, cax=cax) W, H = input_scene['camera']['viewport'][2:] light_vis_ = get_data(torch.sigmoid(light_vis)) plt.subplot(gs1[6]) plt.axis('off') plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none') if (light_vis_.shape[0] > 1): plt.subplot(gs1[7]) plt.axis('off') plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none') if (light_vis_.shape[0] > 2): plt.subplot(gs1[8]) plt.axis('off') plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none') plt.savefig(out_dir + '/fig_%05d.png' % iter) plt.figure(h2.number) plt.clf() plt.imshow(res_depth_) plt.colorbar() plt.savefig(out_dir + '/fig_depth_%05d.png' % iter) plt.figure(h3.number) plt.clf() plt.imshow(z_.reshape(H, W)) plt.colorbar() plt.savefig(out_dir + '/fig_z_%05d.png' % iter) loss.backward() optimizer.step() plt.figure() plt.plot(loss_per_iter, linewidth=2) plt.xlabel('Iteration', fontsize=14) plt.title('Loss', fontsize=12) plt.grid(True) plt.savefig(out_dir + '/loss.png') plt.ioff() plt.show()
def test_sphere_splat_NDC(out_dir, cam_pos, width, height, fovy, focal_length, b_display=False): """ Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system and to NDC and then render using render_splat_NDC. """ import copy print('render sphere') sampling_time = [] rendering_time = [] num_samples = width * height large_scene = copy.deepcopy(SCENE_TEST) large_scene['camera']['viewport'] = [0, 0, width, height] large_scene['camera']['eye'] = tch_var_f(cam_pos) large_scene['camera']['fovy'] = np.deg2rad(fovy) large_scene['camera']['focal_length'] = focal_length large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist()) large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]]) large_scene['tonemap']['gamma'] = tch_var_f([1.0]) # Linear output x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height)) #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1)) unit_disk_mask = (x ** 2 + y ** 2) <= 1 z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2)) # Make a hemi-sphere bulging out of the xy-plane scene z[~unit_disk_mask] = 0 pos = np.stack((x.ravel(), y.ravel(), z.ravel(), np.ones(num_samples)), axis=1) # Normals outside the sphere should be [0, 0, 1] x[~unit_disk_mask] = 0 y[~unit_disk_mask] = 0 z[~unit_disk_mask] = 1 normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1)) if b_display: plt.ion() plt.figure() plt.imshow(pos[..., 2].reshape((height, width))) plt.figure() plt.imshow(normals[..., 2].reshape((height, width))) # Convert to the camera's coordinate system Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up']) Mproj = perspective(fovy=large_scene['camera']['fovy'], aspect=width/height, near=large_scene['camera']['near'], far=large_scene['camera']['far']) pos_CC = torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0)) pos_NDC = torch.matmul(pos_CC, Mproj.transpose(1, 0)) large_scene['objects']['disk']['pos'] = pos_NDC / pos_NDC[..., 3][:, np.newaxis] large_scene['objects']['disk']['normal'] = tch_var_f(normals) # main render run start_time = time() res = render_splats_NDC(large_scene) rendering_time.append(time() - start_time) im = get_data(res['image']) im = np.uint8(255. * im) depth = get_data(res['depth']) depth[depth >= large_scene['camera']['far']] = depth.min() im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min())) if b_display: plt.figure() plt.imshow(im, interpolation='none') plt.title('Image') plt.savefig(out_dir + '/fig_img_orig.png') plt.figure() plt.imshow(im_depth, interpolation='none') plt.title('Depth Image') plt.savefig(out_dir + '/fig_depth_orig.png') imsave(out_dir + '/img_orig.png', im) imsave(out_dir + '/depth_orig.png', im_depth) # hold matplotlib figure plt.ioff() plt.show()
def render_batch(self, batch, batch_cond,light_pos): """Render a batch of splats.""" batch_size = batch.size()[0] # Generate camera positions on a sphere if batch_cond is None: if self.opt.full_sphere_sampling: cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(self.opt.angle), theta_range=self.opt.theta, phi_range=self.opt.phi) # TODO: deg2grad!! else: cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=self.opt.angle, theta_range=np.deg2rad(self.opt.theta), phi_range=np.deg2rad(self.opt.phi)) # TODO: deg2grad!! rendered_data = [] rendered_data_depth = [] rendered_data_cond = [] scenes = [] inpath = self.opt.vis_images + '/' inpath_xyz = self.opt.vis_xyz + '/' z_min = self.scene['camera']['focal_length'] z_max = z_min + 3 # TODO (fmannan): Move this in init. This only needs to be done once! # Set splats into rendering scene if 'sphere' in self.scene['objects']: del self.scene['objects']['sphere'] if 'triangle' in self.scene['objects']: del self.scene['objects']['triangle'] if 'disk' not in self.scene['objects']: self.scene['objects'] = {'disk': {'pos': None, 'normal': None, 'material_idx': None}} lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0] self.scene['camera']['at'] = tch_var_f(lookat) self.scene['objects']['disk']['material_idx'] = tch_var_l( np.zeros(self.opt.splats_img_size * self.opt.splats_img_size)) loss = 0.0 loss_ = 0.0 z_loss_ = 0.0 z_norm_loss_ = 0.0 spatial_loss_ = 0.0 spatial_var_loss_ = 0.0 unit_normal_loss_ = 0.0 normal_away_from_cam_loss_ = 0.0 image_depth_consistency_loss_ = 0.0 for idx in range(batch_size): # Get splats positions and normals eps = 1e-3 if self.opt.rescaled: z = F.relu(-batch[idx][:, 0]) + z_min z = ((z - z.min()) / (z.max() - z.min() + eps) * (z_max - z_min) + z_min) pos = -z else: z = F.relu(-batch[idx][:, 0]) + z_min pos = -F.relu(-batch[idx][:, 0]) - z_min normals = batch[idx][:, 1:] self.scene['objects']['disk']['pos'] = pos # Normal estimation network and est_normals don't go together self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None # Set camera position if batch_cond is None: if not self.opt.same_view: self.scene['camera']['eye'] = tch_var_f(cam_pos[idx]) else: self.scene['camera']['eye'] = tch_var_f(cam_pos[0]) else: if not self.opt.same_view: self.scene['camera']['eye'] = batch_cond[idx] else: self.scene['camera']['eye'] = batch_cond[0] self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx]) #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx]) # Render scene # res = render_splats_NDC(self.scene) res = render_splats_along_ray(self.scene, samples=self.opt.pixel_samples, normal_estimation_method='plane') world_tform = cam_to_world(res['pos'].view((-1, 3)), res['normal'].view((-1, 3)), self.scene['camera']) # Get rendered output res_pos = res['pos'].contiguous() res_pos_2D = res_pos.view(res['image'].shape) # The z_loss needs to be applied after supersampling # TODO: Enable this (currently final loss becomes NaN!!) # loss += torch.mean( # (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 + # (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2) res_normal = res['normal'] # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis]) # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis]) # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis]) image_depth_consistency_loss = depth_rgb_gradient_consistency( res['image'], res['depth']) unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0) # TODO: MN normal_away_from_cam_loss = away_from_camera_penalty( res_pos, res_normal) z_pos = res_pos[..., 2] z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 + (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2) z_norm_loss = normal_consistency_cost( res_pos, res['normal'], norm=1) spatial_loss = spatial_3x3(res_pos_2D) spatial_var = torch.mean(res_pos[..., 0].var() + res_pos[..., 1].var() + res_pos[..., 2].var()) spatial_var_loss = (1 / (spatial_var + 1e-4)) loss = (self.opt.zloss * z_loss + self.opt.unit_normalloss*unit_normal_loss + self.opt.normal_consistency_loss_weight * z_norm_loss + self.opt.spatial_var_loss_weight * spatial_var_loss + self.opt.grad_img_depth_loss*image_depth_consistency_loss + self.opt.spatial_loss_weight * spatial_loss) pos_out_ = get_data(res['pos']) loss_ += get_data(loss) z_loss_ += get_data(z_loss) z_norm_loss_ += get_data(z_norm_loss) spatial_loss_ += get_data(spatial_loss) spatial_var_loss_ += get_data(spatial_var_loss) unit_normal_loss_ += get_data(unit_normal_loss) normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss) image_depth_consistency_loss_ += get_data( image_depth_consistency_loss) normals_ = get_data(res_normal) if self.opt.render_img_nc == 1: depth = res['depth'] im = depth.unsqueeze(0) else: depth = res['depth'] im_d = depth.unsqueeze(0) im = res['image'].permute(2, 0, 1) H, W = im.shape[1:] target_normal_ = get_data(res['normal']).reshape((H, W, 3)) target_normalmap_img_ = get_normalmap_image(target_normal_) target_worldnormal_ = get_data(world_tform['normal']).reshape( (H, W, 3)) target_worldnormalmap_img_ = get_normalmap_image( target_worldnormal_) if self.iterationa_no % (self.opt.save_image_interval*5) == 0: imsave((inpath + str(self.iterationa_no) + 'normalmap_{:05d}.png'.format(idx)), target_normalmap_img_) imsave((inpath + str(self.iterationa_no) + 'depthmap_{:05d}.png'.format(idx)), get_data(res['depth'])) imsave((inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx)), target_worldnormalmap_img_) if self.iterationa_no % 1000 == 0: im2 = get_data(res['image']) depth2 = get_data(res['depth']) pos = get_data(res['pos']) out_file2 = ("pos"+".npy") np.save(inpath_xyz+out_file2, pos) out_file2 = ("im"+".npy") np.save(inpath_xyz+out_file2, im2) out_file2 = ("depth"+".npy") np.save(inpath_xyz+out_file2, depth2) # Save xyz file save_xyz((inpath_xyz + str(self.iterationa_no) + 'withnormal_{:05d}.xyz'.format(idx)), pos=get_data(res['pos']), normal=get_data(res['normal'])) # Save xyz file in world coordinates save_xyz((inpath_xyz + str(self.iterationa_no) + 'withnormal_world_{:05d}.xyz'.format(idx)), pos=get_data(world_tform['pos']), normal=get_data(world_tform['normal'])) if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0: gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis]) gradImg = grad_spatial2d(torch.mean(im, dim=0)[:, :, np.newaxis]) for (gZ, gI) in zip(gradZ, gradImg): loss += (self.opt.gz_gi_loss * torch.mean(torch.abs( torch.abs(gZ) - torch.abs(gI)))) # Store normalized depth into the data rendered_data.append(im) rendered_data_depth.append(im_d) rendered_data_cond.append(self.scene['camera']['eye']) scenes.append(self.scene) rendered_data = torch.stack(rendered_data) rendered_data_depth = torch.stack(rendered_data_depth) return rendered_data, rendered_data_depth, loss/self.opt.batchSize