示例#1
0
    def train_step(self, data):
        r''' Performs a training step of the model.

        Arguments:
            data (tensor): The input data
        '''

        self.model.train()
        points = data.get('pointcloud').to(self.device)
        normals = data.get('pointcloud.normals').to(self.device)
        img = data.get('inputs').to(self.device)
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)

        # Transform GT data into camera coordinate system
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        points_transformed = common.transform_points(points, world_mat)
        # Transform GT normals to camera coordinate system
        world_normal_mat = world_mat[:, :, :3]
        normals = common.transform_points(normals, world_normal_mat)

        outputs1, outputs2 = self.model(img, camera_mat)
        loss = self.compute_loss(outputs1, outputs2, points_transformed,
                                 normals, img)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()
示例#2
0
    def eval_step(self, data):
        r''' Performs an evaluation step.

        Arguments:
            data (tensor): input data
        '''
        self.model.eval()
        points = data.get('pointcloud').to(self.device)
        img = data.get('inputs').to(self.device)
        normals = data.get('pointcloud.normals').to(self.device)

        # Transform GT points to camera coordinates
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        points_transformed = common.transform_points(points, world_mat)
        # Transform GT normals to camera coordinates
        world_normal_mat = world_mat[:, :, :3]
        normals = common.transform_points(normals, world_normal_mat)

        with torch.no_grad():
            outputs1, outputs2 = self.model(img, camera_mat)

        pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1

        loss = self.compute_loss(outputs1, outputs2, points_transformed,
                                 normals, img)
        lc1, lc2, id31, id32 = chamfer_distance(pred_vertices_3,
                                                points_transformed,
                                                give_id=True)
        l_c = (lc1 + lc2).mean()
        l_e = self.edge_length_loss(pred_vertices_3, 3)
        l_n = self.normal_loss(pred_vertices_3, normals, id31, 3)
        l_l, move_loss = self.laplacian_loss(pred_vertices_3,
                                             outputs2[2],
                                             block_id=3)

        eval_dict = {
            'loss': loss.item(),
            'chamfer': l_c.item(),
            'edge': l_e.item(),
            'normal': l_n.item(),
            'laplace': l_l.item(),
            'move': move_loss.item()
        }
        return eval_dict
示例#3
0
def compose_pointcloud(data, device, pointcloud_transfer=None, world_mat=None):
    gt_pc = data.get('pointcloud').to(device)

    # default : 'world_normalized'

    if pointcloud_transfer in ('world_scale_model', 'view',
                               'view_scale_model'):
        batch_size = gt_pc.size(0)
        gt_pc_loc = data.get('pointcloud.loc').to(device).view(
            batch_size, 1, 3)
        gt_pc_scale = data.get('pointcloud.scale').to(device).view(
            batch_size, 1, 1)

        gt_pc = gt_pc * gt_pc_scale + gt_pc_loc

    if pointcloud_transfer in ('view', 'view_scale_model'):
        assert world_mat is not None
        R = world_mat[:, :, :3]
        gt_pc = transform_points(gt_pc, R)

    if pointcloud_transfer in ('view'):
        assert world_mat is not None
        t = world_mat[:, :, 3:]
        gt_pc = gt_pc / t[:, 2:, :]

    return gt_pc
示例#4
0
def visualise_projection(self,
                         points,
                         world_mat,
                         camera_mat,
                         img,
                         output_file='out.png'):
    r''' Visualizes the transformation and projection to image plane.

        The first points of the batch are transformed and projected to the
        respective image. After performing the relevant transformations, the
        visualization is saved in the provided output_file path.

    Arguments:
        points (tensor): batch of point cloud points
        world_mat (tensor): batch of matrices to rotate pc to camera-based
                coordinates
        camera_mat (tensor): batch of camera matrices to project to 2D image
                plane
        img (tensor): tensor of batch GT image files
        output_file (string): where the output should be saved
    '''
    points_transformed = common.transform_points(points, world_mat)
    points_img = common.project_to_camera(points_transformed, camera_mat)
    pimg2 = points_img[0].detach().cpu().numpy()
    image = img[0].cpu().numpy()
    plt.imshow(image.transpose(1, 2, 0))
    plt.plot((pimg2[:, 0] + 1) * image.shape[1] / 2,
             (pimg2[:, 1] + 1) * image.shape[2] / 2, 'x')
    plt.savefig(output_file)
    def forward_local_second_step(self, data, c, local_feat_maps, pts):
        world_mat = data['world_mat']
        camera_mat = data['camera_mat']

        assert self.local
        pts = common.transform_points(pts, world_mat)
        points_img = common.project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        # get local feats
        local_feats = []
        for f in local_feat_maps:
            #f = f.detach()
            f = F.grid_sample(f, points_img, mode='nearest')
            f = f.squeeze(2)
            local_feats.append(f)

        local_feats = torch.cat(local_feats, dim=1)
        local_feats = local_feats.transpose(1, 2)  # batch * n_pts * f_dim

        local_feats = self.local_fc(local_feats)

        # x: B * c_dim
        # local: feats B * n_pts * c_dim
        return c, local_feats
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        '''TRANSFORMATION TO CAMERA SPACE'''
        if self.camera_space:
            transform = data.get('inputs.world_mat').to(device)
            R = transform[:, :, :3]
            p = transform_points(p, transform)
        '''END'''

        kwargs = {}

        c = self.model.encode_inputs(inputs)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()
        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        logits = self.model.decode(p, z, c, **kwargs).logits
        loss_i = F.binary_cross_entropy_with_logits(logits,
                                                    occ,
                                                    reduction='none')
        loss = loss + loss_i.sum(-1).mean()

        return loss
示例#7
0
    def encode_second_step(self, f3, f2, f1, pts, world_mat, camera_mat):
        pts = common.transform_points(pts, world_mat)
        points_img = common.project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        f2 = f2.detach()
        f2 = F.relu(f2)
        f2 = F.grid_sample(f2, points_img)
        f2 = f2.squeeze(2)
        f2 = self.f2_conv(f2)

        f1 = f1.detach()
        f1 = F.relu(f1)
        f1 = F.grid_sample(f1, points_img)
        f1 = f1.squeeze(2)
        f1 = self.f1_conv(f1)
        
        f3 = self.fc3(f3)

        if self.batch_norm:
            f3 = self.f3_bn(f3)
            f2 = self.f2_bn(f2)
            f1 = self.f1_bn(f1)

        f2 = f2.transpose(1, 2)
        f1 = f1.transpose(1, 2)
        # f2 : batch * n_pts * fmap_dim
        # f1 : batch * n_pts * fmap_dim
        return f3, f2, f1
示例#8
0
    def forward_local(self, data, pts):
        assert self.local
        world_mat = data['world_mat']
        camera_mat = data['camera_mat']
        x = data[None]

        pts = transform_points(pts, world_mat)
        points_img = project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        local_feat_maps = []
        if self.normalize:
            x = normalize_imagenet(x)

        x = self.features.conv1(x)
        x = self.features.bn1(x)
        x = self.features.relu(x)
        x = self.features.maxpool(x)  # 64 * 112 * 112

        x = self.features.layer1(x)
        local_feat_maps.append(x)  # 64 * 56 * 56
        x = self.features.layer2(x)
        local_feat_maps.append(x)  # 128 * 28 * 28
        x = self.features.layer3(x)
        local_feat_maps.append(x)  # 256 * 14 * 14
        x = self.features.layer4(x)
        local_feat_maps.append(x)  # 512 * 7 * 7

        x = self.features.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # get local feats
        local_feats = []
        for f in local_feat_maps:
            #f = f.detach()
            f = F.grid_sample(f, points_img, mode='nearest')
            f = f.squeeze(2)
            local_feats.append(f)

        local_feats = torch.cat(local_feats, dim=1)
        local_feats = local_feats.transpose(1, 2)  # batch * n_pts * f_dim

        local_feats = self.local_fc(local_feats)

        # x: B * c_dim
        # local: feats B * n_pts * c_dim
        return x, local_feats
示例#9
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )

        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)

        if self.model.encoder_world_mat is not None:
            out = self.model(encoder_inputs, world_mat=world_mat)
        else:
            out = self.model(encoder_inputs)

        loss = 0
        if isinstance(out, tuple):
            out, trans_feat = out

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat)

        # chamfer distance loss
        if self.loss_type == 'cd':
            dist1, dist2 = cd.chamfer_distance(out, gt_pc)
            loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2.
        else:
            out_pts_count = out.size(1)
            loss = loss + (emd.earth_mover_distance(
                out, gt_pc, transpose=False) / out_pts_count).mean()

        # view penalty loss
        if self.view_penalty:
            gt_mask_flow = data.get('inputs.mask_flow').to(
                device)  # B * 1 * H * W
            if world_mat is None:
                world_mat = get_world_mat(data, device=device)
            camera_mat = get_camera_mat(data, device=device)

            # projection use world mat & camera mat
            if self.gt_pointcloud_transfer == 'world_scale_model':
                out_pts = transform_points(out, world_mat)
            elif self.gt_pointcloud_transfer == 'view_scale_model':
                t = world_mat[:, :, 3:]
                out_pts = out_pts + t
            elif self.gt_pointcloud_transfer == 'view':
                t = world_mat[:, :, 3:]
                out_pts = out_pts * t[:, 2:, :]
                out_pts = out_pts + t
            else:
                raise NotImplementedError

            out_pts_img = project_to_camera(out_pts, camera_mat)
            out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

            out_mask_flow = F.grid_sample(gt_mask_flow,
                                          out_pts_img)  # B * 1 * 1 * n_pts
            loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow,
                                    inplace=True).mean()
            loss = loss + self.loss_mask_flow_ratio * loss_mask_flow

            if self.view_penalty == 'mask_flow_and_depth':
                # depth test loss
                t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1)
                gt_mask = data.get('inputs.mask').byte().to(device)
                depth_pred = data.get('inputs.depth_pred').to(
                    device) * t_scale  # absolute depth from view
                background_setting(depth_pred, gt_mask)
                depth_z = out_pts[:, :, 2:].transpose(1, 2)
                corresponding_z = F.grid_sample(
                    depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                corresponding_z = corresponding_z.squeeze(1)

                # eps
                loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                         corresponding_z,
                                         inplace=True).mean()
                loss = loss + self.loss_depth_test_ratio * loss_depth_test

        return loss
示例#10
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer)
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

            if isinstance(out, tuple):
                out, trans_feat = out

            eval_dict = {}
            if batch_size == 1:
                pointcloud_hat = out.cpu().squeeze(0).numpy()
                pointcloud_gt = gt_pc.cpu().squeeze(0).numpy()

                eval_dict = self.mesh_evaluator.eval_pointcloud(
                    pointcloud_hat, pointcloud_gt)

            # chamfer distance loss
            if self.loss_type == 'cd':
                dist1, dist2 = cd.chamfer_distance(out, gt_pc)
                loss = (dist1.mean(1) + dist2.mean(1)) / 2.
            else:
                loss = emd.earth_mover_distance(out, gt_pc, transpose=False)

            if self.gt_pointcloud_transfer in ('world_scale_model',
                                               'view_scale_model', 'view'):
                pointcloud_scale = data.get('pointcloud.scale').to(
                    device).view(batch_size, 1, 1)
                loss = loss / (pointcloud_scale**2)
                if self.gt_pointcloud_transfer == 'view':
                    if world_mat is None:
                        world_mat = get_world_mat(data, device=device)
                    t_scale = world_mat[:, 2:, 3:]
                    loss = loss * (t_scale**2)

            if self.loss_type == 'cd':
                loss = loss.mean()
                eval_dict['chamfer'] = loss.item()
            else:
                out_pts_count = out.size(1)
                loss = (loss / out_pts_count).mean()
                eval_dict['emd'] = loss.item()

            # view penalty loss
            if self.view_penalty:
                gt_mask_flow = data.get('inputs.mask_flow').to(
                    device)  # B * 1 * H * W
                if world_mat is None:
                    world_mat = get_world_mat(data, device=device)
                camera_mat = get_camera_mat(data, device=device)

                # projection use world mat & camera mat
                if self.gt_pointcloud_transfer == 'world_scale_model':
                    out_pts = transform_points(out, world_mat)
                elif self.gt_pointcloud_transfer == 'view_scale_model':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts + t
                elif self.gt_pointcloud_transfer == 'view':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts * t[:, 2:, :]
                    out_pts = out_pts + t
                else:
                    raise NotImplementedError

                out_pts_img = project_to_camera(out_pts, camera_mat)
                out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

                out_mask_flow = F.grid_sample(gt_mask_flow,
                                              out_pts_img)  # B * 1 * 1 * n_pts
                loss_mask_flow = F.relu(1. - self.mask_flow_eps -
                                        out_mask_flow,
                                        inplace=True).mean()
                loss = self.loss_mask_flow_ratio * loss_mask_flow
                eval_dict['loss_mask_flow'] = loss.item()

                if self.view_penalty == 'mask_flow_and_depth':
                    # depth test loss
                    t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1,
                                                      1)
                    gt_mask = data.get('inputs.mask').byte().to(device)
                    depth_pred = data.get('inputs.depth_pred').to(
                        device) * t_scale

                    background_setting(depth_pred, gt_mask)
                    depth_z = out_pts[:, :, 2:].transpose(1, 2)
                    corresponding_z = F.grid_sample(
                        depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                    corresponding_z = corresponding_z.squeeze(1)

                    # eps = 0.05
                    loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                             corresponding_z,
                                             inplace=True).mean()
                    loss = self.loss_depth_test_ratio * loss_depth_test
                    eval_dict['loss_depth_test'] = loss.item()

        return eval_dict
示例#11
0
def compose_inputs(data,
                   mode='train',
                   device=None,
                   input_type='depth_pred',
                   use_gt_depth_map=False,
                   depth_map_mix=False,
                   with_img=False,
                   depth_pointcloud_transfer=None,
                   local=False):

    assert mode in ('train', 'val', 'test')
    raw_data = {}
    if input_type == 'depth_pred':
        gt_mask = data.get('inputs.mask').to(device).byte()
        raw_data['mask'] = gt_mask

        batch_size = gt_mask.size(0)
        if use_gt_depth_map:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            encoder_inputs = gt_depth_maps
            raw_data['depth'] = gt_depth_maps
        else:
            pr_depth_maps = data.get('inputs.depth_pred').to(device)
            background_setting(pr_depth_maps, gt_mask)
            raw_data['depth_pred'] = pr_depth_maps
            if depth_map_mix and mode == 'train':
                gt_depth_maps = data.get('inputs.depth').to(device)
                background_setting(gt_depth_maps, gt_mask)
                raw_data['depth'] = gt_depth_maps

                alpha = torch.rand(batch_size, 1, 1, 1).to(device)
                pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                         alpha)
            encoder_inputs = pr_depth_maps

        if with_img:
            img = data.get('inputs').to(device)
            raw_data[None] = img
            encoder_inputs = {'img': img, 'depth': encoder_inputs}

        if local:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            encoder_inputs = {
                None: encoder_inputs,
                'world_mat': Rt,
                'camera_mat': K,
            }
            raw_data['world_mat'] = Rt
            raw_data['camera_mat'] = K

        return encoder_inputs, raw_data
    elif input_type == 'depth_pointcloud':
        encoder_inputs = data.get('inputs.depth_pointcloud').to(device)

        if depth_pointcloud_transfer is not None:
            if depth_pointcloud_transfer in ('world', 'world_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]
                world_mat = get_world_mat(data, transpose=None, device=device)
                raw_data['world_mat'] = world_mat

                R = world_mat[:, :, :3]
                # R's inverse is R^T
                encoder_inputs = transform_points(encoder_inputs,
                                                  R.transpose(1, 2))
                # or encoder_inputs = transform_points_back(encoder_inputs, R)

                if depth_pointcloud_transfer == 'world_scale_model':
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            elif depth_pointcloud_transfer in ('view', 'view_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]

                if depth_pointcloud_transfer == 'view_scale_model':
                    world_mat = get_world_mat(data,
                                              transpose=None,
                                              device=device)
                    raw_data['world_mat'] = world_mat
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            else:
                raise NotImplementedError

        raw_data['depth_pointcloud'] = encoder_inputs
        if local:
            #assert depth_pointcloud_transfer.startswith('world')
            encoder_inputs = {None: encoder_inputs}

        return encoder_inputs, raw_data
    else:
        raise NotImplementedError
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device
        threshold = self.threshold
        eval_dict = {}

        # Compute elbo
        points = data.get('points').to(device)
        occ = data.get('points.occ').to(device)

        inputs = data.get('inputs', torch.empty(points.size(0), 0)).to(device)
        voxels_occ = data.get('voxels')

        points_iou = data.get('points_iou').to(device)
        occ_iou = data.get('points_iou.occ').to(device)
        '''TRANSFORMATION TO CAMERA SPACE'''
        if self.camera_space:
            transform = data.get('inputs.world_mat').to(device)
            R = transform[:, :, :3]
            points = transform_points(points, transform)
            points_iou = transform_points(points_iou, R)
        '''END'''

        kwargs = {}

        with torch.no_grad():
            elbo, rec_error, kl = self.model.compute_elbo(
                points, occ, inputs, **kwargs)

        eval_dict['loss'] = -elbo.mean().item()
        eval_dict['rec_error'] = rec_error.mean().item()
        eval_dict['kl'] = kl.mean().item()

        # Compute iou
        batch_size = points.size(0)

        with torch.no_grad():
            p_out = self.model(points_iou,
                               inputs,
                               sample=self.eval_sample,
                               **kwargs)

        occ_iou_np = (occ_iou >= 0.5).cpu().numpy()
        occ_iou_hat_np = (p_out.probs >= threshold).cpu().numpy()
        iou = compute_iou(occ_iou_np, occ_iou_hat_np).mean()
        eval_dict['iou'] = iou
        '''
        with torch.no_grad():
            p_out_r = self.model(points_iou_r, inputs,
                               sample=self.eval_sample, **kwargs)

        occ_iou_r_hat_np = (p_out_r.probs >= threshold).cpu().numpy()
        iou_r = compute_iou(occ_iou_np, occ_iou_r_hat_np).mean()
        eval_dict['iou_r'] = iou_r

        data['iou_r'] = iou_r
        data['iou'] = iou
        data['occ_iou_r_hat_np'] = occ_iou_r_hat_np
        data['occ_iou_hat_np'] = occ_iou_hat_np
        
        SAVE IOU DATA
        # Create npz and save
        im_path = str(data.get('inputs.image_path')[0]).split('/')
        model=im_path[3]
        im_nr=im_path[5].split('.')[0]
        out_dir = 'IOU'
        np.savez(os.path.join(out_dir, model+'_'+im_nr), iou_r=iou_r,
                 occ_iou_hat_np=occ_iou_hat_np,
                 occ_iou_r_hat_np=occ_iou_r_hat_np,
                 iou=iou,
                 occ_iou=occ_iou.cpu().numpy(),
                 transform=transform.cpu().numpy(),
                 points_iou=points_iou.cpu().numpy(),
                 points_iou_r=points_iou_r.cpu().numpy(),
                 model=model,
                 im_nr=im_nr,
                 image=inputs.cpu().numpy()
                 )
        END'''

        # Estimate voxel iou
        if voxels_occ is not None:
            voxels_occ = voxels_occ.to(device)
            points_voxels = make_3d_grid((-0.5 + 1 / 64, ) * 3,
                                         (0.5 - 1 / 64, ) * 3, (32, ) * 3)
            points_voxels = points_voxels.expand(batch_size,
                                                 *points_voxels.size())
            points_voxels = points_voxels.to(device)
            with torch.no_grad():
                p_out = self.model(points_voxels,
                                   inputs,
                                   sample=self.eval_sample,
                                   **kwargs)

            voxels_occ_np = (voxels_occ >= 0.5).cpu().numpy()
            occ_hat_np = (p_out.probs >= threshold).cpu().numpy()
            iou_voxels = compute_iou(voxels_occ_np, occ_hat_np).mean()

            eval_dict['iou_voxels'] = iou_voxels

        return eval_dict