def main():

    #create pcd object list to save the reconstructed patch per capsule
    pcd_list = []
    for i in range(opt.latent_caps_size):
        pcd_ = PointCloud()
        pcd_list.append(pcd_)
    colors = plt.cm.tab20((np.arange(20)).astype(int))
    #random selected viz capsules
    hight_light_caps = [
        np.random.randint(0, opt.latent_caps_size) for r in range(10)
    ]

    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        test_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                        npoints=opt.num_points,
                                                        split='test')
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        test_dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                       npoints=opt.num_points,
                                                       train=False)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        test_dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=False)

    capsule_net.eval()
    if 'test_dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(test_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            for pointset_id in range(opt.batch_size):
                prc_r_all = reconstructions[pointset_id].transpose(
                    1, 0).contiguous().data.cpu()
                prc_r_all_point = PointCloud()
                prc_r_all_point.points = Vector3dVector(prc_r_all)
                colored_re_pointcloud = PointCloud()
                jc = 0
                for j in range(opt.latent_caps_size):
                    current_patch = torch.zeros(
                        int(opt.num_points / opt.latent_caps_size), 3)
                    for m in range(int(opt.num_points / opt.latent_caps_size)):
                        current_patch[m, ] = prc_r_all[
                            opt.latent_caps_size * m + j,
                        ]  # the reconstructed patch of the capsule m is not saved continuesly in the output reconstruction.
                    pcd_list[j].points = Vector3dVector(current_patch)
                    if (j in hight_light_caps):
                        pcd_list[j].paint_uniform_color(
                            [colors[jc, 0], colors[jc, 1], colors[jc, 2]])
                        jc += 1
                    else:
                        pcd_list[j].paint_uniform_color([0.8, 0.8, 0.8])
                    colored_re_pointcloud += pcd_list[j]
                draw_geometries([colored_re_pointcloud])


# test process for 'shapenet_core55'
    else:
        test_loss_sum = 0
        while test_dataset.has_next_batch():
            batch_id, points_ = test_dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            for pointset_id in range(opt.batch_size):
                prc_r_all = reconstructions[pointset_id].transpose(
                    1, 0).contiguous().data.cpu()
                prc_r_all_point = PointCloud()
                prc_r_all_point.points = Vector3dVector(prc_r_all)
                colored_re_pointcloud = PointCloud()
                jc = 0
                for j in range(opt.latent_caps_size):
                    current_patch = torch.zeros(
                        int(opt.num_points / opt.latent_caps_size), 3)
                    for m in range(int(opt.num_points / opt.latent_caps_size)):
                        current_patch[m, ] = prc_r_all[
                            opt.latent_caps_size * m + j,
                        ]  # the reconstructed patch of the capsule m is not saved continuesly in the output reconstruction.
                    pcd_list[j].points = Vector3dVector(current_patch)
                    if (j in hight_light_caps):
                        pcd_list[j].paint_uniform_color(
                            [colors[jc, 0], colors[jc, 1], colors[jc, 2]])
                        jc += 1
                    else:
                        pcd_list[j].paint_uniform_color([0.8, 0.8, 0.8])
                    colored_re_pointcloud += pcd_list[j]

                draw_geometries([colored_re_pointcloud])
示例#2
0
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vecs_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        test_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                        npoints=opt.num_points,
                                                        split='test')
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        test_dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                       npoints=opt.num_points,
                                                       train=False)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        test_dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=False)

# test process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()
    if 'test_dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(test_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            test_loss = capsule_net.module.loss(points, reconstructions)
            test_loss_sum += test_loss.item()
            print('accumalate of batch %d loss is : %f' %
                  (batch_id, test_loss.item()))
        test_loss_sum = test_loss_sum / float(len(test_dataloader))
        print('test loss is : %f' % (test_loss_sum))


# test process for 'shapenet_core55'
    else:
        test_loss_sum = 0
        while test_dataset.has_next_batch():
            batch_id, points_ = test_dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            test_loss = capsule_net.module.loss(points, reconstructions)
            test_loss_sum += test_loss.item()
            print('accumalate of batch %d loss is : %f' %
                  (batch_id, test_loss.item()))
        test_loss_sum = test_loss_sum / float(len(test_dataloader))
        print('test loss is : %f' % (test_loss_sum))
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        if opt.save_training:
            split = 'train'
        else:
            split = 'test'
        dataset = shapenet_part_loader.PartDataset(classification=True,
                                                   npoints=opt.num_points,
                                                   split=split)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                  npoints=opt.num_points,
                                                  train=opt.save_training)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)
    elif opt.dataset == 'modelnet40':
        dataset = modelnet40_loader.ModelNetH5Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)

# init saving process
    data_size = 0
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    out_file_path = os.path.join(dataset_main_path, opt.dataset, 'latent_caps')
    if not os.path.exists(out_file_path):
        os.makedirs(out_file_path)

    if opt.save_training:
        out_file_name = out_file_path + "/saved_train_wo_part_label.h5"
    else:
        out_file_name = out_file_path + "/saved_test_wo_part_label.h5"
    if os.path.exists(out_file_name):
        os.remove(out_file_name)
    fw = h5py.File(out_file_name, 'w', libver='latest')
    dset = fw.create_dataset("data", (
        1,
        opt.latent_caps_size,
        opt.latent_vec_size,
    ),
                             maxshape=(None, opt.latent_caps_size,
                                       opt.latent_vec_size),
                             dtype='<f4')
    dset_c = fw.create_dataset("cls_label", (1, ),
                               maxshape=(None, ),
                               dtype='uint8')
    fw.swmr_mode = True

    #  process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()
    if 'dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(dataloader):
            points, cls_label = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            # write the output latent caps and cls into file
            data_size = data_size + points.size(0)
            new_shape = (
                data_size,
                opt.latent_caps_size,
                opt.latent_vec_size,
            )
            dset.resize(new_shape)
            dset_c.resize((data_size, ))

            latent_caps_ = latent_caps.cpu().detach().numpy()
            dset[data_size - points.size(0):data_size, :, :] = latent_caps_
            dset_c[data_size -
                   points.size(0):data_size] = cls_label.squeeze().numpy()

            dset.flush()
            dset_c.flush()
            print('accumalate of batch %d, and datasize is %d ' %
                  ((batch_id), (dset.shape[0])))

        fw.close()


#  process for 'shapenet_core55' or 'modelnet40'
    else:
        while dataset.has_next_batch():
            batch_id, points_ = dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            data_size = data_size + points.size(0)
            new_shape = (
                data_size,
                opt.latent_caps_size,
                opt.latent_vec_size,
            )
            dset.resize(new_shape)
            dset_c.resize((data_size, ))

            latent_caps_ = latent_caps.cpu().detach().numpy()
            dset[data_size - points.size(0):data_size, :, :] = latent_caps_
            dset_c[data_size -
                   points.size(0):data_size] = cls_label.squeeze().numpy()

            dset.flush()
            dset_c.flush()
            print('accumalate of batch %d, and datasize is %d ' %
                  ((batch_id), (dset.shape[0])))
        fw.close()
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    #create folder to save trained models
    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)

    train_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                     npoints=opt.num_points,
                                                     split='train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    for epoch in range(opt.n_epochs):
        if epoch < 50:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.0001)
        elif epoch < 150:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.00001)
        else:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.000001)

# train
        capsule_net.train()
        train_loss_sum = 0
        for batch_id, data in enumerate(train_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()

            optimizer.zero_grad()
            codewords, reconstructions = capsule_net(points)
            train_loss = capsule_net.module.loss(points, reconstructions)
            train_loss.backward()
            optimizer.step()
            train_loss_sum += train_loss.item()

            if batch_id % 50 == 0:
                print('bactch_no:%d/%d, train_loss: %f ' %
                      (batch_id, len(train_dataloader), train_loss.item()))

        print('Average train loss of epoch %d : %f' %
              (epoch, (train_loss_sum / len(train_dataloader))))

        if epoch % 5 == 0:
            dict_name = opt.outf + '/' + opt.dataset + '_dataset_' + '_' + str(
                opt.latent_caps_size) + 'caps_' + str(
                    opt.latent_caps_size) + 'vec_' + str(epoch) + '.pth'
            torch.save(capsule_net.module.state_dict(), dict_name)
示例#5
0
                    default=16,
                    help='scale of prim_vec')
parser.add_argument('--latent_caps_size',
                    type=int,
                    default=64,
                    help='number of latent_caps')
parser.add_argument('--latent_vec_size',
                    type=int,
                    default=64,
                    help='scale of latent_vec')

opt = parser.parse_args()
print(opt)

capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                           opt.latent_caps_size, opt.latent_vec_size,
                           opt.num_points)
if opt.model != '':
    capsule_net.load_state_dict(torch.load(opt.model))
else:
    print('pls set the model path')

if USE_CUDA:
    capsule_net = capsule_net.cuda()
capsule_net = capsule_net.eval()

pcd_list = []
for i in range(opt.latent_caps_size):
    pcd_ = PointCloud()
    pcd_list.append(pcd_)
示例#6
0
test_dset = shapenet_part_loader.PartDataset(
    root='../dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
    classification=True,
    class_choice='Table',
    npoints=opt.num_points,
    split='test')
test_dataloader = torch.utils.data.DataLoader(test_dset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers))
length = len(test_dataloader)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                           opt.latent_caps_size, opt.latent_caps_size,
                           opt.num_points)
capsule_net.load_state_dict(
    torch.load(opt.model,
               map_location=lambda storage, location: storage)['state_dict'])
print("Let's use", torch.cuda.device_count(), "GPUs!")
capsule_net.to(device)
capsule_net = torch.nn.DataParallel(capsule_net)
capsule_net.eval()

criterion_PointLoss = PointLoss_test()
errG_min = 100
n = 0
CD = 0
Gt_Pre = 0
Pre_Gt = 0
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for only decoding
    capsule_net_decoder = PointCapsNetDecoder(opt.prim_caps_size,
                                              opt.prim_vec_size,
                                              opt.latent_caps_size,
                                              opt.latent_vec_size,
                                              opt.num_points)
    if opt.model != '':
        capsule_net_decoder.load_state_dict(torch.load(opt.model),
                                            strict=False)
    if USE_CUDA:
        capsule_net_decoder = capsule_net_decoder.cuda()
    capsule_net_decoder = capsule_net_decoder.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    # container for ground truth
    pcd_gt_source = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_source.append(pcd)
    pcd_gt_target = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_target.append(pcd)

    # container for ground truth cut and paste
    pcd_gt_replace_source = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_replace_source.append(pcd)
    pcd_gt_replace_target = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_replace_target.append(pcd)

    # container for capsule based part replacement
    pcd_caps_replace_source = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_caps_replace_source.append(pcd)
    pcd_caps_replace_target = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_caps_replace_target.append(pcd)

    # apply a transformation in order to get a better view point
    ##airplane
    rotation_angle = np.pi / 2
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[1, 0, 0, -2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transforms_r = [[1, 0, 0, 2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]

    flip_transform_gt_s = [[1, 0, 0, -3], [0, cosval, -sinval, -1],
                           [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_gt_t = [[1, 0, 0, -3], [0, cosval, -sinval, 1],
                           [0, sinval, cosval, 0], [0, 0, 0, 1]]

    flip_transform_gt_re_s = [[1, 0, 0, 0], [0, cosval, -sinval, -1],
                              [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_gt_re_t = [[1, 0, 0, 0], [0, cosval, -sinval, 1],
                              [0, sinval, cosval, 0], [0, 0, 0, 1]]

    flip_transform_caps_re_s = [[1, 0, 0, 3], [0, cosval, -sinval, -1],
                                [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_caps_re_t = [[1, 0, 0, 3], [0, cosval, -sinval, 1],
                                [0, sinval, cosval, 0], [0, 0, 0, 1]]

    colors = plt.cm.tab20((np.arange(20)).astype(int))
    part_replace_no = 1  # the part that is replaced

    for batch_id, data in enumerate(train_dataloader):
        points, part_label, cls_label = data
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break
        all_model_pcd = PointCloud()

        gt_source_list0 = []
        gt_source_list1 = []
        gt_target_list0 = []
        gt_target_list1 = []
        for point_id in range(opt.num_points):
            if (part_label[0, point_id] == part_replace_no):
                gt_source_list0.append(points[0, point_id, :])
            else:
                gt_source_list1.append(points[0, point_id, :])

            if (part_label[1, point_id] == part_replace_no):
                gt_target_list0.append(points[1, point_id, :])
            else:
                gt_target_list1.append(points[1, point_id, :])

        # viz GT with colored part
        pcd_gt_source[0].points = Vector3dVector(gt_source_list0)
        pcd_gt_source[0].paint_uniform_color(
            [colors[5, 0], colors[5, 1], colors[5, 2]])
        pcd_gt_source[0].transform(flip_transform_gt_s)
        all_model_pcd += pcd_gt_source[0]

        pcd_gt_source[1].points = Vector3dVector(gt_source_list1)
        pcd_gt_source[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_source[1].transform(flip_transform_gt_s)
        all_model_pcd += pcd_gt_source[1]

        pcd_gt_target[0].points = Vector3dVector(gt_target_list0)
        pcd_gt_target[0].paint_uniform_color(
            [colors[6, 0], colors[6, 1], colors[6, 2]])
        pcd_gt_target[0].transform(flip_transform_gt_t)
        all_model_pcd += pcd_gt_target[0]

        pcd_gt_target[1].points = Vector3dVector(gt_target_list1)
        pcd_gt_target[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_target[1].transform(flip_transform_gt_t)
        all_model_pcd += pcd_gt_target[1]

        # viz replaced GT colored parts
        pcd_gt_replace_source[0].points = Vector3dVector(gt_target_list0)
        pcd_gt_replace_source[0].paint_uniform_color(
            [colors[6, 0], colors[6, 1], colors[6, 2]])
        pcd_gt_replace_source[0].transform(flip_transform_gt_re_s)
        all_model_pcd += pcd_gt_replace_source[0]

        pcd_gt_replace_source[1].points = Vector3dVector(gt_source_list1)
        pcd_gt_replace_source[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_replace_source[1].transform(flip_transform_gt_re_s)
        all_model_pcd += pcd_gt_replace_source[1]

        pcd_gt_replace_target[0].points = Vector3dVector(gt_source_list0)
        pcd_gt_replace_target[0].paint_uniform_color(
            [colors[5, 0], colors[5, 1], colors[5, 2]])
        pcd_gt_replace_target[0].transform(flip_transform_gt_re_t)
        all_model_pcd += pcd_gt_replace_target[0]

        pcd_gt_replace_target[1].points = Vector3dVector(gt_target_list1)
        pcd_gt_replace_target[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_replace_target[1].transform(flip_transform_gt_re_t)
        all_model_pcd += pcd_gt_replace_target[1]

        #capsule based replacement
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        cur_label_one_hot = np.zeros((2, 16), dtype=np.float32)
        for i in range(2):
            cur_label_one_hot[i, cls_label[i]] = 1
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            2, 16, opt.latent_caps_size).transpose(1, 2)

        latent_caps, expand = Variable(latent_caps), Variable(expand)
        latent_caps, expand = latent_caps.cuda(), expand.cuda()

        # predidt the part label of each capsule
        latent_caps_with_one_hot = torch.cat((latent_caps, expand), 2)
        latent_caps_with_one_hot, expand = Variable(
            latent_caps_with_one_hot), Variable(expand)
        latent_caps_with_one_hot, expand = latent_caps_with_one_hot.cuda(
        ), expand.cuda()
        latent_caps_with_one_hot = latent_caps_with_one_hot.transpose(2, 1)
        output_digit = caps_seg_net(latent_caps_with_one_hot)
        for i in range(2):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(set(np.arange(50)).difference(set(iou_oids)))
            mini = torch.min(output_digit[i, :, :])
            output_digit[i, :, non_cat_labels] = mini - 1000
        pred_choice = output_digit.data.cpu().max(2)[1]
        #
        #       saved the index of capsules which are assigned to current part
        part_no = iou_oids[part_replace_no]
        part_viz = []
        for caps_no in range(opt.latent_caps_size):
            if (pred_choice[0, caps_no] == part_no
                    and pred_choice[1, caps_no] == part_no):
                part_viz.append(caps_no)

        #replace the capsules
        latent_caps_replace = latent_caps.clone()
        latent_caps_replace = Variable(latent_caps_replace)
        latent_caps_replace = latent_caps_replace.cuda()
        for j in range(len(part_viz)):
            latent_caps_replace[0, part_viz[j], ] = latent_caps[1,
                                                                part_viz[j], ]
            latent_caps_replace[1, part_viz[j], ] = latent_caps[0,
                                                                part_viz[j], ]

        reconstructions_replace = capsule_net_decoder(latent_caps_replace)
        reconstructions_replace = reconstructions_replace.transpose(
            1, 2).data.cpu()

        for j in range(opt.latent_caps_size):
            current_patch_s = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            current_patch_t = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)

            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch_s[m, ] = reconstructions_replace[0][
                    opt.latent_caps_size * m + j, ]
                current_patch_t[m, ] = reconstructions_replace[1][
                    opt.latent_caps_size * m + j, ]
            pcd_caps_replace_source[j].points = Vector3dVector(current_patch_s)
            pcd_caps_replace_target[j].points = Vector3dVector(current_patch_t)
            part_no = iou_oids[part_replace_no]
            if (j in part_viz):
                pcd_caps_replace_source[j].paint_uniform_color(
                    [colors[6, 0], colors[6, 1], colors[6, 2]])
                pcd_caps_replace_target[j].paint_uniform_color(
                    [colors[5, 0], colors[5, 1], colors[5, 2]])
            else:
                pcd_caps_replace_source[j].paint_uniform_color([0.8, 0.8, 0.8])
                pcd_caps_replace_target[j].paint_uniform_color([0.8, 0.8, 0.8])

            pcd_caps_replace_source[j].transform(flip_transform_caps_re_s)
            pcd_caps_replace_target[j].transform(flip_transform_caps_re_t)

            all_model_pcd += pcd_caps_replace_source[j]
            all_model_pcd += pcd_caps_replace_target[j]
        draw_geometries([all_model_pcd])
示例#8
0
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        if opt.save_training:
            split = 'train'
        else:
            split = 'test'
        dataset = shapenet_part_loader.PartDataset(classification=False,
                                                   npoints=opt.num_points,
                                                   split=split)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)


# init saving process
    pcd = PointCloud()
    data_size = 0
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    out_file_path = os.path.join(dataset_main_path, opt.dataset, 'latent_caps')
    if not os.path.exists(out_file_path):
        os.makedirs(out_file_path)
    if opt.save_training:
        out_file_name = out_file_path + "/saved_train_with_part_label.h5"
    else:
        out_file_name = out_file_path + "/saved_test_with_part_label.h5"
    if os.path.exists(out_file_name):
        os.remove(out_file_name)
    fw = h5py.File(out_file_name, 'w', libver='latest')
    dset = fw.create_dataset("data", (
        1,
        opt.latent_caps_size,
        opt.latent_vec_size,
    ),
                             maxshape=(None, opt.latent_caps_size,
                                       opt.latent_vec_size),
                             dtype='<f4')
    dset_s = fw.create_dataset("part_label", (
        1,
        opt.latent_caps_size,
    ),
                               maxshape=(
                                   None,
                                   opt.latent_caps_size,
                               ),
                               dtype='uint8')
    dset_c = fw.create_dataset("cls_label", (1, ),
                               maxshape=(None, ),
                               dtype='uint8')
    fw.swmr_mode = True

    #  process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()

    for batch_id, data in enumerate(dataloader):
        points, part_label, cls_label = data
        if (points.size(0) < opt.batch_size):
            break
        points = Variable(points)
        points = points.transpose(2, 1)
        if USE_CUDA:
            points = points.cuda()
        latent_caps, reconstructions = capsule_net(points)

        # For each resonstructed point, find the nearest point in the input pointset,
        # use their part label to annotate the resonstructed point,
        # Then after checking which capsule reconstructed this point, use the part label to annotate this capsule
        reconstructions = reconstructions.transpose(1, 2).data.cpu()
        points = points.transpose(1, 2).data.cpu()
        cap_part_count = torch.zeros(
            [opt.batch_size, opt.latent_caps_size, opt.n_classes],
            dtype=torch.int64)
        for batch_no in range(points.size(0)):
            pcd.points = Vector3dVector(points[batch_no, ])
            pcd_tree = KDTreeFlann(pcd)
            for point_id in range(opt.num_points):
                [k, idx, _] = pcd_tree.search_knn_vector_3d(
                    reconstructions[batch_no, point_id, :], 1)
                point_part_label = part_label[batch_no, idx]
                caps_no = point_id % opt.latent_caps_size
                cap_part_count[batch_no, caps_no, point_part_label] += 1
        _, cap_part_label = torch.max(
            cap_part_count, 2
        )  # if the reconstucted points have multiple part labels, use the majority as the capsule part label

        # write the output latent caps and cls into file
        data_size = data_size + points.size(0)
        new_shape = (
            data_size,
            opt.latent_caps_size,
            opt.latent_vec_size,
        )
        dset.resize(new_shape)
        dset_s.resize((
            data_size,
            opt.latent_caps_size,
        ))
        dset_c.resize((data_size, ))

        latent_caps_ = latent_caps.cpu().detach().numpy()
        target_ = cap_part_label.numpy()
        dset[data_size - points.size(0):data_size, :, :] = latent_caps_
        dset_s[data_size - points.size(0):data_size] = target_
        dset_c[data_size -
               points.size(0):data_size] = cls_label.squeeze().numpy()

        dset.flush()
        dset_s.flush()
        dset_c.flush()
        print('accumalate of batch %d, and datasize is %d ' %
              ((batch_id), (dset.shape[0])))

    fw.close()
示例#9
0
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(
        opt.prim_caps_size,
        opt.prim_vec_size,
        opt.latent_caps_size,
        opt.latent_vec_size,
        opt.num_points,
    )
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    pcd_colored = PointCloud()
    pcd_ori_colored = PointCloud()
    rotation_angle = -np.pi / 4
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[cosval, 0, sinval, -1], [0, 1, 0, 0],
                       [-sinval, 0, cosval, 0], [0, 0, 0, 1]]
    flip_transformt = [[cosval, 0, sinval, 1], [0, 1, 0, 0],
                       [-sinval, 0, cosval, 0], [0, 0, 0, 1]]

    correct_sum = 0
    for batch_id, data in enumerate(train_dataloader):

        points, part_label, cls_label = data
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break

        # use the pre-trained AE to encode the point cloud into latent capsules
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        #concatanete the latent caps with one-hot part label
        cur_label_one_hot = np.zeros((opt.batch_size, 16), dtype=np.float32)
        for i in range(opt.batch_size):
            cur_label_one_hot[i, cls_label[i]] = 1
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            for j in range(opt.num_points):
                part_label[i, j] = iou_oids[part_label[i, j]]
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            opt.batch_size, 16, opt.latent_caps_size).transpose(1, 2)
        expand, latent_caps = Variable(expand), Variable(latent_caps)
        expand, latent_caps = expand.cuda(), latent_caps.cuda()
        latent_caps = torch.cat((latent_caps, expand), 2)

        # predict the part class per capsule
        latent_caps = latent_caps.transpose(2, 1)
        output = caps_seg_net(latent_caps)
        for i in range(opt.batch_size):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(
                set(np.arange(50)).difference(set(iou_oids))
            )  # there are 50 part classes in all the 16 catgories of objects
            mini = torch.min(output[i, :, :])
            output[i, :, non_cat_labels] = mini - 1000
        pred_choice = output.data.cpu().max(2)[1]

        # assign predicted the capsule part label to its reconstructed point patch
        reconstructions_part_label = torch.zeros(
            [opt.batch_size, opt.num_points], dtype=torch.int64)
        for i in range(opt.batch_size):
            for j in range(opt.latent_caps_size):
                for m in range(int(opt.num_points / opt.latent_caps_size)):
                    reconstructions_part_label[i, opt.latent_caps_size * m +
                                               j] = pred_choice[i, j]

        # assign the part label from the reconstructed point cloud to the input point set with NN
        pcd = pcd = PointCloud()
        pred_ori_pointcloud_part_label = torch.zeros(
            [opt.batch_size, opt.num_points], dtype=torch.int64)
        for point_set_no in range(opt.batch_size):
            pcd.points = Vector3dVector(reconstructions[point_set_no, ])
            pcd_tree = KDTreeFlann(pcd)
            for point_id in range(opt.num_points):
                [k, idx, _] = pcd_tree.search_knn_vector_3d(
                    points[point_set_no, point_id, :], 10)
                local_patch_labels = reconstructions_part_label[point_set_no,
                                                                idx]
                pred_ori_pointcloud_part_label[point_set_no,
                                               point_id] = statistics.median(
                                                   local_patch_labels)

        # calculate the accuracy with the GT
        correct = pred_ori_pointcloud_part_label.eq(
            part_label.data.cpu()).cpu().sum()
        correct_sum = correct_sum + correct.item()
        print(' accuracy is: %f' %
              (correct_sum / float(opt.batch_size *
                                   (batch_id + 1) * opt.num_points)))

        # viz the part segmentation
        point_color = torch.zeros([opt.batch_size, opt.num_points, 3])
        point_ori_color = torch.zeros([opt.batch_size, opt.num_points, 3])

        for point_set_no in range(opt.batch_size):
            iou_oids = object2setofoid[objcats[cls_label[point_set_no]]]
            for point_id in range(opt.num_points):
                part_no = pred_ori_pointcloud_part_label[
                    point_set_no, point_id] - iou_oids[0]
                point_color[point_set_no, point_id, 0] = colors[part_no, 0]
                point_color[point_set_no, point_id, 1] = colors[part_no, 1]
                point_color[point_set_no, point_id, 2] = colors[part_no, 2]

            pcd_colored.points = Vector3dVector(points[point_set_no, ])
            pcd_colored.colors = Vector3dVector(point_color[point_set_no, ])

            for point_id in range(opt.num_points):
                part_no = part_label[point_set_no, point_id] - iou_oids[0]
                point_ori_color[point_set_no, point_id, 0] = colors[part_no, 0]
                point_ori_color[point_set_no, point_id, 1] = colors[part_no, 1]
                point_ori_color[point_set_no, point_id, 2] = colors[part_no, 2]

            pcd_ori_colored.points = Vector3dVector(points[point_set_no, ])
            pcd_ori_colored.colors = Vector3dVector(
                point_ori_color[point_set_no, ])

            pcd_ori_colored.transform(
                flip_transforms
            )  # tansform the pcd in order to viz both point cloud
            pcd_colored.transform(flip_transformt)
            draw_geometries([pcd_ori_colored, pcd_colored])
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for only decoding
    capsule_net_decoder = PointCapsNetDecoder(opt.prim_caps_size,
                                              opt.prim_vec_size,
                                              opt.latent_caps_size,
                                              opt.latent_vec_size,
                                              opt.num_points)
    if opt.model != '':
        capsule_net_decoder.load_state_dict(torch.load(opt.model),
                                            strict=False)
    if USE_CUDA:
        capsule_net_decoder = capsule_net_decoder.cuda()
    capsule_net_decoder = capsule_net_decoder.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    # init the list for point cloud to save the source and target model
    pcd_list_source = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_list_source.append(pcd)
    pcd_list_target = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_list_target.append(pcd)

    # init the list for point cloud to save the interpolated model
    inter_models_number = 5  # the number of interpolated models
    part_inter_no = 1  # the part that is interpolated
    pcd_list_inter = []
    for i in range(inter_models_number):
        pcd_list = []
        for j in range(opt.latent_caps_size):
            pcd = PointCloud()
            pcd_list.append(pcd)
        pcd_list_inter.append(pcd_list)


# apply a transformation in order to get a better view point
##airplane
    rotation_angle = np.pi / 2
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[1, 0, 0, -2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transforms_r = [[1, 0, 0, 2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]
    flip_transformt = [[1, 0, 0, 2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transformt_r = [[1, 0, 0, -2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]
    transform_range = np.arange(-4, 4.5, 2)

    # init open3d visualizer
    #    vis = Visualizer()
    #    vis.create_window()
    colors = plt.cm.tab20((np.arange(20)).astype(int))

    for batch_id, data in enumerate(train_dataloader):

        points, part_label, cls_label = data
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break

        target = cls_label
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        # create one hot label for the current class
        cur_label_one_hot = np.zeros((2, 16), dtype=np.float32)
        for i in range(2):
            cur_label_one_hot[i, cls_label[i]] = 1
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            2, 16, opt.latent_caps_size).transpose(1, 2)

        # recontstruct from the original presaved capsule
        latent_caps, target, expand = Variable(latent_caps), Variable(
            target), Variable(expand)
        latent_caps, target, expand = latent_caps.cuda(), target.cuda(
        ), expand.cuda()

        # predidt the part label of each capsule
        latent_caps_with_one_hot = torch.cat((latent_caps, expand), 2)
        latent_caps_with_one_hot, expand = Variable(
            latent_caps_with_one_hot), Variable(expand)
        latent_caps_with_one_hot, expand = latent_caps_with_one_hot.cuda(
        ), expand.cuda()
        latent_caps_with_one_hot = latent_caps_with_one_hot.transpose(2, 1)
        output_digit = caps_seg_net(latent_caps_with_one_hot)
        for i in range(2):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(set(np.arange(50)).difference(set(iou_oids)))
            mini = torch.min(output_digit[i, :, :])
            output_digit[i, :, non_cat_labels] = mini - 1000
        pred_choice = output_digit.data.cpu().max(2)[1]

        # Append the capsules which is used to reconstruct the same part in two shapes
        part_no = iou_oids[
            part_inter_no]  # translate part label form [0,4] to [0,50]
        part_viz = []
        for caps_no in range(opt.latent_caps_size):
            if (pred_choice[0, caps_no] == part_no
                    and pred_choice[1, caps_no] == part_no):
                part_viz.append(caps_no)

        # add source reconstruction to open3d point cloud container to viz
        for j in range(opt.latent_caps_size):
            current_patch = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch[m, ] = reconstructions[1][opt.latent_caps_size *
                                                        m + j, ]
            pcd_list_source[j].points = Vector3dVector(current_patch)
            if (j in part_viz):
                pcd_list_source[j].paint_uniform_color(
                    [colors[6, 0], colors[6, 1], colors[6, 2]])
            else:
                pcd_list_source[j].paint_uniform_color([0.8, 0.8, 0.8])

        # add target reconstruction to open3d point cloud container to viz
        for j in range(opt.latent_caps_size):
            current_patch = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch[m, ] = reconstructions[0][opt.latent_caps_size *
                                                        m + j, ]
            pcd_list_target[j].points = Vector3dVector(current_patch)
            if (j in part_viz):
                pcd_list_target[j].paint_uniform_color(
                    [colors[5, 0], colors[5, 1], colors[5, 2]])
            else:
                pcd_list_target[j].paint_uniform_color([0.8, 0.8, 0.8])

        # interpolate the latent capsules between two shape
        latent_caps_inter = torch.zeros(inter_models_number,
                                        opt.latent_caps_size,
                                        opt.latent_vec_size)
        latent_caps_inter = Variable(latent_caps_inter)
        latent_caps_inter = latent_caps_inter.cuda()

        latent_caps_st_diff = torch.zeros(len(part_viz), opt.latent_caps_size)
        latent_caps_st_diff = latent_caps_st_diff.cuda()
        for j in range(len(part_viz)):
            latent_caps_st_diff[j, ] = latent_caps[
                0, part_viz[j], ] - latent_caps[1, part_viz[j], ]

        for i in range(inter_models_number):
            latent_caps_inter[i, ] = latent_caps[1, ]
            for j in range(len(part_viz)):
                latent_caps_inter[i, part_viz[j], ] = latent_caps[
                    1, part_viz[j], ] + latent_caps_st_diff[j, ] * i / (
                        inter_models_number - 1)

        # decode the interpolated latent capsules
        reconstructions_inter = capsule_net_decoder(latent_caps_inter)
        reconstructions_inter = reconstructions_inter.transpose(1,
                                                                2).data.cpu()

        # add interpolated latent capsule reconstruction to open3d point cloud container to viz
        for i in range(inter_models_number):
            for j in range(opt.latent_caps_size):
                current_patch = torch.zeros(
                    int(opt.num_points / opt.latent_caps_size), 3)
                for m in range(int(opt.num_points / opt.latent_caps_size)):
                    current_patch[m, ] = reconstructions_inter[i][
                        opt.latent_caps_size * m + j, ]
                pcd_list_inter[i][j].points = Vector3dVector(current_patch)
                part_no = iou_oids[part_inter_no]
                if (j in part_viz):
                    pcd_list_inter[i][j].paint_uniform_color(
                        [colors[6, 0], colors[6, 1], colors[6, 2]])
                else:
                    pcd_list_inter[i][j].paint_uniform_color([0.8, 0.8, 0.8])

        # show all interpolation
        all_point = PointCloud()
        for j in range(opt.latent_caps_size):
            pcd_list_source[j].transform(flip_transforms)
            pcd_list_target[j].transform(flip_transformt)
            all_point += pcd_list_source[j]
            all_point += pcd_list_target[j]

        for r in range(inter_models_number):
            flip_transform_inter = [[1, 0, 0, transform_range[r]],
                                    [0, cosval, -sinval, -2],
                                    [0, sinval, cosval, 0], [0, 0, 0, 1]]
            for k in range(opt.latent_caps_size):
                pcd_list_inter[r][k].transform(flip_transform_inter)
                all_point += pcd_list_inter[r][k]
        draw_geometries([all_point])