Exemplo n.º 1
0
def test_optimization(scene,
                      batch_size,
                      print_interval=20,
                      imsave_interval=20,
                      max_iter=100,
                      out_dir='./proj_tmp/'):
    """ First render using the full renderer to get the surfel position and color
    and then render using the projection layer for testing

    Returns:

    """
    from torch import optim
    import os
    import matplotlib.pyplot as plt
    plt.ion()

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    res = render_scene(scene)

    scene = make_torch_var(load_scene(scene))
    pos_wc = res['pos'].reshape(-1,
                                res['pos'].shape[-1]).repeat(batch_size, 1, 1)

    camera = scene['camera']
    camera['eye'] = camera['eye'].repeat(batch_size, 1)
    camera['at'] = camera['at'].repeat(batch_size, 1)
    camera['up'] = camera['up'].repeat(batch_size, 1)

    target_image = res['image'].repeat(batch_size, 1, 1, 1)

    input_image = target_image + 0.1 * torch.randn(target_image.size(),
                                                   device=target_image.device)
    input_image.requires_grad = True

    criterion = torch.nn.MSELoss(size_average=True).cuda()
    optimizer = optim.Adam([input_image], lr=1e-2)

    h1 = plt.figure()
    loss_per_iter = []
    for iter in range(100):
        im_est, mask = projection_renderer(pos_wc, input_image, camera)
        optimizer.zero_grad()
        loss = criterion(im_est * 255, target_image * 255)
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)
        if iter % print_interval == 0 or iter == max_iter - 1:
            print('{}. Loss: {}'.format(iter, loss_))
        if iter % imsave_interval == 0 or iter == max_iter - 1:
            im_out_ = get_data(input_image)
            im_out_ = np.uint8(255 * im_out_ / im_out_.max())
            plt.figure(h1.number)
            plt.imshow(im_out_[0].squeeze())
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        optimizer.step()
Exemplo n.º 2
0
    def tensorboard_normal_hook(self, grad):

        self.writer.add_image("normal_gradient_im",
                              torch.sqrt(torch.sum(grad**2, dim=-1)),
                              self.iterationa_no)
        self.writer.add_scalar("normal_gradient_mean_channel1",
                               get_data(torch.mean(torch.abs(grad[:, :, 0]))),
                               self.iterationa_no)
        self.writer.add_scalar("normal_gradient_mean_channel2",
                               get_data(torch.mean(torch.abs(grad[:, :, 1]))),
                               self.iterationa_no)
        self.writer.add_scalar("normal_gradient_mean_channel3",
                               get_data(torch.mean(torch.abs(grad[:, :, 2]))),
                               self.iterationa_no)
        self.writer.add_scalar("normal_gradient_mean",
                               get_data(torch.mean(grad)), self.iterationa_no)
        self.writer.add_histogram("normal_gradient_hist_channel1",
                                  grad[:, :, 0].clone().cpu().data.numpy(),
                                  self.iterationa_no)
        self.writer.add_histogram("normal_gradient_hist_channel2",
                                  grad[:, :, 1].clone().cpu().data.numpy(),
                                  self.iterationa_no)
        self.writer.add_histogram("normal_gradient_hist_channel3",
                                  grad[:, :, 2].clone().cpu().data.numpy(),
                                  self.iterationa_no)
        self.writer.add_histogram(
            "normal_gradient_hist_norm",
            torch.sqrt(torch.sum(grad**2, dim=-1)).clone().cpu().data.numpy(),
            self.iterationa_no)
Exemplo n.º 3
0
def test_raster_coordinates(scene, batch_size):
    """Test if the projected raster coordinates are correct

    Args:
        scene: Path to scene file

    Returns:
        None

    """
    res = render_scene(scene)
    scene = make_torch_var(load_scene(scene))
    pos_cc = res['pos'].reshape(1, -1, res['pos'].shape[-1])
    pos_cc = pos_cc.repeat(batch_size, 1, 1)

    camera = scene['camera']
    camera['eye'] = camera['eye'].repeat(batch_size, 1)
    camera['at'] = camera['at'].repeat(batch_size, 1)
    camera['up'] = camera['up'].repeat(batch_size, 1)

    viewport = make_list2np(camera['viewport'])
    W, H = float(viewport[2] - viewport[0]), float(viewport[3] - viewport[1])
    px_coord_idx, px_coord = project_image_coordinates(pos_cc, camera)
    xp, yp = np.meshgrid(np.linspace(0, W - 1, int(W)),
                         np.linspace(0, H - 1, int(H)))
    xp = xp.ravel()[None, ...].repeat(batch_size, axis=0)
    yp = yp.ravel()[None, ...].repeat(batch_size, axis=0)

    px_coord = torch.round(px_coord - 0.5).long()

    np.testing.assert_array_almost_equal(xp, get_data(px_coord[..., 0]))
    np.testing.assert_array_almost_equal(yp, get_data(px_coord[..., 1]))
Exemplo n.º 4
0
def test_render(scene, lfnet, num_samples=200):
    res = render_scene(scene)

    pos = get_data(res['pos'])
    normal = get_data(res['normal'])

    im = lf_renderer(pos, normal, lfnet, num_samples=num_samples)
    im_ = get_data(im)
    im_ = im_ / im_.max()

    plt.figure()
    plt.imshow(im_)
    plt.show()
Exemplo n.º 5
0
def test_transformation_consistency(scene, batch_size):
    print('test_transformation_consistency')
    res = render_scene(scene)
    scene = make_torch_var(load_scene(scene))
    pos_cc = res['pos'].reshape(-1, res['pos'].shape[-1])
    normal_cc = res['normal'].reshape(-1, res['normal'].shape[-1])
    surfels = cam_to_world(pos_cc, normal_cc, scene['camera'])
    surfels_cc = world_to_cam(surfels['pos'], surfels['normal'],
                              scene['camera'])

    np.testing.assert_array_almost_equal(get_data(pos_cc),
                                         get_data(surfels_cc['pos'][:, :3]))
    np.testing.assert_array_almost_equal(get_data(normal_cc),
                                         get_data(surfels_cc['normal'][:, :3]))
Exemplo n.º 6
0
def render_scene(scene,
                 output_folder,
                 norm_depth_image_only=False,
                 backface_culling=False,
                 plot_res=True):
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    # main render run
    res = render(scene,
                 norm_depth_image_only=norm_depth_image_only,
                 backface_culling=backface_culling)
    im = get_data(res['image'])
    im_nearest = get_data(res['nearest'])
    obj_pixel_count = get_data(
        res['obj_pixel_count']) if 'obj_pixel_count' in res else None

    if plot_res:
        plt.ion()
        plt.figure()
        plt.imshow(im)
        plt.title('Final Rendered Image')
        plt.savefig(output_folder + '/img_torch.png')

        plt.figure()
        plt.imshow(im_nearest)
        plt.title('Nearest Object Index')
        plt.colorbar()
        plt.savefig(output_folder + '/img_nearest.png')

        plt.figure()
        plt.plot(obj_pixel_count, 'r-+')
        plt.xlabel('Object Index')
        plt.ylabel('Number of Pixels')

    depth = get_data(res['depth'])
    depth[depth >= scene['camera']['far']] = np.inf
    print(depth.min(), depth.max())
    if plot_res and depth.min() != np.inf:
        plt.figure()
        plt.imshow(depth)
        plt.title('Depth Image')
        plt.savefig(output_folder + '/img_depth_torch.png')

    if plot_res:
        plt.ioff()
        plt.show()

    return res
Exemplo n.º 7
0
def uniformly_rotate_cameras(camera,
                             theta_range=[-np.pi / 2, np.pi / 2],
                             phi_range=[-np.pi, np.pi]):
    """
    Given a batch of camera positions, rotate the 'eye' properties around the 'lookat' position uniformly
    for the given angle ranges (equiangular samples)
    :param camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]

    Modifies the camera object in place

    ASSUMES A SQUARE NUMBER OF CAMERAS if both theta_range and phi_range are not None
    """
    num_cameras = get_data(camera['eye']).shape[0]

    # Sample a theta and phi to add to the current camera rotation
    if theta_range is not None and phi_range is not None:
        width = int(np.sqrt(num_cameras))
        phi_samples, theta_samples = np.meshgrid(
            np.linspace(*phi_range, width), np.linspace(*theta_range, width))
        theta_samples, phi_samples = theta_samples.ravel(), phi_samples.ravel()
    elif theta_range is not None:
        theta_samples = np.linspace(*theta_range, num_cameras)
        phi_samples = np.zeros(theta_samples.shape)
    elif phi_range is not None:
        phi_samples = np.linspace(*phi_range, num_cameras)
        theta_samples = np.zeros(phi_samples.shape)

    rotate_cameras(camera, theta=theta_samples, phi=phi_samples)
Exemplo n.º 8
0
def batch_render_random_camera(filename, cam_dist, num_views, width, height,
                         fovy, focal_length, theta_range=None, phi_range=None,
                         axis=None, angle=None, cam_pos=None, cam_lookat=None,
                         double_sided=False, use_quartic=False, b_shadow=True,
                         tile_size=None, save_image_queue=None):
    rendering_time = []

    obj = load_model(filename)
    # normalize the vertices
    v = obj['v']
    axis_range = np.max(v, axis=0) - np.min(v, axis=0)
    v = (v - np.mean(v, axis=0)) / max(axis_range)  # Normalize to make the largest spread 1
    obj['v'] = v

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length

    mesh = obj_to_triangle_spec(obj)
    faces = mesh['face']
    normals = mesh['normal']
    num_tri = faces.shape[0]

    if 'disk' in scene['objects']:
        del scene['objects']['disk']
    scene['objects'].update({'triangle': {'face': None, 'normal': None, 'material_idx': None}})
    scene['objects']['triangle']['face'] = tch_var_f(faces.tolist())
    scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist())
    scene['objects']['triangle']['material_idx'] = tch_var_l(np.zeros(num_tri, dtype=int).tolist())

    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist, num_samples=num_views,
                                        axis=axis, angle=angle,
                                        theta_range=theta_range, phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0)
    scene['camera']['at'] = tch_var_f(lookat)

    for idx in range(cam_pos.shape[0]):
        scene['camera']['eye'] = tch_var_f(cam_pos[idx])

        # main render run
        start_time = time()
        res = render(scene, tile_size=tile_size, tiled=tile_size is not None,
                     shadow=b_shadow, double_sided=double_sided,
                     use_quartic=use_quartic)
        res['suffix'] = '_{}'.format(idx)
        res['camera_far'] = scene['camera']['far']
        save_image_queue.put_nowait(get_data(res))
        rendering_time.append(time() - start_time)

    # Timing statistics
    print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time), np.std(rendering_time)))
Exemplo n.º 9
0
    def tensorboard_hook(self, grad):
        self.writer.add_scalar("z_gradient_mean",
                               get_data(torch.mean(grad[0])),
                               self.iterationa_no)
        self.writer.add_histogram("z_gradient_hist_channel", grad[0].clone().cpu().data.numpy(),self.iterationa_no)

        self.writer.add_image("z_gradient_im",
                               grad[0].view(self.opt.splats_img_size,self.opt.splats_img_size),
                               self.iterationa_no)
Exemplo n.º 10
0
def save_to_file(out_dir, queue):
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    while True:
        if not queue.empty():
            res = queue.get()
            if res is None:
                print("terminating")
                break
            suffix = res['suffix']
            im = np.uint8(255. * get_data(res['image']))
            depth = get_data(res['depth'])

            depth[depth >= res['camera_far']] = depth.min()
            im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

            imsave(out_dir + '/img' + suffix + '.png', im)
            imsave(out_dir + '/depth' + suffix + '.png', im_depth)
Exemplo n.º 11
0
    def tensorboard_z_hook(self, grad):

        self.writer.add_scalar("z_gradient_mean",
                               get_data(torch.mean(torch.abs(grad))),
                               self.iterationa_no)
        self.writer.add_histogram("z_gradient_hist_channel",
                                  grad.clone().cpu().data.numpy(),
                                  self.iterationa_no)

        self.writer.add_image("z_gradient_im", grad, self.iterationa_no)
Exemplo n.º 12
0
def rotate_cameras(camera, theta=0, phi=0):
    # Get the current camera rotation (relative to the 'lookat' position)
    camera_eye = cartesian_to_spherical(
        get_data(camera['eye']) - get_data(camera['at']))

    # Rotate the camera
    new_thetas = camera_eye[..., 0] + theta
    new_phis = camera_eye[..., 1] + phi

    # Go back to cartesian coordinates and place the camera back relative to the 'lookat' position
    camera_eye = spherical_to_cartesian(new_thetas,
                                        new_phis,
                                        radius=np.expand_dims(
                                            camera_eye[..., 2], -1))

    if camera['at'].shape[-1] == 4:
        zeros = np.zeros((camera_eye.shape[0], 1))
        camera_eye = np.concatenate((camera_eye, zeros), axis=-1)

    camera['eye'] = tch_var_f(camera_eye) + camera['at']
Exemplo n.º 13
0
def test_render_projection_consistency(scene, batch_size):
    """ First render using the full renderer to get the surfel position and color
    and then render using the projection layer for testing

    Returns:

    """
    res = render_scene(scene)

    scene = make_torch_var(load_scene(scene))
    pos_cc = res['pos'].reshape(-1,
                                res['pos'].shape[-1]).repeat(batch_size, 1, 1)

    camera = scene['camera']
    camera['eye'] = camera['eye'].repeat(batch_size, 1)
    camera['at'] = camera['at'].repeat(batch_size, 1)
    camera['up'] = camera['up'].repeat(batch_size, 1)

    image = res['image'].repeat(batch_size, 1, 1, 1)

    im, mask = projection_renderer(pos_cc, image, camera)
    diff = np.abs(get_data(image) - get_data(im))
    np.testing.assert_(diff.sum() < 1e-10, 'Non-zero difference.')
Exemplo n.º 14
0
def randomly_rotate_cameras(camera,
                            theta_range=[-np.pi / 2, np.pi / 2],
                            phi_range=[-np.pi, np.pi]):
    """
    Given a batch of camera positions, rotate the 'eye' properties around the 'lookat' position
    :param camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]
    
    Modifies the camera object in place
    """

    # Sample a theta and phi to add to the current camera rotation
    theta_samples, phi_samples = uniform_sample_sphere_patch(
        get_data(camera['eye']).shape[0], theta_range, phi_range)

    # Flip the theta samples to allow users to enter a positive value for theta range
    rotate_cameras(camera, theta=-theta_samples, phi=phi_samples)
Exemplo n.º 15
0
def test_depth_to_world_consistency(scene, batch_size):
    res = render_scene(scene)

    scene = make_torch_var(load_scene(scene))

    pos_wc1 = res['pos'].reshape(-1, res['pos'].shape[-1])

    pos_cc1 = world_to_cam(pos_wc1, None, scene['camera'])['pos']
    # First test the non-batched z_to_pcl_CC method:
    # NOTE: z_to_pcl_CC takes as input the Z dimension in the camera coordinate
    # and gets the full (X, Y, Z) in the camera coordinate.
    pos_cc2 = z_to_pcl_CC(pos_cc1[:, 2], scene['camera'])
    pos_wc2 = cam_to_world(pos_cc2, None, scene['camera'])['pos']

    # Test Z -> (X, Y, Z)
    np.testing.assert_array_almost_equal(get_data(pos_cc1[..., :3]),
                                         get_data(pos_cc2[..., :3]))
    # Test world -> camera -> Z -> (X, Y, Z) in camera -> world
    np.testing.assert_array_almost_equal(get_data(pos_wc1[..., :3]),
                                         get_data(pos_wc2[..., :3]))

    # Then test the batched version:
    camera = scene['camera']
    camera['eye'] = camera['eye'].repeat(batch_size, 1)
    camera['at'] = camera['at'].repeat(batch_size, 1)
    camera['up'] = camera['up'].repeat(batch_size, 1)

    pos_wc1 = pos_wc1.repeat(batch_size, 1, 1)
    pos_cc1 = world_to_cam_batched(pos_wc1, None, scene['camera'])['pos']
    pos_cc2 = z_to_pcl_CC_batched(pos_cc1[..., 2], camera)  # NOTE: z = -depth
    pos_wc2 = cam_to_world_batched(pos_cc2, None, camera)['pos']

    # Test Z -> (X, Y, Z)
    np.testing.assert_array_almost_equal(get_data(pos_cc1[..., :3]),
                                         get_data(pos_cc2[..., :3]))
    # Test world -> camera -> Z -> (X, Y, Z) in camera -> world
    np.testing.assert_array_almost_equal(get_data(pos_wc1[..., :3]),
                                         get_data(pos_wc2[..., :3]))
Exemplo n.º 16
0
def test_depth_optimization(scene,
                            batch_size,
                            print_interval=20,
                            imsave_interval=20,
                            max_iter=100,
                            out_dir='./proj_tmp_depth-fast/'):
    """ First render using the full renderer to get the surfel position and color
    and then render using the projection layer for testing

    Returns:

    """
    from torch import optim
    import torchvision
    import os
    import matplotlib
    matplotlib.use('agg')
    import matplotlib.pyplot as plt
    import imageio
    from PIL import Image
    plt.ion()

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    use_chair = True
    use_fast_projection = True
    use_masked_loss = True
    use_same_render_method_for_target = False
    lr = 1e-2

    res = render_scene(scene)

    scene = make_torch_var(load_scene(scene))
    # true_pos_wc = res['pos'].reshape(-1, res['pos'].shape[-1]).repeat(batch_size, 1, 1)
    true_input_img = res['image'].unsqueeze(0).repeat(batch_size, 1, 1, 1)

    if use_chair:
        camera = scene['camera']
        camera['eye'] = tch_var_f([0, 0, 4, 1]).repeat(batch_size, 1)
        camera['at'] = tch_var_f([0, 0, 0, 1]).repeat(batch_size, 1)
        camera['up'] = tch_var_f([0, 1, 0, 0]).repeat(batch_size, 1)
    else:
        camera = scene['camera']
        camera['eye'] = camera['eye'].repeat(batch_size, 1)
        camera['at'] = camera['at'].repeat(batch_size, 1)
        camera['up'] = camera['up'].repeat(batch_size, 1)

    if use_chair:
        chair_0 = Image.open('object-0-azimuth-000.006-rgb.png')
        true_input_img = torchvision.transforms.ToTensor()(chair_0).to(
            true_input_img.device).unsqueeze(0)
        true_input_img = true_input_img.permute(0, 2, 3, 1)
        camera['viewport'] = [0, 0, 128, 128]  # TODO don't hardcode

    true_depth = res['depth'].repeat(batch_size, 1, 1).reshape(
        batch_size, -1)  # Not relevant if 'use_chair' is True
    # depth = true_depth.clone() + 0.1 * torch.randn_like(true_depth)
    depth = 0.1 * torch.randn(
        batch_size,
        true_input_img.size(-2) * true_input_img.size(-3),
        device=true_input_img.device,
        dtype=torch.float)
    depth.requires_grad = True

    if use_chair:
        target_angle = np.deg2rad(20)
    else:
        target_angle = -np.pi / 12

    rotated_camera = copy.deepcopy(camera)
    # randomly_rotate_cameras(rotated_camera, theta_range=[-np.pi / 16, np.pi / 16], phi_range=[-np.pi / 8, np.pi / 8])
    randomly_rotate_cameras(rotated_camera,
                            theta_range=[0, 1e-10],
                            phi_range=[target_angle, target_angle + 1e-10])

    if use_chair:
        target_image = Image.open('object-0-azimuth-020.006-rgb.png')
        target_image = torchvision.transforms.ToTensor()(target_image).to(
            true_input_img.device).unsqueeze(0)
        target_image = target_image.permute(0, 2, 3, 1)
        target_mask = torch.ones(*target_image.size()[:-1],
                                 1,
                                 device=target_image.device,
                                 dtype=torch.float)
    else:
        true_pos_cc = z_to_pcl_CC_batched(-true_depth,
                                          camera)  # NOTE: z = -depth
        true_pos_wc = cam_to_world_batched(true_pos_cc, None, camera)['pos']

        if use_same_render_method_for_target:
            if use_fast_projection:
                target_image, proj_out = projection_renderer_differentiable_fast(
                    true_pos_wc, true_input_img, rotated_camera)
                target_mask = proj_out['mask']
            else:
                target_image, target_mask = projection_renderer_differentiable(
                    true_pos_wc, true_input_img, rotated_camera)
            # target_image, _ = projection_renderer(true_pos_wc, true_input_img, rotated_camera)
        else:
            scene2 = copy.deepcopy(scene)
            scene['camera'] = copy.deepcopy(rotated_camera)
            scene['camera']['eye'] = scene['camera']['eye'][0]
            scene['camera']['at'] = scene['camera']['at'][0]
            scene['camera']['up'] = scene['camera']['up'][0]
            target_image = render(scene)['image'].unsqueeze(0).repeat(
                batch_size, 1, 1, 1)
            target_mask = torch.ones(*target_image.size()[:-1],
                                     1,
                                     device=target_image.device,
                                     dtype=torch.float)

    input_image = true_input_img  # + 0.1 * torch.randn(target_image.size(), device=target_image.device)

    criterion = torch.nn.MSELoss(reduction='none').cuda()
    optimizer = optim.Adam([depth], lr=1e-2)

    h1 = plt.figure()
    # fig_imgs = []
    depth_imgs = []
    out_imgs = []

    imageio.imsave(out_dir + '/optimization_input_image.png',
                   input_image[0].cpu().numpy())
    imageio.imsave(out_dir + '/optimization_target_image.png',
                   target_image[0].cpu().numpy())
    if not use_chair:
        imageio.imsave(
            out_dir + '/optimization_target_depth.png',
            true_depth.view(*input_image.size()[:-1], 1)[0].cpu().numpy())

    loss_per_iter = []
    for iter in range(500):
        optimizer.zero_grad()
        # depth_in = torch.nn.functional.softplus(depth + 3)
        depth_in = depth + 4
        pos_cc = z_to_pcl_CC_batched(-depth_in, camera)  # NOTE: z = -depth
        pos_wc = cam_to_world_batched(pos_cc, None, camera)['pos']
        if use_fast_projection:
            im_est, proj_out = projection_renderer_differentiable_fast(
                pos_wc, input_image, rotated_camera)
            im_mask = proj_out['mask']
        else:
            im_est, im_mask = projection_renderer_differentiable(
                pos_wc, input_image, rotated_camera)
        # im_est, mask = projection_renderer(pos_wc, input_image, rotated_camera)
        if use_masked_loss:
            loss = torch.sum(target_mask * im_mask * criterion(
                im_est * 255, target_image * 255)) / torch.sum(
                    target_mask * im_mask)
        else:
            loss = criterion(im_est * 255, target_image * 255).mean()
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)
        if iter % print_interval == 0 or iter == max_iter - 1:
            print('{}. Loss: {}'.format(iter, loss_))
        if iter % imsave_interval == 0 or iter == max_iter - 1:
            # Input image
            # im_out_ = get_data(input_image.detach())
            # im_out_ = np.uint8(255 * im_out_ / im_out_.max())
            # fig = plt.figure(h1.number)
            # plot = fig.add_subplot(111)
            # plot.imshow(im_out_[0].squeeze())
            # plot.set_title('%d. loss= %f' % (iter, loss_))
            # # plt.savefig(out_dir + '/fig_%05d.png' % iter)
            # fig_data = np.array(fig.canvas.renderer._renderer)
            # fig_imgs.append(fig_data)

            # Depth
            im_out_ = get_data(
                depth_in.view(*input_image.size()[:-1], 1).detach())
            im_out_ = np.uint8(255 * im_out_ / im_out_.max())
            fig = plt.figure(h1.number)
            plot = fig.add_subplot(111)
            plot.imshow(im_out_[0].squeeze())
            plot.set_title('%d. loss= %f' % (iter, loss_))
            # plt.savefig(out_dir + '/fig_%05d.png' % iter)
            depth_data = np.array(fig.canvas.renderer._renderer)
            depth_imgs.append(depth_data)

            # Output image
            im_out_ = get_data(im_est.detach())
            im_out_ = np.uint8(255 * im_out_ / im_out_.max())
            fig = plt.figure(h1.number)
            plot = fig.add_subplot(111)
            plot.imshow(im_out_[0].squeeze())
            plot.set_title('%d. loss= %f' % (iter, loss_))
            # plt.savefig(out_dir + '/fig_%05d.png' % iter)
            out_data = np.array(fig.canvas.renderer._renderer)
            out_imgs.append(out_data)

        loss.backward()
        optimizer.step()

    # imageio.mimsave(out_dir + '/optimization_anim_in.gif', fig_imgs)
    imageio.mimsave(out_dir + '/optimization_anim_depth.gif', depth_imgs)
    imageio.mimsave(out_dir + '/optimization_anim_out.gif', out_imgs)
Exemplo n.º 17
0
def optimize_splats_along_ray_shadow_with_normalest_test(
        out_dir,
        width,
        height,
        max_iter=100,
        lr=1e-3,
        scale=10,
        shadow=True,
        vis_only=False,
        samples=1,
        est_normals=False,
        b_generate_normals=False,
        print_interval=10,
        imsave_interval=10,
        xyz_save_interval=100):
    """A demo function to check if the differentiable renderer can optimize splats rendered along ray.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX_0
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f(
        [2, 1, 2, 1])  # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) #
    scene['camera']['at'] = tch_var_f(
        [0, 0.8, 0, 1])  # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1])  #
    scene['lights']['attenuation'] = tch_var_f([
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
    ])
    scene['materials']['coeffs'] = tch_var_f([
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.5, 0.2, 8.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
    ])

    target_res = render(scene, tiled=True, shadow=shadow)
    target_im = normalize_maxmin(target_res['image'])
    target_im.require_grad = False
    target_im_ = get_data(target_im)
    target_pos_ = get_data(target_res['pos'])
    target_normal_ = get_data(target_res['normal'])
    target_normalmap_img_ = get_normalmap_image(target_normal_)
    target_depth_ = get_data(target_res['depth'])
    print('[z_min, z_max] = [%f, %f]' %
          (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2])))
    print('[depth_min, depth_max] = [%f, %f]' %
          (np.min(target_depth_), np.max(target_depth_)))

    # world -> cam -> render_splats_along_ray
    cc_tform = world_to_cam(target_res['pos'].view(
        (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera'])
    wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'],
                               scene['camera'])

    # Check normal estimation in camera space
    pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape)
    normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape)
    plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None)
    normal_cc_normalmap = get_normalmap_image(get_data(normal_cc))
    plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est))

    pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] -
                         target_res['pos'].view((-1, 3)))
    mean_pos_diff = torch.mean(pos_diff)
    normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] -
                            target_res['normal'].view(-1, 3))
    mean_normal_diff = torch.mean(normal_diff)
    print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff)

    wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape)
    wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal))

    material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3)
    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    light_vis = tch_var_f(
        np.ones(
            (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0])))
    input_scene['objects'] = {
        'disk': {
            'pos': cc_tform['pos'],
            'normal': cc_tform['normal'],
            'material_idx': material_idx,
            'light_vis': light_vis,
        }
    }
    target_res_noshadow = render(scene, tiled=True, shadow=False)
    res = render_splats_along_ray(input_scene)
    test_img_ = get_data(normalize_maxmin(res['image']))
    test_depth_ = get_data(res['depth'])
    test_normal_ = get_data(res['normal']).reshape(test_img_.shape)
    test_normalmap_ = get_normalmap_image(test_normal_)
    im_diff = np.abs(test_img_ -
                     get_data(normalize_maxmin(target_res_noshadow['image'])))
    print('mean image diff: {}'.format(np.mean(im_diff)))
    #### PLOT
    plt.ion()
    plt.figure()
    plt.imshow(test_img_, interpolation='none')
    plt.title('Test Image')
    plt.savefig(out_dir + '/test_img.png')
    plt.figure()
    plt.imshow(test_depth_, interpolation='none')
    plt.title('Test Depth')
    plt.savefig(out_dir + '/test_depth.png')

    plt.figure()
    plt.imshow(test_normalmap_, interpolation='none')
    plt.title('Test Normals')
    plt.savefig(out_dir + '/test_normal.png')

    ####
    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_, interpolation='none')
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    plt.figure()
    plt.imshow(target_normalmap_img_, interpolation='none')
    plt.title('Normals')
    plt.savefig(out_dir + '/normal.png')

    plt.figure()
    plt.imshow(wc_cc_normal_img, interpolation='none')
    plt.title('WC_CC Normals')
    plt.savefig(out_dir + '/wc_cc_normal.png')

    plt.figure()
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.savefig(out_dir + '/normal_cc.png')

    plt.figure()
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/est_normal_cc.png')

    plt.figure()
    plt.subplot(121)
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.subplot(122)
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    input_scene['camera']['viewport'] = [
        0, 0, int(width / samples),
        int(height / samples)
    ]

    num_splats = int(width * height / (samples * samples))
    #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples)))
    z_min = scene['camera']['focal_length']
    z_max = 3

    z = -tch_var_f(
        np.ones(num_splats) * (z_min + z_max) / 2
    )  # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max)
    z.requires_grad = True

    normal_angles = tch_var_f(np.random.rand(num_splats, 2))
    normal_angles.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    light_vis = tch_var_f(
        np.ones((input_scene['lights']['pos'].shape[0], num_splats)))
    light_vis.requires_grad = True

    if vis_only:
        assert shadow is True
        opt_vars = [light_vis]
        z = cc_tform['pos'][:, 2]
        # FIXME: sph2cart
        #normals = cc_tform['normal']
    else:
        opt_vars = [z, normal_angles]
        if shadow:
            opt_vars += [light_vis]

    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8)

    h0 = plt.figure()
    h1 = plt.figure()
    h2 = plt.figure()
    h3 = plt.figure()
    h4 = plt.figure()

    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.0025, hspace=0.02)

    # Two options for z_norm_consistency
    # 1. start after N iterations
    # 2. start at the beginning and decay
    # 3. start after N iterations and decay to 0
    no_decay = lambda x: x
    exp_decay = lambda x, scale: torch.exp(-x / scale)
    linear_decay = lambda x, scale: scale / (x + 1e-6)

    spatial_var_loss_weight = 10.0  #0.0
    normal_away_from_cam_loss_weight = 0.0
    grad_img_depth_loss_weight = 1.0
    spatial_loss_weight = 2

    z_norm_weight_init = 1  # 1e-5
    z_norm_activate_iter = 0  # 1000
    decay_fn = lambda x: linear_decay(x, 100)
    loss_per_iter = []
    if b_generate_normals:
        est_normals = False
        normal_est_network = NEstNetAffine(kernel_size=3, sph=False)
        print(normal_est_network)
        normal_est_network.cuda()
    for iter in range(max_iter):
        lr_scheduler.step()
        zz = -F.relu(-z) - z_min  # torch.clamp(z, -z_max, -z_min)
        if b_generate_normals:
            normals = generate_normals(zz, scene['camera'], normal_est_network)
            #if iter > 100 and iter % 10 == 0:
            #    print(normals)
        elif not est_normals:
            phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi
            theta = F.sigmoid(
                normal_angles[:, 1]
            ) * np.pi / 2  # F.tanh(normal_angles[:, 1]) * np.pi / 2
            normals = sph2cart_unit(torch.stack((phi, theta), dim=1))

        pos = zz  # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1)

        input_scene['objects'] = {
            'disk': {
                'pos': pos,
                'normal': normalize(normals) if not est_normals else None,
                'material_idx': material_idx,
                'light_vis': torch.sigmoid(light_vis),
            }
        }
        res = render_splats_along_ray(input_scene,
                                      samples=samples,
                                      normal_estimation_method='plane')
        res_pos = res['pos']
        res_normal = res['normal']
        spatial_loss = spatial_3x3(res_pos)
        depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
        grad_img = grad_spatial2d(
            torch.mean(res['image'], dim=-1)[..., np.newaxis])
        grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
        image_depth_consistency_loss = depth_rgb_gradient_consistency(
            res['image'], res['depth'])
        unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)
        normal_away_from_cam_loss = away_from_camera_penalty(
            res_pos, res_normal)
        z_pos = res_pos[..., 2]
        z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 +
                            (10 * F.relu(torch.abs(z_pos) - z_max))**2)
        z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1)
        spatial_var = torch.mean(res_pos[..., 0].var() +
                                 res_pos[..., 1].var() + res_pos[..., 2].var())
        spatial_var_loss = (1 / (spatial_var + 1e-4))
        im_out = normalize_maxmin(res['image'])
        res_depth_ = get_data(res['depth'])

        optimizer.zero_grad()
        z_norm_weight = z_norm_weight_init * float(
            iter > z_norm_activate_iter) * decay_fn(iter -
                                                    z_norm_activate_iter)
        loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \
            z_norm_weight * z_norm_loss + \
            spatial_var_loss_weight * spatial_var_loss + \
            grad_img_depth_loss_weight * image_depth_consistency_loss
        #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \
        #spatial_loss_weight * spatial_loss

        im_out_ = get_data(im_out)
        im_out_normal_ = get_data(res['normal'])
        pos_out_ = get_data(res['pos'])

        loss_ = get_data(loss)
        z_loss_ = get_data(z_loss)
        z_norm_loss_ = get_data(z_norm_loss)
        spatial_loss_ = get_data(spatial_loss)
        spatial_var_loss_ = get_data(spatial_var_loss)
        unit_normal_loss_ = get_data(unit_normal_loss)
        normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss)
        normals_ = get_data(res_normal)
        image_depth_consistency_loss_ = get_data(image_depth_consistency_loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            z__ = pos_out_[..., 2]
            print(
                '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,'
                ' spatial_var_loss: %f, normal_away_loss: %f'
                ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' %
                (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_),
                 np.max(z_), np.min(z__), np.max(z__), z_norm_loss_,
                 spatial_var_loss_, normal_away_from_cam_loss_,
                 normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_,
                 image_depth_consistency_loss_))

        if iter % xyz_save_interval == 0 or iter == max_iter - 1:
            save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter),
                     get_data(res_pos), get_data(res_normal))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            plt.figure(h4.number)
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(121)
            #plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.title('Output')
            plt.subplot(122)
            #plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.title('Ground truth')
            # plt.subplot(223)
            # plt.plot(loss_per_iter, linewidth=2)
            # plt.xlabel('Iteration', fontsize=14)
            # plt.title('Loss', fontsize=12)
            # plt.grid(True)
            plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter)

            plt.figure(h1.number, figsize=(4, 4))
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(gs1[0])
            plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.subplot(gs1[1])
            plt.axis('off')
            plt.imshow(get_normalmap_image(im_out_normal_),
                       interpolation='none')
            ax = plt.subplot(gs1[2])
            plt.axis('off')
            im_tmp = ax.imshow(res_depth_, interpolation='none')
            # create an axes on the right side of ax. The width of cax will be 5%
            # of ax and the padding between cax and ax will be fixed at 0.05 inch.
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            plt.subplot(gs1[3])
            plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.subplot(gs1[4])
            plt.axis('off')
            plt.imshow(test_normalmap_, interpolation='none')
            ax = plt.subplot(gs1[5])
            plt.axis('off')
            im_tmp = ax.imshow(test_depth_, interpolation='none')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            W, H = input_scene['camera']['viewport'][2:]
            light_vis_ = get_data(torch.sigmoid(light_vis))
            plt.subplot(gs1[6])
            plt.axis('off')
            plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 1):
                plt.subplot(gs1[7])
                plt.axis('off')
                plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 2):
                plt.subplot(gs1[8])
                plt.axis('off')
                plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none')

            plt.savefig(out_dir + '/fig_%05d.png' % iter)

            plt.figure(h2.number)
            plt.clf()
            plt.imshow(res_depth_)
            plt.colorbar()
            plt.savefig(out_dir + '/fig_depth_%05d.png' % iter)

            plt.figure(h3.number)
            plt.clf()
            plt.imshow(z_.reshape(H, W))
            plt.colorbar()
            plt.savefig(out_dir + '/fig_z_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Exemplo n.º 18
0
def render_sphere_world(out_dir, cam_pos, radius, width, height, fovy, focal_length,
                        b_display=False):
    """
    Generate z positions on a grid fixed inside the view frustum in the world coordinate system. Place the camera and
    choose the camera's field of view so that the side of the square touches the frustum.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height
    r = np.ones(num_samples) * radius

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['radius'] = tch_var_f(r)
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    large_scene['objects']['disk']['pos'] = tch_var_f(pos)
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    large_scene['camera']['eye'] = tch_var_f(cam_pos)

    # main render run
    start_time = time()
    res = render(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 19
0
    def get_real_samples(self):
        """Get a real sample."""
        # Define the camera poses
        if not self.opt.same_view:
            if self.opt.full_sphere_sampling:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta,
                    phi_range=self.opt.phi)
            else:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
        if self.opt.full_sphere_sampling_light:
            self.light_pos1 = uniform_sample_sphere(
                radius=self.opt.cam_dist,
                num_samples=self.opt.batchSize,
                axis=self.opt.axis,
                angle=np.deg2rad(44),
                theta_range=self.opt.theta,
                phi_range=self.opt.phi)
            # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
            #                                      axis=self.opt.axis, angle=np.deg2rad(40),
            #                                      theta_range=self.opt.theta, phi_range=self.opt.phi)
        else:
            print("inbox")
            light_eps = 0.15
            self.light_pos1 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps
            self.light_pos2 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps

            # TODO: deg2rad in all the angles????

        # Create a splats rendering scene
        large_scene = create_scene(self.opt.width, self.opt.height,
                                   self.opt.fovy, self.opt.focal_length,
                                   self.opt.n_splats)
        lookat = self.opt.at if self.opt.at is not None else [
            0.0, 0.0, 0.0, 1.0
        ]
        large_scene['camera']['at'] = tch_var_f(lookat)

        # Render scenes
        data, data_depth, data_normal, data_cond = [], [], [], []
        inpath = self.opt.vis_images + '/'
        inpath2 = self.opt.vis_input + '/'
        for idx in range(self.opt.batchSize):
            # Save the splats into the rendering scene
            if self.opt.use_mesh:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'disk' in large_scene['objects']:
                    del large_scene['objects']['disk']
                if 'triangle' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'triangle': {
                            'face': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                samples = self.get_samples()

                large_scene['objects']['triangle']['material_idx'] = tch_var_l(
                    np.zeros(samples['mesh']['face'][0].shape[0],
                             dtype=int).tolist())
                large_scene['objects']['triangle']['face'] = Variable(
                    samples['mesh']['face'][0].cuda(), requires_grad=False)
                large_scene['objects']['triangle']['normal'] = Variable(
                    samples['mesh']['normal'][0].cuda(), requires_grad=False)
            else:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'triangle' in large_scene['objects']:
                    del large_scene['objects']['triangle']
                if 'disk' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'disk': {
                            'pos': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                large_scene['objects']['disk']['radius'] = tch_var_f(
                    np.ones(self.opt.n_splats) * self.opt.splats_radius)
                large_scene['objects']['disk']['material_idx'] = tch_var_l(
                    np.zeros(self.opt.n_splats, dtype=int).tolist())
                large_scene['objects']['disk']['pos'] = Variable(
                    samples['splats']['pos'][idx].cuda(), requires_grad=False)
                large_scene['objects']['disk']['normal'] = Variable(
                    samples['splats']['normal'][idx].cuda(),
                    requires_grad=False)

            # Set camera position
            if not self.opt.same_view:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx])
            else:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0])

            large_scene['lights']['pos'][0, :3] = tch_var_f(
                self.light_pos1[idx])
            #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            res = render(large_scene,
                         norm_depth_image_only=self.opt.norm_depth_image_only,
                         double_sided=True,
                         use_quartic=self.opt.use_quartic)

            # Get rendered output
            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                im_ = get_data(res['image'])
                #im_img_ = get_normalmap_image(im_)
                target_normal_ = get_data(res['normal'])
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                im_n = tch_var_f(target_normalmap_img_).view(
                    im.shape[1], im.shape[2], 3).permute(2, 0, 1)

            # Add depth image to the output structure
            file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.txt'.format(idx)
            text_file = open(file_name, "w")
            text_file.write('%s\n' % (str(large_scene['camera']['eye'].data)))
            text_file.close()
            out_file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.npy'.format(idx)
            np.save(out_file_name, self.cam_pos[idx])
            out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_light{:05d}.npy'.format(idx)
            np.save(out_file_name2, self.light_pos1[idx])
            out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_im{:05d}.npy'.format(idx)
            np.save(out_file_name3, get_data(res['image']))
            out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_depth{:05d}.npy'.format(idx)
            np.save(out_file_name4, get_data(res['depth']))
            out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_normal{:05d}.npy'.format(idx)
            np.save(out_file_name5, get_data(res['normal']))

            if self.iterationa_no % (self.opt.save_image_interval * 5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'real_normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'real_depth_{:05d}.png'.format(idx)), get_data(depth))
                # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d)
                # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_)
            data.append(im)
            data_depth.append(im_d)
            data_normal.append(im_n)
            data_cond.append(large_scene['camera']['eye'])
        # Stack real samples
        real_samples = torch.stack(data)
        real_samples_depth = torch.stack(data_depth)
        real_samples_normal = torch.stack(data_normal)
        real_samples_cond = torch.stack(data_cond)
        self.batch_size = real_samples.size(0)
        if not self.opt.no_cuda:
            real_samples = real_samples.cuda()
            real_samples_depth = real_samples_depth.cuda()
            real_samples_normal = real_samples_normal.cuda()
            real_samples_cond = real_samples_cond.cuda()

        # Set input/output variables

        self.input.resize_as_(real_samples.data).copy_(real_samples.data)
        self.input_depth.resize_as_(real_samples_depth.data).copy_(
            real_samples_depth.data)
        self.input_normal.resize_as_(real_samples_normal.data).copy_(
            real_samples_normal.data)
        self.input_cond.resize_as_(real_samples_cond.data).copy_(
            real_samples_cond.data)
        self.label.resize_(self.batch_size).fill_(self.real_label)
        # TODO: Remove Variables
        self.inputv = Variable(self.input)
        self.inputv_depth = Variable(self.input_depth)
        self.inputv_normal = Variable(self.input_normal)
        self.inputv_cond = Variable(self.input_cond)
        self.labelv = Variable(self.label)
Exemplo n.º 20
0
def test_sphere_splat_NDC(out_dir, cam_pos, width, height, fovy, focal_length,  b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system and to
    NDC and then render using render_splat_NDC.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel(), np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    # Convert to the camera's coordinate system
    Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])
    Mproj = perspective(fovy=large_scene['camera']['fovy'], aspect=width/height, near=large_scene['camera']['near'],
                        far=large_scene['camera']['far'])

    pos_CC = torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))
    pos_NDC = torch.matmul(pos_CC, Mproj.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_NDC / pos_NDC[..., 3][:, np.newaxis]
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    # main render run
    start_time = time()
    res = render_splats_NDC(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 21
0
def test_sphere_splat_render_along_ray(out_dir, cam_pos, width, height, fovy, focal_length, use_quartic,
                                       b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system
    and then render using render_splats_along_ray.
    """
    import copy
    print('render sphere along ray')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel() - 5, np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.subplot(131)
        plt.imshow(pos[..., 0].reshape((height, width)))
        plt.subplot(132)
        plt.imshow(pos[..., 1].reshape((height, width)))
        plt.subplot(133)
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    ## Convert to the camera's coordinate system
    #Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])

    pos_CC = tch_var_f(pos) #torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_CC
    large_scene['objects']['disk']['normal'] = None  # Estimate the normals tch_var_f(normals)
    # large_scene['camera']['eye'] = tch_var_f([-10., 0., 10.])
    # large_scene['camera']['eye'] = tch_var_f([2., 0., 10.])
    large_scene['camera']['eye'] = tch_var_f([-5., 0., 0.])

    # main render run
    start_time = time()
    res = render_splats_along_ray(large_scene, use_quartic=use_quartic)
    rendering_time.append(time() - start_time)

    # Test cam_to_world
    res_world = cam_to_world(res['pos'].reshape(-1, 3), res['normal'].reshape(-1, 3), large_scene['camera'])

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = large_scene['camera']['far']

    if b_display:


        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(depth, interpolation='none')
        plt.title('Depth Image')
        #plt.savefig(out_dir + '/fig_depth_orig.png')

        plt.figure()
        pos_world = get_data(res_world['pos'])
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_world')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_world')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_world')

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(pos_world[:, 0], pos_world[:, 1], pos_world[:, 2], s=1.3)
        ax.set_xlabel('x')
        ax.set_ylabel('y')

        plt.figure()
        pos_world = get_data(res['pos'].reshape(-1, 3))
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_CC')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_CC')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_CC')

    imsave(out_dir + '/img_orig.png', im)
    #imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 22
0
def render_random_camera(filename,
                         out_dir,
                         num_samples,
                         radius,
                         cam_dist,
                         num_views,
                         width,
                         height,
                         fovy,
                         focal_length,
                         norm_depth_image_only,
                         theta_range=None,
                         phi_range=None,
                         axis=None,
                         angle=None,
                         cam_pos=None,
                         cam_lookat=None,
                         use_mesh=False,
                         double_sided=False,
                         use_quartic=False,
                         b_shadow=True,
                         b_display=False,
                         tile_size=None):
    """
    Randomly generate N samples on a surface and render them. The samples include position and normal, the radius is set
    to a constant.
    """
    sampling_time = []
    rendering_time = []

    obj = load_model(filename)
    # normalize the vertices
    v = obj['v']
    axis_range = np.max(v, axis=0) - np.min(v, axis=0)
    v = (v - np.mean(v, axis=0)) / max(
        axis_range)  # Normalize to make the largest spread 1
    obj['v'] = v

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    r = np.ones(num_samples) * radius

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    if use_mesh:
        mesh = obj_to_triangle_spec(obj)
        faces = mesh['face']
        normals = mesh['normal']
        num_tri = faces.shape[0]
        # if faces.shape[-1] == 3:
        #     faces = np.concatenate((faces, np.ones((faces.shape[0], faces.shape[1], 1))), axis=-1).tolist()
        # if normals.shape[-1] == 3:
        #     normals = np.concatenate((normals, ))
        if 'disk' in scene['objects']:
            del scene['objects']['disk']
        scene['objects'].update(
            {'triangle': {
                'face': None,
                'normal': None,
                'material_idx': None
            }})
        scene['objects']['triangle']['face'] = tch_var_f(faces.tolist())
        scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist())
        scene['objects']['triangle']['material_idx'] = tch_var_l(
            np.zeros(num_tri, dtype=int).tolist())
    else:
        scene['objects']['disk']['radius'] = tch_var_f(r)
        scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(num_samples, dtype=int).tolist())
    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist,
                                        num_samples=num_views,
                                        axis=axis,
                                        angle=angle,
                                        theta_range=theta_range,
                                        phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0)
    scene['camera']['at'] = tch_var_f(lookat)

    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    for idx in range(cam_pos.shape[0]):
        if not use_mesh:
            start_time = time()
            v, vn = uniform_sample_mesh(obj, num_samples=num_samples)
            sampling_time.append(time() - start_time)

            scene['objects']['disk']['pos'] = tch_var_f(v)
            scene['objects']['disk']['normal'] = tch_var_f(vn)

        scene['camera']['eye'] = tch_var_f(cam_pos[idx])
        suffix = '_{}'.format(idx)

        # main render run
        start_time = time()
        res = render(scene,
                     tile_size=tile_size,
                     tiled=tile_size is not None,
                     shadow=b_shadow,
                     norm_depth_image_only=norm_depth_image_only,
                     double_sided=double_sided,
                     use_quartic=use_quartic)
        rendering_time.append(time() - start_time)

        im = np.uint8(255. * get_data(res['image']))
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img' + suffix + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth' + suffix + '.png')

        imsave(out_dir + '/img' + suffix + '.png', im)
        imsave(out_dir + '/depth' + suffix + '.png', im_depth)

    # Timing statistics
    if not use_mesh:
        print('Sampling time mean: {}s, std: {}s'.format(
            np.mean(sampling_time), np.std(sampling_time)))
    print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time),
                                                      np.std(rendering_time)))
Exemplo n.º 23
0
def render_sphere(out_dir,
                  cam_pos,
                  radius,
                  width,
                  height,
                  fovy,
                  focal_length,
                  num_views,
                  std_z=0.01,
                  std_normal=0.01,
                  b_display=False):
    """
    Randomly generate N samples on a surface and render them. The samples include position and normal, the radius is set
    to a constant.
    """
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height
    r = np.ones(num_samples) * radius

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    scene['objects']['disk']['radius'] = tch_var_f(r)
    scene['objects']['disk']['material_idx'] = tch_var_l(
        np.zeros(num_samples, dtype=int).tolist())
    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x**2 + y**2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x**2 + y**2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    scene['objects']['disk']['pos'] = tch_var_f(pos)
    scene['objects']['disk']['normal'] = tch_var_f(normals)

    scene['camera']['eye'] = tch_var_f(cam_pos)

    # main render run
    start_time = time()
    res = render(scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    depth = get_data(res['depth'])

    depth[depth >= scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) /
                        (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im)
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth)
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # generate noisy data
    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    noisy_pos = pos
    for view_idx in range(num_views):
        noisy_pos[..., 2] = pos[..., 2] + std_z * np.random.randn(num_samples)
        noisy_normals = np_normalize(normals + std_normal *
                                     np.random.randn(num_samples, 3))

        scene['objects']['disk']['pos'] = tch_var_f(noisy_pos)
        scene['objects']['disk']['normal'] = tch_var_f(noisy_normals)

        scene['camera']['eye'] = tch_var_f(cam_pos)

        # main render run
        start_time = time()
        res = render(scene)
        rendering_time.append(time() - start_time)

        im = get_data(res['image'])
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        suffix_str = '{:05d}'.format(view_idx)

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img_' + suffix_str + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth_' + suffix_str + '.png')

        imsave(out_dir + '/img_' + suffix_str + '.png', im)
        imsave(out_dir + '/depth_' + suffix_str + '.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 24
0

#scene = np.load('scene_output_twogans.npy')
scene = np.load('scene_input_twogans_unnorm.npy')
#scene = np.load('scene_output.npy')
#pos = get_data(scene[0]['objects']['disk']['pos'])
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(scene[:, 0], scene[:, 1], scene[:, 2], s=1.3)

for idx in range(0, len(scene), 20):
    print(idx)
    scene[idx]['lights']['attenuation'] = tch_var_f([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
    res = render_splats_along_ray(scene[idx], use_old_sign=False)

    im = get_data(res['image'])
    depth = get_data(res['depth'])

    plt.figure()
    plt.imshow(im)

    plt.figure()
    plt.imshow(depth)

    pos = get_data(res['pos'])
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=1.3)

    res_world = cam_to_world(pos=res['pos'], normal=res['normal'], camera=scene[idx]['camera'])
    pos = get_data(res_world['pos'])
Exemplo n.º 25
0
    def train(self, ):
        """Train network."""
        # Load pretrained model if required
        if self.opt.gen_model_path is not None:
            print("Reloading networks from")
            print(' > Generator', self.opt.gen_model_path)
            self.netG.load_state_dict(
                torch.load(open(self.opt.gen_model_path, 'rb')))
            print(' > Generator2', self.opt.gen_model_path2)
            self.netG2.load_state_dict(
                torch.load(open(self.opt.gen_model_path2, 'rb')))
            print(' > Discriminator', self.opt.dis_model_path)
            self.netD.load_state_dict(
                torch.load(open(self.opt.dis_model_path, 'rb')))
            print(' > Discriminator2', self.opt.dis_model_path2)
            self.netD2.load_state_dict(
                torch.load(open(self.opt.dis_model_path2, 'rb')))

        # Start training
        train_stream = Iterator(batch_size=self.opt.batchSize)
        file_name = os.path.join(self.opt.out_dir, 'L2.txt')
        dsize=len(train_stream)

        for epoch in range(self.opt.n_iter):

            self.critic_iter=0
            # Train Discriminator critic_iters times
            for cnt, batch in enumerate(train_stream):
                # Train with real
                #################
                #print("hii")
                self.iterationa_no = epoch*dsize+cnt
                iteration=self.iterationa_no
                x, cp,lp = batch

                real_data = tch_var_f(x)
                cam_pos = tch_var_f(cp)
                light_pos = tch_var_f(lp)

                # real_data = real_data.cuda()
                # cam_pos = cam_pos.cuda()
                # light_pos = light_pos.cuda()
                self.in_critic=1
                self.netD.zero_grad()
                real_data = real_data.permute(0,3, 1, 2)
                # input_D = torch.cat([self.inputv, self.inputv_depth], 1)
                #import ipdb; ipdb.set_trace()
                real_output = self.netD(real_data, cam_pos)

                if self.opt.criterion == 'GAN':
                    errD_real = self.criterion(real_output, self.labelv)
                    errD_real.backward()
                elif self.opt.criterion == 'WGAN':
                    errD_real = real_output.mean()
                    errD_real.backward(self.mone)
                else:
                    raise ValueError('Unknown GAN criterium')

                # Train with fake
                #################
                self.generate_noise_vector()
                fake_z = self.netG(self.noisev, cam_pos)
                # The normal generator is dependent on z
                fake_n = self.generate_normals(fake_z, cam_pos,
                                               self.scene['camera'])
                fake = torch.cat([fake_z, fake_n], 2)
                fake_rendered, fd, loss = self.render_batch(
                    fake, cam_pos,lp)
                # Do not bp through gen
                outD_fake = self.netD(fake_rendered.detach(),
                                      cam_pos.detach())
                if self.opt.criterion == 'GAN':
                    labelv = Variable(self.label.fill_(self.fake_label))
                    errD_fake = self.criterion(outD_fake, labelv)
                    errD_fake.backward()
                    errD = errD_real + errD_fake
                elif self.opt.criterion == 'WGAN':
                    errD_fake = outD_fake.mean()
                    errD_fake.backward(self.one)
                    errD = errD_fake - errD_real
                else:
                    raise ValueError('Unknown GAN criterium')

                # Compute gradient penalty
                if self.opt.gp != 'None':
                    gradient_penalty = calc_gradient_penalty(
                        self.netD, real_data.data, fake_rendered.data,
                        cam_pos.data, self.opt.gp_lambda)
                    gradient_penalty.backward()
                    errD += gradient_penalty

                gnorm_D = torch.nn.utils.clip_grad_norm(
                    self.netD.parameters(), self.opt.max_gnorm)  # TODO

                # Update weight
                self.optimizerD.step()
                # Clamp critic weigths if not GP and if WGAN
                if self.opt.criterion == 'WGAN' and self.opt.gp == 'None':
                    for p in self.netD.parameters():
                        p.data.clamp_(-self.opt.clamp, self.opt.clamp)
                self.critic_iter+=1

            ############################
            # (2) Update G network
            ###########################
            # To avoid computation
            # for p in self.netD.parameters():
            #     p.requires_grad = False
                if cnt % self.opt.critic_iters==0 and cnt >0:
                    self.netG.zero_grad()
                    self.in_critic=0
                    self.generate_noise_vector()
                    fake_z = self.netG(self.noisev, cam_pos)
                    if iteration % self.opt.print_interval*4 == 0:
                        fake_z.register_hook(self.tensorboard_hook)
                    fake_n = self.generate_normals(fake_z, cam_pos,
                                                   self.scene['camera'])
                    fake = torch.cat([fake_z, fake_n], 2)
                    fake_rendered, fd, loss = self.render_batch(
                        fake, cam_pos, lp)
                    outG_fake = self.netD(fake_rendered, cam_pos)

                    if self.opt.criterion == 'GAN':
                        # Fake labels are real for generator cost
                        labelv = Variable(self.label.fill_(self.real_label))
                        errG = self.criterion(outG_fake, labelv)
                        errG.backward()
                    elif self.opt.criterion == 'WGAN':
                        errG = outG_fake.mean() + loss
                        errG.backward(self.mone)
                    else:
                        raise ValueError('Unknown GAN criterium')
                    gnorm_G = torch.nn.utils.clip_grad_norm(
                        self.netG.parameters(), self.opt.max_gnorm)  # TODO
                    if (self.opt.alt_opt_zn_interval is not None and
                        iteration >= self.opt.alt_opt_zn_start):
                        # update one of the generators
                        if (((iteration - self.opt.alt_opt_zn_start) %
                             self.opt.alt_opt_zn_interval) == 0):
                            # switch generator vars to optimize
                            curr_generator_idx = (1 - curr_generator_idx)
                        if iteration < self.opt.lr_iter:
                            self.LR_SCHED_MAP[curr_generator_idx].step()
                            self.OPT_MAP[curr_generator_idx].step()
                    else:
                        if iteration < self.opt.lr_iter:
                            self.optG_z_lr_scheduler.step()

                        self.optimizerG.step()

                mse_criterion = nn.MSELoss().cuda()

                # Log print
                if iteration % (self.opt.print_interval*5) == 0 and cnt >0 :


                    Wassertein_D = (errD_real.data[0] - errD_fake.data[0])
                    self.writer.add_scalar("Loss_G",
                                           errG.data[0],
                                           self.iterationa_no)
                    self.writer.add_scalar("Loss_D",
                                           errD.data[0],
                                           self.iterationa_no)
                    self.writer.add_scalar("Wassertein_D",
                                           Wassertein_D,
                                           self.iterationa_no)
                    self.writer.add_scalar("Disc_grad_norm",
                                           gnorm_D,
                                           self.iterationa_no)
                    self.writer.add_scalar("Gen_grad_norm",
                                           gnorm_G,
                                           self.iterationa_no)

                    print('\n[%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_D_real: %.4f'
                          ' Loss_D_fake: %.4f Wassertein_D: %.4f '
                          ' L2_loss: %.4f z_lr: %.8f, n_lr: %.8f, Disc_grad_norm: %.8f, Gen_grad_norm: %.8f' % (
                          iteration, self.opt.n_iter, errD.data[0],
                          errG.data[0], errD_real.data[0], errD_fake.data[0],
                          Wassertein_D, loss.data[0],
                          self.optG_z_lr_scheduler.get_lr()[0], self.optG2_normal_lr_scheduler.get_lr()[0], gnorm_D, gnorm_G))


                # Save output images
                if iteration % (self.opt.save_image_interval*5) == 0 and cnt >0:
                    cs = tch_var_f(contrast_stretch_percentile(
                        get_data(fd), 200, [fd.data.min(), fd.data.max()]))
                    torchvision.utils.save_image(
                        fake_rendered.data,
                        os.path.join(self.opt.vis_images,
                                     'output_%d.png' % (iteration)),
                        nrow=2, normalize=True, scale_each=True)

                # Save input images
                if iteration % (self.opt.save_image_interval*5) == 0:
                    cs = tch_var_f(contrast_stretch_percentile(
                        get_data(fd), 200, [fd.data.min(), fd.data.max()]))
                    torchvision.utils.save_image(
                        real_data.data, os.path.join(
                            self.opt.vis_images, 'input_%d.png' % (iteration)),
                        nrow=2, normalize=True, scale_each=True)

                # Do checkpointing
                if iteration % (self.opt.save_interval*2) == 0:
                    self.save_networks(iteration)
Exemplo n.º 26
0
def render_sphere_halfbox(out_dir,
                          cam_pos,
                          width,
                          height,
                          fovy,
                          focal_length,
                          num_views,
                          cam_dist,
                          norm_depth_image_only,
                          theta_range=None,
                          phi_range=None,
                          axis=None,
                          angle=None,
                          cam_lookat=None,
                          tile_size=None,
                          use_quartic=False,
                          b_shadow=True,
                          b_display=False):
    # python splat_render_demo.py --sphere-halfbox --fovy 30 --out_dir ./sphere_halfbox_demo --cam_dist 4 --axis .8 .5 1
    # --angle 5 --at 0 .4 0 --nv 10 --width=256 --height=256
    scene = SCENE_SPHERE_HALFBOX
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    scene['lights']['pos'] = tch_var_f([[2., 2., 1.5, 1.0], [1., 4., 1.5, 1.0]
                                        ])  # tch_var_f([[4., 4., 3., 1.0]])
    scene['lights']['color_idx'] = tch_var_l([1, 3])
    scene['lights']['attenuation'] = tch_var_f([[1., 0., 0.], [1., 0., 0.]])

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist,
                                        num_samples=num_views,
                                        axis=axis,
                                        angle=angle,
                                        theta_range=theta_range,
                                        phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else [0.0, 0.0, 0.0, 1.0]
    scene['camera']['at'] = tch_var_f(lookat)

    b_tiled = tile_size is not None
    res = render(scene, tile_size=tile_size, tiled=b_tiled, shadow=b_shadow)
    im = np.uint8(255. * get_data(res['image']))
    depth = get_data(res['depth'])

    depth[depth >= scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) /
                        (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im)
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth)
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    for idx in range(cam_pos.shape[0]):
        scene['camera']['eye'] = tch_var_f(cam_pos[idx])
        suffix = '_{}'.format(idx)

        # main render run
        res = render(scene,
                     tiled=b_tiled,
                     shadow=b_shadow,
                     norm_depth_image_only=norm_depth_image_only,
                     use_quartic=use_quartic)

        im = np.uint8(255. * get_data(res['image']))
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img' + suffix + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth' + suffix + '.png')

        imsave(out_dir + '/img' + suffix + '.png', im)
        imsave(out_dir + '/depth' + suffix + '.png', im_depth)
Exemplo n.º 27
0
    def render_batch(self, batch, batch_cond,light_pos):
        """Render a batch of splats."""
        batch_size = batch.size()[0]

        # Generate camera positions on a sphere
        if batch_cond is None:
            if self.opt.full_sphere_sampling:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta, phi_range=self.opt.phi)
                # TODO: deg2grad!!
            else:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
                # TODO: deg2grad!!

        rendered_data = []
        rendered_data_depth = []
        rendered_data_cond = []
        scenes = []
        inpath = self.opt.vis_images + '/'
        inpath_xyz = self.opt.vis_xyz + '/'
        z_min = self.scene['camera']['focal_length']
        z_max = z_min + 3

        # TODO (fmannan): Move this in init. This only needs to be done once!
        # Set splats into rendering scene
        if 'sphere' in self.scene['objects']:
            del self.scene['objects']['sphere']
        if 'triangle' in self.scene['objects']:
            del self.scene['objects']['triangle']
        if 'disk' not in self.scene['objects']:
            self.scene['objects'] = {'disk': {'pos': None, 'normal': None,
                                              'material_idx': None}}
        lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0]
        self.scene['camera']['at'] = tch_var_f(lookat)
        self.scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(self.opt.splats_img_size * self.opt.splats_img_size))
        loss = 0.0
        loss_ = 0.0
        z_loss_ = 0.0
        z_norm_loss_ = 0.0
        spatial_loss_ = 0.0
        spatial_var_loss_ = 0.0
        unit_normal_loss_ = 0.0
        normal_away_from_cam_loss_ = 0.0
        image_depth_consistency_loss_ = 0.0
        for idx in range(batch_size):
            # Get splats positions and normals
            eps = 1e-3
            if self.opt.rescaled:
                z = F.relu(-batch[idx][:, 0]) + z_min
                z = ((z - z.min()) / (z.max() - z.min() + eps) *
                     (z_max - z_min) + z_min)
                pos = -z
            else:
                z = F.relu(-batch[idx][:, 0]) + z_min
                pos = -F.relu(-batch[idx][:, 0]) - z_min
            normals = batch[idx][:, 1:]

            self.scene['objects']['disk']['pos'] = pos

            # Normal estimation network and est_normals don't go together
            self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None

            # Set camera position
            if batch_cond is None:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[idx])
                else:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[0])
            else:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = batch_cond[idx]
                else:
                    self.scene['camera']['eye'] = batch_cond[0]

            self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx])
            #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            # res = render_splats_NDC(self.scene)
            res = render_splats_along_ray(self.scene,
                                          samples=self.opt.pixel_samples,
                                          normal_estimation_method='plane')

            world_tform = cam_to_world(res['pos'].view((-1, 3)),
                                       res['normal'].view((-1, 3)),
                                       self.scene['camera'])

            # Get rendered output
            res_pos = res['pos'].contiguous()
            res_pos_2D = res_pos.view(res['image'].shape)
            # The z_loss needs to be applied after supersampling
            # TODO: Enable this (currently final loss becomes NaN!!)
            # loss += torch.mean(
            #    (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 +
            #    (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2)


            res_normal = res['normal']
            # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
            # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis])
            # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
            image_depth_consistency_loss = depth_rgb_gradient_consistency(
                res['image'], res['depth'])
            unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)  # TODO: MN
            normal_away_from_cam_loss = away_from_camera_penalty(
                res_pos, res_normal)
            z_pos = res_pos[..., 2]
            z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 +
                                (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2)
            z_norm_loss = normal_consistency_cost(
                res_pos, res['normal'], norm=1)
            spatial_loss = spatial_3x3(res_pos_2D)
            spatial_var = torch.mean(res_pos[..., 0].var() +
                                     res_pos[..., 1].var() +
                                     res_pos[..., 2].var())
            spatial_var_loss = (1 / (spatial_var + 1e-4))

            loss = (self.opt.zloss * z_loss +
                    self.opt.unit_normalloss*unit_normal_loss +
                    self.opt.normal_consistency_loss_weight * z_norm_loss +
                    self.opt.spatial_var_loss_weight * spatial_var_loss +
                    self.opt.grad_img_depth_loss*image_depth_consistency_loss +
                    self.opt.spatial_loss_weight * spatial_loss)
            pos_out_ = get_data(res['pos'])
            loss_ += get_data(loss)
            z_loss_ += get_data(z_loss)
            z_norm_loss_ += get_data(z_norm_loss)
            spatial_loss_ += get_data(spatial_loss)
            spatial_var_loss_ += get_data(spatial_var_loss)
            unit_normal_loss_ += get_data(unit_normal_loss)
            normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss)
            image_depth_consistency_loss_ += get_data(
                image_depth_consistency_loss)
            normals_ = get_data(res_normal)

            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                H, W = im.shape[1:]
                target_normal_ = get_data(res['normal']).reshape((H, W, 3))
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                target_worldnormal_ = get_data(world_tform['normal']).reshape(
                    (H, W, 3))
                target_worldnormalmap_img_ = get_normalmap_image(
                    target_worldnormal_)
            if self.iterationa_no % (self.opt.save_image_interval*5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'depthmap_{:05d}.png'.format(idx)),
                       get_data(res['depth']))
                imsave((inpath + str(self.iterationa_no) +
                        'world_normalmap_{:05d}.png'.format(idx)),
                       target_worldnormalmap_img_)
            if self.iterationa_no % 1000 == 0:
                im2 = get_data(res['image'])
                depth2 = get_data(res['depth'])
                pos = get_data(res['pos'])

                out_file2 = ("pos"+".npy")
                np.save(inpath_xyz+out_file2, pos)

                out_file2 = ("im"+".npy")
                np.save(inpath_xyz+out_file2, im2)

                out_file2 = ("depth"+".npy")
                np.save(inpath_xyz+out_file2, depth2)

                # Save xyz file
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_{:05d}.xyz'.format(idx)),
                         pos=get_data(res['pos']),
                         normal=get_data(res['normal']))

                # Save xyz file in world coordinates
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_world_{:05d}.xyz'.format(idx)),
                         pos=get_data(world_tform['pos']),
                         normal=get_data(world_tform['normal']))
            if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0:
                gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis])
                gradImg = grad_spatial2d(torch.mean(im,
                                                    dim=0)[:, :, np.newaxis])
                for (gZ, gI) in zip(gradZ, gradImg):
                    loss += (self.opt.gz_gi_loss * torch.mean(torch.abs(
                                torch.abs(gZ) - torch.abs(gI))))
            # Store normalized depth into the data
            rendered_data.append(im)
            rendered_data_depth.append(im_d)
            rendered_data_cond.append(self.scene['camera']['eye'])
            scenes.append(self.scene)

        rendered_data = torch.stack(rendered_data)
        rendered_data_depth = torch.stack(rendered_data_depth)


        return rendered_data, rendered_data_depth, loss/self.opt.batchSize
Exemplo n.º 28
0
def optimize_NDC_test(out_dir,
                      width=32,
                      height=32,
                      max_iter=100,
                      lr=1e-3,
                      scale=10,
                      print_interval=10,
                      imsave_interval=10):
    """A demo function to check if the differentiable renderer can optimize splats in NDC.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f([2, 1, 2, 1])
    scene['camera']['at'] = tch_var_f([0, 0.8, 0, 1])

    target_res = render(SCENE_SPHERE_HALFBOX)
    target_im = target_res['image']
    target_im.require_grad = False
    target_im_ = get_data(target_res['image'])

    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_)
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']

    num_splats = width * height
    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    z = tch_var_f(2 * np.random.rand(num_splats) - 1)
    z.requires_grad = True
    pos = torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), z), dim=1)
    normals = tch_var_f(np.ones((num_splats, 4)) * np.array([0, 0, 1, 0]))
    normals.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    input_scene['objects'] = {
        'disk': {
            'pos': pos,
            'normal': normals,
            'material_idx': material_idx
        }
    }
    optimizer = optim.Adam((z, normals), lr=lr)

    h0 = plt.figure()
    h1 = plt.figure()
    loss_per_iter = []
    for iter in range(max_iter):
        res = render_splats_NDC(input_scene)
        im_out = res['image']

        optimizer.zero_grad()
        loss = criterion(scale * im_out, scale * target_im)

        im_out_ = get_data(im_out)
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            print('%d. loss= %f' % (iter, loss_))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            plt.figure(h1.number)
            plt.imshow(im_out_)
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Exemplo n.º 29
0
def optimize_lfnet(scene,
                   lfnet,
                   max_iter=2000,
                   num_samples=120,
                   lr=1e-3,
                   print_interval=10,
                   imsave_interval=100,
                   out_dir='./tmp_lf_opt'):
    """
    Args:
        scene: scene file
        lfnet: Light Field Network

    Returns:

    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    res = render_scene(scene)

    pos = get_data(res['pos'])
    normal = get_data(res['normal'])

    opt_vars = lfnet.parameters()
    criterion = torch.nn.MSELoss(size_average=True).cuda()
    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=500, gamma=0.8)

    loss_per_iter = []
    target_im = res['image']
    target_im_grad = grad_spatial2d(target_im.mean(dim=-1)[..., np.newaxis])
    h1 = plt.figure()
    plt.figure(h1.number)
    plt.imshow(get_data(target_im))
    plt.title('Target')
    plt.savefig(out_dir + '/Target.png')

    for iter in range(max_iter):
        im_est = lf_renderer(pos, normal, lfnet, num_samples=num_samples)
        im_est_grad = grad_spatial2d(im_est.mean(dim=-1)[..., np.newaxis])
        optimizer.zero_grad()
        loss = criterion(im_est * 255, target_im * 255) + criterion(
            target_im_grad * 100, im_est_grad * 100)

        loss_ = get_data(loss)
        loss_per_iter.append(loss_)
        if iter % print_interval == 0 or iter == max_iter - 1:
            print('{}. Loss: {}'.format(iter, loss_))
        if iter % imsave_interval == 0 or iter == max_iter - 1:
            im_out_ = get_data(im_est)
            im_out_ = np.uint8(255 * im_out_ / im_out_.max())

            plt.figure(h1.number)
            plt.imshow(im_out_)
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        lr_scheduler.step()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')
Exemplo n.º 30
0
def optimize_scene(input_scene,
                   target_scene,
                   out_dir,
                   max_iter=100,
                   lr=1e-3,
                   print_interval=10,
                   imsave_interval=10):
    """A demo function to check if the differentiable renderer can optimize.
    :param scene:
    :param out_dir:
    :return:
    """
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    target_res = render(target_scene)
    target_im = target_res['image']
    target_im.require_grad = False
    criterion = nn.MSELoss()
    if CUDA:
        target_im_ = target_res['image'].cpu()
        criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_.data.numpy())
    plt.title('Target Image')
    plt.savefig(out_dir + 'target.png')

    input_scene['materials']['albedo'].requires_grad = True
    optimizer = optim.Adam(input_scene['materials'].values(), lr=lr)

    h0 = plt.figure()
    h1 = plt.figure()
    loss_per_iter = []
    for iter in range(max_iter):
        res = render(input_scene)
        im_out = res['image']

        optimizer.zero_grad()
        loss = criterion(im_out, target_im)

        im_out_ = get_data(im_out)
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0:
            print('%d. loss= %f' % (iter, loss_))
            print(input_scene['materials'])

            plt.figure(h1.number)
            plt.imshow(im_out_)
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('MSE Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()