Esempio n. 1
0
    def render_batch(self, batch, batch_cond,light_pos):
        """Render a batch of splats."""
        batch_size = batch.size()[0]

        # Generate camera positions on a sphere
        if batch_cond is None:
            if self.opt.full_sphere_sampling:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta, phi_range=self.opt.phi)
                # TODO: deg2grad!!
            else:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
                # TODO: deg2grad!!

        rendered_data = []
        rendered_data_depth = []
        rendered_data_cond = []
        scenes = []
        inpath = self.opt.vis_images + '/'
        inpath_xyz = self.opt.vis_xyz + '/'
        z_min = self.scene['camera']['focal_length']
        z_max = z_min + 3

        # TODO (fmannan): Move this in init. This only needs to be done once!
        # Set splats into rendering scene
        if 'sphere' in self.scene['objects']:
            del self.scene['objects']['sphere']
        if 'triangle' in self.scene['objects']:
            del self.scene['objects']['triangle']
        if 'disk' not in self.scene['objects']:
            self.scene['objects'] = {'disk': {'pos': None, 'normal': None,
                                              'material_idx': None}}
        lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0]
        self.scene['camera']['at'] = tch_var_f(lookat)
        self.scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(self.opt.splats_img_size * self.opt.splats_img_size))
        loss = 0.0
        loss_ = 0.0
        z_loss_ = 0.0
        z_norm_loss_ = 0.0
        spatial_loss_ = 0.0
        spatial_var_loss_ = 0.0
        unit_normal_loss_ = 0.0
        normal_away_from_cam_loss_ = 0.0
        image_depth_consistency_loss_ = 0.0
        for idx in range(batch_size):
            # Get splats positions and normals
            eps = 1e-3
            if self.opt.rescaled:
                z = F.relu(-batch[idx][:, 0]) + z_min
                z = ((z - z.min()) / (z.max() - z.min() + eps) *
                     (z_max - z_min) + z_min)
                pos = -z
            else:
                z = F.relu(-batch[idx][:, 0]) + z_min
                pos = -F.relu(-batch[idx][:, 0]) - z_min
            normals = batch[idx][:, 1:]

            self.scene['objects']['disk']['pos'] = pos

            # Normal estimation network and est_normals don't go together
            self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None

            # Set camera position
            if batch_cond is None:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[idx])
                else:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[0])
            else:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = batch_cond[idx]
                else:
                    self.scene['camera']['eye'] = batch_cond[0]

            self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx])
            #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            # res = render_splats_NDC(self.scene)
            res = render_splats_along_ray(self.scene,
                                          samples=self.opt.pixel_samples,
                                          normal_estimation_method='plane')

            world_tform = cam_to_world(res['pos'].view((-1, 3)),
                                       res['normal'].view((-1, 3)),
                                       self.scene['camera'])

            # Get rendered output
            res_pos = res['pos'].contiguous()
            res_pos_2D = res_pos.view(res['image'].shape)
            # The z_loss needs to be applied after supersampling
            # TODO: Enable this (currently final loss becomes NaN!!)
            # loss += torch.mean(
            #    (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 +
            #    (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2)


            res_normal = res['normal']
            # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
            # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis])
            # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
            image_depth_consistency_loss = depth_rgb_gradient_consistency(
                res['image'], res['depth'])
            unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)  # TODO: MN
            normal_away_from_cam_loss = away_from_camera_penalty(
                res_pos, res_normal)
            z_pos = res_pos[..., 2]
            z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 +
                                (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2)
            z_norm_loss = normal_consistency_cost(
                res_pos, res['normal'], norm=1)
            spatial_loss = spatial_3x3(res_pos_2D)
            spatial_var = torch.mean(res_pos[..., 0].var() +
                                     res_pos[..., 1].var() +
                                     res_pos[..., 2].var())
            spatial_var_loss = (1 / (spatial_var + 1e-4))

            loss = (self.opt.zloss * z_loss +
                    self.opt.unit_normalloss*unit_normal_loss +
                    self.opt.normal_consistency_loss_weight * z_norm_loss +
                    self.opt.spatial_var_loss_weight * spatial_var_loss +
                    self.opt.grad_img_depth_loss*image_depth_consistency_loss +
                    self.opt.spatial_loss_weight * spatial_loss)
            pos_out_ = get_data(res['pos'])
            loss_ += get_data(loss)
            z_loss_ += get_data(z_loss)
            z_norm_loss_ += get_data(z_norm_loss)
            spatial_loss_ += get_data(spatial_loss)
            spatial_var_loss_ += get_data(spatial_var_loss)
            unit_normal_loss_ += get_data(unit_normal_loss)
            normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss)
            image_depth_consistency_loss_ += get_data(
                image_depth_consistency_loss)
            normals_ = get_data(res_normal)

            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                H, W = im.shape[1:]
                target_normal_ = get_data(res['normal']).reshape((H, W, 3))
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                target_worldnormal_ = get_data(world_tform['normal']).reshape(
                    (H, W, 3))
                target_worldnormalmap_img_ = get_normalmap_image(
                    target_worldnormal_)
            if self.iterationa_no % (self.opt.save_image_interval*5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'depthmap_{:05d}.png'.format(idx)),
                       get_data(res['depth']))
                imsave((inpath + str(self.iterationa_no) +
                        'world_normalmap_{:05d}.png'.format(idx)),
                       target_worldnormalmap_img_)
            if self.iterationa_no % 1000 == 0:
                im2 = get_data(res['image'])
                depth2 = get_data(res['depth'])
                pos = get_data(res['pos'])

                out_file2 = ("pos"+".npy")
                np.save(inpath_xyz+out_file2, pos)

                out_file2 = ("im"+".npy")
                np.save(inpath_xyz+out_file2, im2)

                out_file2 = ("depth"+".npy")
                np.save(inpath_xyz+out_file2, depth2)

                # Save xyz file
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_{:05d}.xyz'.format(idx)),
                         pos=get_data(res['pos']),
                         normal=get_data(res['normal']))

                # Save xyz file in world coordinates
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_world_{:05d}.xyz'.format(idx)),
                         pos=get_data(world_tform['pos']),
                         normal=get_data(world_tform['normal']))
            if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0:
                gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis])
                gradImg = grad_spatial2d(torch.mean(im,
                                                    dim=0)[:, :, np.newaxis])
                for (gZ, gI) in zip(gradZ, gradImg):
                    loss += (self.opt.gz_gi_loss * torch.mean(torch.abs(
                                torch.abs(gZ) - torch.abs(gI))))
            # Store normalized depth into the data
            rendered_data.append(im)
            rendered_data_depth.append(im_d)
            rendered_data_cond.append(self.scene['camera']['eye'])
            scenes.append(self.scene)

        rendered_data = torch.stack(rendered_data)
        rendered_data_depth = torch.stack(rendered_data_depth)


        return rendered_data, rendered_data_depth, loss/self.opt.batchSize
Esempio n. 2
0
    def get_real_samples(self):
        """Get a real sample."""
        # Define the camera poses
        if not self.opt.same_view:
            if self.opt.full_sphere_sampling:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta,
                    phi_range=self.opt.phi)
            else:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
        if self.opt.full_sphere_sampling_light:
            self.light_pos1 = uniform_sample_sphere(
                radius=self.opt.cam_dist,
                num_samples=self.opt.batchSize,
                axis=self.opt.axis,
                angle=np.deg2rad(44),
                theta_range=self.opt.theta,
                phi_range=self.opt.phi)
            # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
            #                                      axis=self.opt.axis, angle=np.deg2rad(40),
            #                                      theta_range=self.opt.theta, phi_range=self.opt.phi)
        else:
            print("inbox")
            light_eps = 0.15
            self.light_pos1 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps
            self.light_pos2 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps

            # TODO: deg2rad in all the angles????

        # Create a splats rendering scene
        large_scene = create_scene(self.opt.width, self.opt.height,
                                   self.opt.fovy, self.opt.focal_length,
                                   self.opt.n_splats)
        lookat = self.opt.at if self.opt.at is not None else [
            0.0, 0.0, 0.0, 1.0
        ]
        large_scene['camera']['at'] = tch_var_f(lookat)

        # Render scenes
        data, data_depth, data_normal, data_cond = [], [], [], []
        inpath = self.opt.vis_images + '/'
        inpath2 = self.opt.vis_input + '/'
        for idx in range(self.opt.batchSize):
            # Save the splats into the rendering scene
            if self.opt.use_mesh:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'disk' in large_scene['objects']:
                    del large_scene['objects']['disk']
                if 'triangle' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'triangle': {
                            'face': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                samples = self.get_samples()

                large_scene['objects']['triangle']['material_idx'] = tch_var_l(
                    np.zeros(samples['mesh']['face'][0].shape[0],
                             dtype=int).tolist())
                large_scene['objects']['triangle']['face'] = Variable(
                    samples['mesh']['face'][0].cuda(), requires_grad=False)
                large_scene['objects']['triangle']['normal'] = Variable(
                    samples['mesh']['normal'][0].cuda(), requires_grad=False)
            else:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'triangle' in large_scene['objects']:
                    del large_scene['objects']['triangle']
                if 'disk' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'disk': {
                            'pos': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                large_scene['objects']['disk']['radius'] = tch_var_f(
                    np.ones(self.opt.n_splats) * self.opt.splats_radius)
                large_scene['objects']['disk']['material_idx'] = tch_var_l(
                    np.zeros(self.opt.n_splats, dtype=int).tolist())
                large_scene['objects']['disk']['pos'] = Variable(
                    samples['splats']['pos'][idx].cuda(), requires_grad=False)
                large_scene['objects']['disk']['normal'] = Variable(
                    samples['splats']['normal'][idx].cuda(),
                    requires_grad=False)

            # Set camera position
            if not self.opt.same_view:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx])
            else:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0])

            large_scene['lights']['pos'][0, :3] = tch_var_f(
                self.light_pos1[idx])
            #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            res = render(large_scene,
                         norm_depth_image_only=self.opt.norm_depth_image_only,
                         double_sided=True,
                         use_quartic=self.opt.use_quartic)

            # Get rendered output
            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                im_ = get_data(res['image'])
                #im_img_ = get_normalmap_image(im_)
                target_normal_ = get_data(res['normal'])
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                im_n = tch_var_f(target_normalmap_img_).view(
                    im.shape[1], im.shape[2], 3).permute(2, 0, 1)

            # Add depth image to the output structure
            file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.txt'.format(idx)
            text_file = open(file_name, "w")
            text_file.write('%s\n' % (str(large_scene['camera']['eye'].data)))
            text_file.close()
            out_file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.npy'.format(idx)
            np.save(out_file_name, self.cam_pos[idx])
            out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_light{:05d}.npy'.format(idx)
            np.save(out_file_name2, self.light_pos1[idx])
            out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_im{:05d}.npy'.format(idx)
            np.save(out_file_name3, get_data(res['image']))
            out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_depth{:05d}.npy'.format(idx)
            np.save(out_file_name4, get_data(res['depth']))
            out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_normal{:05d}.npy'.format(idx)
            np.save(out_file_name5, get_data(res['normal']))

            if self.iterationa_no % (self.opt.save_image_interval * 5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'real_normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'real_depth_{:05d}.png'.format(idx)), get_data(depth))
                # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d)
                # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_)
            data.append(im)
            data_depth.append(im_d)
            data_normal.append(im_n)
            data_cond.append(large_scene['camera']['eye'])
        # Stack real samples
        real_samples = torch.stack(data)
        real_samples_depth = torch.stack(data_depth)
        real_samples_normal = torch.stack(data_normal)
        real_samples_cond = torch.stack(data_cond)
        self.batch_size = real_samples.size(0)
        if not self.opt.no_cuda:
            real_samples = real_samples.cuda()
            real_samples_depth = real_samples_depth.cuda()
            real_samples_normal = real_samples_normal.cuda()
            real_samples_cond = real_samples_cond.cuda()

        # Set input/output variables

        self.input.resize_as_(real_samples.data).copy_(real_samples.data)
        self.input_depth.resize_as_(real_samples_depth.data).copy_(
            real_samples_depth.data)
        self.input_normal.resize_as_(real_samples_normal.data).copy_(
            real_samples_normal.data)
        self.input_cond.resize_as_(real_samples_cond.data).copy_(
            real_samples_cond.data)
        self.label.resize_(self.batch_size).fill_(self.real_label)
        # TODO: Remove Variables
        self.inputv = Variable(self.input)
        self.inputv_depth = Variable(self.input_depth)
        self.inputv_normal = Variable(self.input_normal)
        self.inputv_cond = Variable(self.input_cond)
        self.labelv = Variable(self.label)
Esempio n. 3
0
def optimize_splats_along_ray_shadow_with_normalest_test(
        out_dir,
        width,
        height,
        max_iter=100,
        lr=1e-3,
        scale=10,
        shadow=True,
        vis_only=False,
        samples=1,
        est_normals=False,
        b_generate_normals=False,
        print_interval=10,
        imsave_interval=10,
        xyz_save_interval=100):
    """A demo function to check if the differentiable renderer can optimize splats rendered along ray.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX_0
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f(
        [2, 1, 2, 1])  # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) #
    scene['camera']['at'] = tch_var_f(
        [0, 0.8, 0, 1])  # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1])  #
    scene['lights']['attenuation'] = tch_var_f([
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
    ])
    scene['materials']['coeffs'] = tch_var_f([
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.5, 0.2, 8.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
    ])

    target_res = render(scene, tiled=True, shadow=shadow)
    target_im = normalize_maxmin(target_res['image'])
    target_im.require_grad = False
    target_im_ = get_data(target_im)
    target_pos_ = get_data(target_res['pos'])
    target_normal_ = get_data(target_res['normal'])
    target_normalmap_img_ = get_normalmap_image(target_normal_)
    target_depth_ = get_data(target_res['depth'])
    print('[z_min, z_max] = [%f, %f]' %
          (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2])))
    print('[depth_min, depth_max] = [%f, %f]' %
          (np.min(target_depth_), np.max(target_depth_)))

    # world -> cam -> render_splats_along_ray
    cc_tform = world_to_cam(target_res['pos'].view(
        (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera'])
    wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'],
                               scene['camera'])

    # Check normal estimation in camera space
    pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape)
    normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape)
    plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None)
    normal_cc_normalmap = get_normalmap_image(get_data(normal_cc))
    plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est))

    pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] -
                         target_res['pos'].view((-1, 3)))
    mean_pos_diff = torch.mean(pos_diff)
    normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] -
                            target_res['normal'].view(-1, 3))
    mean_normal_diff = torch.mean(normal_diff)
    print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff)

    wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape)
    wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal))

    material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3)
    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    light_vis = tch_var_f(
        np.ones(
            (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0])))
    input_scene['objects'] = {
        'disk': {
            'pos': cc_tform['pos'],
            'normal': cc_tform['normal'],
            'material_idx': material_idx,
            'light_vis': light_vis,
        }
    }
    target_res_noshadow = render(scene, tiled=True, shadow=False)
    res = render_splats_along_ray(input_scene)
    test_img_ = get_data(normalize_maxmin(res['image']))
    test_depth_ = get_data(res['depth'])
    test_normal_ = get_data(res['normal']).reshape(test_img_.shape)
    test_normalmap_ = get_normalmap_image(test_normal_)
    im_diff = np.abs(test_img_ -
                     get_data(normalize_maxmin(target_res_noshadow['image'])))
    print('mean image diff: {}'.format(np.mean(im_diff)))
    #### PLOT
    plt.ion()
    plt.figure()
    plt.imshow(test_img_, interpolation='none')
    plt.title('Test Image')
    plt.savefig(out_dir + '/test_img.png')
    plt.figure()
    plt.imshow(test_depth_, interpolation='none')
    plt.title('Test Depth')
    plt.savefig(out_dir + '/test_depth.png')

    plt.figure()
    plt.imshow(test_normalmap_, interpolation='none')
    plt.title('Test Normals')
    plt.savefig(out_dir + '/test_normal.png')

    ####
    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_, interpolation='none')
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    plt.figure()
    plt.imshow(target_normalmap_img_, interpolation='none')
    plt.title('Normals')
    plt.savefig(out_dir + '/normal.png')

    plt.figure()
    plt.imshow(wc_cc_normal_img, interpolation='none')
    plt.title('WC_CC Normals')
    plt.savefig(out_dir + '/wc_cc_normal.png')

    plt.figure()
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.savefig(out_dir + '/normal_cc.png')

    plt.figure()
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/est_normal_cc.png')

    plt.figure()
    plt.subplot(121)
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.subplot(122)
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    input_scene['camera']['viewport'] = [
        0, 0, int(width / samples),
        int(height / samples)
    ]

    num_splats = int(width * height / (samples * samples))
    #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples)))
    z_min = scene['camera']['focal_length']
    z_max = 3

    z = -tch_var_f(
        np.ones(num_splats) * (z_min + z_max) / 2
    )  # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max)
    z.requires_grad = True

    normal_angles = tch_var_f(np.random.rand(num_splats, 2))
    normal_angles.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    light_vis = tch_var_f(
        np.ones((input_scene['lights']['pos'].shape[0], num_splats)))
    light_vis.requires_grad = True

    if vis_only:
        assert shadow is True
        opt_vars = [light_vis]
        z = cc_tform['pos'][:, 2]
        # FIXME: sph2cart
        #normals = cc_tform['normal']
    else:
        opt_vars = [z, normal_angles]
        if shadow:
            opt_vars += [light_vis]

    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8)

    h0 = plt.figure()
    h1 = plt.figure()
    h2 = plt.figure()
    h3 = plt.figure()
    h4 = plt.figure()

    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.0025, hspace=0.02)

    # Two options for z_norm_consistency
    # 1. start after N iterations
    # 2. start at the beginning and decay
    # 3. start after N iterations and decay to 0
    no_decay = lambda x: x
    exp_decay = lambda x, scale: torch.exp(-x / scale)
    linear_decay = lambda x, scale: scale / (x + 1e-6)

    spatial_var_loss_weight = 10.0  #0.0
    normal_away_from_cam_loss_weight = 0.0
    grad_img_depth_loss_weight = 1.0
    spatial_loss_weight = 2

    z_norm_weight_init = 1  # 1e-5
    z_norm_activate_iter = 0  # 1000
    decay_fn = lambda x: linear_decay(x, 100)
    loss_per_iter = []
    if b_generate_normals:
        est_normals = False
        normal_est_network = NEstNetAffine(kernel_size=3, sph=False)
        print(normal_est_network)
        normal_est_network.cuda()
    for iter in range(max_iter):
        lr_scheduler.step()
        zz = -F.relu(-z) - z_min  # torch.clamp(z, -z_max, -z_min)
        if b_generate_normals:
            normals = generate_normals(zz, scene['camera'], normal_est_network)
            #if iter > 100 and iter % 10 == 0:
            #    print(normals)
        elif not est_normals:
            phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi
            theta = F.sigmoid(
                normal_angles[:, 1]
            ) * np.pi / 2  # F.tanh(normal_angles[:, 1]) * np.pi / 2
            normals = sph2cart_unit(torch.stack((phi, theta), dim=1))

        pos = zz  # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1)

        input_scene['objects'] = {
            'disk': {
                'pos': pos,
                'normal': normalize(normals) if not est_normals else None,
                'material_idx': material_idx,
                'light_vis': torch.sigmoid(light_vis),
            }
        }
        res = render_splats_along_ray(input_scene,
                                      samples=samples,
                                      normal_estimation_method='plane')
        res_pos = res['pos']
        res_normal = res['normal']
        spatial_loss = spatial_3x3(res_pos)
        depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
        grad_img = grad_spatial2d(
            torch.mean(res['image'], dim=-1)[..., np.newaxis])
        grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
        image_depth_consistency_loss = depth_rgb_gradient_consistency(
            res['image'], res['depth'])
        unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)
        normal_away_from_cam_loss = away_from_camera_penalty(
            res_pos, res_normal)
        z_pos = res_pos[..., 2]
        z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 +
                            (10 * F.relu(torch.abs(z_pos) - z_max))**2)
        z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1)
        spatial_var = torch.mean(res_pos[..., 0].var() +
                                 res_pos[..., 1].var() + res_pos[..., 2].var())
        spatial_var_loss = (1 / (spatial_var + 1e-4))
        im_out = normalize_maxmin(res['image'])
        res_depth_ = get_data(res['depth'])

        optimizer.zero_grad()
        z_norm_weight = z_norm_weight_init * float(
            iter > z_norm_activate_iter) * decay_fn(iter -
                                                    z_norm_activate_iter)
        loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \
            z_norm_weight * z_norm_loss + \
            spatial_var_loss_weight * spatial_var_loss + \
            grad_img_depth_loss_weight * image_depth_consistency_loss
        #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \
        #spatial_loss_weight * spatial_loss

        im_out_ = get_data(im_out)
        im_out_normal_ = get_data(res['normal'])
        pos_out_ = get_data(res['pos'])

        loss_ = get_data(loss)
        z_loss_ = get_data(z_loss)
        z_norm_loss_ = get_data(z_norm_loss)
        spatial_loss_ = get_data(spatial_loss)
        spatial_var_loss_ = get_data(spatial_var_loss)
        unit_normal_loss_ = get_data(unit_normal_loss)
        normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss)
        normals_ = get_data(res_normal)
        image_depth_consistency_loss_ = get_data(image_depth_consistency_loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            z__ = pos_out_[..., 2]
            print(
                '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,'
                ' spatial_var_loss: %f, normal_away_loss: %f'
                ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' %
                (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_),
                 np.max(z_), np.min(z__), np.max(z__), z_norm_loss_,
                 spatial_var_loss_, normal_away_from_cam_loss_,
                 normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_,
                 image_depth_consistency_loss_))

        if iter % xyz_save_interval == 0 or iter == max_iter - 1:
            save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter),
                     get_data(res_pos), get_data(res_normal))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            plt.figure(h4.number)
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(121)
            #plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.title('Output')
            plt.subplot(122)
            #plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.title('Ground truth')
            # plt.subplot(223)
            # plt.plot(loss_per_iter, linewidth=2)
            # plt.xlabel('Iteration', fontsize=14)
            # plt.title('Loss', fontsize=12)
            # plt.grid(True)
            plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter)

            plt.figure(h1.number, figsize=(4, 4))
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(gs1[0])
            plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.subplot(gs1[1])
            plt.axis('off')
            plt.imshow(get_normalmap_image(im_out_normal_),
                       interpolation='none')
            ax = plt.subplot(gs1[2])
            plt.axis('off')
            im_tmp = ax.imshow(res_depth_, interpolation='none')
            # create an axes on the right side of ax. The width of cax will be 5%
            # of ax and the padding between cax and ax will be fixed at 0.05 inch.
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            plt.subplot(gs1[3])
            plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.subplot(gs1[4])
            plt.axis('off')
            plt.imshow(test_normalmap_, interpolation='none')
            ax = plt.subplot(gs1[5])
            plt.axis('off')
            im_tmp = ax.imshow(test_depth_, interpolation='none')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            W, H = input_scene['camera']['viewport'][2:]
            light_vis_ = get_data(torch.sigmoid(light_vis))
            plt.subplot(gs1[6])
            plt.axis('off')
            plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 1):
                plt.subplot(gs1[7])
                plt.axis('off')
                plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 2):
                plt.subplot(gs1[8])
                plt.axis('off')
                plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none')

            plt.savefig(out_dir + '/fig_%05d.png' % iter)

            plt.figure(h2.number)
            plt.clf()
            plt.imshow(res_depth_)
            plt.colorbar()
            plt.savefig(out_dir + '/fig_depth_%05d.png' % iter)

            plt.figure(h3.number)
            plt.clf()
            plt.imshow(z_.reshape(H, W))
            plt.colorbar()
            plt.savefig(out_dir + '/fig_z_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()