def render_batch(self, batch, batch_cond,light_pos): """Render a batch of splats.""" batch_size = batch.size()[0] # Generate camera positions on a sphere if batch_cond is None: if self.opt.full_sphere_sampling: cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(self.opt.angle), theta_range=self.opt.theta, phi_range=self.opt.phi) # TODO: deg2grad!! else: cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=self.opt.angle, theta_range=np.deg2rad(self.opt.theta), phi_range=np.deg2rad(self.opt.phi)) # TODO: deg2grad!! rendered_data = [] rendered_data_depth = [] rendered_data_cond = [] scenes = [] inpath = self.opt.vis_images + '/' inpath_xyz = self.opt.vis_xyz + '/' z_min = self.scene['camera']['focal_length'] z_max = z_min + 3 # TODO (fmannan): Move this in init. This only needs to be done once! # Set splats into rendering scene if 'sphere' in self.scene['objects']: del self.scene['objects']['sphere'] if 'triangle' in self.scene['objects']: del self.scene['objects']['triangle'] if 'disk' not in self.scene['objects']: self.scene['objects'] = {'disk': {'pos': None, 'normal': None, 'material_idx': None}} lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0] self.scene['camera']['at'] = tch_var_f(lookat) self.scene['objects']['disk']['material_idx'] = tch_var_l( np.zeros(self.opt.splats_img_size * self.opt.splats_img_size)) loss = 0.0 loss_ = 0.0 z_loss_ = 0.0 z_norm_loss_ = 0.0 spatial_loss_ = 0.0 spatial_var_loss_ = 0.0 unit_normal_loss_ = 0.0 normal_away_from_cam_loss_ = 0.0 image_depth_consistency_loss_ = 0.0 for idx in range(batch_size): # Get splats positions and normals eps = 1e-3 if self.opt.rescaled: z = F.relu(-batch[idx][:, 0]) + z_min z = ((z - z.min()) / (z.max() - z.min() + eps) * (z_max - z_min) + z_min) pos = -z else: z = F.relu(-batch[idx][:, 0]) + z_min pos = -F.relu(-batch[idx][:, 0]) - z_min normals = batch[idx][:, 1:] self.scene['objects']['disk']['pos'] = pos # Normal estimation network and est_normals don't go together self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None # Set camera position if batch_cond is None: if not self.opt.same_view: self.scene['camera']['eye'] = tch_var_f(cam_pos[idx]) else: self.scene['camera']['eye'] = tch_var_f(cam_pos[0]) else: if not self.opt.same_view: self.scene['camera']['eye'] = batch_cond[idx] else: self.scene['camera']['eye'] = batch_cond[0] self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx]) #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx]) # Render scene # res = render_splats_NDC(self.scene) res = render_splats_along_ray(self.scene, samples=self.opt.pixel_samples, normal_estimation_method='plane') world_tform = cam_to_world(res['pos'].view((-1, 3)), res['normal'].view((-1, 3)), self.scene['camera']) # Get rendered output res_pos = res['pos'].contiguous() res_pos_2D = res_pos.view(res['image'].shape) # The z_loss needs to be applied after supersampling # TODO: Enable this (currently final loss becomes NaN!!) # loss += torch.mean( # (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 + # (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2) res_normal = res['normal'] # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis]) # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis]) # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis]) image_depth_consistency_loss = depth_rgb_gradient_consistency( res['image'], res['depth']) unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0) # TODO: MN normal_away_from_cam_loss = away_from_camera_penalty( res_pos, res_normal) z_pos = res_pos[..., 2] z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 + (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2) z_norm_loss = normal_consistency_cost( res_pos, res['normal'], norm=1) spatial_loss = spatial_3x3(res_pos_2D) spatial_var = torch.mean(res_pos[..., 0].var() + res_pos[..., 1].var() + res_pos[..., 2].var()) spatial_var_loss = (1 / (spatial_var + 1e-4)) loss = (self.opt.zloss * z_loss + self.opt.unit_normalloss*unit_normal_loss + self.opt.normal_consistency_loss_weight * z_norm_loss + self.opt.spatial_var_loss_weight * spatial_var_loss + self.opt.grad_img_depth_loss*image_depth_consistency_loss + self.opt.spatial_loss_weight * spatial_loss) pos_out_ = get_data(res['pos']) loss_ += get_data(loss) z_loss_ += get_data(z_loss) z_norm_loss_ += get_data(z_norm_loss) spatial_loss_ += get_data(spatial_loss) spatial_var_loss_ += get_data(spatial_var_loss) unit_normal_loss_ += get_data(unit_normal_loss) normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss) image_depth_consistency_loss_ += get_data( image_depth_consistency_loss) normals_ = get_data(res_normal) if self.opt.render_img_nc == 1: depth = res['depth'] im = depth.unsqueeze(0) else: depth = res['depth'] im_d = depth.unsqueeze(0) im = res['image'].permute(2, 0, 1) H, W = im.shape[1:] target_normal_ = get_data(res['normal']).reshape((H, W, 3)) target_normalmap_img_ = get_normalmap_image(target_normal_) target_worldnormal_ = get_data(world_tform['normal']).reshape( (H, W, 3)) target_worldnormalmap_img_ = get_normalmap_image( target_worldnormal_) if self.iterationa_no % (self.opt.save_image_interval*5) == 0: imsave((inpath + str(self.iterationa_no) + 'normalmap_{:05d}.png'.format(idx)), target_normalmap_img_) imsave((inpath + str(self.iterationa_no) + 'depthmap_{:05d}.png'.format(idx)), get_data(res['depth'])) imsave((inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx)), target_worldnormalmap_img_) if self.iterationa_no % 1000 == 0: im2 = get_data(res['image']) depth2 = get_data(res['depth']) pos = get_data(res['pos']) out_file2 = ("pos"+".npy") np.save(inpath_xyz+out_file2, pos) out_file2 = ("im"+".npy") np.save(inpath_xyz+out_file2, im2) out_file2 = ("depth"+".npy") np.save(inpath_xyz+out_file2, depth2) # Save xyz file save_xyz((inpath_xyz + str(self.iterationa_no) + 'withnormal_{:05d}.xyz'.format(idx)), pos=get_data(res['pos']), normal=get_data(res['normal'])) # Save xyz file in world coordinates save_xyz((inpath_xyz + str(self.iterationa_no) + 'withnormal_world_{:05d}.xyz'.format(idx)), pos=get_data(world_tform['pos']), normal=get_data(world_tform['normal'])) if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0: gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis]) gradImg = grad_spatial2d(torch.mean(im, dim=0)[:, :, np.newaxis]) for (gZ, gI) in zip(gradZ, gradImg): loss += (self.opt.gz_gi_loss * torch.mean(torch.abs( torch.abs(gZ) - torch.abs(gI)))) # Store normalized depth into the data rendered_data.append(im) rendered_data_depth.append(im_d) rendered_data_cond.append(self.scene['camera']['eye']) scenes.append(self.scene) rendered_data = torch.stack(rendered_data) rendered_data_depth = torch.stack(rendered_data_depth) return rendered_data, rendered_data_depth, loss/self.opt.batchSize
def get_real_samples(self): """Get a real sample.""" # Define the camera poses if not self.opt.same_view: if self.opt.full_sphere_sampling: self.cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(self.opt.angle), theta_range=self.opt.theta, phi_range=self.opt.phi) else: self.cam_pos = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=self.opt.angle, theta_range=np.deg2rad(self.opt.theta), phi_range=np.deg2rad(self.opt.phi)) if self.opt.full_sphere_sampling_light: self.light_pos1 = uniform_sample_sphere( radius=self.opt.cam_dist, num_samples=self.opt.batchSize, axis=self.opt.axis, angle=np.deg2rad(44), theta_range=self.opt.theta, phi_range=self.opt.phi) # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize, # axis=self.opt.axis, angle=np.deg2rad(40), # theta_range=self.opt.theta, phi_range=self.opt.phi) else: print("inbox") light_eps = 0.15 self.light_pos1 = np.random.rand(self.opt.batchSize, 3) * self.opt.cam_dist + light_eps self.light_pos2 = np.random.rand(self.opt.batchSize, 3) * self.opt.cam_dist + light_eps # TODO: deg2rad in all the angles???? # Create a splats rendering scene large_scene = create_scene(self.opt.width, self.opt.height, self.opt.fovy, self.opt.focal_length, self.opt.n_splats) lookat = self.opt.at if self.opt.at is not None else [ 0.0, 0.0, 0.0, 1.0 ] large_scene['camera']['at'] = tch_var_f(lookat) # Render scenes data, data_depth, data_normal, data_cond = [], [], [], [] inpath = self.opt.vis_images + '/' inpath2 = self.opt.vis_input + '/' for idx in range(self.opt.batchSize): # Save the splats into the rendering scene if self.opt.use_mesh: if 'sphere' in large_scene['objects']: del large_scene['objects']['sphere'] if 'disk' in large_scene['objects']: del large_scene['objects']['disk'] if 'triangle' not in large_scene['objects']: large_scene['objects'] = { 'triangle': { 'face': None, 'normal': None, 'material_idx': None } } samples = self.get_samples() large_scene['objects']['triangle']['material_idx'] = tch_var_l( np.zeros(samples['mesh']['face'][0].shape[0], dtype=int).tolist()) large_scene['objects']['triangle']['face'] = Variable( samples['mesh']['face'][0].cuda(), requires_grad=False) large_scene['objects']['triangle']['normal'] = Variable( samples['mesh']['normal'][0].cuda(), requires_grad=False) else: if 'sphere' in large_scene['objects']: del large_scene['objects']['sphere'] if 'triangle' in large_scene['objects']: del large_scene['objects']['triangle'] if 'disk' not in large_scene['objects']: large_scene['objects'] = { 'disk': { 'pos': None, 'normal': None, 'material_idx': None } } large_scene['objects']['disk']['radius'] = tch_var_f( np.ones(self.opt.n_splats) * self.opt.splats_radius) large_scene['objects']['disk']['material_idx'] = tch_var_l( np.zeros(self.opt.n_splats, dtype=int).tolist()) large_scene['objects']['disk']['pos'] = Variable( samples['splats']['pos'][idx].cuda(), requires_grad=False) large_scene['objects']['disk']['normal'] = Variable( samples['splats']['normal'][idx].cuda(), requires_grad=False) # Set camera position if not self.opt.same_view: large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx]) else: large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0]) large_scene['lights']['pos'][0, :3] = tch_var_f( self.light_pos1[idx]) #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx]) # Render scene res = render(large_scene, norm_depth_image_only=self.opt.norm_depth_image_only, double_sided=True, use_quartic=self.opt.use_quartic) # Get rendered output if self.opt.render_img_nc == 1: depth = res['depth'] im_d = depth.unsqueeze(0) else: depth = res['depth'] im_d = depth.unsqueeze(0) im = res['image'].permute(2, 0, 1) im_ = get_data(res['image']) #im_img_ = get_normalmap_image(im_) target_normal_ = get_data(res['normal']) target_normalmap_img_ = get_normalmap_image(target_normal_) im_n = tch_var_f(target_normalmap_img_).view( im.shape[1], im.shape[2], 3).permute(2, 0, 1) # Add depth image to the output structure file_name = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_{:05d}.txt'.format(idx) text_file = open(file_name, "w") text_file.write('%s\n' % (str(large_scene['camera']['eye'].data))) text_file.close() out_file_name = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_{:05d}.npy'.format(idx) np.save(out_file_name, self.cam_pos[idx]) out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_light{:05d}.npy'.format(idx) np.save(out_file_name2, self.light_pos1[idx]) out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_im{:05d}.npy'.format(idx) np.save(out_file_name3, get_data(res['image'])) out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_depth{:05d}.npy'.format(idx) np.save(out_file_name4, get_data(res['depth'])) out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str( self.critic_iter) + 'input_normal{:05d}.npy'.format(idx) np.save(out_file_name5, get_data(res['normal'])) if self.iterationa_no % (self.opt.save_image_interval * 5) == 0: imsave((inpath + str(self.iterationa_no) + 'real_normalmap_{:05d}.png'.format(idx)), target_normalmap_img_) imsave((inpath + str(self.iterationa_no) + 'real_depth_{:05d}.png'.format(idx)), get_data(depth)) # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d) # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_) data.append(im) data_depth.append(im_d) data_normal.append(im_n) data_cond.append(large_scene['camera']['eye']) # Stack real samples real_samples = torch.stack(data) real_samples_depth = torch.stack(data_depth) real_samples_normal = torch.stack(data_normal) real_samples_cond = torch.stack(data_cond) self.batch_size = real_samples.size(0) if not self.opt.no_cuda: real_samples = real_samples.cuda() real_samples_depth = real_samples_depth.cuda() real_samples_normal = real_samples_normal.cuda() real_samples_cond = real_samples_cond.cuda() # Set input/output variables self.input.resize_as_(real_samples.data).copy_(real_samples.data) self.input_depth.resize_as_(real_samples_depth.data).copy_( real_samples_depth.data) self.input_normal.resize_as_(real_samples_normal.data).copy_( real_samples_normal.data) self.input_cond.resize_as_(real_samples_cond.data).copy_( real_samples_cond.data) self.label.resize_(self.batch_size).fill_(self.real_label) # TODO: Remove Variables self.inputv = Variable(self.input) self.inputv_depth = Variable(self.input_depth) self.inputv_normal = Variable(self.input_normal) self.inputv_cond = Variable(self.input_cond) self.labelv = Variable(self.label)
def optimize_splats_along_ray_shadow_with_normalest_test( out_dir, width, height, max_iter=100, lr=1e-3, scale=10, shadow=True, vis_only=False, samples=1, est_normals=False, b_generate_normals=False, print_interval=10, imsave_interval=10, xyz_save_interval=100): """A demo function to check if the differentiable renderer can optimize splats rendered along ray. :param scene: :param out_dir: :return: """ import torch import copy from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0 if not os.path.exists(out_dir): os.mkdir(out_dir) scene = SCENE_SPHERE_HALFBOX_0 scene['camera']['viewport'] = [0, 0, width, height] scene['camera']['fovy'] = np.deg2rad(45) scene['camera']['focal_length'] = 1 scene['camera']['eye'] = tch_var_f( [2, 1, 2, 1]) # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) # scene['camera']['at'] = tch_var_f( [0, 0.8, 0, 1]) # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1]) # scene['lights']['attenuation'] = tch_var_f([ [0., 0.0, 0.01], [0., 0.0, 0.01], [0., 0.0, 0.01], ]) scene['materials']['coeffs'] = tch_var_f([ [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.2, 8.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], ]) target_res = render(scene, tiled=True, shadow=shadow) target_im = normalize_maxmin(target_res['image']) target_im.require_grad = False target_im_ = get_data(target_im) target_pos_ = get_data(target_res['pos']) target_normal_ = get_data(target_res['normal']) target_normalmap_img_ = get_normalmap_image(target_normal_) target_depth_ = get_data(target_res['depth']) print('[z_min, z_max] = [%f, %f]' % (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2]))) print('[depth_min, depth_max] = [%f, %f]' % (np.min(target_depth_), np.max(target_depth_))) # world -> cam -> render_splats_along_ray cc_tform = world_to_cam(target_res['pos'].view( (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera']) wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'], scene['camera']) # Check normal estimation in camera space pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape) normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape) plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None) normal_cc_normalmap = get_normalmap_image(get_data(normal_cc)) plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est)) pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] - target_res['pos'].view((-1, 3))) mean_pos_diff = torch.mean(pos_diff) normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] - target_res['normal'].view(-1, 3)) mean_normal_diff = torch.mean(normal_diff) print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff) wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape) wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal)) material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3) input_scene = copy.deepcopy(scene) del input_scene['objects']['sphere'] del input_scene['objects']['triangle'] light_vis = tch_var_f( np.ones( (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0]))) input_scene['objects'] = { 'disk': { 'pos': cc_tform['pos'], 'normal': cc_tform['normal'], 'material_idx': material_idx, 'light_vis': light_vis, } } target_res_noshadow = render(scene, tiled=True, shadow=False) res = render_splats_along_ray(input_scene) test_img_ = get_data(normalize_maxmin(res['image'])) test_depth_ = get_data(res['depth']) test_normal_ = get_data(res['normal']).reshape(test_img_.shape) test_normalmap_ = get_normalmap_image(test_normal_) im_diff = np.abs(test_img_ - get_data(normalize_maxmin(target_res_noshadow['image']))) print('mean image diff: {}'.format(np.mean(im_diff))) #### PLOT plt.ion() plt.figure() plt.imshow(test_img_, interpolation='none') plt.title('Test Image') plt.savefig(out_dir + '/test_img.png') plt.figure() plt.imshow(test_depth_, interpolation='none') plt.title('Test Depth') plt.savefig(out_dir + '/test_depth.png') plt.figure() plt.imshow(test_normalmap_, interpolation='none') plt.title('Test Normals') plt.savefig(out_dir + '/test_normal.png') #### criterion = nn.L1Loss() #nn.MSELoss() criterion = criterion.cuda() plt.ion() plt.figure() plt.imshow(target_im_, interpolation='none') plt.title('Target Image') plt.savefig(out_dir + '/target.png') plt.figure() plt.imshow(target_normalmap_img_, interpolation='none') plt.title('Normals') plt.savefig(out_dir + '/normal.png') plt.figure() plt.imshow(wc_cc_normal_img, interpolation='none') plt.title('WC_CC Normals') plt.savefig(out_dir + '/wc_cc_normal.png') plt.figure() plt.imshow(normal_cc_normalmap, interpolation='none') plt.title('Normal CC GT') plt.savefig(out_dir + '/normal_cc.png') plt.figure() plt.imshow(plane_fit_est_normalmap, interpolation='none') plt.title('Plane fit CC') plt.savefig(out_dir + '/est_normal_cc.png') plt.figure() plt.subplot(121) plt.imshow(normal_cc_normalmap, interpolation='none') plt.title('Normal CC GT') plt.subplot(122) plt.imshow(plane_fit_est_normalmap, interpolation='none') plt.title('Plane fit CC') plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png') input_scene = copy.deepcopy(scene) del input_scene['objects']['sphere'] del input_scene['objects']['triangle'] input_scene['camera']['viewport'] = [ 0, 0, int(width / samples), int(height / samples) ] num_splats = int(width * height / (samples * samples)) #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples))) z_min = scene['camera']['focal_length'] z_max = 3 z = -tch_var_f( np.ones(num_splats) * (z_min + z_max) / 2 ) # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max) z.requires_grad = True normal_angles = tch_var_f(np.random.rand(num_splats, 2)) normal_angles.requires_grad = True material_idx = tch_var_l(np.ones(num_splats) * 3) light_vis = tch_var_f( np.ones((input_scene['lights']['pos'].shape[0], num_splats))) light_vis.requires_grad = True if vis_only: assert shadow is True opt_vars = [light_vis] z = cc_tform['pos'][:, 2] # FIXME: sph2cart #normals = cc_tform['normal'] else: opt_vars = [z, normal_angles] if shadow: opt_vars += [light_vis] optimizer = optim.Adam(opt_vars, lr=lr) lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8) h0 = plt.figure() h1 = plt.figure() h2 = plt.figure() h3 = plt.figure() h4 = plt.figure() gs1 = gridspec.GridSpec(3, 3) gs1.update(wspace=0.0025, hspace=0.02) # Two options for z_norm_consistency # 1. start after N iterations # 2. start at the beginning and decay # 3. start after N iterations and decay to 0 no_decay = lambda x: x exp_decay = lambda x, scale: torch.exp(-x / scale) linear_decay = lambda x, scale: scale / (x + 1e-6) spatial_var_loss_weight = 10.0 #0.0 normal_away_from_cam_loss_weight = 0.0 grad_img_depth_loss_weight = 1.0 spatial_loss_weight = 2 z_norm_weight_init = 1 # 1e-5 z_norm_activate_iter = 0 # 1000 decay_fn = lambda x: linear_decay(x, 100) loss_per_iter = [] if b_generate_normals: est_normals = False normal_est_network = NEstNetAffine(kernel_size=3, sph=False) print(normal_est_network) normal_est_network.cuda() for iter in range(max_iter): lr_scheduler.step() zz = -F.relu(-z) - z_min # torch.clamp(z, -z_max, -z_min) if b_generate_normals: normals = generate_normals(zz, scene['camera'], normal_est_network) #if iter > 100 and iter % 10 == 0: # print(normals) elif not est_normals: phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi theta = F.sigmoid( normal_angles[:, 1] ) * np.pi / 2 # F.tanh(normal_angles[:, 1]) * np.pi / 2 normals = sph2cart_unit(torch.stack((phi, theta), dim=1)) pos = zz # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1) input_scene['objects'] = { 'disk': { 'pos': pos, 'normal': normalize(normals) if not est_normals else None, 'material_idx': material_idx, 'light_vis': torch.sigmoid(light_vis), } } res = render_splats_along_ray(input_scene, samples=samples, normal_estimation_method='plane') res_pos = res['pos'] res_normal = res['normal'] spatial_loss = spatial_3x3(res_pos) depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis]) grad_img = grad_spatial2d( torch.mean(res['image'], dim=-1)[..., np.newaxis]) grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis]) image_depth_consistency_loss = depth_rgb_gradient_consistency( res['image'], res['depth']) unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0) normal_away_from_cam_loss = away_from_camera_penalty( res_pos, res_normal) z_pos = res_pos[..., 2] z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 + (10 * F.relu(torch.abs(z_pos) - z_max))**2) z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1) spatial_var = torch.mean(res_pos[..., 0].var() + res_pos[..., 1].var() + res_pos[..., 2].var()) spatial_var_loss = (1 / (spatial_var + 1e-4)) im_out = normalize_maxmin(res['image']) res_depth_ = get_data(res['depth']) optimizer.zero_grad() z_norm_weight = z_norm_weight_init * float( iter > z_norm_activate_iter) * decay_fn(iter - z_norm_activate_iter) loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \ z_norm_weight * z_norm_loss + \ spatial_var_loss_weight * spatial_var_loss + \ grad_img_depth_loss_weight * image_depth_consistency_loss #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \ #spatial_loss_weight * spatial_loss im_out_ = get_data(im_out) im_out_normal_ = get_data(res['normal']) pos_out_ = get_data(res['pos']) loss_ = get_data(loss) z_loss_ = get_data(z_loss) z_norm_loss_ = get_data(z_norm_loss) spatial_loss_ = get_data(spatial_loss) spatial_var_loss_ = get_data(spatial_var_loss) unit_normal_loss_ = get_data(unit_normal_loss) normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss) normals_ = get_data(res_normal) image_depth_consistency_loss_ = get_data(image_depth_consistency_loss) loss_per_iter.append(loss_) if iter == 0: plt.figure(h0.number) plt.imshow(im_out_) plt.title('Initial') if iter % print_interval == 0 or iter == max_iter - 1: z_ = get_data(z) z__ = pos_out_[..., 2] print( '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,' ' spatial_var_loss: %f, normal_away_loss: %f' ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' % (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_), np.max(z_), np.min(z__), np.max(z__), z_norm_loss_, spatial_var_loss_, normal_away_from_cam_loss_, normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_, image_depth_consistency_loss_)) if iter % xyz_save_interval == 0 or iter == max_iter - 1: save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter), get_data(res_pos), get_data(res_normal)) if iter % imsave_interval == 0 or iter == max_iter - 1: z_ = get_data(z) plt.figure(h4.number) plt.clf() plt.suptitle('%d. loss= %f [%f, %f]' % (iter, loss_, np.min(z_), np.max(z_))) plt.subplot(121) #plt.axis('off') plt.imshow(im_out_, interpolation='none') plt.title('Output') plt.subplot(122) #plt.axis('off') plt.imshow(target_im_, interpolation='none') plt.title('Ground truth') # plt.subplot(223) # plt.plot(loss_per_iter, linewidth=2) # plt.xlabel('Iteration', fontsize=14) # plt.title('Loss', fontsize=12) # plt.grid(True) plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter) plt.figure(h1.number, figsize=(4, 4)) plt.clf() plt.suptitle('%d. loss= %f [%f, %f]' % (iter, loss_, np.min(z_), np.max(z_))) plt.subplot(gs1[0]) plt.axis('off') plt.imshow(im_out_, interpolation='none') plt.subplot(gs1[1]) plt.axis('off') plt.imshow(get_normalmap_image(im_out_normal_), interpolation='none') ax = plt.subplot(gs1[2]) plt.axis('off') im_tmp = ax.imshow(res_depth_, interpolation='none') # create an axes on the right side of ax. The width of cax will be 5% # of ax and the padding between cax and ax will be fixed at 0.05 inch. divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_tmp, cax=cax) plt.subplot(gs1[3]) plt.axis('off') plt.imshow(target_im_, interpolation='none') plt.subplot(gs1[4]) plt.axis('off') plt.imshow(test_normalmap_, interpolation='none') ax = plt.subplot(gs1[5]) plt.axis('off') im_tmp = ax.imshow(test_depth_, interpolation='none') divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_tmp, cax=cax) W, H = input_scene['camera']['viewport'][2:] light_vis_ = get_data(torch.sigmoid(light_vis)) plt.subplot(gs1[6]) plt.axis('off') plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none') if (light_vis_.shape[0] > 1): plt.subplot(gs1[7]) plt.axis('off') plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none') if (light_vis_.shape[0] > 2): plt.subplot(gs1[8]) plt.axis('off') plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none') plt.savefig(out_dir + '/fig_%05d.png' % iter) plt.figure(h2.number) plt.clf() plt.imshow(res_depth_) plt.colorbar() plt.savefig(out_dir + '/fig_depth_%05d.png' % iter) plt.figure(h3.number) plt.clf() plt.imshow(z_.reshape(H, W)) plt.colorbar() plt.savefig(out_dir + '/fig_z_%05d.png' % iter) loss.backward() optimizer.step() plt.figure() plt.plot(loss_per_iter, linewidth=2) plt.xlabel('Iteration', fontsize=14) plt.title('Loss', fontsize=12) plt.grid(True) plt.savefig(out_dir + '/loss.png') plt.ioff() plt.show()