コード例 #1
0
def validate(test_loader, model, criterion, config, num_votes=10):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()
    with torch.no_grad():
        end = time.time()
        vote_preds = None
        TS = d_utils.BatchPointcloudScaleAndJitter(scale_low=config.scale_low,
                                                   scale_high=config.scale_high,
                                                   std=config.noise_std)
        for v in range(num_votes):
            preds = []
            targets = []
            for idx, (points, mask, features, target) in enumerate(test_loader):
                # augment for voting
                if v > 0:
                    points = TS(points)
                    if config.input_features_dim == 3:
                        features = points
                        features = features.transpose(1, 2).contiguous()
                    elif config.input_features_dim == 4:
                        features = torch.ones(size=(points.shape[0], points.shape[1], 1), dtype=torch.float32)
                        features = torch.cat([features, points], -1)
                        features = features.transpose(1, 2).contiguous()
                    else:
                        raise NotImplementedError(
                            f"input_features_dim {config.input_features_dim} in voting not supported")

                # forward
                points = points.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

                pred = model(points, mask, features)
                target = target.view(-1)
                loss = criterion(pred, target)
                acc1 = accuracy(pred, target, topk=(1,))
                losses.update(loss.item(), points.size(0))
                top1.update(acc1[0].item(), points.size(0))

                preds.append(pred)
                targets.append(target)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                if idx % config.print_freq == 0:
                    logger.info(
                        f'Test: [{idx}/{len(test_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                        f'Acc@1 {top1.val:.3%} ({top1.avg:.3%})')
            logger.info(f' * Acc@1 {top1.avg:.3%}')
            top1.reset()

            preds = torch.cat(preds, 0)
            targets = torch.cat(targets, 0)
            if vote_preds is None:
                vote_preds = preds
            else:
                vote_preds += preds
            vote_acc1 = accuracy(vote_preds, targets, topk=(1,))[0].item()
            logger.info(f' * Vote{v} Acc@1 {vote_acc1:.3%}')

    return vote_acc1
コード例 #2
0
def validate(epoch,
             test_loader,
             model,
             criterion,
             runing_vote_logits,
             config,
             num_votes=10):
    """one epoch validating
    Args:
        epoch (int or str): current epoch
        test_loader ([type]): [description]
        model ([type]): [description]
        criterion ([type]): [description]
        runing_vote_logits ([type]): [description]
        config ([type]): [description]
        num_votes (int, optional): [description]. Defaults to 10.
    Raises:
        NotImplementedError: [description]
    Returns:
        [int]: mIoU for one epoch over the validation set
    """

    vote_logits_sum = [
        np.zeros((config.num_classes, l.shape[0]), dtype=np.float32)
        for l in test_loader.dataset.sub_clouds_points_labels
    ]
    vote_counts = [
        np.zeros((1, l.shape[0]), dtype=np.float32) + 1e-6
        for l in test_loader.dataset.sub_clouds_points_labels
    ]
    vote_logits = [
        np.zeros((config.num_classes, l.shape[0]), dtype=np.float32)
        for l in test_loader.dataset.sub_clouds_points_labels
    ]
    validation_proj = test_loader.dataset.projections
    validation_labels = test_loader.dataset.clouds_points_labels
    test_smooth = 0.95

    val_proportions = np.zeros(config.num_classes, dtype=np.float32)
    for label_value in range(config.num_classes):
        val_proportions[label_value] = np.sum([
            np.sum(labels == label_value)
            for labels in test_loader.dataset.clouds_points_labels
        ])

    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        end = time.time()
        RT = d_utils.BatchPointcloudRandomRotate(x_range=config.x_angle_range,
                                                 y_range=config.y_angle_range,
                                                 z_range=config.z_angle_range)
        TS = d_utils.BatchPointcloudScaleAndJitter(
            scale_low=config.scale_low,
            scale_high=config.scale_high,
            std=config.noise_std,
            clip=config.noise_clip,
            augment_symmetries=config.augment_symmetries)
        for v in range(num_votes):
            test_loader.dataset.epoch = (0 + v) if isinstance(
                epoch, str) else (epoch + v) % 20
            predictions = []
            targets = []
            for idx, (points, mask, features, points_labels, cloud_label,
                      input_inds) in enumerate(test_loader):
                # augment for voting
                if v > 0:
                    points = RT(points)
                    points = TS(points)
                    if config.input_features_dim <= 5:
                        pass
                    elif config.input_features_dim == 6:
                        color = features[:, :3, :]
                        features = torch.cat(
                            [color, points.transpose(1, 2).contiguous()], 1)
                    elif config.input_features_dim == 7:
                        color_h = features[:, :4, :]
                        features = torch.cat(
                            [color_h,
                             points.transpose(1, 2).contiguous()], 1)
                    else:
                        raise NotImplementedError(
                            f"input_features_dim {config.input_features_dim} in voting not supported"
                        )
                # forward
                points = points.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                points_labels = points_labels.cuda(non_blocking=True)
                cloud_label = cloud_label.cuda(non_blocking=True)
                input_inds = input_inds.cuda(non_blocking=True)

                if config.model_name == 'pointnet':
                    pred, _, transform_feature = model(points, mask, features)
                    loss = criterion(pred, points_labels, mask,
                                     transform_feature)
                else:
                    pred = model(points, mask, features)
                    loss = criterion(pred, points_labels, mask)

                losses.update(loss.item(), points.size(0))

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    mask_i = mask[ib].cpu().numpy().astype(np.bool)
                    logits = pred[ib].cpu().numpy()[:, mask_i]
                    inds = input_inds[ib].cpu().numpy()[mask_i]
                    c_i = cloud_label[ib].item()
                    vote_logits_sum[
                        c_i][:, inds] = vote_logits_sum[c_i][:, inds] + logits
                    vote_counts[c_i][:, inds] += 1
                    vote_logits[c_i] = vote_logits_sum[c_i] / vote_counts[c_i]
                    runing_vote_logits[c_i][:, inds] = test_smooth * runing_vote_logits[c_i][:, inds] + \
                                                       (1 - test_smooth) * logits
                    predictions += [logits]
                    targets += [
                        test_loader.dataset.sub_clouds_points_labels[c_i][inds]
                    ]

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                if idx % config.print_freq == 0:
                    logger.info(
                        f'Test: [{idx}/{len(test_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})')

            pIoUs, pmIoU = s3dis_part_metrics(config.num_classes, predictions,
                                              targets, val_proportions)

            logger.info(f'E{epoch} V{v} * part_mIoU {pmIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * part_msIoU {pIoUs}')

            runsubIoUs, runsubmIoU = sub_s3dis_metrics(
                config.num_classes, runing_vote_logits,
                test_loader.dataset.sub_clouds_points_labels, val_proportions)
            logger.info(f'E{epoch} V{v} * running sub_mIoU {runsubmIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * running sub_msIoU {runsubIoUs}')

            subIoUs, submIoU = sub_s3dis_metrics(
                config.num_classes, vote_logits,
                test_loader.dataset.sub_clouds_points_labels, val_proportions)
            logger.info(f'E{epoch} V{v} * sub_mIoU {submIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * sub_msIoU {subIoUs}')

            IoUs, mIoU = s3dis_metrics(config.num_classes, vote_logits,
                                       validation_proj, validation_labels)
            logger.info(f'E{epoch} V{v} * mIoU {mIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * msIoU {IoUs}')

    return mIoU
コード例 #3
0
def validate(epoch,
             split,
             test_loader,
             model,
             criterion,
             config,
             num_votes=10):
    """
    One epoch validating
    """
    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        end = time.time()
        vote_logits = None
        vote_points_labels = None
        vote_shape_labels = None
        TS = d_utils.BatchPointcloudScaleAndJitter(
            scale_low=config.scale_low,
            scale_high=config.scale_high,
            std=config.noise_std)
        for v in range(num_votes):
            all_logits = []
            all_points_labels = []
            all_shape_labels = []
            for idx, (points, mask, features, points_labels,
                      shape_labels) in enumerate(test_loader):
                # augment for voting
                if v > 0:
                    points = TS(points)
                    if config.input_features_dim == 3:
                        features = points
                        features = features.transpose(1, 2).contiguous()
                    elif config.input_features_dim == 4:
                        features = torch.ones(size=(points.shape[0],
                                                    points.shape[1], 1),
                                              dtype=torch.float32)
                        features = torch.cat([features, points], -1)
                        features = features.transpose(1, 2).contiguous()
                    else:
                        raise NotImplementedError(
                            f"input_features_dim {config.input_features_dim} in voting not supported"
                        )

                # forward
                points = points.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                points_labels = points_labels.cuda(non_blocking=True)
                shape_labels = shape_labels.cuda(non_blocking=True)

                pred = model(points, mask, features)
                loss = criterion(pred, points_labels, shape_labels)
                losses.update(loss.item(), points.size(0))

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    sl = shape_labels[ib]
                    logits = pred[sl][ib]
                    pl = points_labels[ib]
                    all_logits.append(logits.cpu().numpy())
                    all_points_labels.append(pl.cpu().numpy())
                    all_shape_labels.append(sl.cpu().numpy())

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                if idx % config.print_freq == 0:
                    logger.info(
                        f'Test: [{idx}/{len(test_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})')

            if vote_logits is None:
                vote_logits = all_logits
                vote_points_labels = all_points_labels
                vote_shape_labels = all_shape_labels
            else:
                for i in range(len(vote_logits)):
                    vote_logits[i] = vote_logits[i] + (
                        all_logits[i] - vote_logits[i]) / (v + 1)

            msIoU, mpIoU, mmsIoU, mmpIoU = partnet_metrics(
                config.num_classes, config.num_parts, vote_shape_labels,
                vote_logits, vote_points_labels)
            logger.info(
                f'E{epoch} V{v} {split} * mmsIoU {mmsIoU:.3%} mmpIoU {mmpIoU:.3%}'
            )
            logger.info(f'E{epoch} V{v} {split} * msIoU {msIoU}')
            logger.info(f'E{epoch} V{v} {split} * mpIoU {mpIoU}')

    return mmsIoU, mmpIoU
コード例 #4
0
def validate(epoch, test_loader, model, criterion, config, num_votes=10):
    vote_logits_sum = [np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in
                       test_loader.dataset.sub_clouds_points_labels]
    vote_logits = [np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in
                   test_loader.dataset.sub_clouds_points_labels]
    vote_counts = [np.zeros((1, l.shape[0]), dtype=np.float32) + 1e-6 for l in
                   test_loader.dataset.sub_clouds_points_labels]
    validation_proj = test_loader.dataset.projections
    validation_labels = test_loader.dataset.clouds_points_labels

    val_proportions = np.zeros(config.num_classes, dtype=np.float32)
    for label_value in range(config.num_classes):
        val_proportions[label_value] = np.sum(
            [np.sum(labels == label_value) for labels in test_loader.dataset.clouds_points_labels])

    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        end = time.time()
        RT = d_utils.BatchPointcloudRandomRotate(x_range=config.x_angle_range, y_range=config.y_angle_range,
                                                 z_range=config.z_angle_range)
        TS = d_utils.BatchPointcloudScaleAndJitter(scale_low=config.scale_low, scale_high=config.scale_high,
                                                   std=config.noise_std, clip=config.noise_clip,
                                                   augment_symmetries=config.augment_symmetries)

        for v in range(num_votes):
            test_loader.dataset.epoch = v
            for idx, (points, mask, features, points_labels, cloud_label, input_inds) in enumerate(test_loader):
                # augment for voting
                if v > 0:
                    points = RT(points)
                    points = TS(points)
                    if config.input_features_dim <= 5:
                        pass
                    elif config.input_features_dim == 6:
                        color = features[:, :3, :]
                        features = torch.cat([color, points.transpose(1, 2).contiguous()], 1)
                    elif config.input_features_dim == 7:
                        color_h = features[:, :4, :]
                        features = torch.cat([color_h, points.transpose(1, 2).contiguous()], 1)
                    else:
                        raise NotImplementedError(
                            f"input_features_dim {config.input_features_dim} in voting not supported")

                # forward
                points = points.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                points_labels = points_labels.cuda(non_blocking=True)
                cloud_label = cloud_label.cuda(non_blocking=True)
                input_inds = input_inds.cuda(non_blocking=True)

                pred = model(points, mask, features)
                loss = criterion(pred, points_labels, mask)
                losses.update(loss.item(), points.size(0))

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    mask_i = mask[ib].cpu().numpy().astype(np.bool)
                    logits = pred[ib].cpu().numpy()[:, mask_i]
                    inds = input_inds[ib].cpu().numpy()[mask_i]
                    c_i = cloud_label[ib].item()
                    vote_logits_sum[c_i][:, inds] = vote_logits_sum[c_i][:, inds] + logits
                    vote_counts[c_i][:, inds] += 1
                    vote_logits[c_i] = vote_logits_sum[c_i] / vote_counts[c_i]

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                if idx % config.print_freq == 0:
                    logger.info(
                        f'Test: [{idx}/{len(test_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})')
            subIoUs, submIoU = sub_s3dis_metrics(config.num_classes, vote_logits,
                                                 test_loader.dataset.sub_clouds_points_labels, val_proportions)
            logger.info(f'E{epoch} V{v} * sub_mIoU {submIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * sub_msIoU {subIoUs}')

            IoUs, mIoU = s3dis_metrics(config.num_classes, vote_logits, validation_proj, validation_labels)
            logger.info(f'E{epoch} V{v} * mIoU {mIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * msIoU {IoUs}')
    return mIoU
コード例 #5
0
def validate(epoch, test_loader, model, criterion, config, num_votes=10):
    """one epoch validating
    Args:
        epoch (int or str): current epoch
        test_loader ([type]): [description]
        model ([type]): [description]
        criterion ([type]): [description]
        config ([type]): [description]
        num_votes (int, optional): [description]. Defaults to 10.
    Raises:
        NotImplementedError: [description]
    Returns:
        [int]: mIoU for one epoch over the validation set
    """
    vote_logits_sum = [np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in
                       test_loader.dataset.sub_clouds_points_labels] # [item1, item2,...], each item is [C,#sub_pc_pts]
    vote_logits = [np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in
                   test_loader.dataset.sub_clouds_points_labels]# [item1, item2,...], each item is [C,#sub_pc_pts]
    vote_counts = [np.zeros((1, l.shape[0]), dtype=np.float32) + 1e-6 for l in
                   test_loader.dataset.sub_clouds_points_labels]# [item1, item2], each item is [1,#sub_pc_pts]
    validation_proj = test_loader.dataset.projections # projections: find the nearest point of orginal pt in the sub-PC, e.g. indices [[22,1000,],]
    validation_labels = test_loader.dataset.clouds_points_labels
    validation_points = test_loader.dataset.clouds_points
    validation_colors = test_loader.dataset.clouds_points_colors

    val_proportions = np.zeros(config.num_classes, dtype=np.float32) # counts for each category
    for label_value in range(config.num_classes):
        val_proportions[label_value] = np.sum(
            [np.sum(labels == label_value) for labels in test_loader.dataset.clouds_points_labels])

    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        end = time.time()
        RT = d_utils.BatchPointcloudRandomRotate(x_range=config.x_angle_range, y_range=config.y_angle_range,
                                                 z_range=config.z_angle_range)
        TS = d_utils.BatchPointcloudScaleAndJitter(scale_low=config.scale_low, scale_high=config.scale_high,
                                                   std=config.noise_std, clip=config.noise_clip,
                                                   augment_symmetries=config.augment_symmetries)

        for v in range(num_votes):
            test_loader.dataset.epoch = v
            points_list=[]
            for idx, (points, mask, features, points_labels, cloud_label, input_inds) in enumerate(test_loader):

                # augment for voting
                if v > 0:
                    points = RT(points)
                    points = TS(points)
                    if config.input_features_dim <= 5:
                        pass
                    elif config.input_features_dim == 6:
                        color = features[:, :3, :]
                        features = torch.cat([color, points.transpose(1, 2).contiguous()], 1)
                    elif config.input_features_dim == 7:
                        color_h = features[:, :4, :]
                        features = torch.cat([color_h, points.transpose(1, 2).contiguous()], 1)
                    else:
                        raise NotImplementedError(
                            f"input_features_dim {config.input_features_dim} in voting not supported")

                # forward (note: all these are pts/features for sub-pc, as for original pt need the projection var)
                points = points.cuda(non_blocking=True) # BxCx15000 (15000 is the num_points)
                mask = mask.cuda(non_blocking=True) # 15000,
                features = features.cuda(non_blocking=True) # Bx4x15000 (4 is the number of features)
                points_labels = points_labels.cuda(non_blocking=True)  # Bx15000
                cloud_label = cloud_label.cuda(non_blocking=True) # B,
                input_inds = input_inds.cuda(non_blocking=True) # B,15000

                pred = model(points, mask, features) # BxCx15000 (15000 is the num_points)
                loss = criterion(pred, points_labels, mask)
                losses.update(loss.item(), points.size(0))

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    mask_i = mask[ib].cpu().numpy().astype(np.bool) # shape: (15000,) ,original pts 1, masking pts 0
                    logits = pred[ib].cpu().numpy()[:, mask_i] # shape: (C, #original pts), e.g. (5,2943)
                    inds = input_inds[ib].cpu().numpy()[mask_i] # shape: (#orginal pts,), e.g. (2943,)
                    c_i = cloud_label[ib].item() # denote which PC
                    vote_logits_sum[c_i][:, inds] = vote_logits_sum[c_i][:, inds] + logits # amazing, collect sum of logits for each pt in sub-pc
                    vote_counts[c_i][:, inds] += 1 # denote how many logits for each pt already
                    vote_logits[c_i] = vote_logits_sum[c_i] / vote_counts[c_i] # logits for each pt in the sub-pc

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                if idx % config.print_freq == 0:
                    logger.info(
                        f'Test: [{idx}/{len(test_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {losses.val:.4f} ({losses.avg:.4f})')
            subIoUs, submIoU = sub_s3dis_metrics(config.num_classes, vote_logits,
                                                 test_loader.dataset.sub_clouds_points_labels, val_proportions)
            logger.info(f'E{epoch} V{v} * sub_mIoU {submIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * sub_msIoU {subIoUs}')

            # IoUs, mIoU, overall_acc = s3dis_metrics(config.num_classes, vote_logits, validation_proj, validation_labels, image_path=IMAGES_PATH, visualize=True, return_OA=True)

            overall_acc = None
            # for last epoch, we will compute more metrics, plot confusion matrix, show results--yc
            if epoch=='Last':
                # define the confusion matrix saving path
                IMAGES_PATH = os.path.join(config.log_dir,"images")
                os.makedirs(IMAGES_PATH, exist_ok=True)

                # compute metircs and plot confusion matrix
                IoUs, mIoU, overall_acc = s3dis_metrics_vis_CM(
                                                config.num_classes, vote_logits, 
                                                validation_proj, validation_labels,
                                                more_metrics=True,image_path=IMAGES_PATH, 
                                                label_to_names=test_loader.dataset.label_to_names,
                                                visualize_CM=True)

                # save gt and pred results to ply file and show orig/gt/pred results using open3d
                # TODO: uncomment temporarily for evaluting fastly
                save_predicted_results_PLY(vote_logits, 
                                           validation_proj,validation_points,
                                           validation_colors,validation_labels, 
                                           test_path=config.log_dir, 
                                           cloud_names=test_loader.dataset.cloud_names,
                                           open3d_visualize=True)
            else:
                IoUs, mIoU = s3dis_metrics(config.num_classes, vote_logits, validation_proj, validation_labels)

            logger.info(f'E{epoch} V{v} * mIoU {mIoU:.3%}')
            logger.info(f'E{epoch} V{v}  * msIoU {IoUs}')
            if overall_acc:
                logger.info(f'E{epoch} V{v}  * OA {overall_acc:.4%}')

    return mIoU
コード例 #6
0
def validate(epoch, test_loader, model, criterion, config, num_votes=10):
    """
    One epoch validating
    """
    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        all_logits = []
        all_points_labels = []
        all_shape_labels = []
        all_masks = []
        end = time.time()
        TS = d_utils.BatchPointcloudScaleAndJitter(
            scale_low=config.scale_low,
            scale_high=config.scale_high,
            std=config.noise_std,
            clip=config.noise_clip)

        for idx, (points_orig, mask, points_labels,
                  shape_labels) in enumerate(test_loader):
            vote_logits = None
            vote_points_labels = None
            vote_shape_labels = None
            vote_masks = None
            for v in range(num_votes):
                batch_logits = []
                batch_points_labels = []
                batch_shape_labels = []
                batch_masks = []
                # augment for voting
                if v > 0:
                    points = TS(points_orig)
                else:
                    points = points_orig
                # forward
                features = points
                features = features.transpose(1, 2).contiguous()
                points = points.cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                points_labels = points_labels.cuda(non_blocking=True)
                shape_labels = shape_labels.cuda(non_blocking=True)

                pred = model(points, mask, features)
                loss = criterion(pred, points_labels, shape_labels)
                losses.update(loss.item(), points.size(0))

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    sl = shape_labels[ib]
                    logits = pred[sl][ib]
                    pl = points_labels[ib]
                    pmk = mask[ib]
                    batch_logits.append(logits.cpu().numpy())
                    batch_points_labels.append(pl.cpu().numpy())
                    batch_shape_labels.append(sl.cpu().numpy())
                    batch_masks.append(pmk.cpu().numpy().astype(np.bool))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if vote_logits is None:
                    vote_logits = batch_logits
                    vote_points_labels = batch_points_labels
                    vote_shape_labels = batch_shape_labels
                    vote_masks = batch_masks
                else:
                    for i in range(len(vote_logits)):
                        vote_logits[i] = vote_logits[i] + (
                            batch_logits[i] - vote_logits[i]) / (v + 1)

            all_logits += vote_logits
            all_points_labels += vote_points_labels
            all_shape_labels += vote_shape_labels
            all_masks += vote_masks
            if idx % config.print_freq == 0:
                logger.info(
                    f'V{num_votes} Test: [{idx}/{len(test_loader)}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Loss {losses.val:.4f} ({losses.avg:.4f})')

        acc, shape_ious, msIoU, mIoU = shapenetpart_metrics(
            config.num_classes, config.num_parts, all_shape_labels, all_logits,
            all_points_labels, all_masks)
        logger.info(
            f'E{epoch} V{num_votes} * mIoU {mIoU:.3%} msIoU {msIoU:.3%}')
        logger.info(f'E{epoch} V{num_votes} * Acc {acc:.3%}')
        logger.info(f'E{epoch} V{num_votes} * shape_ious {shape_ious}')

    return acc, msIoU, mIoU