Exemplo n.º 1
0
def z_to_pcl_CC(z, camera):
    viewport = np.array(camera['viewport'])
    W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1])
    aspect_ratio = W / H

    fovy = camera['fovy']
    focal_length = camera['focal_length']
    h = np.tan(fovy / 2) * 2 * focal_length
    w = h * aspect_ratio

    ##### Find (X, Y) in the Camera's view frustum
    # Force the caller to set the z coordinate with the correct sign
    Z = -torch.nn.functional.relu(-z)

    x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H))
    x *= w / 2
    y *= h / 2

    x = tch_var_f(x.ravel())
    y = tch_var_f(y.ravel())

    X = -Z * x / focal_length
    Y = -Z * y / focal_length

    return torch.stack((X, Y, Z), dim=1)
Exemplo n.º 2
0
def lf_renderer(pos, normal, lfnet, num_samples=20):
    """This is a simpler version of lf_renderer_v0 where the same direction samples are used
    for all surfels. The samples are on a uniform sphere and so this renderer also supports
    transmissive medium.

    Args:
        pos:
        normal:
        lfnet:
        num_samples:

    Returns:

    """
    pos_all = pos.reshape((-1, 3))
    normal_all = tch_var_f(normal.reshape((-1, 3)))

    spherical_samples = uniform_sample_sphere(radius=1.0,
                                              num_samples=num_samples)

    inp = tch_var_f(
        np.concatenate((np.tile(pos_all[:, np.newaxis, :],
                                (1, num_samples, 1)),
                        np.tile(spherical_samples[np.newaxis, :, :],
                                (pos_all.shape[0], 1, 1))),
                       axis=-1))
    Li = lfnet(inp)
    cos_theta = torch.sum(inp[:, :, 3:6] * normal_all[:, np.newaxis, :],
                          dim=-1)
    nonzero_mask = (cos_theta > 0).float()
    pos_cos_theta = cos_theta * nonzero_mask
    im = torch.sum(pos_cos_theta[..., np.newaxis] * Li,
                   dim=1).reshape(pos.shape)

    return im
Exemplo n.º 3
0
def projection_renderer_differentiable(surfels,
                                       rgb,
                                       camera,
                                       rotated_image=None,
                                       blur_size=0.15):
    """Project surfels given in world coordinate to the camera's projection plane
       in a way that is differentiable w.r.t depth. This is achieved by interpolating
       the surfel values using a Gaussian filter.

    Args:
        surfels: [batch_size, num_surfels, pos]
        rgb: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data]
        camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]
        rotated_image: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data]
                        Image to mix in with the result of the rotation.
        sigma: Std of the Gaussian used for filtering. As a rule of thumb, surfels in a radius of 3*sigma
               around a pixel will have a contribution on that pixel in the final image.

    Returns:
        RGB image of dimensions [batch_size, H, W, 3] from projected surfels

    """
    px_idx, px_coord = project_image_coordinates(surfels, camera)
    viewport = make_list2np(camera['viewport'])
    W = int(viewport[2] - viewport[0])
    H = int(viewport[3] - viewport[1])
    rgb_reshaped = rgb.view(rgb.size(0), -1, rgb.size(-1))

    # Perform a weighted average of points surrounding a pixel using a Gaussian filter
    # Very similar to the idea in this paper: https://arxiv.org/pdf/1810.09381.pdf

    x, y = np.meshgrid(
        np.linspace(0, W - 1, W) + 0.5,
        np.linspace(0, H - 1, H) + 0.5)
    x, y = tch_var_f(x.ravel()).repeat(surfels.size(0),
                                       1), tch_var_f(y.ravel()).repeat(
                                           surfels.size(0), 1)
    x, y = x.unsqueeze(-1), y.unsqueeze(-1)

    xp, yp = px_coord[..., 0].unsqueeze(-2), px_coord[..., 1].unsqueeze(-2)
    sigma = blur_size * rgb.size(-2) / 6
    scale = torch.exp((-(xp - x)**2 - (yp - y)**2) / (2 * sigma**2))

    mask = scale.sum(-1)
    if rotated_image is not None:
        rotated_image = rotated_image.view(*rgb_reshaped.size())
        # out = (rotated_image_weight * rotated_image + torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3), -2)) / (scale.sum(-1) + rotated_image_weight + 1e-10).unsqueeze(-1)
        out = torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3),
                        -2) + rotated_image * (1 - mask)
    else:
        out = torch.sum(scale.unsqueeze(-1) * rgb_reshaped.unsqueeze(-3),
                        -2) / (mask + 1e-10).unsqueeze(-1)

    return out.view(*rgb.size()), mask.view(*rgb.size()[:-1], 1)
Exemplo n.º 4
0
def perspective_RH_NO(fovy, aspect, near, far):
    """Right-handed camera with all coords mapped to [-1, 1] """
    mat_00, mat_11, mat_22, mat_23 = perspective_NO_params(
        fovy, aspect, near, far)

    return tch_var_f([[mat_00, 0, 0, 0], [0, mat_11, 0, 0],
                      [0, 0, -mat_22, mat_23], [0, 0, -1, 0]])
Exemplo n.º 5
0
def inv_perspective_RH_NO(fovy, aspect, near, far):
    """Inverse perspective for right-handed camera with all coords mapped from [-1, 1] """
    mat_00, mat_11, mat_22, mat_23 = perspective_NO_params(
        fovy, aspect, near, far)

    return tch_var_f([[1 / mat_00, 0, 0, 0], [0, 1 / mat_11, 0, 0],
                      [0, 0, 0, -1], [0, 0, 1 / mat_23, -mat_22 / mat_23]])
Exemplo n.º 6
0
def test_NEstNet():
    import numpy as np
    pos = tch_var_f(list(np.random.rand(1, 3, 5, 5)))
    y = NEstNetV0(sph=False).cuda()(pos)
    print(y.shape, y.norm(dim=1))
    y = NEstNetAffine(kernel_size=3).cuda()(pos)
    print(y.shape, y.norm(dim=1))
Exemplo n.º 7
0
def test_LFNet():
    from diffrend.torch.utils import tch_var_f
    import numpy as np
    pos = tch_var_f(list(np.random.rand(1, 10, 8)))
    y = LFNetV0(in_ch=8, out_ch=3).cuda()(pos)
    print(y)
    print(y.shape, y)
Exemplo n.º 8
0
    def _init_rays(self, camera):
        viewport = np.array(camera['viewport'])
        W, H = viewport[2] - viewport[0], viewport[3] - viewport[1]
        aspect_ratio = W / H

        x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H))
        n_pixels = x.size

        fovy = np.array(camera['fovy'])
        focal_length = np.array(camera['focal_length'])
        h = np.tan(fovy / 2) * 2 * focal_length
        w = h * aspect_ratio

        x *= w / 2
        y *= h / 2

        x = tch_var_f(x.ravel())
        y = tch_var_f(y.ravel())

        eye = camera['eye'][:3]
        at = camera['at'][:3]
        up = camera['up'][:3]

        proj_type = camera['proj_type']
        if proj_type == 'ortho' or proj_type == 'orthographic':
            ray_dir = normalize(at - eye)[:, np.newaxis]
            ray_orig = torch.stack((x, y, tch_var_f(np.zeros(n_pixels)), tch_var_f(np.ones(n_pixels))), dim=0)
            # inv_view_matrix = lookat_inv(eye=eye, at=at, up=up)
            # ray_orig = torch.mm(inv_view_matrix, ray_orig)
            # ray_orig = (ray_orig[:3] / ray_orig[3][np.newaxis, :]).permute(1, 0)
        elif proj_type == 'persp' or proj_type == 'perspective':
            ray_orig = eye[np.newaxis, :]
            ray_dir = torch.stack((x, y, tch_var_f(-np.ones(n_pixels) * focal_length)), dim=0)
            # inv_view_matrix = lookat_rot_inv(eye=eye, at=at, up=up)
            # ray_dir = torch.mm(inv_view_matrix, ray_dir)

            # normalize ray direction
            ray_dir /= torch.sqrt(torch.sum(ray_dir ** 2, dim=0))
        else:
            raise ValueError("Invalid projection type")

        self.ray_orig = ray_orig
        self.ray_dir = ray_dir
        self.H = H
        self.W = W
        return ray_orig, ray_dir, H, W
Exemplo n.º 9
0
def project_image_coordinates(surfels, camera):
    """Project surfels given in world coordinate to the camera's projection plane.

    Args:
        surfels: [batch_size, pos]
        camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]

    Returns:
        Image of destination indices of dimensions [batch_size, H*W]
        Note that the range of possible coordinates is restricted to be between 0
        and W*H (inclusive). This is inclusive because we use the last index as
        a "dump" for any index that falls outside of the camera's field of view
    """
    surfels_plane = project_surfels(surfels, camera)

    # Rasterize
    viewport = make_list2np(camera['viewport'])
    W, H = float(viewport[2] - viewport[0]), float(viewport[3] - viewport[1])
    aspect_ratio = float(W) / float(H)

    fovy = make_list2np(camera['fovy'])
    focal_length = make_list2np(camera['focal_length'])
    h = np.tan(fovy / 2) * 2 * focal_length
    w = h * aspect_ratio

    px_coord = torch.zeros_like(surfels_plane)
    px_coord[...,
             2] = surfels_plane[...,
                                2]  # Make sure to also transmit the new depth
    px_coord[..., :2] = surfels_plane[..., :2] * tch_var_f(
        [-(W - 1) / w, (H - 1) / h]).unsqueeze(-2) + tch_var_f(
            [W / 2., H / 2.]).unsqueeze(-2)
    px_coord_idx = torch.round(px_coord - 0.5).long()

    px_idx = px_coord_idx[..., 1] * W + px_coord_idx[..., 0]

    max_idx = W * H  # Index used if the indices are out of bounds of the camera
    max_idx_tensor = tch_var_l([max_idx])

    # Map out of bounds pixels to the last (extra) index
    mask = (px_coord_idx[..., 1] < 0) | (px_coord_idx[..., 0] < 0) | (
        px_coord_idx[..., 1] >= H) | (px_coord_idx[..., 0] >= W)
    px_idx = torch.where(mask, max_idx_tensor, px_idx)

    return px_idx, px_coord
Exemplo n.º 10
0
def test_render_splat_NDC_0():
    fovy = np.deg2rad(45)
    aspect_ratio = 1
    near = 0.1
    far = 1000
    M = perspective(fovy, aspect_ratio, near, far)
    Minv = inv_perspective(fovy, aspect_ratio, near, far)

    pos_NDC = tch_var_f([[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 0.0, 1.0]])
    normals_SLC = tch_var_f([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])
    num_objects = pos_NDC.size()[0]

    # Transform params to the Camera's view frustum
    if pos_NDC.size()[-1] == 3:
        pos_NDC = torch.cat((pos_NDC, tch_var_f(np.ones((num_objects, 1)))), dim=1)
    pos_CC = torch.matmul(pos_NDC, Minv.transpose(1, 0))
    pos_CC = pos_CC / pos_CC[..., 3][:, np.newaxis]

    pixel_dist = norm_p(pos_CC[..., :3])
Exemplo n.º 11
0
def lf_renderer_v0(pos, normal, lfnet, num_samples=10):
    pos_all = pos.reshape((-1, 3))
    normal_all = normal.reshape((-1, 3))
    pixel_colors = []

    for idx in range(pos_all.shape[0]):
        dir_sample = uniform_sample_sphere(radius=1.0,
                                           num_samples=num_samples,
                                           axis=normal_all[idx],
                                           angle=np.pi / 2)
        inp = tch_var_f(
            np.concatenate((np.tile(pos_all[idx],
                                    (num_samples, 1)), dir_sample),
                           axis=-1))
        Li = lfnet(inp)
        cos_theta = torch.sum(inp[:, 3:6] * tch_var_f(normal_all[idx]), dim=-1)
        rgb = torch.sum(cos_theta[:, np.newaxis] * Li, dim=0)
        pixel_colors.append(rgb)

    im = torch.cat(pixel_colors, dim=0).reshape(pos.shape)
    return im
Exemplo n.º 12
0
def batch_render_random_camera(filename, cam_dist, num_views, width, height,
                         fovy, focal_length, theta_range=None, phi_range=None,
                         axis=None, angle=None, cam_pos=None, cam_lookat=None,
                         double_sided=False, use_quartic=False, b_shadow=True,
                         tile_size=None, save_image_queue=None):
    rendering_time = []

    obj = load_model(filename)
    # normalize the vertices
    v = obj['v']
    axis_range = np.max(v, axis=0) - np.min(v, axis=0)
    v = (v - np.mean(v, axis=0)) / max(axis_range)  # Normalize to make the largest spread 1
    obj['v'] = v

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length

    mesh = obj_to_triangle_spec(obj)
    faces = mesh['face']
    normals = mesh['normal']
    num_tri = faces.shape[0]

    if 'disk' in scene['objects']:
        del scene['objects']['disk']
    scene['objects'].update({'triangle': {'face': None, 'normal': None, 'material_idx': None}})
    scene['objects']['triangle']['face'] = tch_var_f(faces.tolist())
    scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist())
    scene['objects']['triangle']['material_idx'] = tch_var_l(np.zeros(num_tri, dtype=int).tolist())

    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist, num_samples=num_views,
                                        axis=axis, angle=angle,
                                        theta_range=theta_range, phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0)
    scene['camera']['at'] = tch_var_f(lookat)

    for idx in range(cam_pos.shape[0]):
        scene['camera']['eye'] = tch_var_f(cam_pos[idx])

        # main render run
        start_time = time()
        res = render(scene, tile_size=tile_size, tiled=tile_size is not None,
                     shadow=b_shadow, double_sided=double_sided,
                     use_quartic=use_quartic)
        res['suffix'] = '_{}'.format(idx)
        res['camera_far'] = scene['camera']['far']
        save_image_queue.put_nowait(get_data(res))
        rendering_time.append(time() - start_time)

    # Timing statistics
    print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time), np.std(rendering_time)))
Exemplo n.º 13
0
def test_scalability(filename, out_dir='./test_scale'):
    # GTX 980 8GB
    # 320 x 240 250 objs
    # 64 x 64 5000 objs
    # 32 x 32 20000 objs
    # 16 x 16 75000 objs (slow)
    from diffrend.model import load_model

    splats = load_model(filename)
    v = splats['v']
    # normalize the vertices
    v = (v - np.mean(v, axis=0)) / (v.max() - v.min())

    print(np.min(splats['v'], axis=0))
    print(np.max(splats['v'], axis=0))
    print(np.min(v, axis=0))
    print(np.max(v, axis=0))

    rand_idx = np.arange(
        v.shape[0])  #np.random.randint(0, splats['v'].shape[0], 4000)  #
    large_scene = copy.deepcopy(SCENE_BASIC)

    large_scene['camera']['viewport'] = [0, 0, 64, 64]  #[0, 0, 320, 240]
    large_scene['camera']['fovy'] = np.deg2rad(5.)
    large_scene['camera']['focal_length'] = 2.
    #large_scene['camera']['eye'] = tch_var_f([0.0, 1.0, 5.0, 1.0]),
    large_scene['objects']['disk']['pos'] = tch_var_f(v[rand_idx])
    large_scene['objects']['disk']['normal'] = tch_var_f(
        splats['vn'][rand_idx])
    large_scene['objects']['disk']['radius'] = tch_var_f(
        splats['r'][rand_idx].ravel() * 2)
    large_scene['objects']['disk']['material_idx'] = tch_var_l(
        np.zeros(rand_idx.size, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])

    render_scene(large_scene, out_dir, plot_res=True)
Exemplo n.º 14
0
 def forward(self, x):
     x = self.net(x)
     if self.sph_out:
         x = F.sigmoid(x) * tch_var_f(
             [2 * np.pi, np.pi / 2])[np.newaxis, :, np.newaxis, np.newaxis]
         x = sph2cart_unit(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
     else:
         x = torch.cat([
             x[:, 0, :, :][:, np.newaxis, ...], x[:, 1, :, :][:, np.newaxis,
                                                              ...],
             torch.abs(x[:, 2, :, :][:, np.newaxis, ...])
         ],
                       dim=1)
         sum_squared = torch.sum(x**2, dim=1)
         x = x / torch.sqrt(sum_squared + 1e-12)
     return x
Exemplo n.º 15
0
def rotate_cameras(camera, theta=0, phi=0):
    # Get the current camera rotation (relative to the 'lookat' position)
    camera_eye = cartesian_to_spherical(
        get_data(camera['eye']) - get_data(camera['at']))

    # Rotate the camera
    new_thetas = camera_eye[..., 0] + theta
    new_phis = camera_eye[..., 1] + phi

    # Go back to cartesian coordinates and place the camera back relative to the 'lookat' position
    camera_eye = spherical_to_cartesian(new_thetas,
                                        new_phis,
                                        radius=np.expand_dims(
                                            camera_eye[..., 2], -1))

    if camera['at'].shape[-1] == 4:
        zeros = np.zeros((camera_eye.shape[0], 1))
        camera_eye = np.concatenate((camera_eye, zeros), axis=-1)

    camera['eye'] = tch_var_f(camera_eye) + camera['at']
Exemplo n.º 16
0
import numpy as np
from diffrend.torch.utils import tch_var_f, tch_var_l

# Starter scene for rendering splats
SCENE_BASIC = {
    'camera': {
        'proj_type': 'perspective',
        'viewport': [0, 0, 320, 240],
        'fovy': np.deg2rad(90.),
        'focal_length': 1.,
        'eye': tch_var_f([0.0, 1.0, 10.0, 1.0]),
        'up': tch_var_f([0.0, 1.0, 0.0, 0.0]),
        'at': tch_var_f([0.0, 0.0, 0.0, 1.0]),
        'near': 0.1,
        'far': 1000.0,
    },
    'lights': {
        'pos':
        tch_var_f([
            [10., 0., 0., 1.0],
            [-10, 0., 0., 1.0],
            [0, 10., 0., 1.0],
            [0, -10., 0., 1.0],
            [0, 0., 10., 1.0],
            [0, 0., -10., 1.0],
            [20, 20, 20, 1.0],
        ]),
        'color_idx':
        tch_var_l([1, 3, 4, 5, 6, 7, 1]),
        # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2)
        'attenuation':
Exemplo n.º 17
0
    scene = SCENE_1
    if args.ortho:
        scene['camera']['proj_type'] = 'ortho'
    if args.render:
        res = render_scene(scene,
                           args.out_dir,
                           args.norm_depth_image_only,
                           backface_culling=args.backface_culling,
                           plot_res=args.display)
    if args.opt:
        input_scene = copy.deepcopy(SCENE_BASIC)
        input_scene['materials']['albedo'] = tch_var_f([
            [0.0, 0.0, 0.0],
            [0.1, 0.1, 0.1],
            [0.2, 0.2, 0.2],
            [0.1, 0.8, 0.9],
            [0.1, 0.8, 0.9],
            [0.9, 0.1, 0.1],
        ])
        optimize_scene(input_scene,
                       scene,
                       args.out_dir,
                       max_iter=args.max_iter,
                       lr=args.lr,
                       print_interval=args.print_interval)
    if args.test_scale:
        test_scalability(filename=args.model_filename, out_dir=args.out_dir)

    if args.opt_ndc_test:
        optimize_NDC_test(out_dir=args.out_dir,
                          width=args.width,
Exemplo n.º 18
0
def optimize_NDC_test(out_dir,
                      width=32,
                      height=32,
                      max_iter=100,
                      lr=1e-3,
                      scale=10,
                      print_interval=10,
                      imsave_interval=10):
    """A demo function to check if the differentiable renderer can optimize splats in NDC.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f([2, 1, 2, 1])
    scene['camera']['at'] = tch_var_f([0, 0.8, 0, 1])

    target_res = render(SCENE_SPHERE_HALFBOX)
    target_im = target_res['image']
    target_im.require_grad = False
    target_im_ = get_data(target_res['image'])

    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_)
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']

    num_splats = width * height
    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    z = tch_var_f(2 * np.random.rand(num_splats) - 1)
    z.requires_grad = True
    pos = torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), z), dim=1)
    normals = tch_var_f(np.ones((num_splats, 4)) * np.array([0, 0, 1, 0]))
    normals.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    input_scene['objects'] = {
        'disk': {
            'pos': pos,
            'normal': normals,
            'material_idx': material_idx
        }
    }
    optimizer = optim.Adam((z, normals), lr=lr)

    h0 = plt.figure()
    h1 = plt.figure()
    loss_per_iter = []
    for iter in range(max_iter):
        res = render_splats_NDC(input_scene)
        im_out = res['image']

        optimizer.zero_grad()
        loss = criterion(scale * im_out, scale * target_im)

        im_out_ = get_data(im_out)
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            print('%d. loss= %f' % (iter, loss_))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            plt.figure(h1.number)
            plt.imshow(im_out_)
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Exemplo n.º 19
0
    def train(self, ):
        """Train network."""
        # Load pretrained model if required
        if self.opt.gen_model_path is not None:
            print("Reloading networks from")
            print(' > Generator', self.opt.gen_model_path)
            self.netG.load_state_dict(
                torch.load(open(self.opt.gen_model_path, 'rb')))
            print(' > Generator2', self.opt.gen_model_path2)
            self.netG2.load_state_dict(
                torch.load(open(self.opt.gen_model_path2, 'rb')))
            print(' > Discriminator', self.opt.dis_model_path)
            self.netD.load_state_dict(
                torch.load(open(self.opt.dis_model_path, 'rb')))
            print(' > Discriminator2', self.opt.dis_model_path2)
            self.netD2.load_state_dict(
                torch.load(open(self.opt.dis_model_path2, 'rb')))

        # Start training
        train_stream = Iterator(batch_size=self.opt.batchSize)
        file_name = os.path.join(self.opt.out_dir, 'L2.txt')
        dsize=len(train_stream)

        for epoch in range(self.opt.n_iter):

            self.critic_iter=0
            # Train Discriminator critic_iters times
            for cnt, batch in enumerate(train_stream):
                # Train with real
                #################
                #print("hii")
                self.iterationa_no = epoch*dsize+cnt
                iteration=self.iterationa_no
                x, cp,lp = batch

                real_data = tch_var_f(x)
                cam_pos = tch_var_f(cp)
                light_pos = tch_var_f(lp)

                # real_data = real_data.cuda()
                # cam_pos = cam_pos.cuda()
                # light_pos = light_pos.cuda()
                self.in_critic=1
                self.netD.zero_grad()
                real_data = real_data.permute(0,3, 1, 2)
                # input_D = torch.cat([self.inputv, self.inputv_depth], 1)
                #import ipdb; ipdb.set_trace()
                real_output = self.netD(real_data, cam_pos)

                if self.opt.criterion == 'GAN':
                    errD_real = self.criterion(real_output, self.labelv)
                    errD_real.backward()
                elif self.opt.criterion == 'WGAN':
                    errD_real = real_output.mean()
                    errD_real.backward(self.mone)
                else:
                    raise ValueError('Unknown GAN criterium')

                # Train with fake
                #################
                self.generate_noise_vector()
                fake_z = self.netG(self.noisev, cam_pos)
                # The normal generator is dependent on z
                fake_n = self.generate_normals(fake_z, cam_pos,
                                               self.scene['camera'])
                fake = torch.cat([fake_z, fake_n], 2)
                fake_rendered, fd, loss = self.render_batch(
                    fake, cam_pos,lp)
                # Do not bp through gen
                outD_fake = self.netD(fake_rendered.detach(),
                                      cam_pos.detach())
                if self.opt.criterion == 'GAN':
                    labelv = Variable(self.label.fill_(self.fake_label))
                    errD_fake = self.criterion(outD_fake, labelv)
                    errD_fake.backward()
                    errD = errD_real + errD_fake
                elif self.opt.criterion == 'WGAN':
                    errD_fake = outD_fake.mean()
                    errD_fake.backward(self.one)
                    errD = errD_fake - errD_real
                else:
                    raise ValueError('Unknown GAN criterium')

                # Compute gradient penalty
                if self.opt.gp != 'None':
                    gradient_penalty = calc_gradient_penalty(
                        self.netD, real_data.data, fake_rendered.data,
                        cam_pos.data, self.opt.gp_lambda)
                    gradient_penalty.backward()
                    errD += gradient_penalty

                gnorm_D = torch.nn.utils.clip_grad_norm(
                    self.netD.parameters(), self.opt.max_gnorm)  # TODO

                # Update weight
                self.optimizerD.step()
                # Clamp critic weigths if not GP and if WGAN
                if self.opt.criterion == 'WGAN' and self.opt.gp == 'None':
                    for p in self.netD.parameters():
                        p.data.clamp_(-self.opt.clamp, self.opt.clamp)
                self.critic_iter+=1

            ############################
            # (2) Update G network
            ###########################
            # To avoid computation
            # for p in self.netD.parameters():
            #     p.requires_grad = False
                if cnt % self.opt.critic_iters==0 and cnt >0:
                    self.netG.zero_grad()
                    self.in_critic=0
                    self.generate_noise_vector()
                    fake_z = self.netG(self.noisev, cam_pos)
                    if iteration % self.opt.print_interval*4 == 0:
                        fake_z.register_hook(self.tensorboard_hook)
                    fake_n = self.generate_normals(fake_z, cam_pos,
                                                   self.scene['camera'])
                    fake = torch.cat([fake_z, fake_n], 2)
                    fake_rendered, fd, loss = self.render_batch(
                        fake, cam_pos, lp)
                    outG_fake = self.netD(fake_rendered, cam_pos)

                    if self.opt.criterion == 'GAN':
                        # Fake labels are real for generator cost
                        labelv = Variable(self.label.fill_(self.real_label))
                        errG = self.criterion(outG_fake, labelv)
                        errG.backward()
                    elif self.opt.criterion == 'WGAN':
                        errG = outG_fake.mean() + loss
                        errG.backward(self.mone)
                    else:
                        raise ValueError('Unknown GAN criterium')
                    gnorm_G = torch.nn.utils.clip_grad_norm(
                        self.netG.parameters(), self.opt.max_gnorm)  # TODO
                    if (self.opt.alt_opt_zn_interval is not None and
                        iteration >= self.opt.alt_opt_zn_start):
                        # update one of the generators
                        if (((iteration - self.opt.alt_opt_zn_start) %
                             self.opt.alt_opt_zn_interval) == 0):
                            # switch generator vars to optimize
                            curr_generator_idx = (1 - curr_generator_idx)
                        if iteration < self.opt.lr_iter:
                            self.LR_SCHED_MAP[curr_generator_idx].step()
                            self.OPT_MAP[curr_generator_idx].step()
                    else:
                        if iteration < self.opt.lr_iter:
                            self.optG_z_lr_scheduler.step()

                        self.optimizerG.step()

                mse_criterion = nn.MSELoss().cuda()

                # Log print
                if iteration % (self.opt.print_interval*5) == 0 and cnt >0 :


                    Wassertein_D = (errD_real.data[0] - errD_fake.data[0])
                    self.writer.add_scalar("Loss_G",
                                           errG.data[0],
                                           self.iterationa_no)
                    self.writer.add_scalar("Loss_D",
                                           errD.data[0],
                                           self.iterationa_no)
                    self.writer.add_scalar("Wassertein_D",
                                           Wassertein_D,
                                           self.iterationa_no)
                    self.writer.add_scalar("Disc_grad_norm",
                                           gnorm_D,
                                           self.iterationa_no)
                    self.writer.add_scalar("Gen_grad_norm",
                                           gnorm_G,
                                           self.iterationa_no)

                    print('\n[%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_D_real: %.4f'
                          ' Loss_D_fake: %.4f Wassertein_D: %.4f '
                          ' L2_loss: %.4f z_lr: %.8f, n_lr: %.8f, Disc_grad_norm: %.8f, Gen_grad_norm: %.8f' % (
                          iteration, self.opt.n_iter, errD.data[0],
                          errG.data[0], errD_real.data[0], errD_fake.data[0],
                          Wassertein_D, loss.data[0],
                          self.optG_z_lr_scheduler.get_lr()[0], self.optG2_normal_lr_scheduler.get_lr()[0], gnorm_D, gnorm_G))


                # Save output images
                if iteration % (self.opt.save_image_interval*5) == 0 and cnt >0:
                    cs = tch_var_f(contrast_stretch_percentile(
                        get_data(fd), 200, [fd.data.min(), fd.data.max()]))
                    torchvision.utils.save_image(
                        fake_rendered.data,
                        os.path.join(self.opt.vis_images,
                                     'output_%d.png' % (iteration)),
                        nrow=2, normalize=True, scale_each=True)

                # Save input images
                if iteration % (self.opt.save_image_interval*5) == 0:
                    cs = tch_var_f(contrast_stretch_percentile(
                        get_data(fd), 200, [fd.data.min(), fd.data.max()]))
                    torchvision.utils.save_image(
                        real_data.data, os.path.join(
                            self.opt.vis_images, 'input_%d.png' % (iteration)),
                        nrow=2, normalize=True, scale_each=True)

                # Do checkpointing
                if iteration % (self.opt.save_interval*2) == 0:
                    self.save_networks(iteration)
Exemplo n.º 20
0
def render_splats_along_ray(scene, **params):
    """Render splats specified in the camera's coordinate system

    For now, assume number of splats to be the number of pixels This would be relaxed later to allow subpixel rendering.
    :param scene: Scene description
    :return: [H, W, 3] image
    """
    # TODO (fmannan): reuse z_to_pcl_CC
    camera = scene['camera']
    viewport = np.array(camera['viewport'])
    W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1])
    aspect_ratio = W / H
    eye = camera['eye'][:3]
    at = camera['at'][:3]
    up = camera['up'][:3]
    Mcam = lookat(eye=eye, at=at, up=up)
    #M = perspective(fovy, aspect_ratio, near, far)
    #Minv = inv_perspective(fovy, aspect_ratio, near, far)

    splats = scene['objects']['disk']
    pos_ray = splats['pos']
    normals_CC = get_param_value('normal', splats, None)
    #num_objects = pos_ray.size()[0]

    fovy = camera['fovy']
    focal_length = camera['focal_length']
    h = np.tan(fovy / 2) * 2 * focal_length
    w = h * aspect_ratio

    ##### Find (X, Y) in the Camera's view frustum
    # Force the caller to set the z coordinate with the correct sign
    if pos_ray.dim() == 1:
        Z = -torch.nn.functional.relu(-pos_ray)  # -torch.abs(pos_ray[:, 2])
    else:
        Z = -torch.nn.functional.relu(-pos_ray[:, 2]) #-torch.abs(pos_ray[:, 2])

    x, y = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H))
    x *= w / 2
    y *= h / 2

    x = tch_var_f(x.ravel())
    y = tch_var_f(y.ravel())
    #sgn = 1 if get_param_value('use_old_sign', params, False) else -1
    X = -Z * x / focal_length
    Y = -Z * y / focal_length

    pos_CC = torch.stack((X, Y, Z), dim=1)

    if get_param_value('orient_splats', params, False) and normals_CC is not None:
        # TODO (fmannan): Orient splats so that [0, 0, 1] maps to the camera direction
        # Peform this operation only when splat normals are generated by the caller in CC
        # This should help with splats that are at the edge of the view-frustum and the camera has
        # a large fov.
        pass

    # Estimate normals from splats/point-cloud if no normals were provided
    if normals_CC is None:
        normal_est_method = get_param_value('normal_estimation_method', params, 'plane')
        kernel_size = get_param_value('normal_estimation_kernel_size', params, 3)
        normals_CC = estimate_surface_normals(pos_CC.view(H, W, 3), kernel_size, normal_est_method)[..., :3].view(-1, 3)

    material_idx = scene['objects']['disk']['material_idx']
    light_visibility = None
    if 'light_vis' in scene['objects']['disk']:
        light_visibility = scene['objects']['disk']['light_vis']

    # Samples per pixel (supersampling)
    samples = get_param_value('samples', params, 1)
    if samples > 1:
        """There are three variables that need to be upsampled:
        1. positions, 2. normals, and 3. shadow maps (light visibility)
        The idea here is to generate an x-y grid in the original resolution, then shift that
        to find the subpixels, then find the plane parameters for the splat bounded within the pixel
        frustum (i.e., a frustum projected into the scene by a pixel), and then for each subpixel
        find the ray-plane intersection with that splat plane.

        The subpixel rays are generated by taking the mesh on the projection plane and shifting it
        by the appropriate amount to get the pixel coordinate that the ray should go through, then
        finding the position in the 3D camera space. The normal of the splat is copied to all those
        surface samples.

        n_x (x - x0) + n_y (y - y0) + n_z (z - z0) = 0
        n_x x0 + n_y y0 + n_z z0 = d0
        n_x t u_x + ... = d0
        t = d0 / dot(n, ray)
        """
        # plane parameter
        d = torch.sum(pos_CC * normals_CC[:, :3], dim=1)
        z = tch_var_f(np.ones(x.shape) * -focal_length)
        # # Test consistency
        # pos_CC_projplane = torch.stack((x, y, z), dim=1)
        # dot_ray_normal = torch.sum(pos_CC_projplane * normals_CC[:, :3], dim=1)
        # t = d / dot_ray_normal
        # pos_CC_test = t[:, np.newaxis] * pos_CC_projplane
        # diff = torch.mean(torch.abs(pos_CC_test - pos_CC))
        # print(diff)
        # # End of consistency check

        # Find ray-plane intersection for the plane bounded by the frustum
        # The width and height of the projection plane are w and h
        dx = w / (samples * W - 1)  # subpixel width
        dy = h / (samples * H - 1)  # subpixel height
        pos_CC_supersampled = []
        normals_CC_supersampled = []
        material_idx_supersampled = []
        if light_visibility is not None:
            light_visibility_supersampled = []
            light_visibility = light_visibility.transpose(1, 0)
        for c, deltax in enumerate(np.linspace(-1, 1, samples)):
            # TODO (fmannan): generalize (the div by 2) for samples > 3
            xx = x + deltax * dx / 2  # Shift by half of the subpixel size
            for r, deltay in enumerate(np.linspace(1, -1, samples)):
                yy = y + deltay * dy / 2
                # unit ray going through sub-pixels
                pos_CC_projplane = normalize(torch.stack((xx, yy, z), dim=1))
                dot_ray_normal = torch.sum(pos_CC_projplane * normals_CC[:, :3], dim=1)
                t = d / dot_ray_normal

                pos_CC_supersampled.append(t[:, np.newaxis] * pos_CC_projplane)
                normals_CC_supersampled.append(normals_CC[:, :3])
                material_idx_supersampled.append(material_idx[:, np.newaxis])
                if light_visibility is not None:
                    light_visibility_supersampled.append(light_visibility)
        pos_CC_supersampled = torch.stack(pos_CC_supersampled, dim=2)
        normals_CC_supersampled = torch.stack(normals_CC_supersampled, dim=2)
        material_idx_supersampled = torch.stack(material_idx_supersampled, dim=2)
        if light_visibility is not None:
            light_visibility_supersampled = torch.stack(light_visibility_supersampled, dim=2)

        pos_CC = reshape_upsampled_data(pos_CC_supersampled, H, W, 3, samples)
        normals_CC = reshape_upsampled_data(normals_CC_supersampled, H, W, 3, samples)
        material_idx = reshape_upsampled_data(material_idx_supersampled, H, W, 1, samples).view(-1)
        if light_visibility is not None:
            light_visibility = reshape_upsampled_data(light_visibility_supersampled, H, W, light_visibility.shape[1], samples).transpose(1, 0)
        H *= samples
        W *= samples
        ####
    im_depth = norm_p(pos_CC[..., :3]).view(H, W)

    if get_param_value('norm_depth_image_only', params, False):
        min_depth = torch.min(im_depth)
        norm_depth_image = where(im_depth >= camera['far'], min_depth, im_depth)
        norm_depth_image = (norm_depth_image - min_depth) / (torch.max(im_depth) - min_depth)
        return {
            'image': norm_depth_image,
            'depth': im_depth,
            'pos': pos_CC,
            'normal': normals_CC
        }
    ##############################
    # Fragment processing
    # -------------------
    # We can either perform the operations in the world coordinate or in the camera coordinate
    # Since the inputs are in NDC and converted to CC, converting to world coordinate would require more operations.
    # There are fewer lights than splats, so converting light positions and directions to CC is more efficient.
    ##############################
    # Lighting
    color_table = scene['colors']
    light_pos = scene['lights']['pos']
    light_clr_idx = scene['lights']['color_idx']
    light_colors = color_table[light_clr_idx]
    light_attenuation_coeffs = scene['lights']['attenuation']
    ambient_light = scene['lights']['ambient']

    material_albedo = scene['materials']['albedo']
    material_coeffs = scene['materials']['coeffs']


    light_pos_CC = torch.mm(light_pos, Mcam.transpose(1, 0))

    # Generate the fragments
    """
    Get the normal and material for the visible objects.
    """
    frag_normals = normals_CC[:, :3]
    frag_pos = pos_CC[:, :3]

    frag_albedo = torch.index_select(material_albedo, 0, material_idx)
    frag_coeffs = torch.index_select(material_coeffs, 0, material_idx)

    im_color = fragment_shader(frag_normals=frag_normals,
                               light_dir=light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3],
                               cam_dir=-normalize(frag_pos[np.newaxis, :, :3]),
                               light_attenuation_coeffs=light_attenuation_coeffs,
                               frag_coeffs=frag_coeffs,
                               light_colors=light_colors,
                               ambient_light=ambient_light,
                               frag_albedo=frag_albedo,
                               double_sided=False,
                               use_quartic=get_param_value('use_quartic', params, False),
                               light_visibility=light_visibility)

    im = torch.sum(im_color, dim=0).view(int(H), int(W), 3)

    # clip non-negative
    im = torch.nn.functional.relu(im)

    # Tonemapping
    #if 'tonemap' in scene:
    #    im = tonemap(im, **scene['tonemap'])

    return {
        'image': im,
        'depth': im_depth,
        'pos': pos_CC.view(H, W, 3),
        'normal': normals_CC.contiguous().view(H, W, 3)
    }
Exemplo n.º 21
0
    def get_real_samples(self):
        """Get a real sample."""
        # Define the camera poses
        if not self.opt.same_view:
            if self.opt.full_sphere_sampling:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta,
                    phi_range=self.opt.phi)
            else:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
        if self.opt.full_sphere_sampling_light:
            self.light_pos1 = uniform_sample_sphere(
                radius=self.opt.cam_dist,
                num_samples=self.opt.batchSize,
                axis=self.opt.axis,
                angle=np.deg2rad(44),
                theta_range=self.opt.theta,
                phi_range=self.opt.phi)
            # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
            #                                      axis=self.opt.axis, angle=np.deg2rad(40),
            #                                      theta_range=self.opt.theta, phi_range=self.opt.phi)
        else:
            print("inbox")
            light_eps = 0.15
            self.light_pos1 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps
            self.light_pos2 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps

            # TODO: deg2rad in all the angles????

        # Create a splats rendering scene
        large_scene = create_scene(self.opt.width, self.opt.height,
                                   self.opt.fovy, self.opt.focal_length,
                                   self.opt.n_splats)
        lookat = self.opt.at if self.opt.at is not None else [
            0.0, 0.0, 0.0, 1.0
        ]
        large_scene['camera']['at'] = tch_var_f(lookat)

        # Render scenes
        data, data_depth, data_normal, data_cond = [], [], [], []
        inpath = self.opt.vis_images + '/'
        inpath2 = self.opt.vis_input + '/'
        for idx in range(self.opt.batchSize):
            # Save the splats into the rendering scene
            if self.opt.use_mesh:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'disk' in large_scene['objects']:
                    del large_scene['objects']['disk']
                if 'triangle' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'triangle': {
                            'face': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                samples = self.get_samples()

                large_scene['objects']['triangle']['material_idx'] = tch_var_l(
                    np.zeros(samples['mesh']['face'][0].shape[0],
                             dtype=int).tolist())
                large_scene['objects']['triangle']['face'] = Variable(
                    samples['mesh']['face'][0].cuda(), requires_grad=False)
                large_scene['objects']['triangle']['normal'] = Variable(
                    samples['mesh']['normal'][0].cuda(), requires_grad=False)
            else:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'triangle' in large_scene['objects']:
                    del large_scene['objects']['triangle']
                if 'disk' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'disk': {
                            'pos': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                large_scene['objects']['disk']['radius'] = tch_var_f(
                    np.ones(self.opt.n_splats) * self.opt.splats_radius)
                large_scene['objects']['disk']['material_idx'] = tch_var_l(
                    np.zeros(self.opt.n_splats, dtype=int).tolist())
                large_scene['objects']['disk']['pos'] = Variable(
                    samples['splats']['pos'][idx].cuda(), requires_grad=False)
                large_scene['objects']['disk']['normal'] = Variable(
                    samples['splats']['normal'][idx].cuda(),
                    requires_grad=False)

            # Set camera position
            if not self.opt.same_view:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx])
            else:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0])

            large_scene['lights']['pos'][0, :3] = tch_var_f(
                self.light_pos1[idx])
            #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            res = render(large_scene,
                         norm_depth_image_only=self.opt.norm_depth_image_only,
                         double_sided=True,
                         use_quartic=self.opt.use_quartic)

            # Get rendered output
            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                im_ = get_data(res['image'])
                #im_img_ = get_normalmap_image(im_)
                target_normal_ = get_data(res['normal'])
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                im_n = tch_var_f(target_normalmap_img_).view(
                    im.shape[1], im.shape[2], 3).permute(2, 0, 1)

            # Add depth image to the output structure
            file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.txt'.format(idx)
            text_file = open(file_name, "w")
            text_file.write('%s\n' % (str(large_scene['camera']['eye'].data)))
            text_file.close()
            out_file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.npy'.format(idx)
            np.save(out_file_name, self.cam_pos[idx])
            out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_light{:05d}.npy'.format(idx)
            np.save(out_file_name2, self.light_pos1[idx])
            out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_im{:05d}.npy'.format(idx)
            np.save(out_file_name3, get_data(res['image']))
            out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_depth{:05d}.npy'.format(idx)
            np.save(out_file_name4, get_data(res['depth']))
            out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_normal{:05d}.npy'.format(idx)
            np.save(out_file_name5, get_data(res['normal']))

            if self.iterationa_no % (self.opt.save_image_interval * 5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'real_normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'real_depth_{:05d}.png'.format(idx)), get_data(depth))
                # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d)
                # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_)
            data.append(im)
            data_depth.append(im_d)
            data_normal.append(im_n)
            data_cond.append(large_scene['camera']['eye'])
        # Stack real samples
        real_samples = torch.stack(data)
        real_samples_depth = torch.stack(data_depth)
        real_samples_normal = torch.stack(data_normal)
        real_samples_cond = torch.stack(data_cond)
        self.batch_size = real_samples.size(0)
        if not self.opt.no_cuda:
            real_samples = real_samples.cuda()
            real_samples_depth = real_samples_depth.cuda()
            real_samples_normal = real_samples_normal.cuda()
            real_samples_cond = real_samples_cond.cuda()

        # Set input/output variables

        self.input.resize_as_(real_samples.data).copy_(real_samples.data)
        self.input_depth.resize_as_(real_samples_depth.data).copy_(
            real_samples_depth.data)
        self.input_normal.resize_as_(real_samples_normal.data).copy_(
            real_samples_normal.data)
        self.input_cond.resize_as_(real_samples_cond.data).copy_(
            real_samples_cond.data)
        self.label.resize_(self.batch_size).fill_(self.real_label)
        # TODO: Remove Variables
        self.inputv = Variable(self.input)
        self.inputv_depth = Variable(self.input_depth)
        self.inputv_normal = Variable(self.input_normal)
        self.inputv_cond = Variable(self.input_cond)
        self.labelv = Variable(self.label)
Exemplo n.º 22
0
def render_sphere_world(out_dir, cam_pos, radius, width, height, fovy, focal_length,
                        b_display=False):
    """
    Generate z positions on a grid fixed inside the view frustum in the world coordinate system. Place the camera and
    choose the camera's field of view so that the side of the square touches the frustum.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height
    r = np.ones(num_samples) * radius

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['radius'] = tch_var_f(r)
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    large_scene['objects']['disk']['pos'] = tch_var_f(pos)
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    large_scene['camera']['eye'] = tch_var_f(cam_pos)

    # main render run
    start_time = time()
    res = render(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 23
0
def test_sphere_splat_render_along_ray(out_dir, cam_pos, width, height, fovy, focal_length, use_quartic,
                                       b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system
    and then render using render_splats_along_ray.
    """
    import copy
    print('render sphere along ray')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel() - 5, np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.subplot(131)
        plt.imshow(pos[..., 0].reshape((height, width)))
        plt.subplot(132)
        plt.imshow(pos[..., 1].reshape((height, width)))
        plt.subplot(133)
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    ## Convert to the camera's coordinate system
    #Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])

    pos_CC = tch_var_f(pos) #torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_CC
    large_scene['objects']['disk']['normal'] = None  # Estimate the normals tch_var_f(normals)
    # large_scene['camera']['eye'] = tch_var_f([-10., 0., 10.])
    # large_scene['camera']['eye'] = tch_var_f([2., 0., 10.])
    large_scene['camera']['eye'] = tch_var_f([-5., 0., 0.])

    # main render run
    start_time = time()
    res = render_splats_along_ray(large_scene, use_quartic=use_quartic)
    rendering_time.append(time() - start_time)

    # Test cam_to_world
    res_world = cam_to_world(res['pos'].reshape(-1, 3), res['normal'].reshape(-1, 3), large_scene['camera'])

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = large_scene['camera']['far']

    if b_display:


        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(depth, interpolation='none')
        plt.title('Depth Image')
        #plt.savefig(out_dir + '/fig_depth_orig.png')

        plt.figure()
        pos_world = get_data(res_world['pos'])
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_world')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_world')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_world')

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(pos_world[:, 0], pos_world[:, 1], pos_world[:, 2], s=1.3)
        ax.set_xlabel('x')
        ax.set_ylabel('y')

        plt.figure()
        pos_world = get_data(res['pos'].reshape(-1, 3))
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_CC')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_CC')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_CC')

    imsave(out_dir + '/img_orig.png', im)
    #imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 24
0
from diffrend.torch.renderer import render_splats_NDC, render, render_splats_along_ray
from diffrend.torch.ops import perspective, inv_perspective
from diffrend.numpy.ops import normalize as np_normalize
from imageio import imsave
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
from time import time


SCENE_TEST = {
    'camera': {
        'proj_type': 'perspective',
        'viewport': [0, 0, 2, 2],
        'fovy': np.deg2rad(90.),
        'focal_length': 1.,
        'eye': tch_var_f([0.0, 1.0, 10.0, 1.0]),
        'up': tch_var_f([0.0, 1.0, 0.0, 0.0]),
        'at': tch_var_f([0.0, 0.0, 0.0, 1.0]),
        'near': 1.0,
        'far': 1000.0,
    },
    'lights': {
        'pos': tch_var_f([
            [0., 0., -10., 1.0],
            [-15, 3, 15, 1.0],
            [0, 0., 10., 1.0],
        ]),
        'color_idx': tch_var_l([2, 1, 3]),
        # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2)
        'attenuation': tch_var_f([
            [1., 0., 0.0],
Exemplo n.º 25
0
from diffrend.torch.utils import get_data, tch_var_f, cam_to_world
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt


#scene = np.load('scene_output_twogans.npy')
scene = np.load('scene_input_twogans_unnorm.npy')
#scene = np.load('scene_output.npy')
#pos = get_data(scene[0]['objects']['disk']['pos'])
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(scene[:, 0], scene[:, 1], scene[:, 2], s=1.3)

for idx in range(0, len(scene), 20):
    print(idx)
    scene[idx]['lights']['attenuation'] = tch_var_f([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
    res = render_splats_along_ray(scene[idx], use_old_sign=False)

    im = get_data(res['image'])
    depth = get_data(res['depth'])

    plt.figure()
    plt.imshow(im)

    plt.figure()
    plt.imshow(depth)

    pos = get_data(res['pos'])
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=1.3)
Exemplo n.º 26
0
def render_splats_NDC(scene, **params):
    """Render splats specified in the camera's normalized coordinate system

    For now, assume number of splats to be the number of pixels This would be relaxed later to allow subpixel rendering.
    :param scene: Scene description
    :return: [H, W, 3] image
    """
    camera = scene['camera']
    viewport = np.array(camera['viewport'])
    W, H = int(viewport[2] - viewport[0]), int(viewport[3] - viewport[1])
    aspect_ratio = W / H
    fovy = camera['fovy']
    near = camera['near']
    far = camera['far']
    eye = camera['eye'][:3]
    at = camera['at'][:3]
    up = camera['up'][:3]
    Mcam = lookat(eye=eye, at=at, up=up)
    #M = perspective(fovy, aspect_ratio, near, far)
    Minv = inv_perspective(fovy, aspect_ratio, near, far)

    splats = scene['objects']['disk']
    pos_NDC = splats['pos']
    normals_SLC = splats['normal']
    num_objects = pos_NDC.size()[0]

    # Transform params to the Camera's view frustum
    if pos_NDC.size()[-1] == 3:
        pos_NDC = torch.cat((pos_NDC, tch_var_f(np.ones((num_objects, 1)))), dim=1)
    pos_CC = torch.matmul(pos_NDC, Minv.transpose(1, 0))
    pos_CC = pos_CC / pos_CC[..., 3][:, np.newaxis]

    im_depth = norm_p(pos_CC[..., :3]).view(H, W)

    if get_param_value('norm_depth_image_only', params, False):
        min_depth = torch.min(im_depth)
        norm_depth_image = where(im_depth >= camera['far'], min_depth, im_depth)
        norm_depth_image = (norm_depth_image - min_depth) / (torch.max(im_depth) - min_depth)
        return {
            'image': norm_depth_image,
            'depth': im_depth,
            'pos': pos_CC,
            'normal': normals_SLC
        }
    ##############################
    # Fragment processing
    # -------------------
    # We can either perform the operations in the world coordinate or in the camera coordinate
    # Since the inputs are in NDC and converted to CC, converting to world coordinate would require more operations.
    # There are fewer lights than splats, so converting light positions and directions to CC is more efficient.
    ##############################
    # Lighting
    color_table = scene['colors']
    light_pos = scene['lights']['pos']
    light_clr_idx = scene['lights']['color_idx']
    light_colors = color_table[light_clr_idx]
    light_attenuation_coeffs = scene['lights']['attenuation']
    ambient_light = scene['lights']['ambient']

    material_albedo = scene['materials']['albedo']
    material_coeffs = scene['materials']['coeffs']
    material_idx = scene['objects']['disk']['material_idx']

    light_pos_CC = torch.mm(light_pos, Mcam.transpose(1, 0))

    # Generate the fragments
    """
    Get the normal and material for the visible objects.
    """
    normals_CC = normals_SLC   # TODO: Transform to CC, or assume SLC is CC
    frag_normals = normals_CC[:, :3]
    frag_pos = pos_CC[:, :3]

    frag_albedo = torch.index_select(material_albedo, 0, material_idx)
    frag_coeffs = torch.index_select(material_coeffs, 0, material_idx)
    light_visibility = None
    # TODO: CHECK fragment_shader call
    im_color = fragment_shader(frag_normals=frag_normals,
                               light_dir=light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3],
                               cam_dir=-frag_pos[:, :3],
                               light_attenuation_coeffs=light_attenuation_coeffs,
                               frag_coeffs=frag_coeffs,
                               light_colors=light_colors,
                               ambient_light=ambient_light,
                               frag_albedo=frag_albedo,
                               double_sided=get_param_value('double_sided', params, False),
                               use_quartic=get_param_value('use_quartic', params, False),
                               light_visibility=light_visibility)
    # # Fragment shading
    # light_dir = light_pos_CC[:, np.newaxis, :3] - frag_pos[:, :3]
    # light_dir_norm = torch.sqrt(torch.sum(light_dir ** 2, dim=-1))[:, :, np.newaxis]
    # light_dir /= light_dir_norm  # TODO: nonzero_divide
    # # Attenuate the lights
    # per_frag_att_factor = 1 / (light_attenuation_coeffs[:, 0][:, np.newaxis, np.newaxis] +
    #                            light_dir_norm * light_attenuation_coeffs[:, 1][:, np.newaxis, np.newaxis] +
    #                            (light_dir_norm ** 2) * light_attenuation_coeffs[:, 2][:, np.newaxis, np.newaxis])
    #
    # frag_normal_dot_light = tensor_dot(frag_normals, per_frag_att_factor * light_dir, axis=-1)
    # frag_normal_dot_light = torch.nn.functional.relu(frag_normal_dot_light)
    # im_color = frag_normal_dot_light[:, :, np.newaxis] * \
    #            light_colors[:, np.newaxis, :] * frag_albedo[np.newaxis, :, :]

    im = torch.sum(im_color, dim=0).view(int(H), int(W), 3)

    # clip non-negative
    im = torch.nn.functional.relu(im)

    # # Tonemapping
    # if 'tonemap' in scene:
    #     im = tonemap(im, **scene['tonemap'])

    return {
        'image': im,
        'depth': im_depth,
        'pos': pos_CC[:, :3].view(H, W, 3),
        'normal': normals_CC[:, :3].view(H, W, 3)
    }
Exemplo n.º 27
0
def inv_perspective_LH_NO(fovy, aspect, near, far):
    """Left-handed camera with all coords mapped to [-1, 1] """
    mat_00, mat_11, mat_22, mat_23 = perspective_NO_params(
        fovy, aspect, near, far)
    return tch_var_f([[1 / mat_00, 0, 0, 0], [0, 1 / mat_11, 0, 0],
                      [0, 0, 0, 1], [0, 0, 1 / mat_23, -mat_22 / mat_23]])
Exemplo n.º 28
0
def optimize_splats_along_ray_shadow_with_normalest_test(
        out_dir,
        width,
        height,
        max_iter=100,
        lr=1e-3,
        scale=10,
        shadow=True,
        vis_only=False,
        samples=1,
        est_normals=False,
        b_generate_normals=False,
        print_interval=10,
        imsave_interval=10,
        xyz_save_interval=100):
    """A demo function to check if the differentiable renderer can optimize splats rendered along ray.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX_0
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f(
        [2, 1, 2, 1])  # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) #
    scene['camera']['at'] = tch_var_f(
        [0, 0.8, 0, 1])  # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1])  #
    scene['lights']['attenuation'] = tch_var_f([
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
    ])
    scene['materials']['coeffs'] = tch_var_f([
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.5, 0.2, 8.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
    ])

    target_res = render(scene, tiled=True, shadow=shadow)
    target_im = normalize_maxmin(target_res['image'])
    target_im.require_grad = False
    target_im_ = get_data(target_im)
    target_pos_ = get_data(target_res['pos'])
    target_normal_ = get_data(target_res['normal'])
    target_normalmap_img_ = get_normalmap_image(target_normal_)
    target_depth_ = get_data(target_res['depth'])
    print('[z_min, z_max] = [%f, %f]' %
          (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2])))
    print('[depth_min, depth_max] = [%f, %f]' %
          (np.min(target_depth_), np.max(target_depth_)))

    # world -> cam -> render_splats_along_ray
    cc_tform = world_to_cam(target_res['pos'].view(
        (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera'])
    wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'],
                               scene['camera'])

    # Check normal estimation in camera space
    pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape)
    normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape)
    plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None)
    normal_cc_normalmap = get_normalmap_image(get_data(normal_cc))
    plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est))

    pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] -
                         target_res['pos'].view((-1, 3)))
    mean_pos_diff = torch.mean(pos_diff)
    normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] -
                            target_res['normal'].view(-1, 3))
    mean_normal_diff = torch.mean(normal_diff)
    print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff)

    wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape)
    wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal))

    material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3)
    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    light_vis = tch_var_f(
        np.ones(
            (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0])))
    input_scene['objects'] = {
        'disk': {
            'pos': cc_tform['pos'],
            'normal': cc_tform['normal'],
            'material_idx': material_idx,
            'light_vis': light_vis,
        }
    }
    target_res_noshadow = render(scene, tiled=True, shadow=False)
    res = render_splats_along_ray(input_scene)
    test_img_ = get_data(normalize_maxmin(res['image']))
    test_depth_ = get_data(res['depth'])
    test_normal_ = get_data(res['normal']).reshape(test_img_.shape)
    test_normalmap_ = get_normalmap_image(test_normal_)
    im_diff = np.abs(test_img_ -
                     get_data(normalize_maxmin(target_res_noshadow['image'])))
    print('mean image diff: {}'.format(np.mean(im_diff)))
    #### PLOT
    plt.ion()
    plt.figure()
    plt.imshow(test_img_, interpolation='none')
    plt.title('Test Image')
    plt.savefig(out_dir + '/test_img.png')
    plt.figure()
    plt.imshow(test_depth_, interpolation='none')
    plt.title('Test Depth')
    plt.savefig(out_dir + '/test_depth.png')

    plt.figure()
    plt.imshow(test_normalmap_, interpolation='none')
    plt.title('Test Normals')
    plt.savefig(out_dir + '/test_normal.png')

    ####
    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_, interpolation='none')
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    plt.figure()
    plt.imshow(target_normalmap_img_, interpolation='none')
    plt.title('Normals')
    plt.savefig(out_dir + '/normal.png')

    plt.figure()
    plt.imshow(wc_cc_normal_img, interpolation='none')
    plt.title('WC_CC Normals')
    plt.savefig(out_dir + '/wc_cc_normal.png')

    plt.figure()
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.savefig(out_dir + '/normal_cc.png')

    plt.figure()
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/est_normal_cc.png')

    plt.figure()
    plt.subplot(121)
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.subplot(122)
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    input_scene['camera']['viewport'] = [
        0, 0, int(width / samples),
        int(height / samples)
    ]

    num_splats = int(width * height / (samples * samples))
    #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples)))
    z_min = scene['camera']['focal_length']
    z_max = 3

    z = -tch_var_f(
        np.ones(num_splats) * (z_min + z_max) / 2
    )  # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max)
    z.requires_grad = True

    normal_angles = tch_var_f(np.random.rand(num_splats, 2))
    normal_angles.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    light_vis = tch_var_f(
        np.ones((input_scene['lights']['pos'].shape[0], num_splats)))
    light_vis.requires_grad = True

    if vis_only:
        assert shadow is True
        opt_vars = [light_vis]
        z = cc_tform['pos'][:, 2]
        # FIXME: sph2cart
        #normals = cc_tform['normal']
    else:
        opt_vars = [z, normal_angles]
        if shadow:
            opt_vars += [light_vis]

    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8)

    h0 = plt.figure()
    h1 = plt.figure()
    h2 = plt.figure()
    h3 = plt.figure()
    h4 = plt.figure()

    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.0025, hspace=0.02)

    # Two options for z_norm_consistency
    # 1. start after N iterations
    # 2. start at the beginning and decay
    # 3. start after N iterations and decay to 0
    no_decay = lambda x: x
    exp_decay = lambda x, scale: torch.exp(-x / scale)
    linear_decay = lambda x, scale: scale / (x + 1e-6)

    spatial_var_loss_weight = 10.0  #0.0
    normal_away_from_cam_loss_weight = 0.0
    grad_img_depth_loss_weight = 1.0
    spatial_loss_weight = 2

    z_norm_weight_init = 1  # 1e-5
    z_norm_activate_iter = 0  # 1000
    decay_fn = lambda x: linear_decay(x, 100)
    loss_per_iter = []
    if b_generate_normals:
        est_normals = False
        normal_est_network = NEstNetAffine(kernel_size=3, sph=False)
        print(normal_est_network)
        normal_est_network.cuda()
    for iter in range(max_iter):
        lr_scheduler.step()
        zz = -F.relu(-z) - z_min  # torch.clamp(z, -z_max, -z_min)
        if b_generate_normals:
            normals = generate_normals(zz, scene['camera'], normal_est_network)
            #if iter > 100 and iter % 10 == 0:
            #    print(normals)
        elif not est_normals:
            phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi
            theta = F.sigmoid(
                normal_angles[:, 1]
            ) * np.pi / 2  # F.tanh(normal_angles[:, 1]) * np.pi / 2
            normals = sph2cart_unit(torch.stack((phi, theta), dim=1))

        pos = zz  # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1)

        input_scene['objects'] = {
            'disk': {
                'pos': pos,
                'normal': normalize(normals) if not est_normals else None,
                'material_idx': material_idx,
                'light_vis': torch.sigmoid(light_vis),
            }
        }
        res = render_splats_along_ray(input_scene,
                                      samples=samples,
                                      normal_estimation_method='plane')
        res_pos = res['pos']
        res_normal = res['normal']
        spatial_loss = spatial_3x3(res_pos)
        depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
        grad_img = grad_spatial2d(
            torch.mean(res['image'], dim=-1)[..., np.newaxis])
        grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
        image_depth_consistency_loss = depth_rgb_gradient_consistency(
            res['image'], res['depth'])
        unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)
        normal_away_from_cam_loss = away_from_camera_penalty(
            res_pos, res_normal)
        z_pos = res_pos[..., 2]
        z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 +
                            (10 * F.relu(torch.abs(z_pos) - z_max))**2)
        z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1)
        spatial_var = torch.mean(res_pos[..., 0].var() +
                                 res_pos[..., 1].var() + res_pos[..., 2].var())
        spatial_var_loss = (1 / (spatial_var + 1e-4))
        im_out = normalize_maxmin(res['image'])
        res_depth_ = get_data(res['depth'])

        optimizer.zero_grad()
        z_norm_weight = z_norm_weight_init * float(
            iter > z_norm_activate_iter) * decay_fn(iter -
                                                    z_norm_activate_iter)
        loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \
            z_norm_weight * z_norm_loss + \
            spatial_var_loss_weight * spatial_var_loss + \
            grad_img_depth_loss_weight * image_depth_consistency_loss
        #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \
        #spatial_loss_weight * spatial_loss

        im_out_ = get_data(im_out)
        im_out_normal_ = get_data(res['normal'])
        pos_out_ = get_data(res['pos'])

        loss_ = get_data(loss)
        z_loss_ = get_data(z_loss)
        z_norm_loss_ = get_data(z_norm_loss)
        spatial_loss_ = get_data(spatial_loss)
        spatial_var_loss_ = get_data(spatial_var_loss)
        unit_normal_loss_ = get_data(unit_normal_loss)
        normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss)
        normals_ = get_data(res_normal)
        image_depth_consistency_loss_ = get_data(image_depth_consistency_loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            z__ = pos_out_[..., 2]
            print(
                '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,'
                ' spatial_var_loss: %f, normal_away_loss: %f'
                ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' %
                (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_),
                 np.max(z_), np.min(z__), np.max(z__), z_norm_loss_,
                 spatial_var_loss_, normal_away_from_cam_loss_,
                 normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_,
                 image_depth_consistency_loss_))

        if iter % xyz_save_interval == 0 or iter == max_iter - 1:
            save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter),
                     get_data(res_pos), get_data(res_normal))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            plt.figure(h4.number)
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(121)
            #plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.title('Output')
            plt.subplot(122)
            #plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.title('Ground truth')
            # plt.subplot(223)
            # plt.plot(loss_per_iter, linewidth=2)
            # plt.xlabel('Iteration', fontsize=14)
            # plt.title('Loss', fontsize=12)
            # plt.grid(True)
            plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter)

            plt.figure(h1.number, figsize=(4, 4))
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(gs1[0])
            plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.subplot(gs1[1])
            plt.axis('off')
            plt.imshow(get_normalmap_image(im_out_normal_),
                       interpolation='none')
            ax = plt.subplot(gs1[2])
            plt.axis('off')
            im_tmp = ax.imshow(res_depth_, interpolation='none')
            # create an axes on the right side of ax. The width of cax will be 5%
            # of ax and the padding between cax and ax will be fixed at 0.05 inch.
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            plt.subplot(gs1[3])
            plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.subplot(gs1[4])
            plt.axis('off')
            plt.imshow(test_normalmap_, interpolation='none')
            ax = plt.subplot(gs1[5])
            plt.axis('off')
            im_tmp = ax.imshow(test_depth_, interpolation='none')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            W, H = input_scene['camera']['viewport'][2:]
            light_vis_ = get_data(torch.sigmoid(light_vis))
            plt.subplot(gs1[6])
            plt.axis('off')
            plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 1):
                plt.subplot(gs1[7])
                plt.axis('off')
                plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 2):
                plt.subplot(gs1[8])
                plt.axis('off')
                plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none')

            plt.savefig(out_dir + '/fig_%05d.png' % iter)

            plt.figure(h2.number)
            plt.clf()
            plt.imshow(res_depth_)
            plt.colorbar()
            plt.savefig(out_dir + '/fig_depth_%05d.png' % iter)

            plt.figure(h3.number)
            plt.clf()
            plt.imshow(z_.reshape(H, W))
            plt.colorbar()
            plt.savefig(out_dir + '/fig_z_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Exemplo n.º 29
0
def test_sphere_splat_NDC(out_dir, cam_pos, width, height, fovy, focal_length,  b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system and to
    NDC and then render using render_splat_NDC.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel(), np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    # Convert to the camera's coordinate system
    Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])
    Mproj = perspective(fovy=large_scene['camera']['fovy'], aspect=width/height, near=large_scene['camera']['near'],
                        far=large_scene['camera']['far'])

    pos_CC = torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))
    pos_NDC = torch.matmul(pos_CC, Mproj.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_NDC / pos_NDC[..., 3][:, np.newaxis]
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    # main render run
    start_time = time()
    res = render_splats_NDC(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Exemplo n.º 30
0
    def render_batch(self, batch, batch_cond,light_pos):
        """Render a batch of splats."""
        batch_size = batch.size()[0]

        # Generate camera positions on a sphere
        if batch_cond is None:
            if self.opt.full_sphere_sampling:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta, phi_range=self.opt.phi)
                # TODO: deg2grad!!
            else:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
                # TODO: deg2grad!!

        rendered_data = []
        rendered_data_depth = []
        rendered_data_cond = []
        scenes = []
        inpath = self.opt.vis_images + '/'
        inpath_xyz = self.opt.vis_xyz + '/'
        z_min = self.scene['camera']['focal_length']
        z_max = z_min + 3

        # TODO (fmannan): Move this in init. This only needs to be done once!
        # Set splats into rendering scene
        if 'sphere' in self.scene['objects']:
            del self.scene['objects']['sphere']
        if 'triangle' in self.scene['objects']:
            del self.scene['objects']['triangle']
        if 'disk' not in self.scene['objects']:
            self.scene['objects'] = {'disk': {'pos': None, 'normal': None,
                                              'material_idx': None}}
        lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0]
        self.scene['camera']['at'] = tch_var_f(lookat)
        self.scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(self.opt.splats_img_size * self.opt.splats_img_size))
        loss = 0.0
        loss_ = 0.0
        z_loss_ = 0.0
        z_norm_loss_ = 0.0
        spatial_loss_ = 0.0
        spatial_var_loss_ = 0.0
        unit_normal_loss_ = 0.0
        normal_away_from_cam_loss_ = 0.0
        image_depth_consistency_loss_ = 0.0
        for idx in range(batch_size):
            # Get splats positions and normals
            eps = 1e-3
            if self.opt.rescaled:
                z = F.relu(-batch[idx][:, 0]) + z_min
                z = ((z - z.min()) / (z.max() - z.min() + eps) *
                     (z_max - z_min) + z_min)
                pos = -z
            else:
                z = F.relu(-batch[idx][:, 0]) + z_min
                pos = -F.relu(-batch[idx][:, 0]) - z_min
            normals = batch[idx][:, 1:]

            self.scene['objects']['disk']['pos'] = pos

            # Normal estimation network and est_normals don't go together
            self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None

            # Set camera position
            if batch_cond is None:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[idx])
                else:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[0])
            else:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = batch_cond[idx]
                else:
                    self.scene['camera']['eye'] = batch_cond[0]

            self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx])
            #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            # res = render_splats_NDC(self.scene)
            res = render_splats_along_ray(self.scene,
                                          samples=self.opt.pixel_samples,
                                          normal_estimation_method='plane')

            world_tform = cam_to_world(res['pos'].view((-1, 3)),
                                       res['normal'].view((-1, 3)),
                                       self.scene['camera'])

            # Get rendered output
            res_pos = res['pos'].contiguous()
            res_pos_2D = res_pos.view(res['image'].shape)
            # The z_loss needs to be applied after supersampling
            # TODO: Enable this (currently final loss becomes NaN!!)
            # loss += torch.mean(
            #    (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 +
            #    (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2)


            res_normal = res['normal']
            # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
            # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis])
            # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
            image_depth_consistency_loss = depth_rgb_gradient_consistency(
                res['image'], res['depth'])
            unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)  # TODO: MN
            normal_away_from_cam_loss = away_from_camera_penalty(
                res_pos, res_normal)
            z_pos = res_pos[..., 2]
            z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 +
                                (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2)
            z_norm_loss = normal_consistency_cost(
                res_pos, res['normal'], norm=1)
            spatial_loss = spatial_3x3(res_pos_2D)
            spatial_var = torch.mean(res_pos[..., 0].var() +
                                     res_pos[..., 1].var() +
                                     res_pos[..., 2].var())
            spatial_var_loss = (1 / (spatial_var + 1e-4))

            loss = (self.opt.zloss * z_loss +
                    self.opt.unit_normalloss*unit_normal_loss +
                    self.opt.normal_consistency_loss_weight * z_norm_loss +
                    self.opt.spatial_var_loss_weight * spatial_var_loss +
                    self.opt.grad_img_depth_loss*image_depth_consistency_loss +
                    self.opt.spatial_loss_weight * spatial_loss)
            pos_out_ = get_data(res['pos'])
            loss_ += get_data(loss)
            z_loss_ += get_data(z_loss)
            z_norm_loss_ += get_data(z_norm_loss)
            spatial_loss_ += get_data(spatial_loss)
            spatial_var_loss_ += get_data(spatial_var_loss)
            unit_normal_loss_ += get_data(unit_normal_loss)
            normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss)
            image_depth_consistency_loss_ += get_data(
                image_depth_consistency_loss)
            normals_ = get_data(res_normal)

            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                H, W = im.shape[1:]
                target_normal_ = get_data(res['normal']).reshape((H, W, 3))
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                target_worldnormal_ = get_data(world_tform['normal']).reshape(
                    (H, W, 3))
                target_worldnormalmap_img_ = get_normalmap_image(
                    target_worldnormal_)
            if self.iterationa_no % (self.opt.save_image_interval*5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'depthmap_{:05d}.png'.format(idx)),
                       get_data(res['depth']))
                imsave((inpath + str(self.iterationa_no) +
                        'world_normalmap_{:05d}.png'.format(idx)),
                       target_worldnormalmap_img_)
            if self.iterationa_no % 1000 == 0:
                im2 = get_data(res['image'])
                depth2 = get_data(res['depth'])
                pos = get_data(res['pos'])

                out_file2 = ("pos"+".npy")
                np.save(inpath_xyz+out_file2, pos)

                out_file2 = ("im"+".npy")
                np.save(inpath_xyz+out_file2, im2)

                out_file2 = ("depth"+".npy")
                np.save(inpath_xyz+out_file2, depth2)

                # Save xyz file
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_{:05d}.xyz'.format(idx)),
                         pos=get_data(res['pos']),
                         normal=get_data(res['normal']))

                # Save xyz file in world coordinates
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_world_{:05d}.xyz'.format(idx)),
                         pos=get_data(world_tform['pos']),
                         normal=get_data(world_tform['normal']))
            if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0:
                gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis])
                gradImg = grad_spatial2d(torch.mean(im,
                                                    dim=0)[:, :, np.newaxis])
                for (gZ, gI) in zip(gradZ, gradImg):
                    loss += (self.opt.gz_gi_loss * torch.mean(torch.abs(
                                torch.abs(gZ) - torch.abs(gI))))
            # Store normalized depth into the data
            rendered_data.append(im)
            rendered_data_depth.append(im_d)
            rendered_data_cond.append(self.scene['camera']['eye'])
            scenes.append(self.scene)

        rendered_data = torch.stack(rendered_data)
        rendered_data_depth = torch.stack(rendered_data_depth)


        return rendered_data, rendered_data_depth, loss/self.opt.batchSize