Ejemplo n.º 1
0
def create_input_batch(batch, is_minknet, device="cuda", quantization_size=0.05):
    if is_minknet:
        batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
        return ME.TensorField(
            coordinates=batch["coordinates"], features=batch["features"], device=device,
        )
    else:
        return batch["coordinates"].permute(0, 2, 1).to(device)
 def forward(self, input):
     #return input.view(input.size(0), -1)
     output = torch.stack([t.view(-1) for t in input.decomposed_features])
     if isinstance(input, ME.TensorField):
         return ME.TensorField(
             output,
             coordinate_field_map_key=input.coordinate_field_map_key,
             coordinate_manager=input.coordinate_manager,
             quantization_mode=input.quantization_mode,
         )
     else:
         return ME.SparseTensor(
             output,
             coordinate_map_key=input.coordinate_map_key,
             coordinate_manager=input.coordinate_manager,
         )
Ejemplo n.º 3
0
def test(net, test_iter, config, phase="val"):
    net.eval()
    num_correct, tot_num = 0, 0
    for i in range(len(test_iter)):
        data_dict = test_iter.next()
        tfield = ME.TensorField(data_dict["feats"],
                                data_dict["coords"],
                                device=device)
        sout = net(tfield)
        is_correct = data_dict["labels"] == torch.argmax(sout.F, 1).cpu()
        num_correct += is_correct.sum().item()
        tot_num += len(sout)

        if i % config.empty_freq == 0:
            torch.cuda.empty_cache()

        if i % config.stat_freq == 0:
            logging.info(
                f"{phase} set iter: {i} / {len(test_iter)}, Accuracy : {num_correct / tot_num:.3e}"
            )

    logging.info(f"{phase} set accuracy : {num_correct / tot_num:.3e}")
Ejemplo n.º 4
0
    # Define a model and load the weights
    model = MinkUNet34C(3, 20).to(device)
    model_dict = torch.load(config.weights)
    model.load_state_dict(model_dict)
    model.eval()

    coords, colors, pcd = load_file(config.file_name)
    # Measure time
    with torch.no_grad():
        voxel_size = 0.02
        # Feed-forward pass and get the prediction
        in_field = ME.TensorField(
            features=normalize_color(torch.from_numpy(colors)),
            coordinates=ME.utils.batched_coordinates([coords / voxel_size],
                                                     dtype=torch.float32),
            quantization_mode=ME.SparseTensorQuantizationMode.
            UNWEIGHTED_AVERAGE,
            minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
            device=device,
        )
        # Convert to a sparse tensor
        sinput = in_field.sparse()
        # Output sparse tensor
        soutput = model(sinput)
        # get the prediction on the input tensor field
        out_field = soutput.slice(in_field)
        logits = out_field.F

    _, pred = logits.max(1)
    pred = pred.cpu().numpy()
Ejemplo n.º 5
0
def train(net, device, config):
    optimizer = optim.SGD(
        net.parameters(),
        lr=config.lr,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
    )
    scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer,
        0.999,
    )

    crit = torch.nn.CrossEntropyLoss()

    train_dataloader = make_data_loader(
        "train",
        augment_data=True,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        repeat=True,
        config=config,
    )
    val_dataloader = make_data_loader(
        "val",
        augment_data=False,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        repeat=True,
        config=config,
    )

    curr_iter = 0
    if os.path.exists(config.weights):
        checkpoint = torch.load(config.weights)
        net.load_state_dict(checkpoint["state_dict"])
        if config.load_optimizer.lower() == "true":
            curr_iter = checkpoint["curr_iter"] + 1
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])

    net.train()
    train_iter = iter(train_dataloader)
    val_iter = iter(val_dataloader)
    logging.info(f"LR: {scheduler.get_lr()}")
    for i in range(curr_iter, config.max_iter):

        s = time()
        data_dict = train_iter.next()
        d = time() - s

        optimizer.zero_grad()
        sin = ME.TensorField(data_dict["feats"],
                             data_dict["coords"],
                             device=device)
        sout = net(sin)
        loss = crit(sout.F, data_dict["labels"].to(device))
        loss.backward()
        optimizer.step()
        t = time() - s

        if i % config.empty_freq == 0:
            torch.cuda.empty_cache()

        if i % config.stat_freq == 0:
            logging.info(
                f"Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}"
            )

        if i % config.val_freq == 0 and i > 0:
            torch.save(
                {
                    "state_dict": net.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "curr_iter": i,
                },
                config.weights,
            )

            # Validation
            logging.info("Validation")
            test(net, val_iter, config, "val")

            logging.info(f"LR: {scheduler.get_lr()}")

            net.train()

            # one epoch
            scheduler.step()
Ejemplo n.º 6
0
    # Use minkowski_collate_fn for pointnet
    minknet_data_loader = DataLoader(
        dataset,
        num_workers=4,
        collate_fn=minkowski_collate_fn,
        batch_size=16,
    )

    # Network
    pointnet = PointNet(in_channel=3, out_channel=20, embedding_channel=1024)
    minkpointnet = MinkowskiPointNet(in_channel=3,
                                     out_channel=20,
                                     embedding_channel=1024,
                                     dimension=3)

    for i, (pointnet_batch, minknet_batch) in enumerate(
            zip(pointnet_data_loader, minknet_data_loader)):
        # PointNet.
        # WARNING: PointNet inputs must have the same number of points.
        pointnet_input = pointnet_batch["coordinates"].permute(0, 2, 1)
        pred = pointnet(pointnet_input)

        # MinkNet
        # Unlike pointnet, number of points for each point cloud do not need to be the same.
        minknet_input = ME.TensorField(
            coordinates=minknet_batch["coordinates"],
            features=minknet_batch["features"])
        minkpointnet(minknet_input)
        print(f"Processed batch {i}")
Ejemplo n.º 7
0
def train_point(model,
                data_loader,
                val_data_loader,
                config,
                transform_data_fn=None):

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    while is_training:

        num_class = 20
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()
            for sub_iter in range(config.iter_size):
                # Get training data
                data = data_iter.next()
                points, target, sample_weight = data
                if config.pure_point:

                    sinput = points.transpose(1, 2).cuda().float()

                    # DEBUG: use the discrete coord for point-based
                    '''

                        feats = torch.unbind(points[:,:,:], dim=0)
                        voxel_size = config.voxel_size
                        coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0)  # 0.05 is the voxel-size
                        coords, feats= ME.utils.sparse_collate(coords, feats)
                        # assert feats.reshape([16, 4096, -1]) == points[:,:,3:]
                        points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device)
                        tmp_voxel = points_.sparse()
                        sinput_ = tmp_voxel.slice(points_)
                        sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6])
                        # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6])
                        sinput = sinput.transpose(1,2).cuda().float()

                        # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6])
                        # sinput = sinput.transpose(1,2).cuda().float()
                        '''

                    # For some networks, making the network invariant to even, odd coords is important
                    # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                    # Preprocess input
                    # if config.normalize_color:
                    # feats = feats / 255. - 0.5

                    # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth')
                    # torch.save(points_.C, './sandbox/points-c.pth')

                else:
                    # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature
                    voxel_size = config.voxel_size
                    coords = torch.unbind(points[:, :, :3] / voxel_size,
                                          dim=0)  # 0.05 is the voxel-size
                    # Normalize the xyz in feature
                    # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean()
                    feats = torch.unbind(points[:, :, :], dim=0)
                    coords, feats = ME.utils.sparse_collate(coords, feats)

                    # For some networks, making the network invariant to even, odd coords is important
                    coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                    # Preprocess input
                    # if config.normalize_color:
                    # feats = feats / 255. - 0.5

                    # they are the same
                    points_ = ME.TensorField(features=feats.float(),
                                             coordinates=coords,
                                             device=device)
                    # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
                    # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE)
                    sinput = points_.sparse()

                data_time += data_timer.toc(False)
                B, npoint = target.shape

                # model.initialize_coords(*init_args)
                soutput = model(sinput)
                if config.pure_point:
                    soutput = soutput.reshape([B * npoint, -1])
                else:
                    soutput = soutput.slice(points_).F
                    # s1 = soutput.slice(points_)
                    # print(soutput.quantization_mode)
                    # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE
                    # s2 = soutput.slice(points_)

                # The output of the network is not sorted
                target = (target - 1).view(-1).long().to(device)

                # catch NAN
                if torch.isnan(soutput).sum() > 0:
                    import ipdb
                    ipdb.set_trace()

                loss = criterion(soutput, target)

                if torch.isnan(loss).sum() > 0:
                    import ipdb
                    ipdb.set_trace()

                loss = (loss * sample_weight.to(device)).mean()

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()
                # print(model.input_mlp[0].weight.max())
                # print(model.input_mlp[0].weight.grad.max())

            # Update number of steps
            optimizer.step()
            scheduler.step()

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput, target)
            score = precision_at_one(pred, target, ignore_label=-1)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # Calc the iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target >= 0)) |
                                            (target == l)).sum()

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation:
            # for point-based should use alternate dataloader for eval
            # if curr_iter % config.val_freq == 0:
            # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn)
            # if val_miou > best_val_miou:
            # best_val_miou = val_miou
            # best_val_iter = curr_iter
            # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter,
            # "best_val")
            # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))

            # # Recover back
            # model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)

    test_points(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Ejemplo n.º 8
0
        for batch_size in [1, 2, 4, 6, 8, 10, 12]:
            timer = Timer()
            coordinates = ME.utils.batched_coordinates(
                [coords / voxel_size for i in range(batch_size)],
                return_int=False)
            features = torch.rand(len(coordinates), 3).float()
            with torch.no_grad():
                for i in range(10):
                    timer.tic()
                    # Feed-forward pass and get the prediction
                    in_field = ME.TensorField(
                        features=features,
                        coordinates=coordinates,
                        quantization_mode=ME.SparseTensorQuantizationMode.
                        UNWEIGHTED_AVERAGE,
                        minkowski_algorithm=ME.MinkowskiAlgorithm.
                        SPEED_OPTIMIZED,
                        # minkowski_algorithm=ME.MinkowskiAlgorithm.MEMORY_EFFICIENT,
                        allocator_type=ME.GPUMemoryAllocatorType.PYTORCH,
                        device=device,
                    )
                    # Convert to a sparse tensor
                    sinput = in_field.sparse()
                    # Output sparse tensor
                    soutput = model(sinput)
                    # get the prediction on the input tensor field
                    out_field = soutput.slice(in_field)
                    timer.toc()
            print(batch_size, soutput.shape, timer.min_time)

    elif ME.__version__.split(".")[1] == "4":
def test_points(model,
                                 data_loader,
                                 config,
                                 with_aux=False,
                                 save_dir=None,
                                 split='eval',
                                 use_voxel=True):
        '''
        :param pn_list: sn (list => int), the number of points in a scene
        :param scene_list: sn (list => str), scene id
        '''

        pn_list = data_loader.dataset.point_num
        scene_list = data_loader.dataset.scene_list
        SEM_LABELS = data_loader.dataset.semantic_labels_list
        NUM_CLASSES = data_loader.dataset.NUM_LABELS

        model.eval()
        total_seen = 0
        total_correct = 0
        total_seen_class = [0] * NUM_CLASSES
        total_correct_class = [0] * NUM_CLASSES
        total_iou_deno_class = [0] * NUM_CLASSES

        if save_dir is not None:
                save_dict = {}
                save_dict['pred'] = []

        use_voxel = not config.pure_point

        scene_num = len(scene_list)
        for scene_index in range(scene_num):
                logging.info(' ======= {}/{} ======= '.format(scene_index, scene_num))
                # scene_index = 0
                scene_id = scene_list[scene_index]
                point_num = pn_list[scene_index]
                predict = np.zeros((point_num, NUM_CLASSES), dtype=np.float32) # pn,21
                vote_num = np.zeros((point_num, 1), dtype=np.int) # pn,1
                for idx, batch_data in enumerate(data_loader):
                        # logging.info('batch {}'.format(idx))
                        if with_aux:
                                pc, seg, aux, smpw, pidx= batch_data
                                aux = aux.cuda()
                                seg = seg.cuda()
                        else:
                                pc, seg, smpw, pidx= batch_data
                        if pidx.max() > point_num:
                                import ipdb; ipdb.set_trace()

                        pc = pc.cuda().float()
                        '''
                        use voxel-forward for testing the scannet
                        '''
                        if use_voxel:
                                coords = torch.unbind(pc[:,:,:3]/config.voxel_size, dim=0)
                                # Normalize the xyz after the coord is set
                                # pc[:,:,:3] = pc[:,:,:3] / pc[:,:,:3].mean()
                                feats = torch.unbind(pc[:,:,:], dim=0) # use all 6 chs for eval
                                coords, feats= ME.utils.sparse_collate(coords, feats) # the returned coords adds a batch-dimw
                                pc = ME.TensorField(features=feats.float(),coordinates=coords.cuda()) # [xyz, norm_xyz, rgb]
                                voxels = pc.sparse()
                                seg = seg.view(-1)
                                inputs = voxels

                        else:
                                # DEBUG: discrete input xyz for point-based method
                                feats = torch.unbind(pc[:,:,:], dim=0)
                                coords = torch.unbind(pc[:,:,:3]/config.voxel_size, dim=0)
                                coords, feats= ME.utils.sparse_collate(coords, feats) # the returned coords adds a batch-dim

                                pc = ME.TensorField(features=feats.float(),coordinates=coords.cuda()) # [xyz, norm_xyz, rgb]
                                voxels = pc.sparse()
                                pc_ = voxels.slice(pc)
                                # pc = torch.cat([pc_.C[:,1:],pc_.F[:,:3:]],dim=1).reshape([-1, config.num_points, 6])
                                pc = pc_.F.reshape([-1, config.num_points, 6])

                                # discrete_coords = coords.reshape([-1, config.num_points, 4])[:,:,1:] # the batch does not have drop-last
                                # pc[:,:,:3] = discrete_coords

                                pc = pc.transpose(1,2)
                                inputs = pc

                        if with_aux:
                                # DEBUG: Use target as instance for now
                                pred = model(inputs, instance=aux) # B,N,C
                        else:
                                pred = model(inputs) # B,N,C

                        if use_voxel:
                                assert isinstance(pred, ME.SparseTensor)
                                pred = pred.slice(pc).F
                                try:
                                        pred = pred.reshape([-1, config.num_points, NUM_CLASSES])       # leave the 1st dim, since no droplast
                                except RuntimeError:
                                        import ipdb; ipdb.set_trace()

                        pred = torch.nn.functional.softmax(pred, dim=2)
                        pred = pred.cpu().detach().numpy()

                        pidx = pidx.numpy() # B,N
                        predict, vote_num = vote(predict, vote_num, pred, pidx)

                predict = predict / vote_num

                if save_dir is not None:
                        if np.isnan(predict).any():
                                print("found nan in scene{}".format(scene_id))
                                import ipdb; ipdb.set_trace()
                        save_dict['pred'].append(np.argmax(predict, axis=-1))

                # predict = np.argmax(predict[:, 1:], axis=-1) # pn  # debug WHY?
                predict = np.argmax(predict, axis=-1) # pn
                labels = SEM_LABELS[scene_index]

                '''
                additional logic for handling 20 class output
                '''
                labels = labels - 1
                correct = predict == labels
                correct = correct[labels != -1]

                total_seen += np.sum(labels.size) # point_num
                # total_correct += np.sum((predict == labels) & (labels > 0))
                total_correct += np.sum(correct)
                logging.info('accuracy:{} '.format(total_correct / total_seen))
                for l in range(NUM_CLASSES):
                        total_seen_class[l] += np.sum((labels == l) & (labels >= 0))
                        total_correct_class[l] += np.sum((predict == l) & (labels == l))
                        total_iou_deno_class[l] += np.sum(((predict == l) & (labels >= 0)) | (labels == l))

                '''Uncomment this to save the map, this could take about 500M sapce'''
                # save_map(model, config)
                # import ipdb; ipdb.set_trace()

        # final save
        if save_dir is not None:
                torch.save(save_dict, os.path.join(save_dir,'{}_pred.pth'.format(split)))

        IoU = np.array(total_correct_class)/(np.array(total_iou_deno_class,dtype=np.float)+1e-6)
        logging.info('eval point avg class IoU: %f' % (np.mean(IoU)))
        IoU_Class = 'Each Class IoU:::\n'
        for i in range(IoU.shape[0]):
                logging.info('Class %d : %.4f'%(i+1, IoU[i]))
        logging.info('eval accuracy: %f'% (total_correct / float(total_seen)))
        logging.info('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/(np.array(total_seen_class,dtype=np.float)+1e-6))))
def test_scannet(args,
                 model,
                 dst_loader,
                 log_string,
                 with_aux=False,
                 save_dir=None,
                 split='eval',
                 use_voxel=False):
    '''
    :param pn_list: sn (list => int), the number of points in a scene
    :param scene_list: sn (list => str), scene id
    '''

    pn_list = dst_loader.dataset.point_num
    scene_list = dst_loader.dataset.scene_list
    SEM_LABELS = dst_loader.dataset.semantic_labels_list

    model.eval()
    total_seen = 0
    total_correct = 0
    total_seen_class = [0] * NUM_CLASSES
    total_correct_class = [0] * NUM_CLASSES
    total_iou_deno_class = [0] * NUM_CLASSES

    if save_dir is not None:
        save_dict = {}
        save_dict['pred'] = []

    scene_num = len(scene_list)
    for scene_index in range(scene_num):
        log_string(' ======= {}/{} ======= '.format(scene_index, scene_num))
        # scene_index = 0
        scene_id = scene_list[scene_index]
        point_num = pn_list[scene_index]
        predict = np.zeros((point_num, NUM_CLASSES), dtype=np.float32)  # pn,21
        vote_num = np.zeros((point_num, 1), dtype=np.int)  # pn,1
        for idx, batch_data in enumerate(dst_loader):
            log_string('batch {}'.format(idx))
            if with_aux:
                pc, seg, aux, smpw, pidx = batch_data
                aux = aux.cuda()
                seg = seg.cuda()
            else:
                pc, seg, smpw, pidx = batch_data
            if pidx.max() > point_num:
                import ipdb
                ipdb.set_trace()
            pc = pc.cuda().float()
            '''
            use voxel-forward for testing the scannet
            '''
            if use_voxel:
                feats = torch.unbind(pc[:, :, 6:], dim=0)
                coords = torch.unbind(pc[:, :, :3] / args.voxel_size, dim=0)
                coords, feats = ME.utils.sparse_collate(
                    coords, feats)  # the returned coords adds a batch-dim
                pc = ME.TensorField(
                    features=feats.float(),
                    coordinates=coords.cuda())  # [xyz, norm_xyz, rgb]
                voxels = pc.sparse()
                seg = seg.view(-1)
                inputs = voxels

            else:
                pc = pc.transpose(1, 2)
                inputs = pc

            if with_aux:
                # DEBUG: Use target as instance for now
                pred = model(inputs, instance=aux)  # B,N,C
            else:
                pred = model(inputs)  # B,N,C
            if use_voxel:
                assert isinstance(pred, ME.SparseTensor)
                pred = pred.slice(pc).F
                try:
                    pred = pred.reshape(
                        [-1, args.num_point,
                         NUM_CLASSES])  # leave the 1st dim, since no droplast
                except RuntimeError:
                    import ipdb
                    ipdb.set_trace()

            pred = torch.nn.functional.softmax(pred, dim=2)
            pred = pred.cpu().detach().numpy()

            pidx = pidx.numpy()  # B,N
            predict, vote_num = vote(predict, vote_num, pred, pidx)

        predict = predict / vote_num

        if save_dir is not None:
            if np.isnan(predict).any():
                print("found nan in scene{}".format(scene_id))
                import ipdb
                ipdb.set_trace()
            save_dict['pred'].append(np.argmax(predict, axis=-1))

        # if args.log_dir is not None:
        # if not os.path.exists(args.log_dir):
        # os.makedirs(args.log_dir)
        # save_path = os.path.join(args.log_dir, '{}'.format(scene_id))
        # write_to_file(save_path, predict)

        predict = np.argmax(predict[:, 1:], axis=1)  # pn
        predict += 1
        labels = SEM_LABELS[scene_index]
        '''
        additional logic for handling 20 class output
        '''
        labels = labels - 1
        correct = predict == labels
        correct = correct[labels != -1]

        total_seen += np.sum(labels >= 0)  # point_num
        # total_correct += np.sum((predict == labels) & (labels > 0))
        total_correct += np.sum(correct)
        log_string('accuracy:{} '.format(total_correct / total_seen))
        for l in range(NUM_CLASSES):
            total_seen_class[l] += np.sum((labels == l) & (labels > 0))
            total_correct_class[l] += np.sum((predict == l) & (labels == l))
            total_iou_deno_class[l] += np.sum(((predict == l) & (labels > 0))
                                              | (labels == l))

    # final save
    if save_dir is not None:
        torch.save(save_dict,
                   os.path.join(save_dir, '{}_pred.pth'.format(split)))

    IoU = np.array(total_correct_class[1:]) / (
        np.array(total_iou_deno_class[1:], dtype=np.float) + 1e-6)
    log_string('eval point avg class IoU: %f' % (np.mean(IoU)))
    IoU_Class = 'Each Class IoU:::\n'
    for i in range(IoU.shape[0]):
        log_string('Class %d : %.4f' % (i + 1, IoU[i]))
    log_string('eval accuracy: %f' % (total_correct / float(total_seen)))
    log_string('eval avg class acc: %f' % (np.mean(
        np.array(total_correct_class[1:]) /
        (np.array(total_seen_class[1:], dtype=np.float) + 1e-6))))