示例#1
0
 def update_meters(meters, hists, loss_dict):
     for k, v in loss_dict.items():
         if k == 'semantic_hist':
             assert 'semantic_hist' in hists
             hists['semantic_hist'] += v
             meters['semantic_iou'] = utils.per_class_iu(
                 hists['semantic_hist'])
         else:
             meters[k].update(v)
     return meters, hists
示例#2
0
 def test_pointcloud(self, pred_dir):
     print('Running full pointcloud evaluation.')
     # Join subcloud by their area and subcloud id.
     subcloud_dict = defaultdict(list)
     for i, data_path in enumerate(self.data_paths):
         area, subcloud = data_path.split(os.sep)
         subcloud, _ = os.path.splitext(subcloud)
         subcloud_id = '_'.join(subcloud.split('_')[:-1])
         subcloud_dict[(area, subcloud_id)].append(i)
     # Test independently for each subcloud.
     sys.setrecursionlimit(100000)  # Increase recursion limit for k-d tree.
     pred_list = sorted(os.listdir(pred_dir))
     hist = np.zeros((self.NUM_LABELS, self.NUM_LABELS))
     for subcloud_idx, subcloud_list in enumerate(subcloud_dict.values()):
         print(
             f'Evaluating subcloud {subcloud_idx} / {len(subcloud_dict)}.')
         # Join all predictions and query pointclouds of split data.
         pred = np.zeros((0, 4))
         pointcloud = np.zeros(
             (0, 7)
         )  # CHANGED FROM 7 TO 8 BECAUSE I ASSUME 7 WAS 6 XYZRGB AND 1 LABEL AND NOW I ADDED INTENSITY
         for i in subcloud_list:
             pred = np.vstack(
                 (pred, np.load(os.path.join(pred_dir, pred_list[i]))))
             pointcloud = np.vstack(
                 (pointcloud,
                  self.load_ply(
                      i, FacilityVoxelizationDatasetBase.INTENSITY)[0]))
         # Deduplicate all query pointclouds of split data.
         pointcloud = np.array(
             list(set(tuple(l) for l in pointcloud.tolist())))
         # Run test for each subcloud.
         pred_tree = spatial.KDTree(pred[:, :3], leafsize=500)
         _, result = pred_tree.query(pointcloud[:, :3])
         ptc_pred = pred[result, 3].astype(int)
         ptc_gt = pointcloud[:, -1].astype(int)
         if self.IGNORE_LABELS is not None:
             ptc_pred = self.label2masked[ptc_pred]
             ptc_gt = self.label2masked[ptc_gt]
         hist += fast_hist(ptc_pred, ptc_gt, self.NUM_LABELS)
         # Print results.
         ious = []
         print('Per class IoU:')
         for i, iou in enumerate(per_class_iu(hist) * 100):
             result_str = ''
             if hist.sum(1)[i]:
                 result_str += f'{iou}'
                 ious.append(iou)
             else:
                 result_str += 'N/A'  # Do not print if data not in ground truth.
             print(result_str)
         print(f'Average IoU: {np.nanmean(ious)}')
    def test_pointcloud(self, pred_dir):
        print('Running full pointcloud evaluation.')
        eval_path = os.path.join(pred_dir, 'fulleval')
        os.makedirs(eval_path, exist_ok=True)
        # Join room by their area and room id.
        # Test independently for each room.
        sys.setrecursionlimit(100000)  # Increase recursion limit for k-d tree.
        hist = np.zeros((self.NUM_LABELS, self.NUM_LABELS))
        for i, data_path in enumerate(self.data_paths):
            room_id = self.get_output_id(i)
            pred = np.load(
                os.path.join(pred_dir, 'pred_%04d_%02d.npy' % (i, 0)))

            # save voxelized pointcloud predictions
            save_point_cloud(np.hstack(
                (pred[:, :3],
                 np.array([SCANNET_COLOR_MAP[i] for i in pred[:, -1]]))),
                             f'{eval_path}/{room_id}_voxel.ply',
                             verbose=False)

            fullply_f = self.data_root / data_path
            query_pointcloud = read_plyfile(fullply_f)
            query_xyz = query_pointcloud[:, :3]
            query_label = query_pointcloud[:, -1]
            # Run test for each room.
            pred_tree = spatial.KDTree(pred[:, :3], leafsize=500)
            _, result = pred_tree.query(query_xyz)
            ptc_pred = pred[result, 3].astype(int)
            # Save prediciton in txt format for submission.
            np.savetxt(f'{eval_path}/{room_id}.txt', ptc_pred, fmt='%i')
            # Save prediciton in colored pointcloud for visualization.
            save_point_cloud(np.hstack(
                (query_xyz, np.array([SCANNET_COLOR_MAP[i]
                                      for i in ptc_pred]))),
                             f'{eval_path}/{room_id}.ply',
                             verbose=False)
            # Evaluate IoU.
            if self.IGNORE_LABELS is not None:
                ptc_pred = np.array([self.label_map[x] for x in ptc_pred],
                                    dtype=np.int)
                query_label = np.array(
                    [self.label_map[x] for x in query_label], dtype=np.int)
            hist += fast_hist(ptc_pred, query_label, self.NUM_LABELS)
        ious = per_class_iu(hist) * 100
        print('mIoU: ' + str(np.nanmean(ious)) + '\n'
              'Class names: ' + ', '.join(CLASS_LABELS) + '\n'
              'IoU: ' + ', '.join(np.round(ious, 2).astype(str)))
示例#4
0
 def test_pointcloud(self, pred_dir):
     print('Running full pointcloud evaluation.')
     # Join room by their area and room id.
     room_dict = defaultdict(list)
     for i, data_path in enumerate(self.data_paths):
         area, room = data_path.split(os.sep)
         room, _ = os.path.splitext(room)
         room_id = '_'.join(room.split('_')[:-1])
         room_dict[(area, room_id)].append(i)
     # Test independently for each room.
     sys.setrecursionlimit(100000)  # Increase recursion limit for k-d tree.
     pred_list = sorted(os.listdir(pred_dir))
     hist = np.zeros((self.NUM_LABELS, self.NUM_LABELS))
     for room_idx, room_list in enumerate(room_dict.values()):
         print(f'Evaluating room {room_idx} / {len(room_dict)}.')
         # Join all predictions and query pointclouds of split data.
         pred = np.zeros((0, 4))
         pointcloud = np.zeros((0, 7))
         for i in room_list:
             pred = np.vstack(
                 (pred, np.load(os.path.join(pred_dir, pred_list[i]))))
             pointcloud = np.vstack((pointcloud, self.load_ply(i)[0]))
         # Deduplicate all query pointclouds of split data.
         pointcloud = np.array(
             list(set(tuple(l) for l in pointcloud.tolist())))
         # Run test for each room.
         pred_tree = spatial.KDTree(pred[:, :3], leafsize=500)
         _, result = pred_tree.query(pointcloud[:, :3])
         ptc_pred = pred[result, 3].astype(int)
         ptc_gt = pointcloud[:, -1].astype(int)
         if self.IGNORE_LABELS:
             ptc_pred = self.label2masked[ptc_pred]
             ptc_gt = self.label2masked[ptc_gt]
         hist += fast_hist(ptc_pred, ptc_gt, self.NUM_LABELS)
         # Print results.
         ious = []
         print('Per class IoU:')
         for i, iou in enumerate(per_class_iu(hist) * 100):
             result_str = ''
             if hist.sum(1)[i]:
                 result_str += f'{iou}'
                 ious.append(iou)
             else:
                 result_str += 'N/A'  # Do not print if data not in ground truth.
             print(result_str)
         print(f'Average IoU: {np.nanmean(ious)}')
示例#5
0
def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
    device = get_torch_device(config.is_cuda)
    dataset = data_loader.dataset
    num_labels = dataset.NUM_LABELS
    global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    losses, scores, ious = AverageMeter(), AverageMeter(), 0
    aps = np.zeros((0, num_labels))
    hist = np.zeros((num_labels, num_labels))

    logging.info('===> Start testing')

    global_timer.tic()
    data_iter = data_loader.__iter__()
    max_iter = len(data_loader)
    max_iter_unique = max_iter

    # Fix batch normalization running mean and std
    model.eval()

    # Clear cache (when run in val mode, cleanup training cache)
    torch.cuda.empty_cache()

    if config.save_prediction or config.test_original_pointcloud:
        if config.save_prediction:
            save_pred_dir = config.save_pred_dir
            os.makedirs(save_pred_dir, exist_ok=True)
        else:
            save_pred_dir = tempfile.mkdtemp()
        if os.listdir(save_pred_dir):
            raise ValueError(f'Directory {save_pred_dir} not empty. '
                             'Please remove the existing prediction.')

    with torch.no_grad():
        for iteration in range(max_iter):
            data_timer.tic()
            if config.return_transformation:
                coords, input, target, transformation = data_iter.next()
            else:
                coords, input, target = data_iter.next()
                transformation = None
            data_time = data_timer.toc(False)

            # Preprocess input
            iter_timer.tic()

            if config.wrapper_type != 'None':
                color = input[:, :3].int()
            if config.normalize_color:
                input[:, :3] = input[:, :3] / 255. - 0.5
            sinput = SparseTensor(input, coords).to(device)

            # Feed forward
            inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput,
                                                                       coords,
                                                                       color)
            soutput = model(*inputs)
            output = soutput.F

            pred = get_prediction(dataset, output, target).int()
            iter_time = iter_timer.toc(False)

            if config.save_prediction or config.test_original_pointcloud:
                save_predictions(coords, pred, transformation, dataset, config,
                                 iteration, save_pred_dir)

            if has_gt:
                if config.evaluate_original_pointcloud:
                    raise NotImplementedError('pointcloud')
                    output, pred, target = permute_pointcloud(
                        coords, pointcloud, transformation, dataset.label_map,
                        output, pred)

                target_np = target.numpy()

                num_sample = target_np.shape[0]

                target = target.to(device)

                cross_ent = criterion(output, target.long())
                losses.update(float(cross_ent), num_sample)
                scores.update(precision_at_one(pred, target), num_sample)
                hist += fast_hist(pred.cpu().numpy().flatten(),
                                  target_np.flatten(), num_labels)
                ious = per_class_iu(hist) * 100

                prob = torch.nn.functional.softmax(output, dim=1)
                ap = average_precision(prob.cpu().detach().numpy(), target_np)
                aps = np.vstack((aps, ap))
                # Due to heavy bias in class, there exists class with no test label at all
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    ap_class = np.nanmean(aps, 0) * 100.

            if iteration % config.test_stat_freq == 0 and iteration > 0:
                reordered_ious = dataset.reorder_result(ious)
                reordered_ap_class = dataset.reorder_result(ap_class)
                class_names = dataset.get_classnames()
                print_info(iteration,
                           max_iter_unique,
                           data_time,
                           iter_time,
                           has_gt,
                           losses,
                           scores,
                           reordered_ious,
                           hist,
                           reordered_ap_class,
                           class_names=class_names)

            if iteration % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

    global_time = global_timer.toc(False)

    reordered_ious = dataset.reorder_result(ious)
    reordered_ap_class = dataset.reorder_result(ap_class)
    class_names = dataset.get_classnames()
    print_info(iteration,
               max_iter_unique,
               data_time,
               iter_time,
               has_gt,
               losses,
               scores,
               reordered_ious,
               hist,
               reordered_ap_class,
               class_names=class_names)

    if config.test_original_pointcloud:
        logging.info('===> Start testing on original pointcloud space.')
        dataset.test_pointcloud(save_pred_dir)

    logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

    return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
        per_class_iu(hist)) * 100
示例#6
0
def test(model,
         data_loader,
         config,
         transform_data_fn=None,
         has_gt=True,
         validation=None,
         epoch=None):
    device = get_torch_device(config.is_cuda)
    dataset = data_loader.dataset
    num_labels = dataset.NUM_LABELS
    global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    alpha, gamma, eps = 1, 2, 1e-6  # Focal Loss parameters
    losses, scores, ious = AverageMeter(), AverageMeter(), 0
    aps = np.zeros((0, num_labels))
    hist = np.zeros((num_labels, num_labels))

    if not config.is_train:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            model.load_state_dict(state['state_dict'])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))
    if validation:
        logging.info('===> Start validating')
    else:
        logging.info('===> Start testing')

    global_timer.tic()
    data_iter = data_loader.__iter__()
    max_iter = len(data_loader)
    max_iter_unique = max_iter

    all_preds, all_labels, batch_losses, batch_loss = [], [], {}, 0

    # Fix batch normalization running mean and std
    model.eval()

    # Clear cache (when run in val mode, cleanup training cache)
    torch.cuda.empty_cache()

    if config.save_prediction or config.test_original_pointcloud:
        if config.save_prediction:
            save_pred_dir = config.save_pred_dir
            os.makedirs(save_pred_dir, exist_ok=True)
        else:
            save_pred_dir = tempfile.mkdtemp()
        if os.listdir(save_pred_dir):
            raise ValueError(f'Directory {save_pred_dir} not empty. '
                             'Please remove the existing prediction.')

    with torch.no_grad():
        for iteration in range(max_iter):
            data_timer.tic()
            if config.return_transformation:
                coords, input, target, transformation = data_iter.next()
            else:
                coords, input, target = data_iter.next()
                transformation = None
            data_time = data_timer.toc(False)

            # Preprocess input
            iter_timer.tic()

            if config.wrapper_type != 'None':
                color = input[:, :3].int()
            if config.normalize_color:
                input[:, :3] = input[:, :3] / 255. - 0.5
            sinput = SparseTensor(input, coords).to(device)

            # Feed forward
            inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput,
                                                                       coords,
                                                                       color)
            soutput = model(*inputs)
            output = soutput.F

            pred = get_prediction(dataset, output, target).int()
            iter_time = iter_timer.toc(False)

            all_preds.append(pred.cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())

            if config.save_prediction or config.test_original_pointcloud:
                save_predictions(coords, pred, transformation, dataset, config,
                                 iteration, save_pred_dir)

            if has_gt:
                if config.evaluate_original_pointcloud:
                    raise NotImplementedError('pointcloud')
                    output, pred, target = permute_pointcloud(
                        coords, pointcloud, transformation, dataset.label_map,
                        output, pred)

                target_np = target.numpy()
                num_sample = target_np.shape[0]
                target = target.to(device)
                """# focal loss
        input_soft = nn.functional.softmax(output, dim=1) + eps
        focal_weight = torch.pow(-input_soft + 1., gamma)
        loss = (-alpha * focal_weight * torch.log(input_soft)).mean()"""

                loss = criterion(output, target.long())

                batch_loss += loss

                losses.update(float(loss), num_sample)
                scores.update(precision_at_one(pred, target), num_sample)
                hist += fast_hist(pred.cpu().numpy().flatten(),
                                  target_np.flatten(), num_labels)
                ious = per_class_iu(hist) * 100

                prob = torch.nn.functional.softmax(output, dim=1)
                ap = average_precision(prob.cpu().detach().numpy(), target_np)
                aps = np.vstack((aps, ap))
                # Due to heavy bias in class, there exists class with no test label at all
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    ap_class = np.nanmean(aps, 0) * 100.

            if iteration % config.test_stat_freq == 0 and iteration > 0:
                preds = np.concatenate(all_preds)
                targets = np.concatenate(all_labels)
                to_ignore = [
                    i for i in range(len(targets)) if targets[i] == 255
                ]
                preds_trunc = [
                    preds[i] for i in range(len(preds)) if i not in to_ignore
                ]
                targets_trunc = [
                    targets[i] for i in range(len(targets))
                    if i not in to_ignore
                ]
                cm = confusion_matrix(targets_trunc,
                                      preds_trunc,
                                      normalize='true')
                np.savetxt(config.log_dir + '/cm_epoch_{0}.txt'.format(epoch),
                           cm)

                reordered_ious = dataset.reorder_result(ious)
                reordered_ap_class = dataset.reorder_result(ap_class)
                class_names = dataset.get_classnames()
                print_info(iteration,
                           max_iter_unique,
                           data_time,
                           iter_time,
                           has_gt,
                           losses,
                           scores,
                           reordered_ious,
                           hist,
                           reordered_ap_class,
                           class_names=class_names)

            if iteration % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

            batch_losses[epoch] = batch_loss

    global_time = global_timer.toc(False)

    reordered_ious = dataset.reorder_result(ious)
    reordered_ap_class = dataset.reorder_result(ap_class)
    class_names = dataset.get_classnames()
    print_info(iteration,
               max_iter_unique,
               data_time,
               iter_time,
               has_gt,
               losses,
               scores,
               reordered_ious,
               hist,
               reordered_ap_class,
               class_names=class_names)

    if not config.is_train:
        preds = np.concatenate(all_preds)
        targets = np.concatenate(all_labels)
        to_ignore = [i for i in range(len(targets)) if targets[i] == 255]
        preds_trunc = [
            preds[i] for i in range(len(preds)) if i not in to_ignore
        ]
        targets_trunc = [
            targets[i] for i in range(len(targets)) if i not in to_ignore
        ]
        cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true')
        np.savetxt(config.log_dir + '/cm.txt', cm)

    if config.test_original_pointcloud:
        logging.info('===> Start testing on original pointcloud space.')
        dataset.test_pointcloud(save_pred_dir)
    logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

    if validation:
        loss_file_name = "/val_loss.txt"
        with open(config.log_dir + loss_file_name, 'a') as val_loss_file:
            for key in batch_losses:
                val_loss_file.writelines('{0}, {1}\n'.format(
                    batch_losses[key], key))
        val_loss_file.close()
        return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
            per_class_iu(hist)) * 100, batch_losses

    else:
        return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
            per_class_iu(hist)) * 100
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, save_pred=False, split=None, submit_dir=None):
    device = get_torch_device(config.is_cuda)
    dataset = data_loader.dataset
    num_labels = dataset.NUM_LABELS
    global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    losses, scores, ious = AverageMeter(), AverageMeter(), 0
    aps = np.zeros((0, num_labels))
    hist = np.zeros((num_labels, num_labels))

    # some cfgs concerning the usage of instance-level information
    config.save_pred = save_pred
    if split is not None:
        assert save_pred
    if config.save_pred:
        save_dict = {}
        save_dict['pred'] = []
        save_dict['coord'] = []

    logging.info('===> Start testing')

    global_timer.tic()
    data_iter = data_loader.__iter__()
    max_iter = len(data_loader)
    max_iter_unique = max_iter

    # Fix batch normalization running mean and std
    model.eval()

    # Clear cache (when run in val mode, cleanup training cache)
    torch.cuda.empty_cache()

    # semantic kitti label inverse mapping
    if config.submit:
        remap_lut = Remap().getRemapLUT()

    with torch.no_grad():

        # Calc of the iou
        total_correct = np.zeros(num_labels)
        total_seen = np.zeros(num_labels)
        total_positive = np.zeros(num_labels)
        point_nums = np.zeros([19])

        for iteration in range(max_iter):
            data_timer.tic()
            if config.return_transformation:
                coords, input, target, unique_map_list, inverse_map_list, pointcloud, transformation, filename = data_iter.next()
            else:
                coords, input, target, unique_map_list, inverse_map_list, filename = data_iter.next()
            data_time = data_timer.toc(False)

            if config.use_aux:
                assert target.shape[1] == 2
                aux = target[:,1]
                target = target[:,0]
            else:
                aux = None

            # Preprocess input
            iter_timer.tic()

            if config.normalize_color:
                input[:, :3] = input[:, :3] / input[:,:3].max() - 0.5
                coords_norm = coords[:,1:] / coords[:,1:].max() - 0.5

            XYZ_INPUT = config.xyz_input
            # cat xyz into the rgb feature
            if XYZ_INPUT:
                input = torch.cat([coords_norm, input], dim=1)

            sinput = ME.SparseTensor(input, coords, device=device)

            # Feed forward
            if aux is not None:
                soutput = model(sinput)
            else:
                soutput = model(sinput, iter_ = iteration / max_iter, enable_point_branch=config.enable_point_branch)
            output = soutput.F
            if torch.isnan(output).sum() > 0:
                import ipdb; ipdb.set_trace()

            pred = get_prediction(dataset, output, target).int()
            assert sum([int(t.shape[0]) for t in unique_map_list]) == len(pred), "number of points in unique_map doesn't match predition, do not enable preprocessing"
            iter_time = iter_timer.toc(False)

            if config.save_pred or config.submit:
                # troublesome processing for splitting each batch's data, and export
                batch_ids = sinput.C[:,0]
                splits_at = torch.stack([torch.where(batch_ids == i)[0][-1] for i in torch.unique(batch_ids)]).int()
                splits_at = splits_at + 1
                splits_at_leftshift_one = splits_at.roll(shifts=1)
                splits_at_leftshift_one[0] = 0
                # len_per_batch = splits_at - splits_at_leftshift_one
                len_sum = 0
                batch_id = 0
                for start, end in zip(splits_at_leftshift_one, splits_at):
                    len_sum += len(pred[int(start):int(end)])
                    pred_this_batch = pred[int(start):int(end)]
                    coord_this_batch = pred[int(start):int(end)]
                    if config.save_pred:
                        save_dict['pred'].append(pred_this_batch[inverse_map_list[batch_id]])
                    else: # save submit result
                        submission_path = filename[batch_id].replace(config.semantic_kitti_path, submit_dir).replace('velodyne', 'predictions').replace('.bin', '.label')
                        parent_dir = Path(submission_path).parent.absolute()
                        if not os.path.exists(parent_dir):
                            os.makedirs(parent_dir)
                        label_pred = pred_this_batch[inverse_map_list[batch_id]].cpu().numpy()
                        label_pred = remap_lut[label_pred].astype(np.uint32)
                        label_pred.tofile(submission_path)
                        print(submission_path)
                    batch_id += 1
                assert len_sum == len(pred)

            # Unpack it to original length
            REVERT_WHOLE_POINTCLOUD = True
            # print('{}/{}'.format(iteration, max_iter))
            if REVERT_WHOLE_POINTCLOUD:
                whole_pred = []
                whole_target = []
                for batch_ in range(config.batch_size):
                    batch_mask_ = (soutput.C[:,0] == batch_).cpu().numpy()
                    if batch_mask_.sum() == 0: # for empty batch, skip em 
                        continue
                    try:
                        whole_pred_ = soutput.F[batch_mask_][inverse_map_list[batch_]]
                    except:
                        import ipdb; ipdb.set_trace()
                    whole_target_ = target[batch_mask_][inverse_map_list[batch_]]
                    whole_pred.append(whole_pred_)
                    whole_target.append(whole_target_)
                whole_pred = torch.cat(whole_pred, dim=0)
                whole_target = torch.cat(whole_target, dim=0)

                pred = get_prediction(dataset, whole_pred, whole_target).int()
                output = whole_pred
                target = whole_target

            if has_gt:
                target_np = target.numpy()
                num_sample = target_np.shape[0]
                target = target.to(device)
                output = output.to(device)

                cross_ent = criterion(output, target.long())
                losses.update(float(cross_ent), num_sample)
                scores.update(precision_at_one(pred, target), num_sample)
                hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) # within fast hist, mark label should >=0 & < num_label to filter out 255 / -1
                ious = per_class_iu(hist) * 100
                prob = torch.nn.functional.softmax(output, dim=-1)

                pred = pred[target != -1]
                target = target[target != -1]

                # for _ in range(num_labels): # debug for SemKITTI: spvnas way of calc miou
                    # total_seen[_] += torch.sum(target == _)
                    # total_correct[_] += torch.sum((pred == target) & (target == _))
                    # total_positive[_] += torch.sum(pred == _)

                # ious_ = []
                # for _ in range(num_labels):
                    # if total_seen[_] == 0:
                        # ious_.append(1)
                    # else:
                        # ious_.append(total_correct[_]/(total_seen[_] + total_positive[_] - total_correct[_]))
                # ious_ = torch.stack(ious_, dim=-1).cpu().numpy()*100
                # print(np.nanmean(per_class_iu(hist)), np.nanmean(ious_))
                # ious = np.array(ious_)*100


                # calc the ratio of total points
                # for i_ in range(19):
                    # point_nums[i_] += (target == i_).sum().detach()

                # skip calculating aps
                ap = average_precision(prob.cpu().detach().numpy(), target_np)
                aps = np.vstack((aps, ap))
                # Due to heavy bias in class, there exists class with no test label at all
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    ap_class = np.nanmean(aps, 0) * 100.

            if iteration % config.test_stat_freq == 0 and iteration > 0 and not config.submit:
                reordered_ious = dataset.reorder_result(ious)
                reordered_ap_class = dataset.reorder_result(ap_class)
                # dirty fix for semnaticcKITTI has no getclassnames
                if hasattr(dataset, "class_names"):
                    class_names = dataset.get_classnames()
                else: # semnantic KITTI
                    class_names = None
                print_info(
                        iteration,
                        max_iter_unique,
                        data_time,
                        iter_time,
                        has_gt,
                        losses,
                        scores,
                        reordered_ious,
                        hist,
                        reordered_ap_class,
                        class_names=class_names)

            if iteration % 5 == 0:
                # Clear cache
                torch.cuda.empty_cache()

    if config.save_pred:
        # torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}_with_coord.pth'.format(split)))
        torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}.pth'.format(split)))
        print("===> saved prediction result")

    global_time = global_timer.toc(False)

    save_map(model, config)

    reordered_ious = dataset.reorder_result(ious)
    reordered_ap_class = dataset.reorder_result(ap_class)
    if hasattr(dataset, "class_names"):
        class_names = dataset.get_classnames()
    else:
        class_names = None
    print_info(
            iteration,
            max_iter_unique,
            data_time,
            iter_time,
            has_gt,
            losses,
            scores,
            reordered_ious,
            hist,
            reordered_ap_class,
            class_names=class_names)

    logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(per_class_iu(hist)) * 100