Esempio n. 1
0
 def flat_px(px):
     """Flatten the pixel locations and make sure everything is within bounds"""
     out = px[..., 1] * W + px[..., 0]
     max_idx = tch_var_l([W * H])
     mask = (px[..., 1] < 0) | (px[..., 0] < 0) | (px[..., 1] >=
                                                   H) | (px[..., 0] >= W)
     out = torch.where(mask, max_idx, out)
     return out
Esempio n. 2
0
def batch_render_random_camera(filename, cam_dist, num_views, width, height,
                         fovy, focal_length, theta_range=None, phi_range=None,
                         axis=None, angle=None, cam_pos=None, cam_lookat=None,
                         double_sided=False, use_quartic=False, b_shadow=True,
                         tile_size=None, save_image_queue=None):
    rendering_time = []

    obj = load_model(filename)
    # normalize the vertices
    v = obj['v']
    axis_range = np.max(v, axis=0) - np.min(v, axis=0)
    v = (v - np.mean(v, axis=0)) / max(axis_range)  # Normalize to make the largest spread 1
    obj['v'] = v

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length

    mesh = obj_to_triangle_spec(obj)
    faces = mesh['face']
    normals = mesh['normal']
    num_tri = faces.shape[0]

    if 'disk' in scene['objects']:
        del scene['objects']['disk']
    scene['objects'].update({'triangle': {'face': None, 'normal': None, 'material_idx': None}})
    scene['objects']['triangle']['face'] = tch_var_f(faces.tolist())
    scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist())
    scene['objects']['triangle']['material_idx'] = tch_var_l(np.zeros(num_tri, dtype=int).tolist())

    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist, num_samples=num_views,
                                        axis=axis, angle=angle,
                                        theta_range=theta_range, phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0)
    scene['camera']['at'] = tch_var_f(lookat)

    for idx in range(cam_pos.shape[0]):
        scene['camera']['eye'] = tch_var_f(cam_pos[idx])

        # main render run
        start_time = time()
        res = render(scene, tile_size=tile_size, tiled=tile_size is not None,
                     shadow=b_shadow, double_sided=double_sided,
                     use_quartic=use_quartic)
        res['suffix'] = '_{}'.format(idx)
        res['camera_far'] = scene['camera']['far']
        save_image_queue.put_nowait(get_data(res))
        rendering_time.append(time() - start_time)

    # Timing statistics
    print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time), np.std(rendering_time)))
Esempio n. 3
0
def project_image_coordinates(surfels, camera):
    """Project surfels given in world coordinate to the camera's projection plane.

    Args:
        surfels: [batch_size, pos]
        camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]

    Returns:
        Image of destination indices of dimensions [batch_size, H*W]
        Note that the range of possible coordinates is restricted to be between 0
        and W*H (inclusive). This is inclusive because we use the last index as
        a "dump" for any index that falls outside of the camera's field of view
    """
    surfels_plane = project_surfels(surfels, camera)

    # Rasterize
    viewport = make_list2np(camera['viewport'])
    W, H = float(viewport[2] - viewport[0]), float(viewport[3] - viewport[1])
    aspect_ratio = float(W) / float(H)

    fovy = make_list2np(camera['fovy'])
    focal_length = make_list2np(camera['focal_length'])
    h = np.tan(fovy / 2) * 2 * focal_length
    w = h * aspect_ratio

    px_coord = torch.zeros_like(surfels_plane)
    px_coord[...,
             2] = surfels_plane[...,
                                2]  # Make sure to also transmit the new depth
    px_coord[..., :2] = surfels_plane[..., :2] * tch_var_f(
        [-(W - 1) / w, (H - 1) / h]).unsqueeze(-2) + tch_var_f(
            [W / 2., H / 2.]).unsqueeze(-2)
    px_coord_idx = torch.round(px_coord - 0.5).long()

    px_idx = px_coord_idx[..., 1] * W + px_coord_idx[..., 0]

    max_idx = W * H  # Index used if the indices are out of bounds of the camera
    max_idx_tensor = tch_var_l([max_idx])

    # Map out of bounds pixels to the last (extra) index
    mask = (px_coord_idx[..., 1] < 0) | (px_coord_idx[..., 0] < 0) | (
        px_coord_idx[..., 1] >= H) | (px_coord_idx[..., 0] >= W)
    px_idx = torch.where(mask, max_idx_tensor, px_idx)

    return px_idx, px_coord
Esempio n. 4
0
def test_scalability(filename, out_dir='./test_scale'):
    # GTX 980 8GB
    # 320 x 240 250 objs
    # 64 x 64 5000 objs
    # 32 x 32 20000 objs
    # 16 x 16 75000 objs (slow)
    from diffrend.model import load_model

    splats = load_model(filename)
    v = splats['v']
    # normalize the vertices
    v = (v - np.mean(v, axis=0)) / (v.max() - v.min())

    print(np.min(splats['v'], axis=0))
    print(np.max(splats['v'], axis=0))
    print(np.min(v, axis=0))
    print(np.max(v, axis=0))

    rand_idx = np.arange(
        v.shape[0])  #np.random.randint(0, splats['v'].shape[0], 4000)  #
    large_scene = copy.deepcopy(SCENE_BASIC)

    large_scene['camera']['viewport'] = [0, 0, 64, 64]  #[0, 0, 320, 240]
    large_scene['camera']['fovy'] = np.deg2rad(5.)
    large_scene['camera']['focal_length'] = 2.
    #large_scene['camera']['eye'] = tch_var_f([0.0, 1.0, 5.0, 1.0]),
    large_scene['objects']['disk']['pos'] = tch_var_f(v[rand_idx])
    large_scene['objects']['disk']['normal'] = tch_var_f(
        splats['vn'][rand_idx])
    large_scene['objects']['disk']['radius'] = tch_var_f(
        splats['r'][rand_idx].ravel() * 2)
    large_scene['objects']['disk']['material_idx'] = tch_var_l(
        np.zeros(rand_idx.size, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])

    render_scene(large_scene, out_dir, plot_res=True)
Esempio n. 5
0
def render_sphere_world(out_dir, cam_pos, radius, width, height, fovy, focal_length,
                        b_display=False):
    """
    Generate z positions on a grid fixed inside the view frustum in the world coordinate system. Place the camera and
    choose the camera's field of view so that the side of the square touches the frustum.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height
    r = np.ones(num_samples) * radius

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['radius'] = tch_var_f(r)
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    large_scene['objects']['disk']['pos'] = tch_var_f(pos)
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    large_scene['camera']['eye'] = tch_var_f(cam_pos)

    # main render run
    start_time = time()
    res = render(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Esempio n. 6
0
def test_sphere_splat_render_along_ray(out_dir, cam_pos, width, height, fovy, focal_length, use_quartic,
                                       b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system
    and then render using render_splats_along_ray.
    """
    import copy
    print('render sphere along ray')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel() - 5, np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.subplot(131)
        plt.imshow(pos[..., 0].reshape((height, width)))
        plt.subplot(132)
        plt.imshow(pos[..., 1].reshape((height, width)))
        plt.subplot(133)
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    ## Convert to the camera's coordinate system
    #Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])

    pos_CC = tch_var_f(pos) #torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_CC
    large_scene['objects']['disk']['normal'] = None  # Estimate the normals tch_var_f(normals)
    # large_scene['camera']['eye'] = tch_var_f([-10., 0., 10.])
    # large_scene['camera']['eye'] = tch_var_f([2., 0., 10.])
    large_scene['camera']['eye'] = tch_var_f([-5., 0., 0.])

    # main render run
    start_time = time()
    res = render_splats_along_ray(large_scene, use_quartic=use_quartic)
    rendering_time.append(time() - start_time)

    # Test cam_to_world
    res_world = cam_to_world(res['pos'].reshape(-1, 3), res['normal'].reshape(-1, 3), large_scene['camera'])

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = large_scene['camera']['far']

    if b_display:


        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(depth, interpolation='none')
        plt.title('Depth Image')
        #plt.savefig(out_dir + '/fig_depth_orig.png')

        plt.figure()
        pos_world = get_data(res_world['pos'])
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_world')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_world')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_world')

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(pos_world[:, 0], pos_world[:, 1], pos_world[:, 2], s=1.3)
        ax.set_xlabel('x')
        ax.set_ylabel('y')

        plt.figure()
        pos_world = get_data(res['pos'].reshape(-1, 3))
        posx_world = pos_world[:, 0].reshape((im.shape[0], im.shape[1]))
        posy_world = pos_world[:, 1].reshape((im.shape[0], im.shape[1]))
        posz_world = pos_world[:, 2].reshape((im.shape[0], im.shape[1]))
        plt.subplot(131)
        plt.imshow(posx_world)
        plt.title('x_CC')
        plt.subplot(132)
        plt.imshow(posy_world)
        plt.title('y_CC')
        plt.subplot(133)
        plt.imshow(posz_world)
        plt.title('z_CC')

    imsave(out_dir + '/img_orig.png', im)
    #imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Esempio n. 7
0
def test_sphere_splat_NDC(out_dir, cam_pos, width, height, fovy, focal_length,  b_display=False):
    """
    Create a sphere on a square as in render_sphere_world, and then convert to the camera's coordinate system and to
    NDC and then render using render_splat_NDC.
    """
    import copy
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height

    large_scene = copy.deepcopy(SCENE_TEST)

    large_scene['camera']['viewport'] = [0, 0, width, height]
    large_scene['camera']['eye'] = tch_var_f(cam_pos)
    large_scene['camera']['fovy'] = np.deg2rad(fovy)
    large_scene['camera']['focal_length'] = focal_length
    large_scene['objects']['disk']['material_idx'] = tch_var_l(np.zeros(num_samples, dtype=int).tolist())
    large_scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    large_scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x ** 2 + y ** 2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x ** 2 + y ** 2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel(), np.ones(num_samples)), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel(), np.zeros(num_samples)), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    # Convert to the camera's coordinate system
    Mcam = lookat(eye=large_scene['camera']['eye'], at=large_scene['camera']['at'], up=large_scene['camera']['up'])
    Mproj = perspective(fovy=large_scene['camera']['fovy'], aspect=width/height, near=large_scene['camera']['near'],
                        far=large_scene['camera']['far'])

    pos_CC = torch.matmul(tch_var_f(pos), Mcam.transpose(1, 0))
    pos_NDC = torch.matmul(pos_CC, Mproj.transpose(1, 0))

    large_scene['objects']['disk']['pos'] = pos_NDC / pos_NDC[..., 3][:, np.newaxis]
    large_scene['objects']['disk']['normal'] = tch_var_f(normals)

    # main render run
    start_time = time()
    res = render_splats_NDC(large_scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    im = np.uint8(255. * im)

    depth = get_data(res['depth'])
    depth[depth >= large_scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) / (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im, interpolation='none')
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth, interpolation='none')
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Esempio n. 8
0
     'viewport': [0, 0, 2, 2],
     'fovy': np.deg2rad(90.),
     'focal_length': 1.,
     'eye': tch_var_f([0.0, 1.0, 10.0, 1.0]),
     'up': tch_var_f([0.0, 1.0, 0.0, 0.0]),
     'at': tch_var_f([0.0, 0.0, 0.0, 1.0]),
     'near': 1.0,
     'far': 1000.0,
 },
 'lights': {
     'pos': tch_var_f([
         [0., 0., -10., 1.0],
         [-15, 3, 15, 1.0],
         [0, 0., 10., 1.0],
     ]),
     'color_idx': tch_var_l([2, 1, 3]),
     # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2)
     'attenuation': tch_var_f([
         [1., 0., 0.0],
         [0., 0., 0.01],
         [0., 0., 0.01],
     ]),
     'ambient': tch_var_f([0.01, 0.01, 0.01])
 },
 'colors': tch_var_f([
     [0.0, 0.0, 0.0],
     [0.8, 0.1, 0.1],
     [0.0, 0.0, 0.8],
     [0.2, 0.8, 0.2],
 ]),
 'materials': {
Esempio n. 9
0
    def render_batch(self, batch, batch_cond,light_pos):
        """Render a batch of splats."""
        batch_size = batch.size()[0]

        # Generate camera positions on a sphere
        if batch_cond is None:
            if self.opt.full_sphere_sampling:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta, phi_range=self.opt.phi)
                # TODO: deg2grad!!
            else:
                cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
                    axis=self.opt.axis, angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
                # TODO: deg2grad!!

        rendered_data = []
        rendered_data_depth = []
        rendered_data_cond = []
        scenes = []
        inpath = self.opt.vis_images + '/'
        inpath_xyz = self.opt.vis_xyz + '/'
        z_min = self.scene['camera']['focal_length']
        z_max = z_min + 3

        # TODO (fmannan): Move this in init. This only needs to be done once!
        # Set splats into rendering scene
        if 'sphere' in self.scene['objects']:
            del self.scene['objects']['sphere']
        if 'triangle' in self.scene['objects']:
            del self.scene['objects']['triangle']
        if 'disk' not in self.scene['objects']:
            self.scene['objects'] = {'disk': {'pos': None, 'normal': None,
                                              'material_idx': None}}
        lookat = self.opt.at if self.opt.at is not None else [0.0, 0.0, 0.0, 1.0]
        self.scene['camera']['at'] = tch_var_f(lookat)
        self.scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(self.opt.splats_img_size * self.opt.splats_img_size))
        loss = 0.0
        loss_ = 0.0
        z_loss_ = 0.0
        z_norm_loss_ = 0.0
        spatial_loss_ = 0.0
        spatial_var_loss_ = 0.0
        unit_normal_loss_ = 0.0
        normal_away_from_cam_loss_ = 0.0
        image_depth_consistency_loss_ = 0.0
        for idx in range(batch_size):
            # Get splats positions and normals
            eps = 1e-3
            if self.opt.rescaled:
                z = F.relu(-batch[idx][:, 0]) + z_min
                z = ((z - z.min()) / (z.max() - z.min() + eps) *
                     (z_max - z_min) + z_min)
                pos = -z
            else:
                z = F.relu(-batch[idx][:, 0]) + z_min
                pos = -F.relu(-batch[idx][:, 0]) - z_min
            normals = batch[idx][:, 1:]

            self.scene['objects']['disk']['pos'] = pos

            # Normal estimation network and est_normals don't go together
            self.scene['objects']['disk']['normal'] = normals if self.opt.est_normals is False else None

            # Set camera position
            if batch_cond is None:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[idx])
                else:
                    self.scene['camera']['eye'] = tch_var_f(cam_pos[0])
            else:
                if not self.opt.same_view:
                    self.scene['camera']['eye'] = batch_cond[idx]
                else:
                    self.scene['camera']['eye'] = batch_cond[0]

            self.scene['lights']['pos'][0,:3]=tch_var_f(light_pos[idx])
            #self.scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            # res = render_splats_NDC(self.scene)
            res = render_splats_along_ray(self.scene,
                                          samples=self.opt.pixel_samples,
                                          normal_estimation_method='plane')

            world_tform = cam_to_world(res['pos'].view((-1, 3)),
                                       res['normal'].view((-1, 3)),
                                       self.scene['camera'])

            # Get rendered output
            res_pos = res['pos'].contiguous()
            res_pos_2D = res_pos.view(res['image'].shape)
            # The z_loss needs to be applied after supersampling
            # TODO: Enable this (currently final loss becomes NaN!!)
            # loss += torch.mean(
            #    (10 * F.relu(z_min - torch.abs(res_pos[..., 2]))) ** 2 +
            #    (10 * F.relu(torch.abs(res_pos[..., 2]) - z_max)) ** 2)


            res_normal = res['normal']
            # depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
            # grad_img = grad_spatial2d(torch.mean(res['image'], dim=-1)[..., np.newaxis])
            # grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
            image_depth_consistency_loss = depth_rgb_gradient_consistency(
                res['image'], res['depth'])
            unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)  # TODO: MN
            normal_away_from_cam_loss = away_from_camera_penalty(
                res_pos, res_normal)
            z_pos = res_pos[..., 2]
            z_loss = torch.mean((2 * F.relu(z_min - torch.abs(z_pos))) ** 2 +
                                (2 * F.relu(torch.abs(z_pos) - z_max)) ** 2)
            z_norm_loss = normal_consistency_cost(
                res_pos, res['normal'], norm=1)
            spatial_loss = spatial_3x3(res_pos_2D)
            spatial_var = torch.mean(res_pos[..., 0].var() +
                                     res_pos[..., 1].var() +
                                     res_pos[..., 2].var())
            spatial_var_loss = (1 / (spatial_var + 1e-4))

            loss = (self.opt.zloss * z_loss +
                    self.opt.unit_normalloss*unit_normal_loss +
                    self.opt.normal_consistency_loss_weight * z_norm_loss +
                    self.opt.spatial_var_loss_weight * spatial_var_loss +
                    self.opt.grad_img_depth_loss*image_depth_consistency_loss +
                    self.opt.spatial_loss_weight * spatial_loss)
            pos_out_ = get_data(res['pos'])
            loss_ += get_data(loss)
            z_loss_ += get_data(z_loss)
            z_norm_loss_ += get_data(z_norm_loss)
            spatial_loss_ += get_data(spatial_loss)
            spatial_var_loss_ += get_data(spatial_var_loss)
            unit_normal_loss_ += get_data(unit_normal_loss)
            normal_away_from_cam_loss_ += get_data(normal_away_from_cam_loss)
            image_depth_consistency_loss_ += get_data(
                image_depth_consistency_loss)
            normals_ = get_data(res_normal)

            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                H, W = im.shape[1:]
                target_normal_ = get_data(res['normal']).reshape((H, W, 3))
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                target_worldnormal_ = get_data(world_tform['normal']).reshape(
                    (H, W, 3))
                target_worldnormalmap_img_ = get_normalmap_image(
                    target_worldnormal_)
            if self.iterationa_no % (self.opt.save_image_interval*5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'depthmap_{:05d}.png'.format(idx)),
                       get_data(res['depth']))
                imsave((inpath + str(self.iterationa_no) +
                        'world_normalmap_{:05d}.png'.format(idx)),
                       target_worldnormalmap_img_)
            if self.iterationa_no % 1000 == 0:
                im2 = get_data(res['image'])
                depth2 = get_data(res['depth'])
                pos = get_data(res['pos'])

                out_file2 = ("pos"+".npy")
                np.save(inpath_xyz+out_file2, pos)

                out_file2 = ("im"+".npy")
                np.save(inpath_xyz+out_file2, im2)

                out_file2 = ("depth"+".npy")
                np.save(inpath_xyz+out_file2, depth2)

                # Save xyz file
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_{:05d}.xyz'.format(idx)),
                         pos=get_data(res['pos']),
                         normal=get_data(res['normal']))

                # Save xyz file in world coordinates
                save_xyz((inpath_xyz + str(self.iterationa_no) +
                          'withnormal_world_{:05d}.xyz'.format(idx)),
                         pos=get_data(world_tform['pos']),
                         normal=get_data(world_tform['normal']))
            if self.opt.gz_gi_loss is not None and self.opt.gz_gi_loss > 0:
                gradZ = grad_spatial2d(res_pos_2D[:, :, 2][:, :, np.newaxis])
                gradImg = grad_spatial2d(torch.mean(im,
                                                    dim=0)[:, :, np.newaxis])
                for (gZ, gI) in zip(gradZ, gradImg):
                    loss += (self.opt.gz_gi_loss * torch.mean(torch.abs(
                                torch.abs(gZ) - torch.abs(gI))))
            # Store normalized depth into the data
            rendered_data.append(im)
            rendered_data_depth.append(im_d)
            rendered_data_cond.append(self.scene['camera']['eye'])
            scenes.append(self.scene)

        rendered_data = torch.stack(rendered_data)
        rendered_data_depth = torch.stack(rendered_data_depth)


        return rendered_data, rendered_data_depth, loss/self.opt.batchSize
Esempio n. 10
0
def projection_renderer_differentiable_fast(surfels,
                                            rgb,
                                            camera,
                                            rotated_image=None,
                                            blur_size=0.15,
                                            use_depth=True,
                                            use_center_dist=True,
                                            compute_new_depth=False,
                                            blur_rotated_image=True,
                                            detach_mask=False,
                                            detach_mask2=False,
                                            detach_depth_merge=False):
    """Project surfels given in world coordinate to the camera's projection plane
       in a way that is differentiable w.r.t depth. This is achieved by interpolating
       the surfel values using bilinear interpolation then blurring the output image using a Gaussian filter.

    Args:
        surfels: [batch_size, num_surfels, pos] - world coordinates
        rgb: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data]
        camera: [{'eye': [num_batches,...], 'lookat': [num_batches,...], 'up': [num_batches,...],
                    'viewport': [0, 0, W, H], 'fovy': <radians>}]
        rotated_image: [batch_size, num_surfels, D-channel data] or [batch_size, H, W, D-channel data]
                        Image to mix in with the result of the rotation.
        blur_size: (between 0 and 1). Determines the size of the gaussian kernel as a percentage of the width of the input image
                   The standard deviation of the Gaussian kernel is automatically calculated from this value
        use_depth: Whether to weight the surfels landing on the same output pixel by their depth relative to the camera
        use_center_dist: Whether to weight the surfels landing on the same output pixel by their distance to the nearest pixel center location
        compute_new_depth: Whether to compute and output the depth as seen by the new camera
        blur_rotated_image: Whether to blur the 'rotated_image' passed as argument before merging it with the output image.
                            Set to False if the rotated image is already blurred
        detach_mask: Whether to detach the mask m in I_top + (1 - m) * I_bottom
        detach_mask2: Alternative, to `detach_mask`, Whether to detach the mask m in m * (I_top / m') + (1 - m) * I_bottom

    Returns:
        RGB image of dimensions [batch_size, H, W, 3] from projected surfels
    """
    _, px_coord = project_image_coordinates(surfels, camera)
    viewport = make_list2np(camera['viewport'])
    W = int(viewport[2] - viewport[0])
    H = int(viewport[3] - viewport[1])
    rgb_in = rgb.view(rgb.size(0), -1, rgb.size(-1))

    # First create a uniform grid through bilinear interpolation
    # Then, perform a convolution with a Gaussian kernel to blur the output image
    # Idea from this paper: https://arxiv.org/pdf/1810.09381.pdf
    # Tensorflow implementation: https://github.com/eldar/differentiable-point-clouds/blob/master/dpc/util/point_cloud.py#L60

    px_idx = torch.floor(px_coord[..., :2] - 0.5).long()

    # Difference to the nearest pixel center on the top left
    x = (px_coord[..., 0] - 0.5) - px_idx[..., 0].float()
    y = (px_coord[..., 1] - 0.5) - px_idx[..., 1].float()
    x, y = x.unsqueeze(-1), y.unsqueeze(-1)

    def flat_px(px):
        """Flatten the pixel locations and make sure everything is within bounds"""
        out = px[..., 1] * W + px[..., 0]
        max_idx = tch_var_l([W * H])
        mask = (px[..., 1] < 0) | (px[..., 0] < 0) | (px[..., 1] >=
                                                      H) | (px[..., 0] >= W)
        out = torch.where(mask, max_idx, out)
        return out

    depth = px_coord[..., 2].detach() if detach_depth_merge else px_coord[...,
                                                                          2]
    center_dist_2 = (x**2 + y**2).squeeze(
        -1)  # squared distance to the nearest pixel center
    rgb_out = scatter_weighted_blended_oit(rgb_in * (1 - x) * (1 - y),
                                           depth,
                                           center_dist_2,
                                           flat_px(px_idx + tch_var_l([0, 0])),
                                           use_depth=use_depth,
                                           use_center_dist=use_center_dist)
    rgb_out += scatter_weighted_blended_oit(rgb_in * (1 - x) * y,
                                            depth,
                                            center_dist_2,
                                            flat_px(px_idx +
                                                    tch_var_l([0, 1])),
                                            use_depth=use_depth,
                                            use_center_dist=use_center_dist)
    rgb_out += scatter_weighted_blended_oit(rgb_in * x * (1 - y),
                                            depth,
                                            center_dist_2,
                                            flat_px(px_idx +
                                                    tch_var_l([1, 0])),
                                            use_depth=use_depth,
                                            use_center_dist=use_center_dist)
    rgb_out += scatter_weighted_blended_oit(rgb_in * x * y,
                                            depth,
                                            center_dist_2,
                                            flat_px(px_idx +
                                                    tch_var_l([1, 1])),
                                            use_depth=use_depth,
                                            use_center_dist=use_center_dist)

    soft_mask = scatter_weighted_blended_oit(
        (1 - x) * (1 - y),
        depth,
        center_dist_2,
        flat_px(px_idx + tch_var_l([0, 0])),
        use_depth=use_depth,
        use_center_dist=use_center_dist)
    soft_mask += scatter_weighted_blended_oit(
        (1 - x) * y,
        depth,
        center_dist_2,
        flat_px(px_idx + tch_var_l([0, 1])),
        use_depth=use_depth,
        use_center_dist=use_center_dist)
    soft_mask += scatter_weighted_blended_oit(x * (1 - y),
                                              depth,
                                              center_dist_2,
                                              flat_px(px_idx +
                                                      tch_var_l([1, 0])),
                                              use_depth=use_depth,
                                              use_center_dist=use_center_dist)
    soft_mask += scatter_weighted_blended_oit(x * y,
                                              depth,
                                              center_dist_2,
                                              flat_px(px_idx +
                                                      tch_var_l([1, 1])),
                                              use_depth=use_depth,
                                              use_center_dist=use_center_dist)

    if compute_new_depth:
        depth_in = depth.unsqueeze(-1)
        depth_out = scatter_weighted_blended_oit(
            depth_in * (1 - x) * (1 - y),
            depth,
            center_dist_2,
            flat_px(px_idx + tch_var_l([0, 0])),
            use_depth=use_depth,
            use_center_dist=use_center_dist)
        depth_out += scatter_weighted_blended_oit(
            depth_in * (1 - x) * y,
            depth,
            center_dist_2,
            flat_px(px_idx + tch_var_l([0, 1])),
            use_depth=use_depth,
            use_center_dist=use_center_dist)
        depth_out += scatter_weighted_blended_oit(
            depth_in * x * (1 - y),
            depth,
            center_dist_2,
            flat_px(px_idx + tch_var_l([1, 0])),
            use_depth=use_depth,
            use_center_dist=use_center_dist)
        depth_out += scatter_weighted_blended_oit(
            depth_in * x * y,
            depth,
            center_dist_2,
            flat_px(px_idx + tch_var_l([1, 1])),
            use_depth=use_depth,
            use_center_dist=use_center_dist)
        depth_out = depth_out.view(*rgb.size()[:-1], 1)

    rgb_out = rgb_out.view(*rgb.size())
    soft_mask = soft_mask.view(*rgb.size()[:-1], 1)

    # Blur the rgb and mask images
    rgb_out = blur(rgb_out.permute(0, 3, 1, 2), blur_size).permute(0, 2, 3, 1)
    soft_mask = blur(soft_mask.permute(0, 3, 1, 2),
                     blur_size).permute(0, 2, 3, 1)

    # There seems to be a bug in PyTorch where if a single division by 0 occurs in a tensor, the whole thing becomes NaN?
    # Might be related to this issue: https://github.com/pytorch/pytorch/issues/4132
    # Because of this behavior, one can't simply do `out / out_mask` in `torch.where`
    soft_mask_nonzero = torch.where(soft_mask > 0, soft_mask,
                                    torch.ones_like(soft_mask)) + 1e-20

    # If an additional image is passed in, merge it using the soft mask:
    rgb_out_normalized = torch.where(soft_mask > 0,
                                     rgb_out / soft_mask_nonzero, rgb_out)
    if rotated_image is not None:
        if blur_rotated_image:
            rotated_image = blur(rotated_image.permute(0, 3, 1, 2),
                                 blur_size).permute(0, 2, 3, 1)
        if detach_mask:
            out = torch.where(
                soft_mask > 1, rgb_out / soft_mask_nonzero.detach(),
                rgb_out + rotated_image * (1 - soft_mask.detach()))
        elif detach_mask2:
            soft_mask_detached = soft_mask.detach()
            out = soft_mask_detached * rgb_out_normalized + (
                1 - soft_mask_detached) * rotated_image
        else:
            out = torch.where(soft_mask > 1, rgb_out / soft_mask_nonzero,
                              rgb_out + rotated_image * (1 - soft_mask))
    else:
        out = rgb_out_normalized

    # Other things to output:
    proj_out = {'mask': soft_mask, 'image1': rgb_out_normalized}

    if compute_new_depth:
        depth_out = torch.where(soft_mask > 0, depth_out / soft_mask_nonzero,
                                depth_out)
        proj_out['depth'] = depth_out

    return out, proj_out
Esempio n. 11
0
def optimize_splats_along_ray_shadow_with_normalest_test(
        out_dir,
        width,
        height,
        max_iter=100,
        lr=1e-3,
        scale=10,
        shadow=True,
        vis_only=False,
        samples=1,
        est_normals=False,
        b_generate_normals=False,
        print_interval=10,
        imsave_interval=10,
        xyz_save_interval=100):
    """A demo function to check if the differentiable renderer can optimize splats rendered along ray.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX_0

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX_0
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f(
        [2, 1, 2, 1])  # tch_var_f([1, 1, 1, 1]) # tch_var_f([2, 2, 2, 1]) #
    scene['camera']['at'] = tch_var_f(
        [0, 0.8, 0, 1])  # tch_var_f([0, 1, 0, 1]) # tch_var_f([2, 2, 0, 1])  #
    scene['lights']['attenuation'] = tch_var_f([
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
        [0., 0.0, 0.01],
    ])
    scene['materials']['coeffs'] = tch_var_f([
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.5, 0.2, 8.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
    ])

    target_res = render(scene, tiled=True, shadow=shadow)
    target_im = normalize_maxmin(target_res['image'])
    target_im.require_grad = False
    target_im_ = get_data(target_im)
    target_pos_ = get_data(target_res['pos'])
    target_normal_ = get_data(target_res['normal'])
    target_normalmap_img_ = get_normalmap_image(target_normal_)
    target_depth_ = get_data(target_res['depth'])
    print('[z_min, z_max] = [%f, %f]' %
          (np.min(target_pos_[..., 2]), np.max(target_pos_[..., 2])))
    print('[depth_min, depth_max] = [%f, %f]' %
          (np.min(target_depth_), np.max(target_depth_)))

    # world -> cam -> render_splats_along_ray
    cc_tform = world_to_cam(target_res['pos'].view(
        (-1, 3)), target_res['normal'].view((-1, 3)), scene['camera'])
    wc_cc_tform = cam_to_world(cc_tform['pos'], cc_tform['normal'],
                               scene['camera'])

    # Check normal estimation in camera space
    pos_cc = cc_tform['pos'][:, :3].contiguous().view(target_im.shape)
    normal_cc = cc_tform['normal'][:, :3].contiguous().view(target_im.shape)
    plane_fit_est = estimate_surface_normals_plane_fit(pos_cc, None)
    normal_cc_normalmap = get_normalmap_image(get_data(normal_cc))
    plane_fit_est_normalmap = get_normalmap_image(get_data(plane_fit_est))

    pos_diff = torch.abs(wc_cc_tform['pos'][:, :3] -
                         target_res['pos'].view((-1, 3)))
    mean_pos_diff = torch.mean(pos_diff)
    normal_diff = torch.abs(wc_cc_tform['normal'][:, :3] -
                            target_res['normal'].view(-1, 3))
    mean_normal_diff = torch.mean(normal_diff)
    print('mean_pos_diff', mean_pos_diff, 'mean_normal_diff', mean_normal_diff)

    wc_cc_normal = wc_cc_tform['normal'].view(target_im_.shape)
    wc_cc_normal_img = get_normalmap_image(get_data(wc_cc_normal))

    material_idx = tch_var_l(np.ones(cc_tform['pos'].shape[0]) * 3)
    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    light_vis = tch_var_f(
        np.ones(
            (input_scene['lights']['pos'].shape[0], cc_tform['pos'].shape[0])))
    input_scene['objects'] = {
        'disk': {
            'pos': cc_tform['pos'],
            'normal': cc_tform['normal'],
            'material_idx': material_idx,
            'light_vis': light_vis,
        }
    }
    target_res_noshadow = render(scene, tiled=True, shadow=False)
    res = render_splats_along_ray(input_scene)
    test_img_ = get_data(normalize_maxmin(res['image']))
    test_depth_ = get_data(res['depth'])
    test_normal_ = get_data(res['normal']).reshape(test_img_.shape)
    test_normalmap_ = get_normalmap_image(test_normal_)
    im_diff = np.abs(test_img_ -
                     get_data(normalize_maxmin(target_res_noshadow['image'])))
    print('mean image diff: {}'.format(np.mean(im_diff)))
    #### PLOT
    plt.ion()
    plt.figure()
    plt.imshow(test_img_, interpolation='none')
    plt.title('Test Image')
    plt.savefig(out_dir + '/test_img.png')
    plt.figure()
    plt.imshow(test_depth_, interpolation='none')
    plt.title('Test Depth')
    plt.savefig(out_dir + '/test_depth.png')

    plt.figure()
    plt.imshow(test_normalmap_, interpolation='none')
    plt.title('Test Normals')
    plt.savefig(out_dir + '/test_normal.png')

    ####
    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_, interpolation='none')
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    plt.figure()
    plt.imshow(target_normalmap_img_, interpolation='none')
    plt.title('Normals')
    plt.savefig(out_dir + '/normal.png')

    plt.figure()
    plt.imshow(wc_cc_normal_img, interpolation='none')
    plt.title('WC_CC Normals')
    plt.savefig(out_dir + '/wc_cc_normal.png')

    plt.figure()
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.savefig(out_dir + '/normal_cc.png')

    plt.figure()
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/est_normal_cc.png')

    plt.figure()
    plt.subplot(121)
    plt.imshow(normal_cc_normalmap, interpolation='none')
    plt.title('Normal CC GT')
    plt.subplot(122)
    plt.imshow(plane_fit_est_normalmap, interpolation='none')
    plt.title('Plane fit CC')
    plt.savefig(out_dir + '/normal_and_estnormal_cc_comparison.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']
    input_scene['camera']['viewport'] = [
        0, 0, int(width / samples),
        int(height / samples)
    ]

    num_splats = int(width * height / (samples * samples))
    #x, y = np.meshgrid(np.linspace(-1, 1, int(width / samples)), np.linspace(-1, 1, int(height / samples)))
    z_min = scene['camera']['focal_length']
    z_max = 3

    z = -tch_var_f(
        np.ones(num_splats) * (z_min + z_max) / 2
    )  # -torch.clamp(tch_var_f(2 * np.random.rand(num_splats)), z_min, z_max)
    z.requires_grad = True

    normal_angles = tch_var_f(np.random.rand(num_splats, 2))
    normal_angles.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    light_vis = tch_var_f(
        np.ones((input_scene['lights']['pos'].shape[0], num_splats)))
    light_vis.requires_grad = True

    if vis_only:
        assert shadow is True
        opt_vars = [light_vis]
        z = cc_tform['pos'][:, 2]
        # FIXME: sph2cart
        #normals = cc_tform['normal']
    else:
        opt_vars = [z, normal_angles]
        if shadow:
            opt_vars += [light_vis]

    optimizer = optim.Adam(opt_vars, lr=lr)
    lr_scheduler = StepLR(optimizer, step_size=10000, gamma=0.8)

    h0 = plt.figure()
    h1 = plt.figure()
    h2 = plt.figure()
    h3 = plt.figure()
    h4 = plt.figure()

    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.0025, hspace=0.02)

    # Two options for z_norm_consistency
    # 1. start after N iterations
    # 2. start at the beginning and decay
    # 3. start after N iterations and decay to 0
    no_decay = lambda x: x
    exp_decay = lambda x, scale: torch.exp(-x / scale)
    linear_decay = lambda x, scale: scale / (x + 1e-6)

    spatial_var_loss_weight = 10.0  #0.0
    normal_away_from_cam_loss_weight = 0.0
    grad_img_depth_loss_weight = 1.0
    spatial_loss_weight = 2

    z_norm_weight_init = 1  # 1e-5
    z_norm_activate_iter = 0  # 1000
    decay_fn = lambda x: linear_decay(x, 100)
    loss_per_iter = []
    if b_generate_normals:
        est_normals = False
        normal_est_network = NEstNetAffine(kernel_size=3, sph=False)
        print(normal_est_network)
        normal_est_network.cuda()
    for iter in range(max_iter):
        lr_scheduler.step()
        zz = -F.relu(-z) - z_min  # torch.clamp(z, -z_max, -z_min)
        if b_generate_normals:
            normals = generate_normals(zz, scene['camera'], normal_est_network)
            #if iter > 100 and iter % 10 == 0:
            #    print(normals)
        elif not est_normals:
            phi = F.sigmoid(normal_angles[:, 0]) * 2 * np.pi
            theta = F.sigmoid(
                normal_angles[:, 1]
            ) * np.pi / 2  # F.tanh(normal_angles[:, 1]) * np.pi / 2
            normals = sph2cart_unit(torch.stack((phi, theta), dim=1))

        pos = zz  # torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), zz), dim=1)

        input_scene['objects'] = {
            'disk': {
                'pos': pos,
                'normal': normalize(normals) if not est_normals else None,
                'material_idx': material_idx,
                'light_vis': torch.sigmoid(light_vis),
            }
        }
        res = render_splats_along_ray(input_scene,
                                      samples=samples,
                                      normal_estimation_method='plane')
        res_pos = res['pos']
        res_normal = res['normal']
        spatial_loss = spatial_3x3(res_pos)
        depth_grad_loss = spatial_3x3(res['depth'][..., np.newaxis])
        grad_img = grad_spatial2d(
            torch.mean(res['image'], dim=-1)[..., np.newaxis])
        grad_depth_img = grad_spatial2d(res['depth'][..., np.newaxis])
        image_depth_consistency_loss = depth_rgb_gradient_consistency(
            res['image'], res['depth'])
        unit_normal_loss = unit_norm2_L2loss(res_normal, 10.0)
        normal_away_from_cam_loss = away_from_camera_penalty(
            res_pos, res_normal)
        z_pos = res_pos[..., 2]
        z_loss = torch.mean((10 * F.relu(z_min - torch.abs(z_pos)))**2 +
                            (10 * F.relu(torch.abs(z_pos) - z_max))**2)
        z_norm_loss = normal_consistency_cost(res_pos, res_normal, norm=1)
        spatial_var = torch.mean(res_pos[..., 0].var() +
                                 res_pos[..., 1].var() + res_pos[..., 2].var())
        spatial_var_loss = (1 / (spatial_var + 1e-4))
        im_out = normalize_maxmin(res['image'])
        res_depth_ = get_data(res['depth'])

        optimizer.zero_grad()
        z_norm_weight = z_norm_weight_init * float(
            iter > z_norm_activate_iter) * decay_fn(iter -
                                                    z_norm_activate_iter)
        loss = criterion(scale * im_out, scale * target_im) + z_loss + unit_normal_loss + \
            z_norm_weight * z_norm_loss + \
            spatial_var_loss_weight * spatial_var_loss + \
            grad_img_depth_loss_weight * image_depth_consistency_loss
        #normal_away_from_cam_loss_weight * normal_away_from_cam_loss + \
        #spatial_loss_weight * spatial_loss

        im_out_ = get_data(im_out)
        im_out_normal_ = get_data(res['normal'])
        pos_out_ = get_data(res['pos'])

        loss_ = get_data(loss)
        z_loss_ = get_data(z_loss)
        z_norm_loss_ = get_data(z_norm_loss)
        spatial_loss_ = get_data(spatial_loss)
        spatial_var_loss_ = get_data(spatial_var_loss)
        unit_normal_loss_ = get_data(unit_normal_loss)
        normal_away_from_cam_loss_ = get_data(normal_away_from_cam_loss)
        normals_ = get_data(res_normal)
        image_depth_consistency_loss_ = get_data(image_depth_consistency_loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            z__ = pos_out_[..., 2]
            print(
                '%d. loss= %f nloss=%f z_loss=%f [%f, %f] [%f, %f], z_normal_loss: %f,'
                ' spatial_var_loss: %f, normal_away_loss: %f'
                ' nz_range: [%f, %f], spatial_loss: %f, imd_loss: %f' %
                (iter, loss_, unit_normal_loss_, z_loss_, np.min(z_),
                 np.max(z_), np.min(z__), np.max(z__), z_norm_loss_,
                 spatial_var_loss_, normal_away_from_cam_loss_,
                 normals_[..., 2].min(), normals_[..., 2].max(), spatial_loss_,
                 image_depth_consistency_loss_))

        if iter % xyz_save_interval == 0 or iter == max_iter - 1:
            save_xyz(out_dir + '/res_{:05d}.xyz'.format(iter),
                     get_data(res_pos), get_data(res_normal))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            z_ = get_data(z)
            plt.figure(h4.number)
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(121)
            #plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.title('Output')
            plt.subplot(122)
            #plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.title('Ground truth')
            # plt.subplot(223)
            # plt.plot(loss_per_iter, linewidth=2)
            # plt.xlabel('Iteration', fontsize=14)
            # plt.title('Loss', fontsize=12)
            # plt.grid(True)
            plt.savefig(out_dir + '/fig_im_gt_loss_%05d.png' % iter)

            plt.figure(h1.number, figsize=(4, 4))
            plt.clf()
            plt.suptitle('%d. loss= %f [%f, %f]' %
                         (iter, loss_, np.min(z_), np.max(z_)))
            plt.subplot(gs1[0])
            plt.axis('off')
            plt.imshow(im_out_, interpolation='none')
            plt.subplot(gs1[1])
            plt.axis('off')
            plt.imshow(get_normalmap_image(im_out_normal_),
                       interpolation='none')
            ax = plt.subplot(gs1[2])
            plt.axis('off')
            im_tmp = ax.imshow(res_depth_, interpolation='none')
            # create an axes on the right side of ax. The width of cax will be 5%
            # of ax and the padding between cax and ax will be fixed at 0.05 inch.
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            plt.subplot(gs1[3])
            plt.axis('off')
            plt.imshow(target_im_, interpolation='none')
            plt.subplot(gs1[4])
            plt.axis('off')
            plt.imshow(test_normalmap_, interpolation='none')
            ax = plt.subplot(gs1[5])
            plt.axis('off')
            im_tmp = ax.imshow(test_depth_, interpolation='none')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im_tmp, cax=cax)

            W, H = input_scene['camera']['viewport'][2:]
            light_vis_ = get_data(torch.sigmoid(light_vis))
            plt.subplot(gs1[6])
            plt.axis('off')
            plt.imshow(light_vis_[0].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 1):
                plt.subplot(gs1[7])
                plt.axis('off')
                plt.imshow(light_vis_[1].reshape((H, W)), interpolation='none')

            if (light_vis_.shape[0] > 2):
                plt.subplot(gs1[8])
                plt.axis('off')
                plt.imshow(light_vis_[2].reshape((H, W)), interpolation='none')

            plt.savefig(out_dir + '/fig_%05d.png' % iter)

            plt.figure(h2.number)
            plt.clf()
            plt.imshow(res_depth_)
            plt.colorbar()
            plt.savefig(out_dir + '/fig_depth_%05d.png' % iter)

            plt.figure(h3.number)
            plt.clf()
            plt.imshow(z_.reshape(H, W))
            plt.colorbar()
            plt.savefig(out_dir + '/fig_z_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Esempio n. 12
0
def optimize_NDC_test(out_dir,
                      width=32,
                      height=32,
                      max_iter=100,
                      lr=1e-3,
                      scale=10,
                      print_interval=10,
                      imsave_interval=10):
    """A demo function to check if the differentiable renderer can optimize splats in NDC.
    :param scene:
    :param out_dir:
    :return:
    """
    import torch
    import copy
    from diffrend.torch.params import SCENE_SPHERE_HALFBOX

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    scene = SCENE_SPHERE_HALFBOX
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(45)
    scene['camera']['focal_length'] = 1
    scene['camera']['eye'] = tch_var_f([2, 1, 2, 1])
    scene['camera']['at'] = tch_var_f([0, 0.8, 0, 1])

    target_res = render(SCENE_SPHERE_HALFBOX)
    target_im = target_res['image']
    target_im.require_grad = False
    target_im_ = get_data(target_res['image'])

    criterion = nn.L1Loss()  #nn.MSELoss()
    criterion = criterion.cuda()

    plt.ion()
    plt.figure()
    plt.imshow(target_im_)
    plt.title('Target Image')
    plt.savefig(out_dir + '/target.png')

    input_scene = copy.deepcopy(scene)
    del input_scene['objects']['sphere']
    del input_scene['objects']['triangle']

    num_splats = width * height
    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    z = tch_var_f(2 * np.random.rand(num_splats) - 1)
    z.requires_grad = True
    pos = torch.stack((tch_var_f(x.ravel()), tch_var_f(y.ravel()), z), dim=1)
    normals = tch_var_f(np.ones((num_splats, 4)) * np.array([0, 0, 1, 0]))
    normals.requires_grad = True
    material_idx = tch_var_l(np.ones(num_splats) * 3)

    input_scene['objects'] = {
        'disk': {
            'pos': pos,
            'normal': normals,
            'material_idx': material_idx
        }
    }
    optimizer = optim.Adam((z, normals), lr=lr)

    h0 = plt.figure()
    h1 = plt.figure()
    loss_per_iter = []
    for iter in range(max_iter):
        res = render_splats_NDC(input_scene)
        im_out = res['image']

        optimizer.zero_grad()
        loss = criterion(scale * im_out, scale * target_im)

        im_out_ = get_data(im_out)
        loss_ = get_data(loss)
        loss_per_iter.append(loss_)

        if iter == 0:
            plt.figure(h0.number)
            plt.imshow(im_out_)
            plt.title('Initial')

        if iter % print_interval == 0 or iter == max_iter - 1:
            print('%d. loss= %f' % (iter, loss_))

        if iter % imsave_interval == 0 or iter == max_iter - 1:
            plt.figure(h1.number)
            plt.imshow(im_out_)
            plt.title('%d. loss= %f' % (iter, loss_))
            plt.savefig(out_dir + '/fig_%05d.png' % iter)

        loss.backward()
        optimizer.step()

    plt.figure()
    plt.plot(loss_per_iter, linewidth=2)
    plt.xlabel('Iteration', fontsize=14)
    plt.title('Loss', fontsize=12)
    plt.grid(True)
    plt.savefig(out_dir + '/loss.png')

    plt.ioff()
    plt.show()
Esempio n. 13
0
def render_sphere_halfbox(out_dir,
                          cam_pos,
                          width,
                          height,
                          fovy,
                          focal_length,
                          num_views,
                          cam_dist,
                          norm_depth_image_only,
                          theta_range=None,
                          phi_range=None,
                          axis=None,
                          angle=None,
                          cam_lookat=None,
                          tile_size=None,
                          use_quartic=False,
                          b_shadow=True,
                          b_display=False):
    # python splat_render_demo.py --sphere-halfbox --fovy 30 --out_dir ./sphere_halfbox_demo --cam_dist 4 --axis .8 .5 1
    # --angle 5 --at 0 .4 0 --nv 10 --width=256 --height=256
    scene = SCENE_SPHERE_HALFBOX
    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    scene['lights']['pos'] = tch_var_f([[2., 2., 1.5, 1.0], [1., 4., 1.5, 1.0]
                                        ])  # tch_var_f([[4., 4., 3., 1.0]])
    scene['lights']['color_idx'] = tch_var_l([1, 3])
    scene['lights']['attenuation'] = tch_var_f([[1., 0., 0.], [1., 0., 0.]])

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist,
                                        num_samples=num_views,
                                        axis=axis,
                                        angle=angle,
                                        theta_range=theta_range,
                                        phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else [0.0, 0.0, 0.0, 1.0]
    scene['camera']['at'] = tch_var_f(lookat)

    b_tiled = tile_size is not None
    res = render(scene, tile_size=tile_size, tiled=b_tiled, shadow=b_shadow)
    im = np.uint8(255. * get_data(res['image']))
    depth = get_data(res['depth'])

    depth[depth >= scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) /
                        (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im)
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth)
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    for idx in range(cam_pos.shape[0]):
        scene['camera']['eye'] = tch_var_f(cam_pos[idx])
        suffix = '_{}'.format(idx)

        # main render run
        res = render(scene,
                     tiled=b_tiled,
                     shadow=b_shadow,
                     norm_depth_image_only=norm_depth_image_only,
                     use_quartic=use_quartic)

        im = np.uint8(255. * get_data(res['image']))
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img' + suffix + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth' + suffix + '.png')

        imsave(out_dir + '/img' + suffix + '.png', im)
        imsave(out_dir + '/depth' + suffix + '.png', im_depth)
Esempio n. 14
0
def render_random_camera(filename,
                         out_dir,
                         num_samples,
                         radius,
                         cam_dist,
                         num_views,
                         width,
                         height,
                         fovy,
                         focal_length,
                         norm_depth_image_only,
                         theta_range=None,
                         phi_range=None,
                         axis=None,
                         angle=None,
                         cam_pos=None,
                         cam_lookat=None,
                         use_mesh=False,
                         double_sided=False,
                         use_quartic=False,
                         b_shadow=True,
                         b_display=False,
                         tile_size=None):
    """
    Randomly generate N samples on a surface and render them. The samples include position and normal, the radius is set
    to a constant.
    """
    sampling_time = []
    rendering_time = []

    obj = load_model(filename)
    # normalize the vertices
    v = obj['v']
    axis_range = np.max(v, axis=0) - np.min(v, axis=0)
    v = (v - np.mean(v, axis=0)) / max(
        axis_range)  # Normalize to make the largest spread 1
    obj['v'] = v

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    r = np.ones(num_samples) * radius

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    if use_mesh:
        mesh = obj_to_triangle_spec(obj)
        faces = mesh['face']
        normals = mesh['normal']
        num_tri = faces.shape[0]
        # if faces.shape[-1] == 3:
        #     faces = np.concatenate((faces, np.ones((faces.shape[0], faces.shape[1], 1))), axis=-1).tolist()
        # if normals.shape[-1] == 3:
        #     normals = np.concatenate((normals, ))
        if 'disk' in scene['objects']:
            del scene['objects']['disk']
        scene['objects'].update(
            {'triangle': {
                'face': None,
                'normal': None,
                'material_idx': None
            }})
        scene['objects']['triangle']['face'] = tch_var_f(faces.tolist())
        scene['objects']['triangle']['normal'] = tch_var_f(normals.tolist())
        scene['objects']['triangle']['material_idx'] = tch_var_l(
            np.zeros(num_tri, dtype=int).tolist())
    else:
        scene['objects']['disk']['radius'] = tch_var_f(r)
        scene['objects']['disk']['material_idx'] = tch_var_l(
            np.zeros(num_samples, dtype=int).tolist())
    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    # generate camera positions on a sphere
    if cam_pos is None:
        cam_pos = uniform_sample_sphere(radius=cam_dist,
                                        num_samples=num_views,
                                        axis=axis,
                                        angle=angle,
                                        theta_range=theta_range,
                                        phi_range=phi_range)
    lookat = cam_lookat if cam_lookat is not None else np.mean(v, axis=0)
    scene['camera']['at'] = tch_var_f(lookat)

    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    for idx in range(cam_pos.shape[0]):
        if not use_mesh:
            start_time = time()
            v, vn = uniform_sample_mesh(obj, num_samples=num_samples)
            sampling_time.append(time() - start_time)

            scene['objects']['disk']['pos'] = tch_var_f(v)
            scene['objects']['disk']['normal'] = tch_var_f(vn)

        scene['camera']['eye'] = tch_var_f(cam_pos[idx])
        suffix = '_{}'.format(idx)

        # main render run
        start_time = time()
        res = render(scene,
                     tile_size=tile_size,
                     tiled=tile_size is not None,
                     shadow=b_shadow,
                     norm_depth_image_only=norm_depth_image_only,
                     double_sided=double_sided,
                     use_quartic=use_quartic)
        rendering_time.append(time() - start_time)

        im = np.uint8(255. * get_data(res['image']))
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img' + suffix + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth' + suffix + '.png')

        imsave(out_dir + '/img' + suffix + '.png', im)
        imsave(out_dir + '/depth' + suffix + '.png', im_depth)

    # Timing statistics
    if not use_mesh:
        print('Sampling time mean: {}s, std: {}s'.format(
            np.mean(sampling_time), np.std(sampling_time)))
    print('Rendering time mean: {}s, std: {}s'.format(np.mean(rendering_time),
                                                      np.std(rendering_time)))
Esempio n. 15
0
def render_sphere(out_dir,
                  cam_pos,
                  radius,
                  width,
                  height,
                  fovy,
                  focal_length,
                  num_views,
                  std_z=0.01,
                  std_normal=0.01,
                  b_display=False):
    """
    Randomly generate N samples on a surface and render them. The samples include position and normal, the radius is set
    to a constant.
    """
    print('render sphere')
    sampling_time = []
    rendering_time = []

    num_samples = width * height
    r = np.ones(num_samples) * radius

    scene = copy.deepcopy(SCENE_BASIC)

    scene['camera']['viewport'] = [0, 0, width, height]
    scene['camera']['fovy'] = np.deg2rad(fovy)
    scene['camera']['focal_length'] = focal_length
    scene['objects']['disk']['radius'] = tch_var_f(r)
    scene['objects']['disk']['material_idx'] = tch_var_l(
        np.zeros(num_samples, dtype=int).tolist())
    scene['materials']['albedo'] = tch_var_f([[0.6, 0.6, 0.6]])
    scene['tonemap']['gamma'] = tch_var_f([1.0])  # Linear output

    x, y = np.meshgrid(np.linspace(-1, 1, width), np.linspace(-1, 1, height))
    #z = np.sqrt(1 - np.min(np.stack((x ** 2 + y ** 2, np.ones_like(x)), axis=-1), axis=-1))
    unit_disk_mask = (x**2 + y**2) <= 1
    z = np.sqrt(1 - unit_disk_mask * (x**2 + y**2))

    # Make a hemi-sphere bulging out of the xy-plane scene
    z[~unit_disk_mask] = 0
    pos = np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1)

    # Normals outside the sphere should be [0, 0, 1]
    x[~unit_disk_mask] = 0
    y[~unit_disk_mask] = 0
    z[~unit_disk_mask] = 1

    normals = np_normalize(np.stack((x.ravel(), y.ravel(), z.ravel()), axis=1))

    if b_display:
        plt.ion()
        plt.figure()
        plt.imshow(pos[..., 2].reshape((height, width)))

        plt.figure()
        plt.imshow(normals[..., 2].reshape((height, width)))

    scene['objects']['disk']['pos'] = tch_var_f(pos)
    scene['objects']['disk']['normal'] = tch_var_f(normals)

    scene['camera']['eye'] = tch_var_f(cam_pos)

    # main render run
    start_time = time()
    res = render(scene)
    rendering_time.append(time() - start_time)

    im = get_data(res['image'])
    depth = get_data(res['depth'])

    depth[depth >= scene['camera']['far']] = depth.min()
    im_depth = np.uint8(255. * (depth - depth.min()) /
                        (depth.max() - depth.min()))

    if b_display:
        plt.figure()
        plt.imshow(im)
        plt.title('Image')
        plt.savefig(out_dir + '/fig_img_orig.png')

        plt.figure()
        plt.imshow(im_depth)
        plt.title('Depth Image')
        plt.savefig(out_dir + '/fig_depth_orig.png')

    imsave(out_dir + '/img_orig.png', im)
    imsave(out_dir + '/depth_orig.png', im_depth)

    # generate noisy data
    if b_display:
        h1 = plt.figure()
        h2 = plt.figure()
    noisy_pos = pos
    for view_idx in range(num_views):
        noisy_pos[..., 2] = pos[..., 2] + std_z * np.random.randn(num_samples)
        noisy_normals = np_normalize(normals + std_normal *
                                     np.random.randn(num_samples, 3))

        scene['objects']['disk']['pos'] = tch_var_f(noisy_pos)
        scene['objects']['disk']['normal'] = tch_var_f(noisy_normals)

        scene['camera']['eye'] = tch_var_f(cam_pos)

        # main render run
        start_time = time()
        res = render(scene)
        rendering_time.append(time() - start_time)

        im = get_data(res['image'])
        depth = get_data(res['depth'])

        depth[depth >= scene['camera']['far']] = depth.min()
        im_depth = np.uint8(255. * (depth - depth.min()) /
                            (depth.max() - depth.min()))

        suffix_str = '{:05d}'.format(view_idx)

        if b_display:
            plt.figure(h1.number)
            plt.imshow(im)
            plt.title('Image')
            plt.savefig(out_dir + '/fig_img_' + suffix_str + '.png')

            plt.figure(h2.number)
            plt.imshow(im_depth)
            plt.title('Depth Image')
            plt.savefig(out_dir + '/fig_depth_' + suffix_str + '.png')

        imsave(out_dir + '/img_' + suffix_str + '.png', im)
        imsave(out_dir + '/depth_' + suffix_str + '.png', im_depth)

    # hold matplotlib figure
    plt.ioff()
    plt.show()
Esempio n. 16
0
    def get_real_samples(self):
        """Get a real sample."""
        # Define the camera poses
        if not self.opt.same_view:
            if self.opt.full_sphere_sampling:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=np.deg2rad(self.opt.angle),
                    theta_range=self.opt.theta,
                    phi_range=self.opt.phi)
            else:
                self.cam_pos = uniform_sample_sphere(
                    radius=self.opt.cam_dist,
                    num_samples=self.opt.batchSize,
                    axis=self.opt.axis,
                    angle=self.opt.angle,
                    theta_range=np.deg2rad(self.opt.theta),
                    phi_range=np.deg2rad(self.opt.phi))
        if self.opt.full_sphere_sampling_light:
            self.light_pos1 = uniform_sample_sphere(
                radius=self.opt.cam_dist,
                num_samples=self.opt.batchSize,
                axis=self.opt.axis,
                angle=np.deg2rad(44),
                theta_range=self.opt.theta,
                phi_range=self.opt.phi)
            # self.light_pos2 = uniform_sample_sphere(radius=self.opt.cam_dist, num_samples=self.opt.batchSize,
            #                                      axis=self.opt.axis, angle=np.deg2rad(40),
            #                                      theta_range=self.opt.theta, phi_range=self.opt.phi)
        else:
            print("inbox")
            light_eps = 0.15
            self.light_pos1 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps
            self.light_pos2 = np.random.rand(self.opt.batchSize,
                                             3) * self.opt.cam_dist + light_eps

            # TODO: deg2rad in all the angles????

        # Create a splats rendering scene
        large_scene = create_scene(self.opt.width, self.opt.height,
                                   self.opt.fovy, self.opt.focal_length,
                                   self.opt.n_splats)
        lookat = self.opt.at if self.opt.at is not None else [
            0.0, 0.0, 0.0, 1.0
        ]
        large_scene['camera']['at'] = tch_var_f(lookat)

        # Render scenes
        data, data_depth, data_normal, data_cond = [], [], [], []
        inpath = self.opt.vis_images + '/'
        inpath2 = self.opt.vis_input + '/'
        for idx in range(self.opt.batchSize):
            # Save the splats into the rendering scene
            if self.opt.use_mesh:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'disk' in large_scene['objects']:
                    del large_scene['objects']['disk']
                if 'triangle' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'triangle': {
                            'face': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                samples = self.get_samples()

                large_scene['objects']['triangle']['material_idx'] = tch_var_l(
                    np.zeros(samples['mesh']['face'][0].shape[0],
                             dtype=int).tolist())
                large_scene['objects']['triangle']['face'] = Variable(
                    samples['mesh']['face'][0].cuda(), requires_grad=False)
                large_scene['objects']['triangle']['normal'] = Variable(
                    samples['mesh']['normal'][0].cuda(), requires_grad=False)
            else:
                if 'sphere' in large_scene['objects']:
                    del large_scene['objects']['sphere']
                if 'triangle' in large_scene['objects']:
                    del large_scene['objects']['triangle']
                if 'disk' not in large_scene['objects']:
                    large_scene['objects'] = {
                        'disk': {
                            'pos': None,
                            'normal': None,
                            'material_idx': None
                        }
                    }
                large_scene['objects']['disk']['radius'] = tch_var_f(
                    np.ones(self.opt.n_splats) * self.opt.splats_radius)
                large_scene['objects']['disk']['material_idx'] = tch_var_l(
                    np.zeros(self.opt.n_splats, dtype=int).tolist())
                large_scene['objects']['disk']['pos'] = Variable(
                    samples['splats']['pos'][idx].cuda(), requires_grad=False)
                large_scene['objects']['disk']['normal'] = Variable(
                    samples['splats']['normal'][idx].cuda(),
                    requires_grad=False)

            # Set camera position
            if not self.opt.same_view:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[idx])
            else:
                large_scene['camera']['eye'] = tch_var_f(self.cam_pos[0])

            large_scene['lights']['pos'][0, :3] = tch_var_f(
                self.light_pos1[idx])
            #large_scene['lights']['pos'][1,:3]=tch_var_f(self.light_pos2[idx])

            # Render scene
            res = render(large_scene,
                         norm_depth_image_only=self.opt.norm_depth_image_only,
                         double_sided=True,
                         use_quartic=self.opt.use_quartic)

            # Get rendered output
            if self.opt.render_img_nc == 1:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
            else:
                depth = res['depth']
                im_d = depth.unsqueeze(0)
                im = res['image'].permute(2, 0, 1)
                im_ = get_data(res['image'])
                #im_img_ = get_normalmap_image(im_)
                target_normal_ = get_data(res['normal'])
                target_normalmap_img_ = get_normalmap_image(target_normal_)
                im_n = tch_var_f(target_normalmap_img_).view(
                    im.shape[1], im.shape[2], 3).permute(2, 0, 1)

            # Add depth image to the output structure
            file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.txt'.format(idx)
            text_file = open(file_name, "w")
            text_file.write('%s\n' % (str(large_scene['camera']['eye'].data)))
            text_file.close()
            out_file_name = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_{:05d}.npy'.format(idx)
            np.save(out_file_name, self.cam_pos[idx])
            out_file_name2 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_light{:05d}.npy'.format(idx)
            np.save(out_file_name2, self.light_pos1[idx])
            out_file_name3 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_im{:05d}.npy'.format(idx)
            np.save(out_file_name3, get_data(res['image']))
            out_file_name4 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_depth{:05d}.npy'.format(idx)
            np.save(out_file_name4, get_data(res['depth']))
            out_file_name5 = inpath2 + str(self.iterationa_no) + "_" + str(
                self.critic_iter) + 'input_normal{:05d}.npy'.format(idx)
            np.save(out_file_name5, get_data(res['normal']))

            if self.iterationa_no % (self.opt.save_image_interval * 5) == 0:
                imsave((inpath + str(self.iterationa_no) +
                        'real_normalmap_{:05d}.png'.format(idx)),
                       target_normalmap_img_)
                imsave((inpath + str(self.iterationa_no) +
                        'real_depth_{:05d}.png'.format(idx)), get_data(depth))
                # imsave(inpath + str(self.iterationa_no) + 'real_depthmap_{:05d}.png'.format(idx), im_d)
                # imsave(inpath + str(self.iterationa_no) + 'world_normalmap_{:05d}.png'.format(idx), target_worldnormalmap_img_)
            data.append(im)
            data_depth.append(im_d)
            data_normal.append(im_n)
            data_cond.append(large_scene['camera']['eye'])
        # Stack real samples
        real_samples = torch.stack(data)
        real_samples_depth = torch.stack(data_depth)
        real_samples_normal = torch.stack(data_normal)
        real_samples_cond = torch.stack(data_cond)
        self.batch_size = real_samples.size(0)
        if not self.opt.no_cuda:
            real_samples = real_samples.cuda()
            real_samples_depth = real_samples_depth.cuda()
            real_samples_normal = real_samples_normal.cuda()
            real_samples_cond = real_samples_cond.cuda()

        # Set input/output variables

        self.input.resize_as_(real_samples.data).copy_(real_samples.data)
        self.input_depth.resize_as_(real_samples_depth.data).copy_(
            real_samples_depth.data)
        self.input_normal.resize_as_(real_samples_normal.data).copy_(
            real_samples_normal.data)
        self.input_cond.resize_as_(real_samples_cond.data).copy_(
            real_samples_cond.data)
        self.label.resize_(self.batch_size).fill_(self.real_label)
        # TODO: Remove Variables
        self.inputv = Variable(self.input)
        self.inputv_depth = Variable(self.input_depth)
        self.inputv_normal = Variable(self.input_normal)
        self.inputv_cond = Variable(self.input_cond)
        self.labelv = Variable(self.label)
Esempio n. 17
0
     'near': 0.1,
     'far': 1000.0,
 },
 'lights': {
     'pos':
     tch_var_f([
         [10., 0., 0., 1.0],
         [-10, 0., 0., 1.0],
         [0, 10., 0., 1.0],
         [0, -10., 0., 1.0],
         [0, 0., 10., 1.0],
         [0, 0., -10., 1.0],
         [20, 20, 20, 1.0],
     ]),
     'color_idx':
     tch_var_l([1, 3, 4, 5, 6, 7, 1]),
     # Light attenuation factors have the form (kc, kl, kq) and eq: 1/(kc + kl * d + kq * d^2)
     'attenuation':
     tch_var_f([
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
     ]),
     'ambient':
     tch_var_f([0.01, 0.01, 0.01]),
 },
 'colors':
Esempio n. 18
0
def make_torch_tensor(var):
    var_elem = var
    if type(var) is list:
        var_elem = var[0]
    return tch_var_l(var) if type(var_elem) is int else tch_var_f(var)