Exemple #1
0
def test_transformation_consistency(scene, batch_size):
    print('test_transformation_consistency')
    res = render_scene(scene)
    scene = make_torch_var(load_scene(scene))
    pos_cc = res['pos'].reshape(-1, res['pos'].shape[-1])
    normal_cc = res['normal'].reshape(-1, res['normal'].shape[-1])
    surfels = cam_to_world(pos_cc, normal_cc, scene['camera'])
    surfels_cc = world_to_cam(surfels['pos'], surfels['normal'],
                              scene['camera'])

    np.testing.assert_array_almost_equal(get_data(pos_cc),
                                         get_data(surfels_cc['pos'][:, :3]))
    np.testing.assert_array_almost_equal(get_data(normal_cc),
                                         get_data(surfels_cc['normal'][:, :3]))
Exemple #2
0
def test_depth_to_world_consistency(scene, batch_size):
    res = render_scene(scene)

    scene = make_torch_var(load_scene(scene))

    pos_wc1 = res['pos'].reshape(-1, res['pos'].shape[-1])

    pos_cc1 = world_to_cam(pos_wc1, None, scene['camera'])['pos']
    # First test the non-batched z_to_pcl_CC method:
    # NOTE: z_to_pcl_CC takes as input the Z dimension in the camera coordinate
    # and gets the full (X, Y, Z) in the camera coordinate.
    pos_cc2 = z_to_pcl_CC(pos_cc1[:, 2], scene['camera'])
    pos_wc2 = cam_to_world(pos_cc2, None, scene['camera'])['pos']

    # Test Z -> (X, Y, Z)
    np.testing.assert_array_almost_equal(get_data(pos_cc1[..., :3]),
                                         get_data(pos_cc2[..., :3]))
    # Test world -> camera -> Z -> (X, Y, Z) in camera -> world
    np.testing.assert_array_almost_equal(get_data(pos_wc1[..., :3]),
                                         get_data(pos_wc2[..., :3]))

    # Then test the batched version:
    camera = scene['camera']
    camera['eye'] = camera['eye'].repeat(batch_size, 1)
    camera['at'] = camera['at'].repeat(batch_size, 1)
    camera['up'] = camera['up'].repeat(batch_size, 1)

    pos_wc1 = pos_wc1.repeat(batch_size, 1, 1)
    pos_cc1 = world_to_cam_batched(pos_wc1, None, scene['camera'])['pos']
    pos_cc2 = z_to_pcl_CC_batched(pos_cc1[..., 2], camera)  # NOTE: z = -depth
    pos_wc2 = cam_to_world_batched(pos_cc2, None, camera)['pos']

    # Test Z -> (X, Y, Z)
    np.testing.assert_array_almost_equal(get_data(pos_cc1[..., :3]),
                                         get_data(pos_cc2[..., :3]))
    # Test world -> camera -> Z -> (X, Y, Z) in camera -> world
    np.testing.assert_array_almost_equal(get_data(pos_wc1[..., :3]),
                                         get_data(pos_wc2[..., :3]))
Exemple #3
0
def test_sphere_splat_render_along_ray(out_dir, cam_pos, width, height, fovy, focal_length, use_quartic,
                                       b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system
    and then render using render_splats_along_ray.
    """
    import copy
    print('render sphere along ray')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel() - 5, np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.subplot(131)
        plt.imshow(pos[..., 0].reshape((height, width)))
        plt.subplot(132)
        plt.imshow(pos[..., 1].reshape((height, width)))
        plt.subplot(133)
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    ## Convert to the camera's coordinate system
    #Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])

    pos_CC = tch_var_f(pos) #torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_CC
    large_scene['objects']['disk']['normal'] = None  # Estimate the normals tch_var_f(normals)
    # large_scene['camera']['eye'] = tch_var_f([-10., 0., 10.])
    # large_scene['camera']['eye'] = tch_var_f([2., 0., 10.])
    large_scene['camera']['eye'] = tch_var_f([-5., 0., 0.])

    # main render run
    start_time = time()
    res = render_splats_along_ray(large_scene, use_quartic=use_quartic)
    rendering_time.append(time() - start_time)

    # Test cam_to_world
    res_world = cam_to_world(res['pos'].reshape(-1, 3), res['normal'].reshape(-1, 3), large_scene['camera'])

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = large_scene['camera']['far']

    if b_display:


        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(depth, interpolation='none')
        plt.title('Depth Image')
        #plt.savefig(out_dir + '/fig_depth_orig.png')

        plt.figure()
        pos_world = get_data(res_world['pos'])
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_world')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_world')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_world')

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(pos_world[:, 0], pos_world[:, 1], pos_world[:, 2], s=1.3)
        ax.set_xlabel('x')
        ax.set_ylabel('y')

        plt.figure()
        pos_world = get_data(res['pos'].reshape(-1, 3))
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_CC')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_CC')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_CC')

    imsave(out_dir + '/img_orig.png', im)
    #imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemple #4
0
    im = get_data(res['image'])
    depth = get_data(res['depth'])

    plt.figure()
    plt.imshow(im)

    plt.figure()
    plt.imshow(depth)

    pos = get_data(res['pos'])
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=1.3)

    res_world = cam_to_world(pos=res['pos'], normal=res['normal'], camera=scene[idx]['camera'])
    pos = get_data(res_world['pos'])
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=1.3)

plt.ioff()
plt.show()
# data = np.load('res_world_twogans.npy')
#
# # pos0 = get_data(data[0]['pos'])
# # fig = plt.figure()
# # ax = fig.add_subplot(111, projection='3d')
# # ax.scatter(pos0[:, 0], pos0[:, 1], pos0[:, 2], s=1.3)
#
# fig = plt.figure()
Exemple #5
0
    def render_batch(self, batch, batch_cond,light_pos):
        """Render a batch of splats."""
        batch_size = batch.size()[0]

        # Generate camera positions on a sphere
        if batch_cond is None:
            if self.opt.full_sphere_sampling:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta, phi_range=self.opt.phi)
                # TODO: deg2grad!!
            else:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
                # TODO: deg2grad!!

        rendered_data = []
        rendered_data_depth = []
        rendered_data_cond = []
        scenes = []
        inpath = self.opt.vis_images + '/'
        inpath_xyz = self.opt.vis_xyz + '/'
        z_min = self.scene['camera']['focal_length']
        z_max = z_min + 3

        # TODO (fmannan): Move this in init. This only needs to be done once!
        # Set splats into rendering scene
        if 'sphere' in self.scene['objects']:
            del self.scene['objects']['sphere']
        if 'triangle' in self.scene['objects']:
            del self.scene['objects']['triangle']
        if 'disk' not in self.scene['objects']:
            self.scene['objects'] = {'disk': {'pos': None, 'normal': None,
                                              'material_idx': None}}
        lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0]
        self.scene['camera']['at'] = tch_var_f(lookat)
        self.scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(self.opt.splats_img_size * self.opt.splats_img_size))
        loss = 0.0
        loss_ = 0.0
        z_loss_ = 0.0
        z_norm_loss_ = 0.0
        spatial_loss_ = 0.0
        spatial_var_loss_ = 0.0
        unit_normal_loss_ = 0.0
        normal_away_from_cam_loss_ = 0.0
        image_depth_consistency_loss_ = 0.0
        for idx in range(batch_size):
            # Get splats positions and normals
            eps = 1e-3
            if self.opt.rescaled:
                z = F.relu(-batch[idx][:, 0]) + z_min
                z = ((z - z.min()) / (z.max() - z.min() + eps) *
                     (z_max - z_min) + z_min)
                pos = -z
            else:
                z = F.relu(-batch[idx][:, 0]) + z_min
                pos = -F.relu(-batch[idx][:, 0]) - z_min
            normals = batch[idx][:, 1:]

            self.scene['objects']['disk']['pos'] = pos

            # Normal estimation network and est_normals don't go together
            self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None

            # Set camera position
            if batch_cond is None:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[idx])
                else:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[0])
            else:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = batch_cond[idx]
                else:
                    self.scene['camera']['eye'] = batch_cond[0]

            self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx])
            #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            # res = render_splats_NDC(self.scene)
            res = render_splats_along_ray(self.scene,
                                          samples=self.opt.pixel_samples,
                                          normal_estimation_method='plane')

            world_tform = cam_to_world(res['pos'].view((-1, 3)),
                                       res['normal'].view((-1, 3)),
                                       self.scene['camera'])

            # Get rendered output
            res_pos = res['pos'].contiguous()
            res_pos_2D = res_pos.view(res['image'].shape)
            # The z_loss needs to be applied after supersampling
            # TODO: Enable this (currently final loss becomes NaN!!)
            # loss += torch.mean(
            #    (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 +
            #    (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2)


            res_normal = res['normal']
            # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
            # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis])
            # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
            image_depth_consistency_loss = depth_rgb_gradient_consistency(
                res['image'], res['depth'])
            unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)  # TODO: MN
            normal_away_from_cam_loss = away_from_camera_penalty(
                res_pos, res_normal)
            z_pos = res_pos[..., 2]
            z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 +
                                (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2)
            z_norm_loss = normal_consistency_cost(
                res_pos, res['normal'], norm=1)
            spatial_loss = spatial_3x3(res_pos_2D)
            spatial_var = torch.mean(res_pos[..., 0].var() +
                                     res_pos[..., 1].var() +
                                     res_pos[..., 2].var())
            spatial_var_loss = (1 / (spatial_var + 1e-4))

            loss = (self.opt.zloss * z_loss +
                    self.opt.unit_normalloss*unit_normal_loss +
                    self.opt.normal_consistency_loss_weight * z_norm_loss +
                    self.opt.spatial_var_loss_weight * spatial_var_loss +
                    self.opt.grad_img_depth_loss*image_depth_consistency_loss +
                    self.opt.spatial_loss_weight * spatial_loss)
            pos_out_ = get_data(res['pos'])
            loss_ += get_data(loss)
            z_loss_ += get_data(z_loss)
            z_norm_loss_ += get_data(z_norm_loss)
            spatial_loss_ += get_data(spatial_loss)
            spatial_var_loss_ += get_data(spatial_var_loss)
            unit_normal_loss_ += get_data(unit_normal_loss)
            normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss)
            image_depth_consistency_loss_ += get_data(
                image_depth_consistency_loss)
            normals_ = get_data(res_normal)

            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                H, W = im.shape[1:]
                target_normal_ = get_data(res['normal']).reshape((H, W, 3))
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                target_worldnormal_ = get_data(world_tform['normal']).reshape(
                    (H, W, 3))
                target_worldnormalmap_img_ = get_normalmap_image(
                    target_worldnormal_)
            if self.iterationa_no % (self.opt.save_image_interval*5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'depthmap_{:05d}.png'.format(idx)),
                       get_data(res['depth']))
                imsave((inpath + str(self.iterationa_no) +
                        'world_normalmap_{:05d}.png'.format(idx)),
                       target_worldnormalmap_img_)
            if self.iterationa_no % 1000 == 0:
                im2 = get_data(res['image'])
                depth2 = get_data(res['depth'])
                pos = get_data(res['pos'])

                out_file2 = ("pos"+".npy")
                np.save(inpath_xyz+out_file2, pos)

                out_file2 = ("im"+".npy")
                np.save(inpath_xyz+out_file2, im2)

                out_file2 = ("depth"+".npy")
                np.save(inpath_xyz+out_file2, depth2)

                # Save xyz file
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_{:05d}.xyz'.format(idx)),
                         pos=get_data(res['pos']),
                         normal=get_data(res['normal']))

                # Save xyz file in world coordinates
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_world_{:05d}.xyz'.format(idx)),
                         pos=get_data(world_tform['pos']),
                         normal=get_data(world_tform['normal']))
            if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0:
                gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis])
                gradImg = grad_spatial2d(torch.mean(im,
                                                    dim=0)[:, :, np.newaxis])
                for (gZ, gI) in zip(gradZ, gradImg):
                    loss += (self.opt.gz_gi_loss * torch.mean(torch.abs(
                                torch.abs(gZ) - torch.abs(gI))))
            # Store normalized depth into the data
            rendered_data.append(im)
            rendered_data_depth.append(im_d)
            rendered_data_cond.append(self.scene['camera']['eye'])
            scenes.append(self.scene)

        rendered_data = torch.stack(rendered_data)
        rendered_data_depth = torch.stack(rendered_data_depth)


        return rendered_data, rendered_data_depth, loss/self.opt.batchSize
Exemple #6
0
def optimize_splats_along_ray_shadow_with_normalest_test(
        out_dir,
        width,
        height,
        max_iter=100,
        lr=1e-3,
        scale=10,
        shadow=True,
        vis_only=False,
        samples=1,
        est_normals=False,
        b_generate_normals=False,
        print_interval=10,
        imsave_interval=10,
        xyz_save_interval=100):
    """A demo function to check if the differentiable renderer can optimize splats rendered along ray.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX_0
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f(
        [2, 1, 2, 1])  # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) #
    scene['camera']['at'] = tch_var_f(
        [0, 0.8, 0, 1])  # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1])  #
    scene['lights']['attenuation'] = tch_var_f([
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
    ])
    scene['materials']['coeffs'] = tch_var_f([
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.5, 0.2, 8.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
    ])

    target_res = render(scene, tiled=True, shadow=shadow)
    target_im = normalize_maxmin(target_res['image'])
    target_im.require_grad = False
    target_im_ = get_data(target_im)
    target_pos_ = get_data(target_res['pos'])
    target_normal_ = get_data(target_res['normal'])
    target_normalmap_img_ = get_normalmap_image(target_normal_)
    target_depth_ = get_data(target_res['depth'])
    print('[z_min, z_max] = [%f, %f]' %
          (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2])))
    print('[depth_min, depth_max] = [%f, %f]' %
          (np.min(target_depth_), np.max(target_depth_)))

    # world -> cam -> render_splats_along_ray
    cc_tform = world_to_cam(target_res['pos'].view(
        (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera'])
    wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'],
                               scene['camera'])

    # Check normal estimation in camera space
    pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape)
    normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape)
    plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None)
    normal_cc_normalmap = get_normalmap_image(get_data(normal_cc))
    plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est))

    pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] -
                         target_res['pos'].view((-1, 3)))
    mean_pos_diff = torch.mean(pos_diff)
    normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] -
                            target_res['normal'].view(-1, 3))
    mean_normal_diff = torch.mean(normal_diff)
    print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff)

    wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape)
    wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal))

    material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3)
    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    light_vis = tch_var_f(
        np.ones(
            (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0])))
    input_scene['objects'] = {
        'disk': {
            'pos': cc_tform['pos'],
            'normal': cc_tform['normal'],
            'material_idx': material_idx,
            'light_vis': light_vis,
        }
    }
    target_res_noshadow = render(scene, tiled=True, shadow=False)
    res = render_splats_along_ray(input_scene)
    test_img_ = get_data(normalize_maxmin(res['image']))
    test_depth_ = get_data(res['depth'])
    test_normal_ = get_data(res['normal']).reshape(test_img_.shape)
    test_normalmap_ = get_normalmap_image(test_normal_)
    im_diff = np.abs(test_img_ -
                     get_data(normalize_maxmin(target_res_noshadow['image'])))
    print('mean image diff: {}'.format(np.mean(im_diff)))
    #### PLOT
    plt.ion()
    plt.figure()
    plt.imshow(test_img_, interpolation='none')
    plt.title('Test Image')
    plt.savefig(out_dir + '/test_img.png')
    plt.figure()
    plt.imshow(test_depth_, interpolation='none')
    plt.title('Test Depth')
    plt.savefig(out_dir + '/test_depth.png')

    plt.figure()
    plt.imshow(test_normalmap_, interpolation='none')
    plt.title('Test Normals')
    plt.savefig(out_dir + '/test_normal.png')

    ####
    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_, interpolation='none')
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    plt.figure()
    plt.imshow(target_normalmap_img_, interpolation='none')
    plt.title('Normals')
    plt.savefig(out_dir + '/normal.png')

    plt.figure()
    plt.imshow(wc_cc_normal_img, interpolation='none')
    plt.title('WC_CC Normals')
    plt.savefig(out_dir + '/wc_cc_normal.png')

    plt.figure()
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.savefig(out_dir + '/normal_cc.png')

    plt.figure()
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/est_normal_cc.png')

    plt.figure()
    plt.subplot(121)
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.subplot(122)
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    input_scene['camera']['viewport'] = [
        0, 0, int(width / samples),
        int(height / samples)
    ]

    num_splats = int(width * height / (samples * samples))
    #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples)))
    z_min = scene['camera']['focal_length']
    z_max = 3

    z = -tch_var_f(
        np.ones(num_splats) * (z_min + z_max) / 2
    )  # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max)
    z.requires_grad = True

    normal_angles = tch_var_f(np.random.rand(num_splats, 2))
    normal_angles.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    light_vis = tch_var_f(
        np.ones((input_scene['lights']['pos'].shape[0], num_splats)))
    light_vis.requires_grad = True

    if vis_only:
        assert shadow is True
        opt_vars = [light_vis]
        z = cc_tform['pos'][:, 2]
        # FIXME: sph2cart
        #normals = cc_tform['normal']
    else:
        opt_vars = [z, normal_angles]
        if shadow:
            opt_vars += [light_vis]

    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8)

    h0 = plt.figure()
    h1 = plt.figure()
    h2 = plt.figure()
    h3 = plt.figure()
    h4 = plt.figure()

    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.0025, hspace=0.02)

    # Two options for z_norm_consistency
    # 1. start after N iterations
    # 2. start at the beginning and decay
    # 3. start after N iterations and decay to 0
    no_decay = lambda x: x
    exp_decay = lambda x, scale: torch.exp(-x / scale)
    linear_decay = lambda x, scale: scale / (x + 1e-6)

    spatial_var_loss_weight = 10.0  #0.0
    normal_away_from_cam_loss_weight = 0.0
    grad_img_depth_loss_weight = 1.0
    spatial_loss_weight = 2

    z_norm_weight_init = 1  # 1e-5
    z_norm_activate_iter = 0  # 1000
    decay_fn = lambda x: linear_decay(x, 100)
    loss_per_iter = []
    if b_generate_normals:
        est_normals = False
        normal_est_network = NEstNetAffine(kernel_size=3, sph=False)
        print(normal_est_network)
        normal_est_network.cuda()
    for iter in range(max_iter):
        lr_scheduler.step()
        zz = -F.relu(-z) - z_min  # torch.clamp(z, -z_max, -z_min)
        if b_generate_normals:
            normals = generate_normals(zz, scene['camera'], normal_est_network)
            #if iter > 100 and iter % 10 == 0:
            #    print(normals)
        elif not est_normals:
            phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi
            theta = F.sigmoid(
                normal_angles[:, 1]
            ) * np.pi / 2  # F.tanh(normal_angles[:, 1]) * np.pi / 2
            normals = sph2cart_unit(torch.stack((phi, theta), dim=1))

        pos = zz  # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1)

        input_scene['objects'] = {
            'disk': {
                'pos': pos,
                'normal': normalize(normals) if not est_normals else None,
                'material_idx': material_idx,
                'light_vis': torch.sigmoid(light_vis),
            }
        }
        res = render_splats_along_ray(input_scene,
                                      samples=samples,
                                      normal_estimation_method='plane')
        res_pos = res['pos']
        res_normal = res['normal']
        spatial_loss = spatial_3x3(res_pos)
        depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
        grad_img = grad_spatial2d(
            torch.mean(res['image'], dim=-1)[..., np.newaxis])
        grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
        image_depth_consistency_loss = depth_rgb_gradient_consistency(
            res['image'], res['depth'])
        unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)
        normal_away_from_cam_loss = away_from_camera_penalty(
            res_pos, res_normal)
        z_pos = res_pos[..., 2]
        z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 +
                            (10 * F.relu(torch.abs(z_pos) - z_max))**2)
        z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1)
        spatial_var = torch.mean(res_pos[..., 0].var() +
                                 res_pos[..., 1].var() + res_pos[..., 2].var())
        spatial_var_loss = (1 / (spatial_var + 1e-4))
        im_out = normalize_maxmin(res['image'])
        res_depth_ = get_data(res['depth'])

        optimizer.zero_grad()
        z_norm_weight = z_norm_weight_init * float(
            iter > z_norm_activate_iter) * decay_fn(iter -
                                                    z_norm_activate_iter)
        loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \
            z_norm_weight * z_norm_loss + \
            spatial_var_loss_weight * spatial_var_loss + \
            grad_img_depth_loss_weight * image_depth_consistency_loss
        #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \
        #spatial_loss_weight * spatial_loss

        im_out_ = get_data(im_out)
        im_out_normal_ = get_data(res['normal'])
        pos_out_ = get_data(res['pos'])

        loss_ = get_data(loss)
        z_loss_ = get_data(z_loss)
        z_norm_loss_ = get_data(z_norm_loss)
        spatial_loss_ = get_data(spatial_loss)
        spatial_var_loss_ = get_data(spatial_var_loss)
        unit_normal_loss_ = get_data(unit_normal_loss)
        normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss)
        normals_ = get_data(res_normal)
        image_depth_consistency_loss_ = get_data(image_depth_consistency_loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            z__ = pos_out_[..., 2]
            print(
                '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,'
                ' spatial_var_loss: %f, normal_away_loss: %f'
                ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' %
                (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_),
                 np.max(z_), np.min(z__), np.max(z__), z_norm_loss_,
                 spatial_var_loss_, normal_away_from_cam_loss_,
                 normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_,
                 image_depth_consistency_loss_))

        if iter % xyz_save_interval == 0 or iter == max_iter - 1:
            save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter),
                     get_data(res_pos), get_data(res_normal))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            plt.figure(h4.number)
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(121)
            #plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.title('Output')
            plt.subplot(122)
            #plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.title('Ground truth')
            # plt.subplot(223)
            # plt.plot(loss_per_iter, linewidth=2)
            # plt.xlabel('Iteration', fontsize=14)
            # plt.title('Loss', fontsize=12)
            # plt.grid(True)
            plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter)

            plt.figure(h1.number, figsize=(4, 4))
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(gs1[0])
            plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.subplot(gs1[1])
            plt.axis('off')
            plt.imshow(get_normalmap_image(im_out_normal_),
                       interpolation='none')
            ax = plt.subplot(gs1[2])
            plt.axis('off')
            im_tmp = ax.imshow(res_depth_, interpolation='none')
            # create an axes on the right side of ax. The width of cax will be 5%
            # of ax and the padding between cax and ax will be fixed at 0.05 inch.
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            plt.subplot(gs1[3])
            plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.subplot(gs1[4])
            plt.axis('off')
            plt.imshow(test_normalmap_, interpolation='none')
            ax = plt.subplot(gs1[5])
            plt.axis('off')
            im_tmp = ax.imshow(test_depth_, interpolation='none')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            W, H = input_scene['camera']['viewport'][2:]
            light_vis_ = get_data(torch.sigmoid(light_vis))
            plt.subplot(gs1[6])
            plt.axis('off')
            plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 1):
                plt.subplot(gs1[7])
                plt.axis('off')
                plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 2):
                plt.subplot(gs1[8])
                plt.axis('off')
                plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none')

            plt.savefig(out_dir + '/fig_%05d.png' % iter)

            plt.figure(h2.number)
            plt.clf()
            plt.imshow(res_depth_)
            plt.colorbar()
            plt.savefig(out_dir + '/fig_depth_%05d.png' % iter)

            plt.figure(h3.number)
            plt.clf()
            plt.imshow(z_.reshape(H, W))
            plt.colorbar()
            plt.savefig(out_dir + '/fig_z_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()