コード例 #1
0
ファイル: shapenet_seg.py プロジェクト: He-jerry/3D-project
def train(args):

    THREADS = 4
    USE_CUDA = True
    N_CLASSES = 50
    EPOCHS = 200
    MILESTONES = [60, 120]

    shapenet_labels = [
        ['Airplane', 4],
        ['Bag', 2],
        ['Cap', 2],
        ['Car', 4],
        ['Chair', 4],
        ['Earphone', 3],
        ['Guitar', 3],
        ['Knife', 2],
        ['Lamp', 4],
        ['Laptop', 2],
        ['Motorbike', 6],
        ['Mug', 2],
        ['Pistol', 3],
        ['Rocket', 3],
        ['Skateboard', 3],
        ['Table', 3],
    ]
    category_range = []
    count = 0
    for element in shapenet_labels:
        part_start = count
        count += element[1]
        part_end = count
        category_range.append([part_start, part_end])

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, labels, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    print("Done", data_train.shape)

    print("Computing class weights (if needed, 1 otherwise)...")
    if args.weighted:
        frequences = []
        for i in range(len(shapenet_labels)):
            frequences.append((labels == i).sum())
        frequences = np.array(frequences)
        frequences = frequences.mean() / frequences
    else:
        frequences = [1 for _ in range(len(shapenet_labels))]
    weights = torch.FloatTensor(frequences)
    if USE_CUDA:
        weights = weights.cuda()
    print("Done")

    print("Creating network...")
    net = get_model(args.model, input_channels=1, output_channels=N_CLASSES)
    net.cuda()
    print("parameters", count_parameters(net))

    ds = PartNormalDataset(data_train,
                           data_num_train,
                           label_train,
                           npoints=args.npoints)
    train_loader = torch.utils.data.DataLoader(ds,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=THREADS)

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, MILESTONES)

    # create the model folder
    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.savedir,
        "{}_b{}_pts{}_weighted{}_{}".format(args.model, args.batchsize,
                                            args.npoints, args.weighted,
                                            time_string))
    os.makedirs(root_folder, exist_ok=True)

    # create the log file
    logs = open(os.path.join(root_folder, "log.txt"), "w")
    for epoch in range(EPOCHS):
        scheduler.step()
        cm = np.zeros((N_CLASSES, N_CLASSES))
        t = tqdm(train_loader, ncols=120, desc="Epoch {}".format(epoch))
        for pts, features, seg, indices in t:

            if USE_CUDA:
                features = features.cuda()
                pts = pts.cuda()
                seg = seg.cuda()

            optimizer.zero_grad()
            outputs = net(features, pts)

            # loss =  F.cross_entropy(outputs.view(-1, N_CLASSES), seg.view(-1))

            loss = 0
            for i in range(pts.size(0)):
                # get the number of part for the shape
                object_label = labels[indices[i]]
                part_start, part_end = category_range[object_label]
                part_nbr = part_end - part_start
                loss = loss + weights[object_label] * F.cross_entropy(
                    outputs[i, :, part_start:part_end].view(-1, part_nbr),
                    seg[i].view(-1) - part_start)

            loss.backward()
            optimizer.step()

            outputs_np = outputs.cpu().detach().numpy()
            for i in range(pts.size(0)):
                # get the number of part for the shape
                object_label = labels[indices[i]]
                part_start, part_end = category_range[object_label]
                part_nbr = part_end - part_start
                outputs_np[i, :, :part_start] = -1e7
                outputs_np[i, :, part_end:] = -1e7

            output_np = np.argmax(outputs_np, axis=2).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(target_np.ravel(),
                                   output_np.ravel(),
                                   labels=list(range(N_CLASSES)))
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa)

        # save the model
        torch.save(net.state_dict(), os.path.join(root_folder,
                                                  "state_dict.pth"))

        # write the logs
        logs.write("{} {} {} \n".format(epoch, oa, aa))
        logs.flush()

    logs.close()
コード例 #2
0
ファイル: shapenet_seg.py プロジェクト: He-jerry/3D-project
def test(args):
    THREADS = 4
    USE_CUDA = True
    N_CLASSES = 50

    args.data_folder = os.path.join(args.rootdir, "test_data")

    # create the output folders
    output_folder = '/public/zebanghe2/convpoint/ConvPointmaster/sp2/'
    category_list = [
        (category, int(label_num))
        for (category,
             label_num) in [line.split() for line in open(args.category, 'r')]
    ]
    offset = 0
    category_range = dict()
    for category, category_label_seg_max in category_list:
        category_range[category] = (offset, offset + category_label_seg_max)
        offset = offset + category_label_seg_max
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(args.data_folder)):
        data_category_folder = os.path.join(args.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(args.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, label, data_num, label_test, _ = data_utils.load_seg(
        args.filelist_val)  # no segmentation labels

    # net = Net(input_channels=1, output_channels=N_CLASSES)
    net = get_model(args.model, input_channels=1, output_channels=N_CLASSES)
    net.load_state_dict(
        torch.load(os.path.join(args.savedir, "state_dict.pth")))
    net.cuda()
    net.eval()

    ds = PartNormalDataset(data,
                           data_num,
                           label_test,
                           npoints=args.npoints,
                           num_iter_per_shape=args.ntree)
    test_loader = torch.utils.data.DataLoader(ds,
                                              batch_size=args.batchsize,
                                              shuffle=False,
                                              num_workers=THREADS)
    shapenet_labels = [
        ['Airplane', 4],
        ['Bag', 2],
        ['Cap', 2],
        ['Car', 4],
        ['Chair', 4],
        ['Earphone', 3],
        ['Guitar', 3],
        ['Knife', 2],
        ['Lamp', 4],
        ['Laptop', 2],
        ['Motorbike', 6],
        ['Mug', 2],
        ['Pistol', 3],
        ['Rocket', 3],
        ['Skateboard', 3],
        ['Table', 3],
    ]

    cm = np.zeros((N_CLASSES, N_CLASSES))
    t = tqdm(test_loader, ncols=120)
    Confs = []

    predictions = [None for _ in range(data.shape[0])]
    predictions_max = [[] for _ in range(data.shape[0])]
    with torch.no_grad():

        for pts, features, seg, indices in t:

            if USE_CUDA:
                features = features.cuda()
                pts = pts.cuda()

            outputs = net(features, pts)

            indices = np.int32(indices.numpy())
            outputs = np.float32(outputs.cpu().numpy())

            # save results
            for i in range(pts.size(0)):

                # shape id
                shape_id = indices[i]

                # pts_src
                pts_src = pts[i].cpu().numpy()

                # pts_dest
                point_num = data_num[shape_id]
                pts_dest = data[shape_id]
                pts_dest = pts_dest[:point_num]

                # get the number of part for the shape
                object_label = label[indices[i]]
                category = category_list[object_label][0]
                part_start, part_end = category_range[category]
                part_nbr = part_end - part_start

                # get the segmentation correspongin to part range
                seg_ = outputs[i][:, part_start:part_end]

                # interpolate to original points
                seg_ = nearest_correspondance(pts_src, pts_dest, seg_)

                if predictions[shape_id] is None:
                    predictions[shape_id] = seg_
                else:
                    predictions[shape_id] += seg_

                predictions_max[shape_id].append(seg_)

    for i in range(len(predictions)):
        a = np.stack(predictions_max[i], axis=1)
        a = np.argmax(a, axis=2)
        a = np.apply_along_axis(np.bincount, 1, a, minlength=6)
        predictions_max[i] = np.argmax(a, axis=1)

    # compute labels
    for i in range(len(predictions)):
        predictions[i] = np.argmax(predictions[i], axis=1)
        #print(len(predictions[i]))
        #save_fname = os.path.join('/public/zebanghe2/convpoint/ConvPointmaster/sp2/',str(i)+"pred.txt")
        #np.savetxt(save_fname,predictions[i],fmt='%d')

    def scores_from_predictions(predictions):

        shape_ious = {cat[0]: [] for cat in category_list}
        for shape_id, prediction in enumerate(predictions):

            segp = prediction
            cat = label[shape_id]
            category = category_list[cat][0]
            part_start, part_end = category_range[category]
            part_nbr = part_end - part_start
            point_num = data_num[shape_id]
            segl = label_test[shape_id][:point_num] - part_start

            part_ious = [0.0 for _ in range(part_nbr)]
            for l in range(part_nbr):
                if (np.sum(segl == l) == 0) and (
                        np.sum(segp == l)
                        == 0):  # part is not present, no prediction as well
                    part_ious[l] = 1.0

                else:
                    part_ious[l] = np.sum((segl == l) & (segp == l)) / float(
                        np.sum((segl == l) | (segp == l)))
            shape_ious[category].append(np.mean(part_ious))
            save_fname = os.path.join(
                '/public/zebanghe2/convpoint/ConvPointmaster/sp2/',
                str(shape_id) + "pred.txt")
            np.savetxt(save_fname, prediction, fmt='%d')

        all_shape_ious = []
        for cat in shape_ious.keys():
            for iou in shape_ious[cat]:
                all_shape_ious.append(iou)
            shape_ious[cat] = np.mean(shape_ious[cat])
        print(len(all_shape_ious))
        mean_shape_ious = np.mean(list(shape_ious.values()))
        for cat in sorted(shape_ious.keys()):
            print('eval mIoU of %s:\t %f' % (cat, shape_ious[cat]))
        print('eval mean mIoU: %f' % (mean_shape_ious))
        print('eval mean mIoU (all shapes): %f' % (np.mean(all_shape_ious)))

    scores_from_predictions(predictions)
コード例 #3
0
def train(args):

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, labels, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    print("Done", data_train.shape)

    THREADS = 4
    BATCH_SIZE = args.batchsize
    USE_CUDA = True
    N_CLASSES = 50
    EPOCHS = 200
    MILESTONES = [60, 120, 180]

    print("Creating network...")
    net = Net(input_channels=1, output_channels=N_CLASSES)
    net.cuda()
    print("parameters", count_parameters(net))

    ds = PartNormalDataset(data_train,
                           data_num_train,
                           label_train,
                           net.config,
                           npoints=args.npoints,
                           shape_labels=labels)
    train_loader = torch.utils.data.DataLoader(ds,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=THREADS,
                                               collate_fn=tree_collate)

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, MILESTONES)

    # create the model folder
    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.savedir, "Net_b{}_pts{}_{}".format(args.batchsize, args.npoints,
                                                time_string))
    os.makedirs(root_folder, exist_ok=True)

    # create the log file
    logs = open(os.path.join(root_folder, "log.txt"), "w")
    for epoch in range(EPOCHS):
        scheduler.step()
        cm = np.zeros((N_CLASSES, N_CLASSES))
        t = tqdm(train_loader, ncols=120, desc="Epoch {}".format(epoch))
        for pts, features, seg, tree, labels in t:

            if USE_CUDA:
                features = features.cuda()
                pts = pts.cuda()
                for l_id in range(len(tree)):
                    tree[l_id]["points"] = tree[l_id]["points"].cuda()
                    tree[l_id]["indices"] = tree[l_id]["indices"].cuda()
                seg = seg.cuda()

            optimizer.zero_grad()
            outputs = net(features, pts, tree)
            loss = F.cross_entropy(outputs.view(-1, N_CLASSES), seg.view(-1))
            loss.backward()
            optimizer.step()

            output_np = np.argmax(outputs.cpu().detach().numpy(),
                                  axis=2).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(target_np.ravel(),
                                   output_np.ravel(),
                                   labels=list(range(N_CLASSES)))
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa)

        # save the model
        torch.save(net.state_dict(), os.path.join(root_folder,
                                                  "state_dict.pth"))

        # write the logs
        logs.write("{} {} {} \n".format(epoch, oa, aa))
        logs.flush()

    logs.close()
コード例 #4
0
def test_multiple(args):
    THREADS = 4
    BATCH_SIZE = args.batchsize
    USE_CUDA = True
    N_CLASSES = 50

    args.data_folder = os.path.join(args.rootdir, "test_data")

    # create the output folders
    output_folder = args.savedir + '_predictions_multi_{}'.format(args.ntree)
    category_list = [
        (category, int(label_num))
        for (category,
             label_num) in [line.split() for line in open(args.category, 'r')]
    ]
    offset = 0
    category_range = dict()
    for category, category_label_seg_max in category_list:
        category_range[category] = (offset, offset + category_label_seg_max)
        offset = offset + category_label_seg_max
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(args.data_folder)):
        data_category_folder = os.path.join(args.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(args.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, label, data_num, label_test, _ = data_utils.load_seg(
        args.filelist_val)  # no segmentation labels

    net = Net(input_channels=1, output_channels=N_CLASSES)
    net.load_state_dict(
        torch.load(os.path.join(args.savedir, "state_dict.pth")))
    net.cuda()
    net.eval()

    ds = PartNormalDataset(data,
                           data_num,
                           label_test,
                           net.config,
                           npoints=args.npoints)
    test_loader = torch.utils.data.DataLoader(ds,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=THREADS,
                                              collate_fn=tree_collate)

    cm = np.zeros((N_CLASSES, N_CLASSES))
    t = tqdm(test_loader, ncols=120)
    with torch.no_grad():
        count = 0

        for shape_id in tqdm(range(len(ds)), ncols=120):

            segmentation_ = None

            batches = []

            if args.ntree <= args.batchsize:

                batch = []
                for tree_id in range(args.ntree):
                    batch.append(ds.__getitem__(shape_id))
                batches.append(batch)
            else:
                for i in range(math.ceil(args.ntree / args.batchsize)):
                    bs = min(args.batchsize, args.ntree - i * args.batchsize)
                    batch = []
                    for tree_id in range(bs):
                        batch.append(ds.__getitem__(shape_id))
                    batches.append(batch)

            for batch in batches:

                pts, features, seg, tree = tree_collate(batch)
                if USE_CUDA:
                    features = features.cuda()
                    pts = pts.cuda()
                    for l_id in range(len(tree)):
                        tree[l_id]["points"] = tree[l_id]["points"].cuda()
                        tree[l_id]["indices"] = tree[l_id]["indices"].cuda()

                outputs = net(features, pts, tree)

                for i in range(pts.size(0)):
                    pts_src = pts[i].cpu().numpy()

                    # pts_dest
                    point_num = data_num[count]
                    pts_dest = data[count]
                    pts_dest = pts_dest[:point_num]

                    object_label = label[count]
                    category = category_list[object_label][0]
                    label_start, label_end = category_range[category]

                    seg_ = outputs[i][:, label_start:label_end].cpu().numpy()
                    seg_ = nearest_correspondance(pts_src, pts_dest, seg_)

                    if segmentation_ is None:
                        segmentation_ = seg_
                    else:
                        segmentation_ += seg_

            segmentation_ = np.argmax(segmentation_, axis=1)

            # save labels
            np.savetxt(output_filelist[count], segmentation_, fmt="%i")

            if args.ply:
                data_utils.save_ply_property(pts_dest, segmentation_, 6,
                                             output_ply_filelist[count])

            count += 1
コード例 #5
0
def test(args):
    THREADS = 4
    BATCH_SIZE = args.batchsize
    USE_CUDA = True
    N_CLASSES = 50

    args.data_folder = os.path.join(args.rootdir, "test_data")

    # create the output folders
    output_folder = args.savedir + '_predictions'
    category_list = [
        (category, int(label_num))
        for (category,
             label_num) in [line.split() for line in open(args.category, 'r')]
    ]
    offset = 0
    category_range = dict()
    for category, category_label_seg_max in category_list:
        category_range[category] = (offset, offset + category_label_seg_max)
        offset = offset + category_label_seg_max
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(args.data_folder)):
        data_category_folder = os.path.join(args.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(args.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, label, data_num, label_test, _ = data_utils.load_seg(
        args.filelist_val)  # no segmentation labels

    net = Net(input_channels=1, output_channels=N_CLASSES)
    net.load_state_dict(
        torch.load(os.path.join(args.savedir, "state_dict.pth")))
    net.cuda()
    net.eval()

    ds = PartNormalDataset(data,
                           data_num,
                           label_test,
                           net.config,
                           npoints=args.npoints)
    test_loader = torch.utils.data.DataLoader(ds,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=THREADS,
                                              collate_fn=tree_collate)

    cm = np.zeros((N_CLASSES, N_CLASSES))
    t = tqdm(test_loader, ncols=120)
    with torch.no_grad():
        count = 0
        for pts, features, seg, tree in t:

            if USE_CUDA:
                features = features.cuda()
                pts = pts.cuda()
                for l_id in range(len(tree)):
                    tree[l_id]["points"] = tree[l_id]["points"].cuda()
                    tree[l_id]["indices"] = tree[l_id]["indices"].cuda()

            outputs = net(features, pts, tree)

            # save results
            for i in range(pts.size(0)):
                # pts_src
                pts_src = pts[i].cpu().numpy()

                # pts_dest
                point_num = data_num[count + i]
                pts_dest = data[count + i]
                pts_dest = pts_dest[:point_num]

                object_label = label[count + i]
                category = category_list[object_label][0]
                label_start, label_end = category_range[category]

                seg_ = outputs[i][:, label_start:label_end].cpu().numpy()
                seg_ = np.argmax(seg_, axis=1)
                seg_ = nearest_correspondance(pts_src, pts_dest, seg_)

                # save labels
                np.savetxt(output_filelist[count + i], seg_, fmt="%i")

                if args.ply:
                    data_utils.save_ply_property(
                        pts_dest, seg_, 6, output_ply_filelist[count + i])
            count += pts.size(0)

            output_np = np.argmax(outputs.cpu().detach().numpy(),
                                  axis=2).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(target_np.ravel(),
                                   output_np.ravel(),
                                   labels=list(range(N_CLASSES)))
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa)
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--load_ckpt',
        '-l',
        default=
        'log/seg/shellconv_seg_shapenet_2019-08-06-14-42-34/ckpts/epoch-326',
        help='Path to a check point file for load')
    parser.add_argument('--model',
                        '-m',
                        default='shellconv',
                        help='Model to use')
    parser.add_argument('--setting',
                        '-x',
                        default='seg_shapenet',
                        help='Setting to use')
    parser.add_argument('--repeat_num',
                        '-r',
                        help='Repeat number',
                        type=int,
                        default=1)
    parser.add_argument('--save_ply',
                        '-s',
                        help='Save results as ply',
                        default=False)
    args = parser.parse_args()
    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    sample_num = setting.sample_num

    output_folder = setting.data_folder + '_pred_shellnet_' + str(
        args.repeat_num)
    category_list = [(category, int(label_num)) for (
        category,
        label_num) in [line.split() for line in open(setting.category, 'r')]]
    offset = 0
    category_range = dict()
    for category, category_label_seg_max in category_list:
        category_range[category] = (offset, offset + category_label_seg_max)
        offset = offset + category_label_seg_max
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(setting.data_folder)):
        data_category_folder = os.path.join(setting.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(setting.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, label, data_num, _, _ = data_utils.load_seg(setting.filelist_val)

    batch_num = data.shape[0]
    max_point_num = data.shape[1]
    batch_size = args.repeat_num * math.ceil(data.shape[1] / sample_num)

    print('{}-{:d} testing batches.'.format(datetime.now(), batch_num))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32,
                             shape=(batch_size, None, 2),
                             name="indices")
    is_training = tf.placeholder(tf.bool, name='is_training')
    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, max_point_num, setting.data_dim),
                             name='pts_fts')
    ######################################################################

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    if setting.data_dim > 3:
        points_sampled, _ = tf.split(pts_fts_sampled,
                                     [3, setting.data_dim - 3],
                                     axis=-1,
                                     name='split_points_features')
    else:
        points_sampled = pts_fts_sampled

    logits_op = model.get_model(points_sampled,
                                is_training,
                                setting.sconv_params,
                                setting.sdconv_params,
                                setting.fc_params,
                                sampling=setting.sampling,
                                weight_decay=setting.weight_decay,
                                bn_decay=None,
                                part_num=setting.num_class)

    probs_op = tf.nn.softmax(logits_op, name='probs')

    saver = tf.train.Saver()

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        # Load the model
        saver.restore(sess, args.load_ckpt)
        print('{}-Checkpoint loaded from {}!'.format(datetime.now(),
                                                     args.load_ckpt))

        indices_batch_indices = np.tile(
            np.reshape(np.arange(batch_size), (batch_size, 1, 1)),
            (1, sample_num, 1))
        for batch_idx in range(batch_num):
            points_batch = data[[batch_idx] * batch_size, ...]
            object_label = label[batch_idx]
            point_num = data_num[batch_idx]
            category = category_list[object_label][0]
            label_start, label_end = category_range[category]

            tile_num = math.ceil((sample_num * batch_size) / point_num)
            indices_shuffle = np.tile(np.arange(point_num),
                                      tile_num)[0:sample_num * batch_size]
            np.random.shuffle(indices_shuffle)
            indices_batch_shuffle = np.reshape(indices_shuffle,
                                               (batch_size, sample_num, 1))
            indices_batch = np.concatenate(
                (indices_batch_indices, indices_batch_shuffle), axis=2)

            probs = sess.run(
                [probs_op],
                feed_dict={
                    pts_fts: points_batch,
                    indices: indices_batch,
                    is_training: False,
                })
            probs_2d = np.reshape(probs, (sample_num * batch_size, -1))
            predictions = [(-1, 0.0)] * point_num
            for idx in range(sample_num * batch_size):
                point_idx = indices_shuffle[idx]
                probs = probs_2d[idx, label_start:label_end]
                confidence = np.amax(probs)
                seg_idx = np.argmax(probs)
                if confidence > predictions[point_idx][1]:
                    predictions[point_idx] = (seg_idx, confidence)

            labels = []
            with open(output_filelist[batch_idx], 'w') as file_seg:
                for seg_idx, _ in predictions:
                    file_seg.write('%d\n' % (seg_idx))
                    labels.append(seg_idx)

            # read the coordinates from the txt file for verification
            coordinates = [[float(value) for value in xyz.split(' ')]
                           for xyz in open(input_filelist[batch_idx], 'r')
                           if len(xyz.split(' ')) == 3]
            assert (point_num == len(coordinates))
            if args.save_ply:
                data_utils.save_ply_property(np.array(coordinates),
                                             np.array(labels), 6,
                                             output_ply_filelist[batch_idx])

            print('{}-[Testing]-Iter: {:06d} saved to {}'.format(
                datetime.now(), batch_idx, output_filelist[batch_idx]))
            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))