def calc_depth_loss(self, mask_depth, depth_img, pixels,
                        camera_mat, world_mat, scale_mat, p_world_hat,
                        reduction_method, loss={}, eval_mode=False):
        ''' Calculates the depth loss.

        Args:
            mask_depth (tensor): mask for depth loss
            depth_img (tensor): depth image
            pixels (tensor): sampled pixels in range [-1, 1]
            camera_mat (tensor): camera matrix
            world_mat (tensor): world matrix
            scale_mat (tensor): scale matrix
            p_world_hat (tensor): predicted world points
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_depth != 0 and mask_depth.sum() > 0:
            batch_size, n_pts, _ = p_world_hat.shape
            loss_depth_val = torch.tensor(10)
            # For depth values, we have to check again if all values are valid
            # as we potentially train with sparse depth maps
            depth_gt, mask_gt_depth = get_tensor_values(
                depth_img, pixels, squeeze_channel_dim=True, with_mask=True)
            mask_depth &= mask_gt_depth
            if self.depth_loss_on_world_points:
                # Applying L2 loss on world points results in the same as
                # applying L1 on the depth values with scaling (see Sup. Mat.)
                p_world = transform_to_world(
                    pixels, depth_gt.unsqueeze(-1), camera_mat, world_mat,
                    scale_mat)
                loss_depth = losses.l2_loss(
                    p_world_hat[mask_depth], p_world[mask_depth],
                    reduction_method) * self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l2_loss(
                        p_world_hat[mask_depth], p_world[mask_depth],
                        'mean') * self.lambda_depth
            else:
                d_pred = transform_to_camera_space(
                    p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1]
                loss_depth = losses.l1_loss(
                    d_pred[mask_depth], depth_gt[mask_depth],
                    reduction_method, feat_dim=False) * \
                    self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l1_loss(
                        d_pred[mask_depth], depth_gt[mask_depth],
                        'mean', feat_dim=False) * self.lambda_depth

            loss['loss'] += loss_depth
            loss['loss_depth'] = loss_depth
            if eval_mode:
                loss['loss_depth_eval'] = loss_depth_val
Exemplo n.º 2
0
    def calc_photoconsistency_loss(self,
                                   mask_rgb,
                                   rgb_pred,
                                   img,
                                   pixels,
                                   reduction_method,
                                   loss,
                                   patch_size,
                                   eval_mode=False):
        ''' Calculates the photo-consistency loss.

        Args:
            mask_rgb (tensor): mask for photo-consistency loss
            rgb_pred (tensor): predicted rgb color values
            img (tensor): GT image
            pixels (tensor): sampled pixels in range [-1, 1]
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            patch_size (int): size of sampled patch
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_rgb != 0 and mask_rgb.sum() > 0:
            batch_size, n_pts, _ = rgb_pred.shape
            loss_rgb_eval = torch.tensor(3)
            # Get GT RGB values
            rgb_gt = get_tensor_values(img, pixels)

            # 3.1) Calculate RGB Loss
            loss_rgb = losses.l1_loss(
                rgb_pred[mask_rgb], rgb_gt[mask_rgb],
                reduction_method) * self.lambda_rgb / batch_size
            loss['loss'] += loss_rgb
            loss['loss_rgb'] = loss_rgb
            if eval_mode:
                loss_rgb_eval = losses.l1_loss(
                    rgb_pred[mask_rgb], rgb_gt[mask_rgb], 'mean') * \
                    self.lambda_rgb

            # 3.2) Image Gradient loss
            if self.lambda_image_gradients != 0:
                assert (patch_size > 1)
                loss_grad = losses.image_gradient_loss(
                    rgb_pred, rgb_gt, mask_rgb, patch_size,
                    reduction_method) * \
                    self.lambda_image_gradients / batch_size
                loss['loss'] += loss_grad
                loss['loss_image_gradient'] = loss_grad
            if eval_mode:
                loss['loss_rgb_eval'] = loss_rgb_eval
rgb_out = []

for i, imgp in enumerate(img_list):
    # load files
    img = np.array(Image.open(imgp).convert("RGB")).astype(np.float32) / 255
    h, w, c = img.shape
    depth = np.array(imageio.imread(depth_list[i]))
    depth = depth.reshape(depth.shape[0], depth.shape[1], -1)[..., 0]

    hd, wd = depth.shape 
    assert(h == hd and w == wd)
    
    p = sample_patch_points(1, n_sample_points, patch_size=1, image_resolution=(h, w), continuous=False)
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) 
    depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(1)
    rgb = get_tensor_values(img, p)
    d = get_tensor_values(depth, p)
    mask = (d != np.inf).squeeze(-1)
    d = d[mask].unsqueeze(0)
    rgb = rgb[mask]
    p = p[mask].unsqueeze(0)
 
    # transform to world
    cm = cam.get('camera_mat_%d' % i).astype(np.float32).reshape(1, 4, 4)
    wm = cam.get('world_mat_%d' % i).astype(np.float32).reshape(1, 4, 4)
    sm = cam.get('scale_mat_%d' % i, np.eye(4)).astype(np.float32).reshape(1, 4, 4)
    p_world = transform_to_world(p, d, cm, wm, sm)[0]
    v_out.append(p_world)

v = np.concatenate(v_out, axis=0)
mesh = trimesh.Trimesh(vertices=v).export('out/pcd.ply')
Exemplo n.º 4
0
    def compute_loss(self, data, eval_mode=False, it=None):
        ''' Compute the loss.

        Args:
            data (dict): data dictionary
            eval_mode (bool): whether to use eval mode
            it (int): training iteration
        '''
        # Initialize loss dictionary and other values
        loss = {}
        n_points = self.n_eval_points if eval_mode else self.n_training_points
        # Process data dictionary
        (img, mask_img, depth_img, world_mat, camera_mat, scale_mat, inputs,
         sparse_depth) = self.process_data_dict(data)

        # Shortcuts
        device = self.device
        patch_size = self.patch_size
        reduction_method = self.reduction_method
        batch_size, _, h, w = img.shape

        # Assertions
        assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0)
                and (n_points > 0))

        # Sample points on image plane ("pixels")
        if n_points >= h * w:
            p = arange_pixels((h, w), batch_size)[1].to(device)
        else:
            p = sample_patch_points(
                batch_size,
                n_points,
                patch_size=patch_size,
                image_resolution=(h, w),
                continuous=self.sample_continuous,
            ).to(device)

        # Apply losses
        # 1.) Get Object Mask values and define masks for losses
        mask_gt = get_tensor_values(mask_img, p,
                                    squeeze_channel_dim=True).bool()

        # Calculate 3D points which need to be evaluated for the occupancy and
        # freespace loss
        p_freespace = get_freespace_loss_points(p, camera_mat, world_mat,
                                                scale_mat,
                                                self.use_cube_intersection,
                                                self.depth_range)

        depth_input = depth_img if (self.lambda_depth != 0
                                    or self.depth_from_visual_hull) else None
        p_occupancy = get_occupancy_loss_points(p, camera_mat, world_mat,
                                                scale_mat, depth_input,
                                                self.use_cube_intersection,
                                                self.occupancy_random_normal,
                                                self.depth_range)

        # 2.) Initialize loss
        loss['loss'] = 0

        # 3.) Make forward pass through the network and obtain predictions
        # with masks
        (p_world_hat, rgb_pred, logits_occupancy, logits_freespace, mask_pred,
         p_world_hat_sparse, mask_pred_sparse,
         normals) = self.model(p, p_occupancy, p_freespace, inputs, camera_mat,
                               world_mat, scale_mat, it, sparse_depth,
                               self.lambda_normal != 0)

        # 4.) Calculate Loss
        # 4.1) Photo Consistency Loss
        mask_rgb = mask_pred & mask_gt
        self.calc_photoconsistency_loss(mask_rgb, rgb_pred, img, p,
                                        reduction_method, loss, patch_size,
                                        eval_mode)

        # 4.2) Calculate Depth Loss
        mask_depth = mask_pred & mask_gt
        self.calc_depth_loss(mask_depth, depth_img, p, camera_mat, world_mat,
                             scale_mat, p_world_hat, reduction_method, loss,
                             eval_mode)

        # 4 3 Calculate normal loss
        self.calc_normal_loss(normals, batch_size, loss, eval_mode)

        # 4.4) Sparse Depth Loss
        self.calc_sparse_depth_loss(sparse_depth, p_world_hat_sparse,
                                    mask_pred_sparse, reduction_method, loss,
                                    eval_mode)

        # 4.5) Freespace loss
        mask_freespace = (mask_gt == 0) if self.always_freespace else \
            ((mask_gt == 0) & (mask_pred))
        self.calc_freespace_loss(logits_freespace, mask_freespace,
                                 reduction_method, loss)

        # 4.6) Occupancy Loss
        mask_occupancy = (mask_pred == 0) & mask_gt
        self.calc_occupancy_loss(logits_occupancy, mask_occupancy,
                                 reduction_method, loss)

        # Save mean mask intersection for tensorboard
        if eval_mode:
            self.calc_mask_intersection(mask_gt, mask_pred, loss)
        return loss if eval_mode else loss['loss']
Exemplo n.º 5
0
    def test_dataset(self):
        # 1. rerender input point clouds / meshes using the saved camera_mat
        #    compare mask image with saved mask image
        # 2. backproject masked points to space with dense depth map,
        #    fuse all views and save
        batch_size = 1
        device = torch.device('cuda:0')

        data_dir = 'data/synthetic/cube_mesh'
        output_dir = os.path.join('tests', 'outputs', 'test_data')
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # dataset
        dataset = MVRDataset(data_dir=data_dir,
                             load_dense_depth=True,
                             mode="train")
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  num_workers=0,
                                                  shuffle=False)
        meshes = load_objs_as_meshes([os.path.join(data_dir,
                                                   'mesh.obj')]).to(device)
        cams = dataset.get_cameras().to(device)
        image_size = imageio.imread(dataset.image_files[0]).shape[0]

        # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=5,
            bin_size=
            None,  # this setting controls whether naive or coarse-to-fine rasterization is used
            max_faces_per_bin=None  # this setting is for coarse rasterization
        )
        rasterizer = MeshRasterizer(cameras=None,
                                    raster_settings=raster_settings)

        # render with loaded cameras positions and training tranformation functions
        pixel_world_all = []
        for idx, data in enumerate(data_loader):
            # get datas
            img = data.get('img.rgb').to(device)
            assert (img.min() >= 0 and img.max() <= 1
                    ), "Image must be a floating number between 0 and 1."
            mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1)

            camera_mat = data['camera_mat'].to(device)

            cams.R, cams.T = decompose_to_R_and_t(camera_mat)
            cams._N = cams.R.shape[0]
            cams.to(device)
            self.assertTrue(
                torch.equal(cams.get_world_to_view_transform().get_matrix(),
                            camera_mat))

            # transform to view and rerender with non-rotated camera
            verts_padded = transform_to_camera_space(meshes.verts_padded(),
                                                     cams)
            meshes_in_view = meshes.offset_verts(
                -meshes.verts_packed() + padded_to_packed(
                    verts_padded, meshes.mesh_to_verts_packed_first_idx(),
                    meshes.verts_packed().shape[0]))

            fragments = rasterizer(meshes_in_view,
                                   cameras=dataset.get_cameras().to(device))

            # compare mask
            mask = fragments.pix_to_face[..., :1] >= 0
            imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx),
                            mask[0, ...].cpu().to(dtype=torch.uint8) * 255)
            # allow 5 pixels difference
            self.assertTrue(torch.sum(mask_gt != mask) < 5)

            # check dense maps
            # backproject points to the world pixel range (-1, 1)
            pixels = arange_pixels((image_size, image_size),
                                   batch_size)[1].to(device)

            depth_img = data.get('img.depth').to(device)
            # get the depth and mask at the sampled pixel position
            depth_gt = get_tensor_values(depth_img,
                                         pixels,
                                         squeeze_channel_dim=True)
            mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(),
                                        pixels,
                                        squeeze_channel_dim=True).bool()
            # get pixels and depth inside the masked area
            pixels_packed = pixels[mask_gt]
            depth_gt_packed = depth_gt[mask_gt]
            first_idx = torch.zeros((pixels.shape[0], ),
                                    device=device,
                                    dtype=torch.long)
            num_pts_in_mask = mask_gt.sum(dim=1)
            first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1]
            pixels_padded = packed_to_padded(pixels_packed, first_idx,
                                             num_pts_in_mask.max().item())
            depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx,
                                               num_pts_in_mask.max().item())
            # backproject to world coordinates
            # contains nan and infinite values due to depth_gt_padded containing 0.0
            pixel_world_padded = transform_to_world(pixels_padded,
                                                    depth_gt_padded[..., None],
                                                    cams)
            # transform back to list, containing no padded values
            split_size = num_pts_in_mask[..., None].repeat(1, 2)
            split_size[:, 1] = 3
            pixel_world_list = padded_to_list(pixel_world_padded, split_size)
            pixel_world_all.extend(pixel_world_list)

            idx += 1
            if idx >= 10:
                break

        pixel_world_all = torch.cat(pixel_world_all, dim=0)
        mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(),
                               faces=None,
                               process=False)
        mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))