Esempio n. 1
0
    def calc_sparse_depth_loss(self,
                               sparse_depth,
                               p_world_hat,
                               mask_pred,
                               reduction_method,
                               loss={},
                               eval_mode=False):
        ''' Calculates the sparse depth loss.

        Args:
            sparse_depth (dict): dictionary for sparse depth loss calculation
            p_world_hat (tensor): predicted world points
            mask_pred (tensor): mask for predicted values
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_sparse_depth != 0:
            p_world = sparse_depth['p_world']
            depth_gt = sparse_depth['depth_gt']
            camera_mat = sparse_depth['camera_mat']
            world_mat = sparse_depth['world_mat']
            scale_mat = sparse_depth['scale_mat']

            # Shortscuts
            batch_size, n_points, _ = p_world.shape
            if self.depth_loss_on_world_points:
                loss_sparse_depth = losses.l2_loss(
                    p_world_hat[mask_pred], p_world[mask_pred],
                    reduction_method) * self.lambda_sparse_depth / batch_size
            else:
                d_pred_cam = transform_to_camera_space(p_world_hat, camera_mat,
                                                       world_mat,
                                                       scale_mat)[:, :, -1]
                loss_sparse_depth = losses.l1_loss(
                    d_pred_cam[mask_pred], depth_gt[mask_pred],
                    reduction_method, feat_dim=False) * \
                    self.lambda_sparse_depth / batch_size

            if eval_mode:
                if self.depth_loss_on_world_points:
                    loss_sparse_depth_val = losses.l2_loss(
                        p_world_hat[mask_pred], p_world[mask_pred], 'mean') * \
                        self.lambda_sparse_depth
                else:
                    d_pred_cam = transform_to_camera_space(
                        p_world_hat, camera_mat, world_mat, scale_mat)[:, :,
                                                                       -1]
                    loss_sparse_depth_val = losses.l1_loss(
                        d_pred_cam[mask_pred],
                        depth_gt[mask_pred],
                        'mean',
                        feat_dim=False) * self.lambda_sparse_depth
                loss['loss_sparse_depth_val'] = loss_sparse_depth_val

            loss['loss'] += loss_sparse_depth
            loss['loss_sparse_depth'] = loss_sparse_depth
    def calc_depth_loss(self, mask_depth, depth_img, pixels,
                        camera_mat, world_mat, scale_mat, p_world_hat,
                        reduction_method, loss={}, eval_mode=False):
        ''' Calculates the depth loss.

        Args:
            mask_depth (tensor): mask for depth loss
            depth_img (tensor): depth image
            pixels (tensor): sampled pixels in range [-1, 1]
            camera_mat (tensor): camera matrix
            world_mat (tensor): world matrix
            scale_mat (tensor): scale matrix
            p_world_hat (tensor): predicted world points
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_depth != 0 and mask_depth.sum() > 0:
            batch_size, n_pts, _ = p_world_hat.shape
            loss_depth_val = torch.tensor(10)
            # For depth values, we have to check again if all values are valid
            # as we potentially train with sparse depth maps
            depth_gt, mask_gt_depth = get_tensor_values(
                depth_img, pixels, squeeze_channel_dim=True, with_mask=True)
            mask_depth &= mask_gt_depth
            if self.depth_loss_on_world_points:
                # Applying L2 loss on world points results in the same as
                # applying L1 on the depth values with scaling (see Sup. Mat.)
                p_world = transform_to_world(
                    pixels, depth_gt.unsqueeze(-1), camera_mat, world_mat,
                    scale_mat)
                loss_depth = losses.l2_loss(
                    p_world_hat[mask_depth], p_world[mask_depth],
                    reduction_method) * self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l2_loss(
                        p_world_hat[mask_depth], p_world[mask_depth],
                        'mean') * self.lambda_depth
            else:
                d_pred = transform_to_camera_space(
                    p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1]
                loss_depth = losses.l1_loss(
                    d_pred[mask_depth], depth_gt[mask_depth],
                    reduction_method, feat_dim=False) * \
                    self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l1_loss(
                        d_pred[mask_depth], depth_gt[mask_depth],
                        'mean', feat_dim=False) * self.lambda_depth

            loss['loss'] += loss_depth
            loss['loss_depth'] = loss_depth
            if eval_mode:
                loss['loss_depth_eval'] = loss_depth_val
Esempio n. 3
0
    def render_img(self,
                   camera_mat,
                   world_mat,
                   inputs,
                   scale_mat=None,
                   c=None,
                   stats_dict={},
                   resolution=(128, 128)):
        ''' Renders an image for provided camera information.

        Args:
            camera_mat (tensor): camera matrix
            world_mat (tensor): world matrix
            scale_mat (tensor): scale matrix
            c (tensor): latent conditioned code c
            stats_dict (dict): statistics dictionary
            resolution (tuple): output image resolution
        '''
        device = self.device
        h, w = resolution

        t0 = time.time()

        p_loc, pixels = arange_pixels(resolution=(h, w))
        pixels = pixels.to(device)
        stats_dict['time_prepare_points'] = time.time() - t0

        if self.colors in ('rgb', 'depth'):
            # Get predicted world points
            with torch.no_grad():
                t0 = time.time()
                p_world_hat, mask_pred, mask_zero_occupied = \
                    self.model.pixels_to_world(
                        pixels, camera_mat, world_mat, scale_mat, c,
                        sampling_accuracy=self.sampling_accuracy)
                stats_dict['time_eval_depth'] = time.time() - t0

            t0 = time.time()
            p_loc = p_loc[mask_pred]
            with torch.no_grad():
                if self.colors == 'rgb':
                    img_out = (255 * np.ones((h, w, 3))).astype(np.uint8)
                    t0 = time.time()
                    if mask_pred.sum() > 0:
                        rgb_hat = self.model.decode_color(p_world_hat, c=c)
                        rgb_hat = rgb_hat[mask_pred].cpu().numpy()
                        rgb_hat = (rgb_hat * 255).astype(np.uint8)
                        img_out[p_loc[:, 1], p_loc[:, 0]] = rgb_hat
                    img_out = Image.fromarray(img_out).convert('RGB')
                elif self.colors == 'depth':
                    img_out = (255 * np.ones((h, w))).astype(np.uint8)
                    if mask_pred.sum() > 0:
                        p_world_hat = p_world_hat[mask_pred].unsqueeze(0)
                        d_values = transform_to_camera_space(
                            p_world_hat, camera_mat, world_mat,
                            scale_mat).squeeze(0)[:, -1].cpu().numpy()
                        m = d_values[d_values != np.inf].min()
                        M = d_values[d_values != np.inf].max()
                        d_values = 0.5 + 0.45 * (d_values - m) / (M - m)
                        d_image_values = d_values * 255
                        img_out[p_loc[:, 1], p_loc[:, 0]] = \
                            d_image_values.astype(np.uint8)
                    img_out = Image.fromarray(img_out).convert("L")

        stats_dict['time_eval_color'] = time.time() - t0
        return img_out
Esempio n. 4
0
    def test_dataset(self):
        # 1. rerender input point clouds / meshes using the saved camera_mat
        #    compare mask image with saved mask image
        # 2. backproject masked points to space with dense depth map,
        #    fuse all views and save
        batch_size = 1
        device = torch.device('cuda:0')

        data_dir = 'data/synthetic/cube_mesh'
        output_dir = os.path.join('tests', 'outputs', 'test_data')
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # dataset
        dataset = MVRDataset(data_dir=data_dir,
                             load_dense_depth=True,
                             mode="train")
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  num_workers=0,
                                                  shuffle=False)
        meshes = load_objs_as_meshes([os.path.join(data_dir,
                                                   'mesh.obj')]).to(device)
        cams = dataset.get_cameras().to(device)
        image_size = imageio.imread(dataset.image_files[0]).shape[0]

        # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=5,
            bin_size=
            None,  # this setting controls whether naive or coarse-to-fine rasterization is used
            max_faces_per_bin=None  # this setting is for coarse rasterization
        )
        rasterizer = MeshRasterizer(cameras=None,
                                    raster_settings=raster_settings)

        # render with loaded cameras positions and training tranformation functions
        pixel_world_all = []
        for idx, data in enumerate(data_loader):
            # get datas
            img = data.get('img.rgb').to(device)
            assert (img.min() >= 0 and img.max() <= 1
                    ), "Image must be a floating number between 0 and 1."
            mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1)

            camera_mat = data['camera_mat'].to(device)

            cams.R, cams.T = decompose_to_R_and_t(camera_mat)
            cams._N = cams.R.shape[0]
            cams.to(device)
            self.assertTrue(
                torch.equal(cams.get_world_to_view_transform().get_matrix(),
                            camera_mat))

            # transform to view and rerender with non-rotated camera
            verts_padded = transform_to_camera_space(meshes.verts_padded(),
                                                     cams)
            meshes_in_view = meshes.offset_verts(
                -meshes.verts_packed() + padded_to_packed(
                    verts_padded, meshes.mesh_to_verts_packed_first_idx(),
                    meshes.verts_packed().shape[0]))

            fragments = rasterizer(meshes_in_view,
                                   cameras=dataset.get_cameras().to(device))

            # compare mask
            mask = fragments.pix_to_face[..., :1] >= 0
            imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx),
                            mask[0, ...].cpu().to(dtype=torch.uint8) * 255)
            # allow 5 pixels difference
            self.assertTrue(torch.sum(mask_gt != mask) < 5)

            # check dense maps
            # backproject points to the world pixel range (-1, 1)
            pixels = arange_pixels((image_size, image_size),
                                   batch_size)[1].to(device)

            depth_img = data.get('img.depth').to(device)
            # get the depth and mask at the sampled pixel position
            depth_gt = get_tensor_values(depth_img,
                                         pixels,
                                         squeeze_channel_dim=True)
            mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(),
                                        pixels,
                                        squeeze_channel_dim=True).bool()
            # get pixels and depth inside the masked area
            pixels_packed = pixels[mask_gt]
            depth_gt_packed = depth_gt[mask_gt]
            first_idx = torch.zeros((pixels.shape[0], ),
                                    device=device,
                                    dtype=torch.long)
            num_pts_in_mask = mask_gt.sum(dim=1)
            first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1]
            pixels_padded = packed_to_padded(pixels_packed, first_idx,
                                             num_pts_in_mask.max().item())
            depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx,
                                               num_pts_in_mask.max().item())
            # backproject to world coordinates
            # contains nan and infinite values due to depth_gt_padded containing 0.0
            pixel_world_padded = transform_to_world(pixels_padded,
                                                    depth_gt_padded[..., None],
                                                    cams)
            # transform back to list, containing no padded values
            split_size = num_pts_in_mask[..., None].repeat(1, 2)
            split_size[:, 1] = 3
            pixel_world_list = padded_to_list(pixel_world_padded, split_size)
            pixel_world_all.extend(pixel_world_list)

            idx += 1
            if idx >= 10:
                break

        pixel_world_all = torch.cat(pixel_world_all, dim=0)
        mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(),
                               faces=None,
                               process=False)
        mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))