コード例 #1
0
    def compute_loss(self, data, eval_mode=False, it=None):
        ''' Compute the loss.

        Args:
            data (dict): data dictionary
            eval_mode (bool): whether to use eval mode
            it (int): training iteration
        '''
        # Initialize loss dictionary and other values
        loss = {}
        n_points = self.n_eval_points if eval_mode else self.n_training_points
        # Process data dictionary
        (img, mask_img, depth_img, world_mat, camera_mat, scale_mat, inputs,
         sparse_depth) = self.process_data_dict(data)

        # Shortcuts
        device = self.device
        patch_size = self.patch_size
        reduction_method = self.reduction_method
        batch_size, _, h, w = img.shape

        # Assertions
        assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0)
                and (n_points > 0))

        # Sample points on image plane ("pixels")
        if n_points >= h * w:
            p = arange_pixels((h, w), batch_size)[1].to(device)
        else:
            p = sample_patch_points(
                batch_size,
                n_points,
                patch_size=patch_size,
                image_resolution=(h, w),
                continuous=self.sample_continuous,
            ).to(device)

        # Apply losses
        # 1.) Get Object Mask values and define masks for losses
        mask_gt = get_tensor_values(mask_img, p,
                                    squeeze_channel_dim=True).bool()

        # Calculate 3D points which need to be evaluated for the occupancy and
        # freespace loss
        p_freespace = get_freespace_loss_points(p, camera_mat, world_mat,
                                                scale_mat,
                                                self.use_cube_intersection,
                                                self.depth_range)

        depth_input = depth_img if (self.lambda_depth != 0
                                    or self.depth_from_visual_hull) else None
        p_occupancy = get_occupancy_loss_points(p, camera_mat, world_mat,
                                                scale_mat, depth_input,
                                                self.use_cube_intersection,
                                                self.occupancy_random_normal,
                                                self.depth_range)

        # 2.) Initialize loss
        loss['loss'] = 0

        # 3.) Make forward pass through the network and obtain predictions
        # with masks
        (p_world_hat, rgb_pred, logits_occupancy, logits_freespace, mask_pred,
         p_world_hat_sparse, mask_pred_sparse,
         normals) = self.model(p, p_occupancy, p_freespace, inputs, camera_mat,
                               world_mat, scale_mat, it, sparse_depth,
                               self.lambda_normal != 0)

        # 4.) Calculate Loss
        # 4.1) Photo Consistency Loss
        mask_rgb = mask_pred & mask_gt
        self.calc_photoconsistency_loss(mask_rgb, rgb_pred, img, p,
                                        reduction_method, loss, patch_size,
                                        eval_mode)

        # 4.2) Calculate Depth Loss
        mask_depth = mask_pred & mask_gt
        self.calc_depth_loss(mask_depth, depth_img, p, camera_mat, world_mat,
                             scale_mat, p_world_hat, reduction_method, loss,
                             eval_mode)

        # 4 3 Calculate normal loss
        self.calc_normal_loss(normals, batch_size, loss, eval_mode)

        # 4.4) Sparse Depth Loss
        self.calc_sparse_depth_loss(sparse_depth, p_world_hat_sparse,
                                    mask_pred_sparse, reduction_method, loss,
                                    eval_mode)

        # 4.5) Freespace loss
        mask_freespace = (mask_gt == 0) if self.always_freespace else \
            ((mask_gt == 0) & (mask_pred))
        self.calc_freespace_loss(logits_freespace, mask_freespace,
                                 reduction_method, loss)

        # 4.6) Occupancy Loss
        mask_occupancy = (mask_pred == 0) & mask_gt
        self.calc_occupancy_loss(logits_occupancy, mask_occupancy,
                                 reduction_method, loss)

        # Save mean mask intersection for tensorboard
        if eval_mode:
            self.calc_mask_intersection(mask_gt, mask_pred, loss)
        return loss if eval_mode else loss['loss']
コード例 #2
0
    def render_img(self,
                   camera_mat,
                   world_mat,
                   inputs,
                   scale_mat=None,
                   c=None,
                   stats_dict={},
                   resolution=(128, 128)):
        ''' Renders an image for provided camera information.

        Args:
            camera_mat (tensor): camera matrix
            world_mat (tensor): world matrix
            scale_mat (tensor): scale matrix
            c (tensor): latent conditioned code c
            stats_dict (dict): statistics dictionary
            resolution (tuple): output image resolution
        '''
        device = self.device
        h, w = resolution

        t0 = time.time()

        p_loc, pixels = arange_pixels(resolution=(h, w))
        pixels = pixels.to(device)
        stats_dict['time_prepare_points'] = time.time() - t0

        if self.colors in ('rgb', 'depth'):
            # Get predicted world points
            with torch.no_grad():
                t0 = time.time()
                p_world_hat, mask_pred, mask_zero_occupied = \
                    self.model.pixels_to_world(
                        pixels, camera_mat, world_mat, scale_mat, c,
                        sampling_accuracy=self.sampling_accuracy)
                stats_dict['time_eval_depth'] = time.time() - t0

            t0 = time.time()
            p_loc = p_loc[mask_pred]
            with torch.no_grad():
                if self.colors == 'rgb':
                    img_out = (255 * np.ones((h, w, 3))).astype(np.uint8)
                    t0 = time.time()
                    if mask_pred.sum() > 0:
                        rgb_hat = self.model.decode_color(p_world_hat, c=c)
                        rgb_hat = rgb_hat[mask_pred].cpu().numpy()
                        rgb_hat = (rgb_hat * 255).astype(np.uint8)
                        img_out[p_loc[:, 1], p_loc[:, 0]] = rgb_hat
                    img_out = Image.fromarray(img_out).convert('RGB')
                elif self.colors == 'depth':
                    img_out = (255 * np.ones((h, w))).astype(np.uint8)
                    if mask_pred.sum() > 0:
                        p_world_hat = p_world_hat[mask_pred].unsqueeze(0)
                        d_values = transform_to_camera_space(
                            p_world_hat, camera_mat, world_mat,
                            scale_mat).squeeze(0)[:, -1].cpu().numpy()
                        m = d_values[d_values != np.inf].min()
                        M = d_values[d_values != np.inf].max()
                        d_values = 0.5 + 0.45 * (d_values - m) / (M - m)
                        d_image_values = d_values * 255
                        img_out[p_loc[:, 1], p_loc[:, 0]] = \
                            d_image_values.astype(np.uint8)
                    img_out = Image.fromarray(img_out).convert("L")

        stats_dict['time_eval_color'] = time.time() - t0
        return img_out
コード例 #3
0
    def test_dataset(self):
        # 1. rerender input point clouds / meshes using the saved camera_mat
        #    compare mask image with saved mask image
        # 2. backproject masked points to space with dense depth map,
        #    fuse all views and save
        batch_size = 1
        device = torch.device('cuda:0')

        data_dir = 'data/synthetic/cube_mesh'
        output_dir = os.path.join('tests', 'outputs', 'test_data')
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # dataset
        dataset = MVRDataset(data_dir=data_dir,
                             load_dense_depth=True,
                             mode="train")
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  num_workers=0,
                                                  shuffle=False)
        meshes = load_objs_as_meshes([os.path.join(data_dir,
                                                   'mesh.obj')]).to(device)
        cams = dataset.get_cameras().to(device)
        image_size = imageio.imread(dataset.image_files[0]).shape[0]

        # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=5,
            bin_size=
            None,  # this setting controls whether naive or coarse-to-fine rasterization is used
            max_faces_per_bin=None  # this setting is for coarse rasterization
        )
        rasterizer = MeshRasterizer(cameras=None,
                                    raster_settings=raster_settings)

        # render with loaded cameras positions and training tranformation functions
        pixel_world_all = []
        for idx, data in enumerate(data_loader):
            # get datas
            img = data.get('img.rgb').to(device)
            assert (img.min() >= 0 and img.max() <= 1
                    ), "Image must be a floating number between 0 and 1."
            mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1)

            camera_mat = data['camera_mat'].to(device)

            cams.R, cams.T = decompose_to_R_and_t(camera_mat)
            cams._N = cams.R.shape[0]
            cams.to(device)
            self.assertTrue(
                torch.equal(cams.get_world_to_view_transform().get_matrix(),
                            camera_mat))

            # transform to view and rerender with non-rotated camera
            verts_padded = transform_to_camera_space(meshes.verts_padded(),
                                                     cams)
            meshes_in_view = meshes.offset_verts(
                -meshes.verts_packed() + padded_to_packed(
                    verts_padded, meshes.mesh_to_verts_packed_first_idx(),
                    meshes.verts_packed().shape[0]))

            fragments = rasterizer(meshes_in_view,
                                   cameras=dataset.get_cameras().to(device))

            # compare mask
            mask = fragments.pix_to_face[..., :1] >= 0
            imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx),
                            mask[0, ...].cpu().to(dtype=torch.uint8) * 255)
            # allow 5 pixels difference
            self.assertTrue(torch.sum(mask_gt != mask) < 5)

            # check dense maps
            # backproject points to the world pixel range (-1, 1)
            pixels = arange_pixels((image_size, image_size),
                                   batch_size)[1].to(device)

            depth_img = data.get('img.depth').to(device)
            # get the depth and mask at the sampled pixel position
            depth_gt = get_tensor_values(depth_img,
                                         pixels,
                                         squeeze_channel_dim=True)
            mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(),
                                        pixels,
                                        squeeze_channel_dim=True).bool()
            # get pixels and depth inside the masked area
            pixels_packed = pixels[mask_gt]
            depth_gt_packed = depth_gt[mask_gt]
            first_idx = torch.zeros((pixels.shape[0], ),
                                    device=device,
                                    dtype=torch.long)
            num_pts_in_mask = mask_gt.sum(dim=1)
            first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1]
            pixels_padded = packed_to_padded(pixels_packed, first_idx,
                                             num_pts_in_mask.max().item())
            depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx,
                                               num_pts_in_mask.max().item())
            # backproject to world coordinates
            # contains nan and infinite values due to depth_gt_padded containing 0.0
            pixel_world_padded = transform_to_world(pixels_padded,
                                                    depth_gt_padded[..., None],
                                                    cams)
            # transform back to list, containing no padded values
            split_size = num_pts_in_mask[..., None].repeat(1, 2)
            split_size[:, 1] = 3
            pixel_world_list = padded_to_list(pixel_world_padded, split_size)
            pixel_world_all.extend(pixel_world_list)

            idx += 1
            if idx >= 10:
                break

        pixel_world_all = torch.cat(pixel_world_all, dim=0)
        mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(),
                               faces=None,
                               process=False)
        mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))