Пример #1
0
    capsule_net = capsule_net.cuda()
capsule_net = capsule_net.eval()

pcd_list = []
for i in range(opt.latent_caps_size):
    pcd_ = PointCloud()
    pcd_list.append(pcd_)

#random selected capsules to show their reconstruction with color
hight_light_caps = [
    np.random.randint(0, opt.latent_caps_size) for r in range(10)
]
colors = plt.cm.tab20((np.arange(20)).astype(int))

test_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                class_choice="Airplane",
                                                npoints=opt.num_points,
                                                split='test')
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=4)

for batch_id, data in enumerate(test_dataloader):
    points, _, = data
    if (points.size(0) < opt.batch_size):
        break

    points = Variable(points)
    points = points.transpose(2, 1)
    if USE_CUDA:
        points = points.cuda()
Пример #2
0
if opt.model != '':
    Autoencoder.load_state_dict(
        torch.load(
            opt.model,
            map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.model)['epoch']

if USE_CUDA:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    Autoencoder.to(device)
    Autoencoder = torch.nn.DataParallel(Autoencoder)

dset = shapenet_part_loader.PartDataset(
    root='../../dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
    classification=True,
    class_choice=None,
    npoints=opt.num_points,
    split='train')
assert dset
dataloader = torch.utils.data.DataLoader(dset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=int(opt.workers))
print(len(dataloader))

test_dset = shapenet_part_loader.PartDataset(
    root='../../dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
    classification=True,
    class_choice=None,
    npoints=opt.num_points,
    split='test')
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    #create folder to save trained models
    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)

    train_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                     npoints=opt.num_points,
                                                     split='train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    for epoch in range(opt.n_epochs):
        if epoch < 50:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.0001)
        elif epoch < 150:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.00001)
        else:
            optimizer = optim.Adam(capsule_net.parameters(), lr=0.000001)

# train
        capsule_net.train()
        train_loss_sum = 0
        for batch_id, data in enumerate(train_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()

            optimizer.zero_grad()
            codewords, reconstructions = capsule_net(points)
            train_loss = capsule_net.module.loss(points, reconstructions)
            train_loss.backward()
            optimizer.step()
            train_loss_sum += train_loss.item()

            if batch_id % 50 == 0:
                print('bactch_no:%d/%d, train_loss: %f ' %
                      (batch_id, len(train_dataloader), train_loss.item()))

        print('Average train loss of epoch %d : %f' %
              (epoch, (train_loss_sum / len(train_dataloader))))

        if epoch % 5 == 0:
            dict_name = opt.outf + '/' + opt.dataset + '_dataset_' + '_' + str(
                opt.latent_caps_size) + 'caps_' + str(
                    opt.latent_caps_size) + 'vec_' + str(epoch) + '.pth'
            torch.save(capsule_net.module.state_dict(), dict_name)
Пример #4
0
def main_worker():
    opt, io, tb = get_args()
    start_epoch = -1
    start_time = time.time()
    BASE_DIR = os.path.dirname(
        os.path.abspath(__file__))  # python script folder
    ckt = None
    if len(opt.restart_from) > 0:
        ckt = torch.load(opt.restart_from)
        start_epoch = ckt['epoch'] - 1

    # load configuration from file
    try:
        with open(opt.config) as cf:
            config = json.load(cf)
    except IOError as error:
        print(error)

    # backup relevant files
    shutil.copy(src=os.path.abspath(__file__),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=opt.config,
                dst=os.path.join(opt.save_dir, 'backup_code',
                                 'config.json.backup'))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    io.cprint(f"Arguments: {str(opt)}")
    io.cprint(f"Configuration: {str(config)}")

    pnum = config['completion_trainer']['num_points']
    class_choice = opt.class_choice
    # datasets + loaders
    if len(class_choice) > 0:
        class_choice = ''.join(opt.class_choice.split()).split(
            ",")  # sanitize + split(",")
        io.cprint("Class choice list: {}".format(str(class_choice)))
    else:
        class_choice = None  # Train on all classes! (if opt.class_choice=='')

    tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='train')

    te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='test')

    tr_loader = torch.utils.data.DataLoader(tr_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.workers,
                                            drop_last=True)

    te_loader = torch.utils.data.DataLoader(te_dataset,
                                            batch_size=64,
                                            shuffle=True,
                                            num_workers=opt.workers)

    num_holes = int(opt.num_holes)
    crop_point_num = int(opt.crop_point_num)
    context_point_num = int(opt.context_point_num)
    # io.cprint("Num holes: {}".format(num_holes))
    # io.cprint("Crop points num: {}".format(crop_point_num))
    # io.cprint("Context points num: {}".format(context_point_num))
    # io.cprint("Pool1 num points selected: {}".format(opt.pool1_points))
    # io.cprint("Pool2 num points selected: {}".format(opt.pool2_points))
    """" Models """
    gl_encoder = Encoder(conf=config)
    generator = Generator(conf=config,
                          pool1_points=int(opt.pool1_points),
                          pool2_points=int(opt.pool2_points))
    gl_encoder.apply(weights_init_normal)  # affecting only non pretrained
    generator.apply(weights_init_normal)  # not pretrained
    print("Encoder: ", gl_encoder)
    print("Generator: ", generator)

    if ckt is not None:
        io.cprint(f"Restart Training from epoch {start_epoch}.")
        gl_encoder.load_state_dict(ckt['gl_encoder_state_dict'])
        generator.load_state_dict(ckt['generator_state_dict'])
    else:
        io.cprint("Training Completion Task...")
        local_fe_fn = config['completion_trainer']['checkpoint_local_enco']
        global_fe_fn = config['completion_trainer']['checkpoint_global_enco']

        if len(local_fe_fn) > 0:
            local_enco_dict = torch.load(local_fe_fn)['model_state_dict']
            # refactoring pretext-trained local dgcnn encoder state dict keys
            local_enco_dict = remove_prefix_dict(
                state_dict=local_enco_dict, to_remove_str='local_encoder.')
            loc_load_result = gl_encoder.local_encoder.load_state_dict(
                local_enco_dict, strict=False)
            io.cprint(
                f"Local FE pretrained weights - loading res: {str(loc_load_result)}"
            )
        else:
            # Ablation experiments only
            io.cprint("Local FE pretrained weights - NOT loaded", color='r')

        if len(global_fe_fn) > 0:
            global_enco_dict = torch.load(global_fe_fn, )['global_encoder']
            glob_load_result = gl_encoder.global_encoder.load_state_dict(
                global_enco_dict, strict=True)
            io.cprint(
                f"Global FE pretrained weights - loading res: {str(glob_load_result)}",
                color='b')
        else:
            # Ablation experiments only
            io.cprint("Global FE pretrained weights - NOT loaded", color='r')

    io.cprint("Num GPUs: " + str(torch.cuda.device_count()) +
              ", Parallelism: {}".format(opt.parallel))
    if opt.parallel and torch.cuda.device_count() > 1:
        gl_encoder = torch.nn.DataParallel(gl_encoder)
        generator = torch.nn.DataParallel(generator)
    gl_encoder.to(device)
    generator.to(device)

    # Optimizers + schedulers
    opt_E = torch.optim.Adam(
        gl_encoder.parameters(),
        lr=config['completion_trainer']['enco_lr'],  # def: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_E = torch.optim.lr_scheduler.StepLR(
        opt_E,
        step_size=config['completion_trainer']['enco_step'],  # def: 25
        gamma=0.5)

    opt_G = torch.optim.Adam(
        generator.parameters(),
        lr=config['completion_trainer']['gen_lr'],  # def: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_G = torch.optim.lr_scheduler.StepLR(
        opt_G,
        step_size=config['completion_trainer']['gen_step'],  # def: 40
        gamma=0.5)

    if ckt is not None:
        opt_E.load_state_dict(ckt['optimizerE_state_dict'])
        opt_G.load_state_dict(ckt['optimizerG_state_dict'])
        sched_E.load_state_dict(ckt['schedulerE_state_dict'])
        sched_G.load_state_dict(ckt['schedulerG_state_dict'])

    if not opt.fps_centroids:
        # 5 viewpoints to crop around - same as in PFNet
        centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0],
                                [-1, 1, 0]])
    else:
        raise NotImplementedError('experimental')
        centroids = None

    io.cprint("Training.. \n")
    best_test = sys.float_info.max
    best_ep = -1
    it = 0  # global iteration counter
    vis_folder = None
    for epoch in range(start_epoch + 1, opt.epochs):
        start_ep_time = time.time()
        count = 0.0
        tot_loss = 0.0
        tot_fine_loss = 0.0
        tot_raw_loss = 0.0
        gl_encoder = gl_encoder.train()
        generator = generator.train()
        for i, data in enumerate(tr_loader, 0):
            it += 1
            points, _ = data
            B, N, dim = points.size()
            count += B

            partials = []
            fine_gts, raw_gts = [], []
            N_partial_points = N - (crop_point_num * num_holes)
            for m in range(B):
                # points[m]: complete shape of size (N,3)
                # partial: partial point cloud to complete
                # fine_gt: missing part ground truth
                # raw_gt: missing part ground truth + frame points (where frame points are points included in partial)
                partial, fine_gt, raw_gt = crop_shape(points[m],
                                                      centroids=centroids,
                                                      scales=[
                                                          crop_point_num,
                                                          (crop_point_num +
                                                           context_point_num)
                                                      ],
                                                      n_c=num_holes)

                if partial.size(0) > N_partial_points:
                    assert num_holes > 1, "Should be no need to resample if not multiple holes case"
                    # sampling without replacement
                    choice = torch.randperm(partial.size(0))[:N_partial_points]
                    partial = partial[choice]

                partials.append(partial)
                fine_gts.append(fine_gt)
                raw_gts.append(raw_gt)

            if i == 1 and epoch % opt.it_test == 0:
                # make some visualization
                vis_folder = os.path.join(opt.vis_dir,
                                          "epoch_{}".format(epoch))
                safe_make_dirs([vis_folder])
                print(f"ep {epoch} - Saving visualizations into: {vis_folder}")
                for j in range(len(partials)):
                    np.savetxt(X=partials[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_cropped.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fine_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_fine_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=raw_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_raw_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

            partials = torch.stack(partials).to(device).permute(
                0, 2, 1)  # [B, 3, N-512]
            fine_gts = torch.stack(fine_gts).to(device)  # [B, 512, 3]
            raw_gts = torch.stack(raw_gts).to(device)  # [B, 512 + context, 3]

            if i == 1:  # sanity check
                print("[dbg]: partials: ", partials.size(), ' ',
                      partials.device)
                print("[dbg]: fine grained gts: ", fine_gts.size(), ' ',
                      fine_gts.device)
                print("[dbg]: raw grained gts: ", raw_gts.size(), ' ',
                      raw_gts.device)

            gl_encoder.zero_grad()
            generator.zero_grad()
            feat = gl_encoder(partials)
            fake_fine, fake_raw = generator(
                feat
            )  # pred_fine (only missing part), pred_intermediate (missing + frame)

            # pytorch 1.2 compiled Chamfer (C2C) dist.
            assert fake_fine.size() == fine_gts.size(
            ), "Wrong input shapes to Chamfer module"
            if i == 0:
                if fake_raw.size() != raw_gts.size():
                    warnings.warn(
                        "size dismatch for: raw_pred: {}, raw_gt: {}".format(
                            str(fake_raw.size()), str(raw_gts.size())))

            # fine grained prediction + gt
            fake_fine = fake_fine.contiguous()
            fine_gts = fine_gts.contiguous()
            # raw prediction + gt
            fake_raw = fake_raw.contiguous()
            raw_gts = raw_gts.contiguous()

            dist1, dist2, _, _ = NND.nnd(
                fake_fine, fine_gts)  # fine grained loss computation
            dist1_raw, dist2_raw, _, _ = NND.nnd(
                fake_raw, raw_gts)  # raw grained loss computation

            # standard C2C distance loss
            fine_loss = 100 * (0.5 * torch.mean(dist1) +
                               0.5 * torch.mean(dist2))

            # raw loss: missing part + frame
            raw_loss = 100 * (0.5 * torch.mean(dist1_raw) +
                              0.5 * torch.mean(dist2_raw))

            loss = fine_loss + opt.raw_weight * raw_loss  # missing part pred loss + α * raw reconstruction loss
            loss.backward()
            opt_E.step()
            opt_G.step()
            tot_loss += loss.item() * B
            tot_fine_loss += fine_loss.item() * B
            tot_raw_loss += raw_loss.item() * B

            if it % 10 == 0:
                io.cprint(
                    '[%d/%d][%d/%d]: loss: %.4f, fine CD: %.4f, interm. CD: %.4f'
                    % (epoch, opt.epochs, i, len(tr_loader), loss.item(),
                       fine_loss.item(), raw_loss.item()))

            # make visualizations
            if i == 1 and epoch % opt.it_test == 0:
                assert (vis_folder is not None and os.path.exists(vis_folder))
                fake_fine = fake_fine.cpu().detach().data.numpy()
                fake_raw = fake_raw.cpu().detach().data.numpy()
                for j in range(len(fake_fine)):
                    np.savetxt(X=fake_fine[j],
                               fname=os.path.join(
                                   vis_folder, '{}_pred_fine.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fake_raw[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_pred_raw.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

        sched_E.step()
        sched_G.step()
        io.cprint(
            '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' %
            (epoch, opt.epochs, tot_loss * 1.0 / count,
             tot_fine_loss * 1.0 / count, tot_raw_loss * 1.0 / count))
        tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_interm', tot_raw_loss * 1.0 / count, epoch)

        if epoch % opt.it_test == 0:
            torch.save(
                {
                    'type_exp':
                    'dgccn at local encoder',
                    'epoch':
                    epoch + 1,
                    'epoch_train_loss':
                    tot_loss * 1.0 / count,
                    'epoch_train_loss_raw':
                    tot_raw_loss * 1.0 / count,
                    'epoch_train_loss_fine':
                    tot_fine_loss * 1.0 / count,
                    'gl_encoder_state_dict':
                    gl_encoder.module.state_dict() if isinstance(
                        gl_encoder,
                        nn.DataParallel) else gl_encoder.state_dict(),
                    'generator_state_dict':
                    generator.module.state_dict() if isinstance(
                        generator, nn.DataParallel) else
                    generator.state_dict(),
                    'optimizerE_state_dict':
                    opt_E.state_dict(),
                    'optimizerG_state_dict':
                    opt_G.state_dict(),
                    'schedulerE_state_dict':
                    sched_E.state_dict(),
                    'schedulerG_state_dict':
                    sched_G.state_dict(),
                },
                os.path.join(opt.models_dir,
                             'checkpoint_' + str(epoch) + '.pth'))

        if epoch % opt.it_test == 0:
            test_cd, count = 0.0, 0.0
            for i, data in enumerate(te_loader, 0):
                points, _ = data
                B, N, dim = points.size()
                count += B

                partials = []
                fine_gts = []
                N_partial_points = N - (crop_point_num * num_holes)

                for m in range(B):
                    partial, fine_gt, _ = crop_shape(points[m],
                                                     centroids=centroids,
                                                     scales=[
                                                         crop_point_num,
                                                         (crop_point_num +
                                                          context_point_num)
                                                     ],
                                                     n_c=num_holes)

                    if partial.size(0) > N_partial_points:
                        assert num_holes > 1
                        # sampling Without replacement
                        choice = torch.randperm(
                            partial.size(0))[:N_partial_points]
                        partial = partial[choice]

                    partials.append(partial)
                    fine_gts.append(fine_gt)
                partials = torch.stack(partials).to(device).permute(
                    0, 2, 1)  # [B, 3, N-512]
                fine_gts = torch.stack(fine_gts).to(
                    device).contiguous()  # [B, 512, 3]

                # TEST FORWARD
                # Considering only missing part prediction at Test Time
                gl_encoder.eval()
                generator.eval()
                with torch.no_grad():
                    feat = gl_encoder(partials)
                    fake_fine, _ = generator(feat)

                fake_fine = fake_fine.contiguous()
                assert fake_fine.size() == fine_gts.size()
                dist1, dist2, _, _ = NND.nnd(fake_fine, fine_gts)
                cd_loss = 100 * (0.5 * torch.mean(dist1) +
                                 0.5 * torch.mean(dist2))
                test_cd += cd_loss.item() * B

            test_cd = test_cd * 1.0 / count
            io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' %
                      (epoch, opt.epochs, test_cd),
                      color="b")
            tb.add_scalar('Test/cd_loss', test_cd, epoch)
            is_best = test_cd < best_test
            best_test = min(best_test, test_cd)

            if is_best:
                # best model case
                best_ep = epoch
                io.cprint("New best test %.5f at epoch %d" %
                          (best_test, best_ep))
                shutil.copyfile(src=os.path.join(
                    opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'),
                                dst=os.path.join(opt.models_dir,
                                                 'best_model.pth'))
        io.cprint(
            '[%d/%d] Epoch time: %s' %
            (epoch, num_epochs,
             time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time))))

    # Script ends
    hours, rem = divmod(time.time() - start_time, 3600)
    minutes, seconds = divmod(rem, 60)
    io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
Пример #5
0
def run():
    print(opt)

    blue = lambda x: '\033[94m' + x + '\033[0m'
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    point_netG = _netG(opt.num_scales, opt.each_scales_size,
                       opt.point_scales_list, opt.crop_point_num)
    point_netD = _netlocalD(opt.crop_point_num)
    cudnn.benchmark = True
    resume_epoch = 0

    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("Conv1d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm1d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        point_netG = torch.nn.DataParallel(point_netG)
        point_netD = torch.nn.DataParallel(point_netD)
        point_netG.to(device)
        point_netG.apply(weights_init_normal)
        point_netD.to(device)
        point_netD.apply(weights_init_normal)
    if opt.netG != '':
        point_netG.load_state_dict(
            torch.load(
                opt.netG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netG)['epoch']
    if opt.netD != '':
        point_netD.load_state_dict(
            torch.load(
                opt.netD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netD)['epoch']

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)


# def run():

# transforms = transforms.Compose([d_utils.PointcloudToTensor(),])
    dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='train')
    assert dset

    dataloader = torch.utils.data.DataLoader(dset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    test_dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='test')
    test_dataloader = torch.utils.data.DataLoader(test_dset,
                                                  batch_size=opt.batchSize,
                                                  shuffle=True,
                                                  num_workers=int(opt.workers))

    #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False)
    #assert dset
    #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))
    #
    #
    #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False)
    #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))

    #pointcls_net.apply(weights_init)
    print(point_netG)
    print(point_netD)

    criterion = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_PointLoss = PointLoss().to(device)

    # setup optimizer
    optimizerD = torch.optim.Adam(point_netD.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    optimizerG = torch.optim.Adam(point_netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD,
                                                 step_size=40,
                                                 gamma=0.2)
    schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG,
                                                 step_size=40,
                                                 gamma=0.2)

    real_label = 1
    fake_label = 0

    crop_point_num = int(opt.crop_point_num)
    input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3)
    label = torch.FloatTensor(opt.batchSize)

    num_batch = len(dset) / opt.batchSize

    ###########################
    #  G-NET and T-NET
    ##########################djel
    if opt.D_choose == 1:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            if __name__ == '__main__':
                # On Windows calling this function is necessary.
                freeze_support()

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    #Set viewpoints
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(
                            choice, 1)  #Random choose one of the viewpoint
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]

                # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                label.resize_([batch_size, 1]).fill_(real_label)
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                label = label.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netD = point_netD.train()
                ############################
                # (2) Update D network
                ###########################
                point_netD.zero_grad()
                real_center = torch.unsqueeze(real_center, 1)
                output = point_netD(real_center)
                errD_real = criterion(output, label)
                errD_real.backward()

                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                label.data.fill_(fake_label)
                output = point_netD(fake.detach())
                errD_fake = criterion(output, label)
                # on(output, label)
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################
                point_netG.zero_grad()
                label.data.fill_(real_label)
                output = point_netD(fake)
                errG_D = criterion(output, label)
                errG_l2 = 0

                # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy()
                np.savetxt('./post_set' + '.txt',
                           torch.cat([
                               input_cropped1[0, :, :],
                               torch.squeeze(fake, 1)[0, :, :]
                           ],
                                     dim=0).cpu().detach().numpy(),
                           fmt="%f,%f,%f")

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2
                errG.backward()
                optimizerG.step()

                np.savetxt('./fake' + '.txt',
                           torch.squeeze(fake,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_A
                np.savetxt('./real_center' + '.txt',
                           torch.squeeze(real_center,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_B

                print(
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' +
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))

                if i % 10 == 0:
                    print('After, ', i, '-th batch')
                    f.write('\n' + 'After, ' + str(i) + '-th batch')
                    for i, data in enumerate(test_dataloader, 0):
                        real_point, target = data

                        batch_size = real_point.size()[0]
                        real_center = torch.FloatTensor(
                            batch_size, 1, opt.crop_point_num, 3)
                        input_cropped1 = torch.FloatTensor(
                            batch_size, opt.pnum, 3)
                        input_cropped1 = input_cropped1.data.copy_(real_point)
                        real_point = torch.unsqueeze(real_point, 1)
                        input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                        p_origin = [0, 0, 0]

                        if opt.cropmethod == 'random_center':
                            choice = [
                                torch.Tensor([1, 0, 0]),
                                torch.Tensor([0, 0, 1]),
                                torch.Tensor([1, 0, 1]),
                                torch.Tensor([-1, 0, 0]),
                                torch.Tensor([-1, 1, 0])
                            ]

                            for m in range(batch_size):
                                index = random.sample(choice, 1)
                                distance_list = []
                                p_center = index[0]
                                for n in range(opt.pnum):
                                    distance_list.append(
                                        distance_squre(real_point[m, 0, n],
                                                       p_center))
                                distance_order = sorted(
                                    enumerate(distance_list),
                                    key=lambda x: x[1])
                                for sp in range(opt.crop_point_num):
                                    input_cropped1.data[
                                        m, 0, distance_order[sp]
                                        [0]] = torch.FloatTensor([0, 0, 0])
                                    real_center.data[m, 0, sp] = real_point[
                                        m, 0, distance_order[sp][0]]
                        real_center = real_center.to(device)
                        real_center = torch.squeeze(real_center, 1)
                        input_cropped1 = input_cropped1.to(device)
                        input_cropped1 = torch.squeeze(input_cropped1, 1)
                        input_cropped2_idx = utils.farthest_point_sample(
                            input_cropped1, opt.point_scales_list[1], RAN=True)
                        input_cropped2 = utils.index_points(
                            input_cropped1, input_cropped2_idx)
                        input_cropped3_idx = utils.farthest_point_sample(
                            input_cropped1,
                            opt.point_scales_list[2],
                            RAN=False)
                        input_cropped3 = utils.index_points(
                            input_cropped1, input_cropped3_idx)
                        input_cropped1 = Variable(input_cropped1,
                                                  requires_grad=False)
                        input_cropped2 = Variable(input_cropped2,
                                                  requires_grad=False)
                        input_cropped3 = Variable(input_cropped3,
                                                  requires_grad=False)
                        input_cropped2 = input_cropped2.to(device)
                        input_cropped3 = input_cropped3.to(device)
                        input_cropped = [
                            input_cropped1, input_cropped2, input_cropped3
                        ]
                        point_netG.eval()
                        fake_center1, fake_center2, fake = point_netG(
                            input_cropped)
                        CD_loss = criterion_PointLoss(
                            torch.squeeze(fake, 1),
                            torch.squeeze(real_center, 1))
                        print('test result:', CD_loss)
                        f.write('\n' + 'test result:  %.4f' % (CD_loss))
                        break
                f.close()
            schedulerD.step()
            schedulerG.step()
            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Trained_Model/point_netG' + str(epoch) + '.pth')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netD.state_dict()
                    }, 'Trained_Model/point_netD' + str(epoch) + '.pth')

    #
    #############################
    ## ONLY G-NET
    ############################
    else:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)
                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(choice, 1)
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netG.zero_grad()
                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG_l2.backward()
                optimizerG.step()
                print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                      (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                    (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f.close()
            schedulerD.step()
            schedulerG.step()

            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Checkpoint/point_netG' + str(epoch) + '.pth')
Пример #6
0
def main(CLASS="None"):
    if CLASS == "None": exit()

    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        if opt.save_training:
            split = 'train'
        else:
            split = 'test'
        dataset = shapenet_part_loader.PartDataset(classification=True,
                                                   npoints=opt.num_points,
                                                   split=split,
                                                   class_choice=CLASS)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                  npoints=opt.num_points,
                                                  train=opt.save_training)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)
    elif opt.dataset == 'modelnet40':
        dataset = modelnet40_loader.ModelNetH5Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)

    #  process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()

    count = 0

    if 'dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, _ = capsule_net(points)

            for i in range(opt.batch_size):
                torch.save(
                    latent_caps[i, :],
                    "tmp_lcs/latcaps_%s_%03d.pt" % (CLASS.lower(), count))
                count += 1
                if (count + 1) % 50 == 0: print(count + 1)

    else:
        pass
Пример #7
0
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #capsule_net = BetaPointCapsNet(opt.prim_caps_size, opt.prim_vec_size, opt.latent_caps_size, opt.latent_vec_size, opt.num_points)
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size, opt.latent_caps_size, opt.latent_vec_size, opt.num_points)
  
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
 
    if USE_CUDA:       
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    # create folder to save trained models
    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)

    # create folder to save logs
    if LOGGING:
        log_dir='./logs'+'/'+opt.dataset+'_dataset_'+str(opt.latent_caps_size)+'caps_'+str(opt.latent_vec_size)+'vec'+'_batch_size_'+str(opt.batch_size)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        logger = Logger(log_dir)

    # select dataset    
    if opt.dataset=='shapenet_part':
        train_dataset = shapenet_part_loader.PartDataset(classification=True, npoints=opt.num_points, split='train')
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)        
    elif opt.dataset=='shapenet_core13':
        train_dataset = shapenet_core13_loader.ShapeNet(normal=False, npoints=opt.num_points, train=True)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
    elif opt.dataset=='shapenet_core55':
        train_dataset = shapenet_core55_loader.Shapnet55Dataset(batch_size=opt.batch_size, npoints=opt.num_points, shuffle=True, train=True)

    # BVAE CONFIGURATIONS HARDCODING
    #loss_mode = 'gaussian' # loss_mode was decoder_list in bVAE
    loss_mode = 'chamfer' 

    loss_objective = "H" # Higgin et al "H", or Burgess et al "B"

    C_max = 25          # default 25, pending addition to args
    C_stop_iter = 1e5   # default 1e5, pending addition to args
    global_iter = 0     # iteration count
    C_max = Variable(torch.FloatTensor([C_max]).cuda()) # use_cuda = True

    gamma = 1000        # default 1000, pending addition to args
    beta = 4            # default 4, pending addition to args
    w_beta = 0.5        # weight assigned to beta loss against reconstruction loss (chamfer distance)


    # training process for 'shapenet_part' or 'shapenet_core13'
    #capsule_net.train()
    if 'train_dataloader' in locals().keys() :
        for epoch in range(opt.n_epochs+1):
            if epoch < 50:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.01)
            elif epoch<150:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.001)
            else:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.0001)

            capsule_net.train()
            train_loss_sum, recon_loss_sum, beta_loss_sum = 0, 0, 0

            for batch_id, data in enumerate(train_dataloader):
                global_iter += 1

                points, _= data
                if(points.size(0)<opt.batch_size):
                    break
                points = Variable(points)
                points = points.transpose(2, 1)
                if USE_CUDA:
                    points = points.cuda()
    
                optimizer.zero_grad()
                
                # ---- CRITICAL PART: new train loss computation (train_loss in bVAE was beta_vae_loss)
                #x_recon, latent_caps, caps_recon, logvar = capsule_net(points) # returns x_recon, latent_caps, caps_recon, logvar
                latent_capsules, x_recon = capsule_net(points)
                recon_loss = reconstruction_loss(points, x_recon, "chamfer") # RECONSTRUCTION LOSS
                #caps_loss = reconstruction_loss(latent_caps, caps_recon, "mse")
                #total_kld, _, _ = kl_divergence(latent_caps, logvar) # DIVERGENCE

                #if loss_objective == 'H':
                #    beta_loss = beta * total_kld
                #elif loss_objective == 'B':
                #    C = torch.clamp(C_max/C_stop_iter*global_iter, 0, C_max.data[0])
                #    beta_loss = gamma*(total_kld-C).abs()

                # sum of losses
                #beta_total_loss = beta_loss.sum()
                #train_loss = 0.7 * recon_loss + 0.2 * caps_loss + 0.1 * beta_total_loss # LOSS (can be weighted)
                
                # original train loss computation
                #train_loss = capsule_net.module.loss(points, x_recon)
                train_loss = recon_loss
                #train_loss.backward()

                # combining per capsule loss (pyTorch requires)
                train_loss.backward()
                optimizer.step()
                train_loss_sum += train_loss.item()

                # ---- END OF CRITICAL PART ----
                
                if LOGGING:
                    info = {'train loss': train_loss.item()}
                    for tag, value in info.items():
                        logger.scalar_summary(
                            tag, value, (len(train_dataloader) * epoch) + batch_id + 1)                
              
                if batch_id % 50 == 0:
                    print('batch_no: %d / %d, train_loss: %f ' %  (batch_id, len(train_dataloader), train_loss.item()))
    
            print('\nAverage train loss of epoch %d : %f\n' %\
                (epoch, (train_loss_sum / len(train_dataloader))))

            if epoch% 5 == 0:
                dict_name = "%s/%s_dataset_%dcaps_%dvec_%d.pth"%\
                    (opt.outf, opt.dataset, opt.latent_caps_size, opt.latent_vec_size, epoch)
                torch.save(capsule_net.module.state_dict(), dict_name)

    # training process for 'shapenet_core55' (NOT UP-TO-DATE)
    else:
        for epoch in range(opt.n_epochs+1):
            if epoch < 20:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.001)
            elif epoch<50:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.0001)
            else:
                optimizer = optim.Adam(capsule_net.parameters(), lr=0.00001)
        
            #capsule_net.train()
            train_loss_sum, recon_loss_sum, beta_loss_sum = 0, 0, 0

            while train_dataset.has_next_batch():
                global_iter += 1

                batch_id, points_= train_dataset.next_batch()
                points = torch.from_numpy(points_)
                if(points.size(0)<opt.batch_size):
                    break
                points = Variable(points)
                points = points.transpose(2, 1)
                if USE_CUDA:
                    points = points.cuda()

                optimizer.zero_grad()

                # ---- CRITICAL PART: same as above
                x_recon, mu, logvar = capsule_net(points)
                recon_loss = reconstruction_loss(points, x_recon, loss_mode)
                total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)

                if loss_objective == 'H':
                    beta_loss = beta*total_kld
                elif loss_objective == 'B':
                    C = torch.clamp(C_max/C_stop_iter*global_iter, 0, C_max.data[0])
                    beta_loss = gamma*(total_kld-C).abs() 

                train_loss = ((1-w_beta) * recon_loss + w_beta * beta_loss).sum()

                train_loss.backward()
                optimizer.step()
                train_loss_sum += train_loss.item()
                recon_loss_sum += recon_loss.item()
                beta_loss_sum += beta_loss.sum().item()
                # ---- END OF CRITICAL PART ----       

                if LOGGING:
                    info = {'train_loss': scalar_loss.item()}
                    for tag, value in info.items():
                        logger.scalar_summary(
                            tag, value, (int(57448/opt.batch_size) * epoch) + batch_id + 1)
                    
                if batch_id % 50 == 0:
                    print('batch_no: %d / %d at epoch %d; train_loss: %f ' %  (batch_id, int(57448/opt.batch_size),epoch,train_loss.item() )) # the dataset size is 57448
            
            print('Average train loss of epoch %d : %f' % \
                (epoch, (train_loss_sum / int(57448/opt.batch_size))))   
            print("Average reconstruction loss (10x): %f, beta loss (1e4x): %f" % \
                (recon_loss_sum * 100 / int(57448/opt.batch_size), beta_loss_sum * 10000 / int(57448/opt.batch_size)) )

            train_dataset.reset()

            if epoch % 5 == 0:
                dict_name = "%s/%s_dataset_%dcaps_%dvec_%d.pth"%\
                    (opt.outf, opt.dataset, opt.latent_caps_size, opt.latent_vec_size, epoch)
                torch.save(capsule_net.module.state_dict(), dict_name)
def main():

    #create pcd object list to save the reconstructed patch per capsule
    pcd_list = []
    for i in range(opt.latent_caps_size):
        pcd_ = PointCloud()
        pcd_list.append(pcd_)
    colors = plt.cm.tab20((np.arange(20)).astype(int))
    #random selected viz capsules
    hight_light_caps = [
        np.random.randint(0, opt.latent_caps_size) for r in range(10)
    ]

    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        test_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                        npoints=opt.num_points,
                                                        split='test')
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        test_dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                       npoints=opt.num_points,
                                                       train=False)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        test_dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=False)

    capsule_net.eval()
    if 'test_dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(test_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            for pointset_id in range(opt.batch_size):
                prc_r_all = reconstructions[pointset_id].transpose(
                    1, 0).contiguous().data.cpu()
                prc_r_all_point = PointCloud()
                prc_r_all_point.points = Vector3dVector(prc_r_all)
                colored_re_pointcloud = PointCloud()
                jc = 0
                for j in range(opt.latent_caps_size):
                    current_patch = torch.zeros(
                        int(opt.num_points / opt.latent_caps_size), 3)
                    for m in range(int(opt.num_points / opt.latent_caps_size)):
                        current_patch[m, ] = prc_r_all[
                            opt.latent_caps_size * m + j,
                        ]  # the reconstructed patch of the capsule m is not saved continuesly in the output reconstruction.
                    pcd_list[j].points = Vector3dVector(current_patch)
                    if (j in hight_light_caps):
                        pcd_list[j].paint_uniform_color(
                            [colors[jc, 0], colors[jc, 1], colors[jc, 2]])
                        jc += 1
                    else:
                        pcd_list[j].paint_uniform_color([0.8, 0.8, 0.8])
                    colored_re_pointcloud += pcd_list[j]
                draw_geometries([colored_re_pointcloud])


# test process for 'shapenet_core55'
    else:
        test_loss_sum = 0
        while test_dataset.has_next_batch():
            batch_id, points_ = test_dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            for pointset_id in range(opt.batch_size):
                prc_r_all = reconstructions[pointset_id].transpose(
                    1, 0).contiguous().data.cpu()
                prc_r_all_point = PointCloud()
                prc_r_all_point.points = Vector3dVector(prc_r_all)
                colored_re_pointcloud = PointCloud()
                jc = 0
                for j in range(opt.latent_caps_size):
                    current_patch = torch.zeros(
                        int(opt.num_points / opt.latent_caps_size), 3)
                    for m in range(int(opt.num_points / opt.latent_caps_size)):
                        current_patch[m, ] = prc_r_all[
                            opt.latent_caps_size * m + j,
                        ]  # the reconstructed patch of the capsule m is not saved continuesly in the output reconstruction.
                    pcd_list[j].points = Vector3dVector(current_patch)
                    if (j in hight_light_caps):
                        pcd_list[j].paint_uniform_color(
                            [colors[jc, 0], colors[jc, 1], colors[jc, 2]])
                        jc += 1
                    else:
                        pcd_list[j].paint_uniform_color([0.8, 0.8, 0.8])
                    colored_re_pointcloud += pcd_list[j]

                draw_geometries([colored_re_pointcloud])
def test_net_new(cfg,
                 epoch_idx=-1,
                 test_data_loader=None,
                 test_writer=None,
                 grnet=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    pnum = 2048
    crop_point_num = 512
    workers = 1
    batchSize = 16

    if test_data_loader == None:
        test_dataset_loader = shapenet_part_loader.PartDataset(
            root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
            classification=False,
            class_choice=save_name,
            npoints=pnum,
            split='test')
        test_data_loader = torch.utils.data.DataLoader(
            test_dataset_loader,
            batch_size=batchSize,
            shuffle=True,
            num_workers=int(workers))

    # Setup networks and initialize networks
    if grnet is None:
        grnet = GRNet(cfg, 4)

        if torch.cuda.is_available():
            grnet = grnet.to(device)

        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)  # lgtm [py/unused-import]
    seg_criterion = torch.nn.CrossEntropyLoss().cuda()

    total_sparse_cd = 0
    total_dense_cd = 0

    total_sparse_ce = 0
    total_dense_ce = 0

    total_sparse_miou = 0
    total_dense_miou = 0

    total_sparse_acc = 0
    total_dense_acc = 0

    # Testing loop
    for batch_idx, (
            data,
            seg,
            model_ids,
    ) in enumerate(test_data_loader):
        model_id = model_ids[0]

        with torch.no_grad():
            input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(data)

            if batch_idx == 200:
                pass  # break

            data = data.to(device)
            seg = seg.to(device)

            input_cropped1 = input_cropped1.to(device)

            # remove points to make input incomplete
            choice = [
                torch.Tensor([1, 0, 0]),
                torch.Tensor([0, 0, 1]),
                torch.Tensor([1, 0, 1]),
                torch.Tensor([-1, 0, 0]),
                torch.Tensor([-1, 1, 0])
            ]
            for m in range(data.size()[0]):
                index = random.sample(choice, 1)
                p_center = index[0].to(device)
                distances = torch.sum((data[m] - p_center)**2, dim=1)
                order = torch.argsort(distances)

                zero_point = torch.FloatTensor([0, 0, 0]).to(device)
                input_cropped1.data[m, order[:crop_point_num]] = zero_point

            sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet(
                input_cropped1)

            if save_mode:
                np.save("./saved_results/original_" + save_name,
                        data.detach().cpu().numpy())
                np.save("./saved_results/original_seg_" + save_name,
                        seg.detach().cpu().numpy())
                np.save("./saved_results/cropped_" + save_name,
                        input_cropped1.detach().cpu().numpy())
                np.save("./saved_results/sparse_" + save_name,
                        sparse_ptcloud.detach().cpu().numpy())
                np.save("./saved_results/sparse_seg_" + save_name,
                        sparse_seg.detach().cpu().numpy())
                np.save("./saved_results/dense_" + save_name,
                        dense_ptcloud.detach().cpu().numpy())
                np.save("./saved_results/dense_seg_" + save_name,
                        dense_seg.detach().cpu().numpy())
                sys.exit()

            total_sparse_cd += chamfer_dist(sparse_ptcloud, data).to(device)
            total_dense_cd += chamfer_dist(dense_ptcloud, data).to(device)

            sparse_seg_gt = get_seg_gts(seg, data, sparse_ptcloud)
            sparse_miou, sparse_acc = miou(torch.argmax(sparse_seg, dim=2),
                                           sparse_seg_gt)
            total_sparse_miou += sparse_miou
            total_sparse_acc += sparse_acc

            print(batch_idx)

            total_sparse_ce += seg_criterion(torch.transpose(sparse_seg, 1, 2),
                                             sparse_seg_gt)

            dense_seg_gt = get_seg_gts(seg, data, dense_ptcloud)
            dense_miou, dense_acc = miou(torch.argmax(dense_seg, dim=2),
                                         dense_seg_gt)
            total_dense_miou += dense_miou
            print(dense_miou)
            total_dense_acc += dense_acc
            total_dense_ce += seg_criterion(torch.transpose(dense_seg, 1, 2),
                                            dense_seg_gt)

    length = len(test_data_loader)
    print("sparse cd: " + str(total_sparse_cd * 1000 / length))
    print("dense cd: " + str(total_dense_cd * 1000 / length))
    print("sparse acc: " + str(total_sparse_acc / length))
    print("dense acc: " + str(total_dense_acc / length))
    print("sparse miou: " + str(total_sparse_miou / length))
    print("dense miou: " + str(total_dense_miou / length))
    print("sparse ce: " + str(total_sparse_ce / length))
    print("dense ce: " + str(total_dense_ce / length))

    return total_dense_miou / length
Пример #10
0
io.cprint("Random Seed: %d" % opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)
""" Datasets and Loader """
if len(opt.class_choice) > 0:
    class_choice = ''.join(opt.class_choice.split()).split(
        ",")  # sanitize + split(",")
    io.cprint("Class choice: {}\n".format(str(class_choice)))
else:
    class_choice = None  # iff. opt.class_choice=='' train on all classes

dset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                        classification=True,
                                        class_choice=class_choice,
                                        npoints=opt.pnum,
                                        split='train')

train_loader = torch.utils.data.DataLoader(
    dset,
    batch_size=opt.batchSize,
    shuffle=True,
    num_workers=int(opt.workers),
    drop_last=True,
)

test_dset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                             classification=True,
                                             class_choice=class_choice,
                                             npoints=opt.pnum,
Пример #11
0
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(
        os.path.join(BASE_DIR, '../../dataset/shapenet/'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(
        opt.prim_caps_size,
        opt.prim_vec_size,
        opt.latent_caps_size,
        opt.latent_vec_size,
        opt.num_points,
    )
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    pcd_colored = PointCloud()
    pcd_ori_colored = PointCloud()
    rotation_angle = -np.pi / 4
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[cosval, 0, sinval, -1], [0, 1, 0, 0],
                       [-sinval, 0, cosval, 0], [0, 0, 0, 1]]
    flip_transformt = [[cosval, 0, sinval, 1], [0, 1, 0, 0],
                       [-sinval, 0, cosval, 0], [0, 0, 0, 1]]

    correct_sum = 0
    for batch_id, data in enumerate(train_dataloader):

        points, part_label, cls_label = data
        print(part_label.shape
              )  # each point has a part label (e.g. wing, engine, tail)
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break

        # use the pre-trained AE to encode the point cloud into latent capsules
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        #concatanete the latent caps with one-hot part label
        cur_label_one_hot = np.zeros((opt.batch_size, 16), dtype=np.float32)
        for i in range(opt.batch_size):
            cur_label_one_hot[i, cls_label[i]] = 1
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            for j in range(opt.num_points):
                part_label[i, j] = iou_oids[part_label[i, j]]
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            opt.batch_size, 16, opt.latent_caps_size).transpose(1, 2)
        expand, latent_caps = Variable(expand), Variable(latent_caps)
        expand, latent_caps = expand.cuda(), latent_caps.cuda()
        print(expand.shape)
        print(latent_caps.shape, "before concat")
        latent_caps = torch.cat((latent_caps, expand), 2)
        print(latent_caps.shape, "after concat")

        # predict the part class per capsule
        latent_caps = latent_caps.transpose(2, 1)
        output = caps_seg_net(latent_caps)
        for i in range(opt.batch_size):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(
                set(np.arange(50)).difference(set(iou_oids))
            )  # there are 50 part classes in all the 16 catgories of objects
            mini = torch.min(output[i, :, :])
            output[i, :, non_cat_labels] = mini - 1000
        pred_choice = output.data.cpu().max(2)[1]

        print(pred_choice.shape)
        print(pred_choice[0, :])
        exit()

        # assign predicted capsule part label to its reconstructed point patch
        reconstructions_part_label = torch.zeros(
            [opt.batch_size, opt.num_points], dtype=torch.int64)
        for i in range(opt.batch_size):
            for j in range(opt.latent_caps_size
                           ):  # subdivisions of points from each latent cap
                for m in range(opt.num_points // opt.latent_caps_size
                               ):  # all points in each subdivision
                    reconstructions_part_label[i, opt.latent_caps_size * m +
                                               j] = pred_choice[i, j]

        # assign the part label from the reconstructed point cloud to the input point set with NN
        pcd = pcd = PointCloud()
        pred_ori_pointcloud_part_label = torch.zeros(
            [opt.batch_size, opt.num_points], dtype=torch.int64)
        for point_set_no in range(opt.batch_size):
            pcd.points = Vector3dVector(reconstructions[point_set_no, ])
            pcd_tree = KDTreeFlann(pcd)
            for point_id in range(opt.num_points):
                [k, idx, _] = pcd_tree.search_knn_vector_3d(
                    points[point_set_no, point_id, :], 10)
                local_patch_labels = reconstructions_part_label[point_set_no,
                                                                idx]
                pred_ori_pointcloud_part_label[point_set_no,
                                               point_id] = statistics.median(
                                                   local_patch_labels)

        # calculate the accuracy with the GT
        correct = pred_ori_pointcloud_part_label.eq(
            part_label.data.cpu()).cpu().sum()
        correct_sum = correct_sum + correct.item()
        print(' accuracy is: %f' %
              (correct_sum / float(opt.batch_size *
                                   (batch_id + 1) * opt.num_points)))

        # viz the part segmentation
        point_color = torch.zeros([opt.batch_size, opt.num_points, 3])
        point_ori_color = torch.zeros([opt.batch_size, opt.num_points, 3])

        for point_set_no in range(opt.batch_size):
            iou_oids = object2setofoid[objcats[cls_label[point_set_no]]]
            for point_id in range(opt.num_points):
                part_no = pred_ori_pointcloud_part_label[
                    point_set_no, point_id] - iou_oids[0]
                point_color[point_set_no, point_id, 0] = colors[part_no, 0]
                point_color[point_set_no, point_id, 1] = colors[part_no, 1]
                point_color[point_set_no, point_id, 2] = colors[part_no, 2]

            pcd_colored.points = Vector3dVector(points[point_set_no, ])
            pcd_colored.colors = Vector3dVector(point_color[point_set_no, ])

            for point_id in range(opt.num_points):
                part_no = part_label[point_set_no, point_id] - iou_oids[0]
                point_ori_color[point_set_no, point_id, 0] = colors[part_no, 0]
                point_ori_color[point_set_no, point_id, 1] = colors[part_no, 1]
                point_ori_color[point_set_no, point_id, 2] = colors[part_no, 2]

            pcd_ori_colored.points = Vector3dVector(points[point_set_no, ])
            pcd_ori_colored.colors = Vector3dVector(
                point_ori_color[point_set_no, ])

            pcd_ori_colored.transform(
                flip_transforms
            )  # tansform the pcd in order to viz both point cloud
            pcd_colored.transform(flip_transformt)
            draw_geometries([pcd_ori_colored, pcd_colored])
Пример #12
0
def main_worker():
    opt, io, tb = get_args()
    start_epoch = -1
    start_time = time.time()
    ckt = None
    if len(opt.restart_from) > 0:
        ckt = torch.load(opt.restart_from)
        start_epoch = ckt['epoch'] - 1

    # load configuration from file
    try:
        with open(opt.config) as cf:
            config = json.load(cf)
    except IOError as error:
        print(error)

    # backup relevant files
    shutil.copy(src=os.path.abspath(__file__),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=opt.config,
                dst=os.path.join(opt.save_dir, 'backup_code',
                                 'config.json.backup'))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    io.cprint(f"Arguments: {str(opt)}")
    io.cprint(f"Configuration: {str(config)}")

    pnum = config['completion_trainer'][
        'num_points']  # number of points of complete pointcloud
    class_choice = opt.class_choice  # config['completion_trainer']['class_choice']

    # datasets + loaders
    if len(class_choice) > 0:
        class_choice = ''.join(opt.class_choice.split()).split(
            ",")  # sanitize + split(",")
        io.cprint("Class choice list: {}".format(str(class_choice)))
    else:
        class_choice = None  # training on all snpart classes

    tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='train')

    te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='test')

    tr_loader = torch.utils.data.DataLoader(tr_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.workers,
                                            drop_last=True)

    te_loader = torch.utils.data.DataLoader(te_dataset,
                                            batch_size=64,
                                            shuffle=True,
                                            num_workers=opt.workers)

    num_holes = int(opt.num_holes)
    crop_point_num = int(opt.crop_point_num)
    context_point_num = int(opt.context_point_num)

    # io.cprint(f"Completion Setting:\n num classes {len(tr_dataset.cat.keys())}, num holes: {num_holes}, "
    #           f"crop point num: {crop_point_num}, frame/context point num: {context_point_num},\n"
    #           f"num points at pool1: {opt.pool1_points}, num points at pool2: {opt.pool2_points} ")

    # Models
    gl_encoder = Encoder(conf=config)
    generator = Generator(conf=config,
                          pool1_points=int(opt.pool1_points),
                          pool2_points=int(opt.pool2_points))
    gl_encoder.apply(
        weights_init_normal)  # affecting only non pretrained layers
    generator.apply(weights_init_normal)

    print("Encoder: ", gl_encoder)
    print("Generator: ", generator)

    if ckt is not None:
        # resuming training from intermediate checkpoint
        # restoring both encoder and generator state
        io.cprint(f"Restart Training from epoch {start_epoch}.")
        gl_encoder.load_state_dict(ckt['gl_encoder_state_dict'])
        generator.load_state_dict(ckt['generator_state_dict'])
        io.cprint("Whole model loaded from {}\n".format(opt.restart_from))
    else:
        # training the completion model
        # load local and global encoder pretrained (ssl pretexts) weights
        io.cprint("Training Completion Task...")
        local_fe_fn = config['completion_trainer']['checkpoint_local_enco']
        global_fe_fn = config['completion_trainer']['checkpoint_global_enco']

        if len(local_fe_fn) > 0:
            local_enco_dict = torch.load(local_fe_fn, )['model_state_dict']
            loc_load_result = gl_encoder.local_encoder.load_state_dict(
                local_enco_dict, strict=False)
            io.cprint(
                f"Local FE pretrained weights - loading res: {str(loc_load_result)}"
            )
        else:
            # Ablation experiments only
            io.cprint("Local FE pretrained weights - NOT loaded", color='r')

        if len(global_fe_fn) > 0:
            global_enco_dict = torch.load(global_fe_fn, )['global_encoder']
            glob_load_result = gl_encoder.global_encoder.load_state_dict(
                global_enco_dict, strict=True)
            io.cprint(
                f"Global FE pretrained weights - loading res: {str(glob_load_result)}",
                color='b')
        else:
            # Ablation experiments only
            io.cprint("Global FE pretrained weights - NOT loaded", color='r')

    io.cprint("Num GPUs: " + str(torch.cuda.device_count()) +
              ", Parallelism: {}".format(opt.parallel))
    if opt.parallel:
        # TODO: implement DistributedDataParallel training
        assert torch.cuda.device_count() > 1
        gl_encoder = torch.nn.DataParallel(gl_encoder)
        generator = torch.nn.DataParallel(generator)
    gl_encoder.to(device)
    generator.to(device)

    # Optimizers + schedulers
    opt_E = torch.optim.Adam(
        gl_encoder.parameters(),
        lr=config['completion_trainer']['enco_lr'],  # default is: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_E = torch.optim.lr_scheduler.StepLR(
        opt_E,
        step_size=config['completion_trainer']['enco_step'],  # default is: 25
        gamma=0.5)

    opt_G = torch.optim.Adam(
        generator.parameters(),
        lr=config['completion_trainer']['gen_lr'],  # default is: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_G = torch.optim.lr_scheduler.StepLR(
        opt_G,
        step_size=config['completion_trainer']['gen_step'],  # default is: 40
        gamma=0.5)

    if ckt is not None:
        # resuming training from intermediate checkpoint
        # restore optimizers state
        opt_E.load_state_dict(ckt['optimizerE_state_dict'])
        opt_G.load_state_dict(ckt['optimizerG_state_dict'])
        sched_E.load_state_dict(ckt['schedulerE_state_dict'])
        sched_G.load_state_dict(ckt['schedulerG_state_dict'])

    # crop centroids
    if not opt.fps_centroids:
        # 5 viewpoints to crop around - same crop procedure of PFNet - main paper
        centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0],
                                [-1, 1, 0]])
    else:
        raise NotImplementedError('experimental')
        centroids = None

    io.cprint('Centroids: ' + str(centroids))

    # training loop
    io.cprint("Training.. \n")
    best_test = sys.float_info.max
    best_ep, glob_it = -1, 0
    vis_folder = None
    for epoch in range(start_epoch + 1, opt.epochs):
        start_ep_time = time.time()
        count = 0.0
        tot_loss = 0.0
        tot_fine_loss = 0.0
        tot_interm_loss = 0.0
        gl_encoder = gl_encoder.train()
        generator = generator.train()
        for i, data in enumerate(tr_loader, 0):
            glob_it += 1
            points, _ = data
            B, N, dim = points.size()
            count += B

            partials = []
            fine_gts, interm_gts = [], []
            N_partial_points = N - (crop_point_num * num_holes)
            for m in range(B):
                partial, fine_gt, interm_gt = crop_shape(
                    points[m],
                    centroids=centroids,
                    scales=[
                        crop_point_num, (crop_point_num + context_point_num)
                    ],
                    n_c=num_holes)

                if partial.size(0) > N_partial_points:
                    assert num_holes > 1
                    # sampling without replacement
                    choice = torch.randperm(partial.size(0))[:N_partial_points]
                    partial = partial[choice]

                partials.append(partial)
                fine_gts.append(fine_gt)
                interm_gts.append(interm_gt)

            if i == 1 and epoch % opt.it_test == 0:
                # make some visualization
                vis_folder = os.path.join(opt.vis_dir,
                                          "epoch_{}".format(epoch))
                safe_make_dirs([vis_folder])
                print(f"ep {epoch} - Saving visualizations into: {vis_folder}")
                for j in range(len(partials)):
                    np.savetxt(X=partials[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_partial.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fine_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_fine_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=interm_gts[j],
                               fname=os.path.join(
                                   vis_folder, '{}_interm_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

            partials = torch.stack(partials).to(device).permute(
                0, 2, 1)  # [B, 3, N-512]
            fine_gts = torch.stack(fine_gts).to(device)  # [B, 512, 3]
            interm_gts = torch.stack(interm_gts).to(device)  # [B, 1024, 3]

            gl_encoder.zero_grad()
            generator.zero_grad()
            feat = gl_encoder(partials)
            pred_fine, pred_raw = generator(feat)

            # pytorch 1.2 compiled Chamfer (C2C) dist.
            assert pred_fine.size() == fine_gts.size()
            pred_fine, pred_raw = pred_fine.contiguous(), pred_raw.contiguous()
            fine_gts, interm_gts = fine_gts.contiguous(
            ), interm_gts.contiguous()

            dist1, dist2, _, _ = NND.nnd(pred_fine,
                                         fine_gts)  # missing part pred loss
            dist1_raw, dist2_raw, _, _ = NND.nnd(
                pred_raw, interm_gts)  # intermediate pred loss
            fine_loss = 50 * (torch.mean(dist1) + torch.mean(dist2)
                              )  # chamfer is weighted by 100
            interm_loss = 50 * (torch.mean(dist1_raw) + torch.mean(dist2_raw))

            loss = fine_loss + opt.raw_weight * interm_loss
            loss.backward()
            opt_E.step()
            opt_G.step()
            tot_loss += loss.item() * B
            tot_fine_loss += fine_loss.item() * B
            tot_interm_loss += interm_loss.item() * B

            if glob_it % 10 == 0:
                header = "[%d/%d][%d/%d]" % (epoch, opt.epochs, i,
                                             len(tr_loader))
                io.cprint('%s: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' %
                          (header, loss.item(), fine_loss.item(),
                           interm_loss.item()))

            # make visualizations
            if i == 1 and epoch % opt.it_test == 0:
                assert (vis_folder is not None and os.path.exists(vis_folder))
                pred_fine = pred_fine.cpu().detach().data.numpy()
                pred_raw = pred_raw.cpu().detach().data.numpy()
                for j in range(len(pred_fine)):
                    np.savetxt(X=pred_fine[j],
                               fname=os.path.join(
                                   vis_folder, '{}_pred_fine.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=pred_raw[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_pred_raw.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

        sched_E.step()
        sched_G.step()
        io.cprint(
            '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' %
            (epoch, opt.epochs, tot_loss * 1.0 / count,
             tot_fine_loss * 1.0 / count, tot_interm_loss * 1.0 / count))
        tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_interm', tot_interm_loss * 1.0 / count, epoch)

        if epoch % opt.it_test == 0:
            torch.save(
                {
                    'epoch':
                    epoch + 1,
                    'epoch_train_loss':
                    tot_loss * 1.0 / count,
                    'epoch_train_loss_raw':
                    tot_interm_loss * 1.0 / count,
                    'epoch_train_loss_fine':
                    tot_fine_loss * 1.0 / count,
                    'gl_encoder_state_dict':
                    gl_encoder.module.state_dict() if isinstance(
                        gl_encoder,
                        nn.DataParallel) else gl_encoder.state_dict(),
                    'generator_state_dict':
                    generator.module.state_dict() if isinstance(
                        generator, nn.DataParallel) else
                    generator.state_dict(),
                    'optimizerE_state_dict':
                    opt_E.state_dict(),
                    'optimizerG_state_dict':
                    opt_G.state_dict(),
                    'schedulerE_state_dict':
                    sched_E.state_dict(),
                    'schedulerG_state_dict':
                    sched_G.state_dict(),
                },
                os.path.join(opt.models_dir,
                             'checkpoint_' + str(epoch) + '.pth'))

        if epoch % opt.it_test == 0:
            test_cd, count = 0.0, 0.0
            for i, data in enumerate(te_loader, 0):
                points, _ = data
                B, N, dim = points.size()
                count += B

                partials = []
                fine_gts = []
                N_partial_points = N - (crop_point_num * num_holes)

                for m in range(B):
                    partial, fine_gt, _ = crop_shape(points[m],
                                                     centroids=centroids,
                                                     scales=[
                                                         crop_point_num,
                                                         (crop_point_num +
                                                          context_point_num)
                                                     ],
                                                     n_c=num_holes)

                    if partial.size(0) > N_partial_points:
                        assert num_holes > 1
                        # sampling Without replacement
                        choice = torch.randperm(
                            partial.size(0))[:N_partial_points]
                        partial = partial[choice]

                    partials.append(partial)
                    fine_gts.append(fine_gt)
                partials = torch.stack(partials).to(device).permute(
                    0, 2, 1)  # [B, 3, N-512]
                fine_gts = torch.stack(fine_gts).to(
                    device).contiguous()  # [B, 512, 3]

                # TEST FORWARD
                # Considering only missing part prediction at Test Time
                gl_encoder.eval()
                generator.eval()
                with torch.no_grad():
                    feat = gl_encoder(partials)
                    pred_fine, _ = generator(feat)

                pred_fine = pred_fine.contiguous()
                assert pred_fine.size() == fine_gts.size()
                dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts)
                cd_loss = 50 * (torch.mean(dist1) + torch.mean(dist2))
                test_cd += cd_loss.item() * B

            test_cd = test_cd * 1.0 / count
            io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' %
                      (epoch, opt.epochs, test_cd),
                      color="b")
            tb.add_scalar('Test/cd_loss', test_cd, epoch)
            is_best = test_cd < best_test
            best_test = min(best_test, test_cd)

            if is_best:
                # best model case
                best_ep = epoch
                io.cprint("New best test %.5f at epoch %d" %
                          (best_test, best_ep))
                shutil.copyfile(src=os.path.join(
                    opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'),
                                dst=os.path.join(opt.models_dir,
                                                 'best_model.pth'))
        io.cprint(
            '[%d/%d] Epoch time: %s' %
            (epoch, opt.epochs,
             time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time))))

    # Script ends
    hours, rem = divmod(time.time() - start_time, 3600)
    minutes, seconds = divmod(rem, 60)
    io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
Пример #13
0
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_caps_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        if opt.save_training:
            split = 'train'
        else:
            split = 'test'
        dataset = shapenet_part_loader.PartDataset(classification=False,
                                                   npoints=opt.num_points,
                                                   split=split)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)


# init saving process
    pcd = PointCloud()
    data_size = 0
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    out_file_path = os.path.join(dataset_main_path, opt.dataset, 'latent_caps')
    if not os.path.exists(out_file_path):
        os.makedirs(out_file_path)
    if opt.save_training:
        out_file_name = out_file_path + "/saved_train_with_part_label.h5"
    else:
        out_file_name = out_file_path + "/saved_test_with_part_label.h5"
    if os.path.exists(out_file_name):
        os.remove(out_file_name)
    fw = h5py.File(out_file_name, 'w', libver='latest')
    dset = fw.create_dataset("data", (
        1,
        opt.latent_caps_size,
        opt.latent_vec_size,
    ),
                             maxshape=(None, opt.latent_caps_size,
                                       opt.latent_vec_size),
                             dtype='<f4')
    dset_s = fw.create_dataset("part_label", (
        1,
        opt.latent_caps_size,
    ),
                               maxshape=(
                                   None,
                                   opt.latent_caps_size,
                               ),
                               dtype='uint8')
    dset_c = fw.create_dataset("cls_label", (1, ),
                               maxshape=(None, ),
                               dtype='uint8')
    fw.swmr_mode = True

    #  process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()

    for batch_id, data in enumerate(dataloader):
        points, part_label, cls_label = data
        if (points.size(0) < opt.batch_size):
            break
        points = Variable(points)
        points = points.transpose(2, 1)
        if USE_CUDA:
            points = points.cuda()
        latent_caps, reconstructions = capsule_net(points)

        # For each resonstructed point, find the nearest point in the input pointset,
        # use their part label to annotate the resonstructed point,
        # Then after checking which capsule reconstructed this point, use the part label to annotate this capsule
        reconstructions = reconstructions.transpose(1, 2).data.cpu()
        points = points.transpose(1, 2).data.cpu()
        cap_part_count = torch.zeros(
            [opt.batch_size, opt.latent_caps_size, opt.n_classes],
            dtype=torch.int64)
        for batch_no in range(points.size(0)):
            pcd.points = Vector3dVector(points[batch_no, ])
            pcd_tree = KDTreeFlann(pcd)
            for point_id in range(opt.num_points):
                [k, idx, _] = pcd_tree.search_knn_vector_3d(
                    reconstructions[batch_no, point_id, :], 1)
                point_part_label = part_label[batch_no, idx]
                caps_no = point_id % opt.latent_caps_size
                cap_part_count[batch_no, caps_no, point_part_label] += 1
        _, cap_part_label = torch.max(
            cap_part_count, 2
        )  # if the reconstucted points have multiple part labels, use the majority as the capsule part label

        # write the output latent caps and cls into file
        data_size = data_size + points.size(0)
        new_shape = (
            data_size,
            opt.latent_caps_size,
            opt.latent_vec_size,
        )
        dset.resize(new_shape)
        dset_s.resize((
            data_size,
            opt.latent_caps_size,
        ))
        dset_c.resize((data_size, ))

        latent_caps_ = latent_caps.cpu().detach().numpy()
        target_ = cap_part_label.numpy()
        dset[data_size - points.size(0):data_size, :, :] = latent_caps_
        dset_s[data_size - points.size(0):data_size] = target_
        dset_c[data_size -
               points.size(0):data_size] = cls_label.squeeze().numpy()

        dset.flush()
        dset_s.flush()
        dset_c.flush()
        print('accumalate of batch %d, and datasize is %d ' %
              ((batch_id), (dset.shape[0])))

    fw.close()
Пример #14
0
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vecs_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        test_dataset = shapenet_part_loader.PartDataset(classification=True,
                                                        npoints=opt.num_points,
                                                        split='test')
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        test_dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                       npoints=opt.num_points,
                                                       train=False)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        test_dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=False)

# test process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()
    if 'test_dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(test_dataloader):
            points, _ = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            test_loss = capsule_net.module.loss(points, reconstructions)
            test_loss_sum += test_loss.item()
            print('accumalate of batch %d loss is : %f' %
                  (batch_id, test_loss.item()))
        test_loss_sum = test_loss_sum / float(len(test_dataloader))
        print('test loss is : %f' % (test_loss_sum))


# test process for 'shapenet_core55'
    else:
        test_loss_sum = 0
        while test_dataset.has_next_batch():
            batch_id, points_ = test_dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)
            test_loss = capsule_net.module.loss(points, reconstructions)
            test_loss_sum += test_loss.item()
            print('accumalate of batch %d loss is : %f' %
                  (batch_id, test_loss.item()))
        test_loss_sum = test_loss_sum / float(len(test_dataloader))
        print('test loss is : %f' % (test_loss_sum))
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for only decoding
    capsule_net_decoder = PointCapsNetDecoder(opt.prim_caps_size,
                                              opt.prim_vec_size,
                                              opt.latent_caps_size,
                                              opt.latent_vec_size,
                                              opt.num_points)
    if opt.model != '':
        capsule_net_decoder.load_state_dict(torch.load(opt.model),
                                            strict=False)
    if USE_CUDA:
        capsule_net_decoder = capsule_net_decoder.cuda()
    capsule_net_decoder = capsule_net_decoder.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    # container for ground truth
    pcd_gt_source = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_source.append(pcd)
    pcd_gt_target = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_target.append(pcd)

    # container for ground truth cut and paste
    pcd_gt_replace_source = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_replace_source.append(pcd)
    pcd_gt_replace_target = []
    for i in range(2):
        pcd = PointCloud()
        pcd_gt_replace_target.append(pcd)

    # container for capsule based part replacement
    pcd_caps_replace_source = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_caps_replace_source.append(pcd)
    pcd_caps_replace_target = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_caps_replace_target.append(pcd)

    # apply a transformation in order to get a better view point
    ##airplane
    rotation_angle = np.pi / 2
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[1, 0, 0, -2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transforms_r = [[1, 0, 0, 2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]

    flip_transform_gt_s = [[1, 0, 0, -3], [0, cosval, -sinval, -1],
                           [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_gt_t = [[1, 0, 0, -3], [0, cosval, -sinval, 1],
                           [0, sinval, cosval, 0], [0, 0, 0, 1]]

    flip_transform_gt_re_s = [[1, 0, 0, 0], [0, cosval, -sinval, -1],
                              [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_gt_re_t = [[1, 0, 0, 0], [0, cosval, -sinval, 1],
                              [0, sinval, cosval, 0], [0, 0, 0, 1]]

    flip_transform_caps_re_s = [[1, 0, 0, 3], [0, cosval, -sinval, -1],
                                [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transform_caps_re_t = [[1, 0, 0, 3], [0, cosval, -sinval, 1],
                                [0, sinval, cosval, 0], [0, 0, 0, 1]]

    colors = plt.cm.tab20((np.arange(20)).astype(int))
    part_replace_no = 1  # the part that is replaced

    for batch_id, data in enumerate(train_dataloader):
        points, part_label, cls_label = data
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break
        all_model_pcd = PointCloud()

        gt_source_list0 = []
        gt_source_list1 = []
        gt_target_list0 = []
        gt_target_list1 = []
        for point_id in range(opt.num_points):
            if (part_label[0, point_id] == part_replace_no):
                gt_source_list0.append(points[0, point_id, :])
            else:
                gt_source_list1.append(points[0, point_id, :])

            if (part_label[1, point_id] == part_replace_no):
                gt_target_list0.append(points[1, point_id, :])
            else:
                gt_target_list1.append(points[1, point_id, :])

        # viz GT with colored part
        pcd_gt_source[0].points = Vector3dVector(gt_source_list0)
        pcd_gt_source[0].paint_uniform_color(
            [colors[5, 0], colors[5, 1], colors[5, 2]])
        pcd_gt_source[0].transform(flip_transform_gt_s)
        all_model_pcd += pcd_gt_source[0]

        pcd_gt_source[1].points = Vector3dVector(gt_source_list1)
        pcd_gt_source[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_source[1].transform(flip_transform_gt_s)
        all_model_pcd += pcd_gt_source[1]

        pcd_gt_target[0].points = Vector3dVector(gt_target_list0)
        pcd_gt_target[0].paint_uniform_color(
            [colors[6, 0], colors[6, 1], colors[6, 2]])
        pcd_gt_target[0].transform(flip_transform_gt_t)
        all_model_pcd += pcd_gt_target[0]

        pcd_gt_target[1].points = Vector3dVector(gt_target_list1)
        pcd_gt_target[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_target[1].transform(flip_transform_gt_t)
        all_model_pcd += pcd_gt_target[1]

        # viz replaced GT colored parts
        pcd_gt_replace_source[0].points = Vector3dVector(gt_target_list0)
        pcd_gt_replace_source[0].paint_uniform_color(
            [colors[6, 0], colors[6, 1], colors[6, 2]])
        pcd_gt_replace_source[0].transform(flip_transform_gt_re_s)
        all_model_pcd += pcd_gt_replace_source[0]

        pcd_gt_replace_source[1].points = Vector3dVector(gt_source_list1)
        pcd_gt_replace_source[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_replace_source[1].transform(flip_transform_gt_re_s)
        all_model_pcd += pcd_gt_replace_source[1]

        pcd_gt_replace_target[0].points = Vector3dVector(gt_source_list0)
        pcd_gt_replace_target[0].paint_uniform_color(
            [colors[5, 0], colors[5, 1], colors[5, 2]])
        pcd_gt_replace_target[0].transform(flip_transform_gt_re_t)
        all_model_pcd += pcd_gt_replace_target[0]

        pcd_gt_replace_target[1].points = Vector3dVector(gt_target_list1)
        pcd_gt_replace_target[1].paint_uniform_color([0.8, 0.8, 0.8])
        pcd_gt_replace_target[1].transform(flip_transform_gt_re_t)
        all_model_pcd += pcd_gt_replace_target[1]

        #capsule based replacement
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        cur_label_one_hot = np.zeros((2, 16), dtype=np.float32)
        for i in range(2):
            cur_label_one_hot[i, cls_label[i]] = 1
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            2, 16, opt.latent_caps_size).transpose(1, 2)

        latent_caps, expand = Variable(latent_caps), Variable(expand)
        latent_caps, expand = latent_caps.cuda(), expand.cuda()

        # predidt the part label of each capsule
        latent_caps_with_one_hot = torch.cat((latent_caps, expand), 2)
        latent_caps_with_one_hot, expand = Variable(
            latent_caps_with_one_hot), Variable(expand)
        latent_caps_with_one_hot, expand = latent_caps_with_one_hot.cuda(
        ), expand.cuda()
        latent_caps_with_one_hot = latent_caps_with_one_hot.transpose(2, 1)
        output_digit = caps_seg_net(latent_caps_with_one_hot)
        for i in range(2):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(set(np.arange(50)).difference(set(iou_oids)))
            mini = torch.min(output_digit[i, :, :])
            output_digit[i, :, non_cat_labels] = mini - 1000
        pred_choice = output_digit.data.cpu().max(2)[1]
        #
        #       saved the index of capsules which are assigned to current part
        part_no = iou_oids[part_replace_no]
        part_viz = []
        for caps_no in range(opt.latent_caps_size):
            if (pred_choice[0, caps_no] == part_no
                    and pred_choice[1, caps_no] == part_no):
                part_viz.append(caps_no)

        #replace the capsules
        latent_caps_replace = latent_caps.clone()
        latent_caps_replace = Variable(latent_caps_replace)
        latent_caps_replace = latent_caps_replace.cuda()
        for j in range(len(part_viz)):
            latent_caps_replace[0, part_viz[j], ] = latent_caps[1,
                                                                part_viz[j], ]
            latent_caps_replace[1, part_viz[j], ] = latent_caps[0,
                                                                part_viz[j], ]

        reconstructions_replace = capsule_net_decoder(latent_caps_replace)
        reconstructions_replace = reconstructions_replace.transpose(
            1, 2).data.cpu()

        for j in range(opt.latent_caps_size):
            current_patch_s = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            current_patch_t = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)

            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch_s[m, ] = reconstructions_replace[0][
                    opt.latent_caps_size * m + j, ]
                current_patch_t[m, ] = reconstructions_replace[1][
                    opt.latent_caps_size * m + j, ]
            pcd_caps_replace_source[j].points = Vector3dVector(current_patch_s)
            pcd_caps_replace_target[j].points = Vector3dVector(current_patch_t)
            part_no = iou_oids[part_replace_no]
            if (j in part_viz):
                pcd_caps_replace_source[j].paint_uniform_color(
                    [colors[6, 0], colors[6, 1], colors[6, 2]])
                pcd_caps_replace_target[j].paint_uniform_color(
                    [colors[5, 0], colors[5, 1], colors[5, 2]])
            else:
                pcd_caps_replace_source[j].paint_uniform_color([0.8, 0.8, 0.8])
                pcd_caps_replace_target[j].paint_uniform_color([0.8, 0.8, 0.8])

            pcd_caps_replace_source[j].transform(flip_transform_caps_re_s)
            pcd_caps_replace_target[j].transform(flip_transform_caps_re_t)

            all_model_pcd += pcd_caps_replace_source[j]
            all_model_pcd += pcd_caps_replace_target[j]
        draw_geometries([all_model_pcd])
def main():
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)

    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    else:
        print('pls set the model path')

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        capsule_net = torch.nn.DataParallel(capsule_net)
        capsule_net.to(device)

    if opt.dataset == 'shapenet_part':
        if opt.save_training:
            split = 'train'
        else:
            split = 'test'
        dataset = shapenet_part_loader.PartDataset(classification=True,
                                                   npoints=opt.num_points,
                                                   split=split)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core13':
        dataset = shapenet_core13_loader.ShapeNet(normal=False,
                                                  npoints=opt.num_points,
                                                  train=opt.save_training)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=4)
    elif opt.dataset == 'shapenet_core55':
        dataset = shapenet_core55_loader.Shapnet55Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)
    elif opt.dataset == 'modelnet40':
        dataset = modelnet40_loader.ModelNetH5Dataset(
            batch_size=opt.batch_size,
            npoints=opt.num_points,
            shuffle=True,
            train=opt.save_training)

# init saving process
    data_size = 0
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    out_file_path = os.path.join(dataset_main_path, opt.dataset, 'latent_caps')
    if not os.path.exists(out_file_path):
        os.makedirs(out_file_path)

    if opt.save_training:
        out_file_name = out_file_path + "/saved_train_wo_part_label.h5"
    else:
        out_file_name = out_file_path + "/saved_test_wo_part_label.h5"
    if os.path.exists(out_file_name):
        os.remove(out_file_name)
    fw = h5py.File(out_file_name, 'w', libver='latest')
    dset = fw.create_dataset("data", (
        1,
        opt.latent_caps_size,
        opt.latent_vec_size,
    ),
                             maxshape=(None, opt.latent_caps_size,
                                       opt.latent_vec_size),
                             dtype='<f4')
    dset_c = fw.create_dataset("cls_label", (1, ),
                               maxshape=(None, ),
                               dtype='uint8')
    fw.swmr_mode = True

    #  process for 'shapenet_part' or 'shapenet_core13'
    capsule_net.eval()
    if 'dataloader' in locals().keys():
        test_loss_sum = 0
        for batch_id, data in enumerate(dataloader):
            points, cls_label = data
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            # write the output latent caps and cls into file
            data_size = data_size + points.size(0)
            new_shape = (
                data_size,
                opt.latent_caps_size,
                opt.latent_vec_size,
            )
            dset.resize(new_shape)
            dset_c.resize((data_size, ))

            latent_caps_ = latent_caps.cpu().detach().numpy()
            dset[data_size - points.size(0):data_size, :, :] = latent_caps_
            dset_c[data_size -
                   points.size(0):data_size] = cls_label.squeeze().numpy()

            dset.flush()
            dset_c.flush()
            print('accumalate of batch %d, and datasize is %d ' %
                  ((batch_id), (dset.shape[0])))

        fw.close()


#  process for 'shapenet_core55' or 'modelnet40'
    else:
        while dataset.has_next_batch():
            batch_id, points_ = dataset.next_batch()
            points = torch.from_numpy(points_)
            if (points.size(0) < opt.batch_size):
                break
            points = Variable(points)
            points = points.transpose(2, 1)
            if USE_CUDA:
                points = points.cuda()
            latent_caps, reconstructions = capsule_net(points)

            data_size = data_size + points.size(0)
            new_shape = (
                data_size,
                opt.latent_caps_size,
                opt.latent_vec_size,
            )
            dset.resize(new_shape)
            dset_c.resize((data_size, ))

            latent_caps_ = latent_caps.cpu().detach().numpy()
            dset[data_size - points.size(0):data_size, :, :] = latent_caps_
            dset_c[data_size -
                   points.size(0):data_size] = cls_label.squeeze().numpy()

            dset.flush()
            dset_c.flush()
            print('accumalate of batch %d, and datasize is %d ' %
                  ((batch_id), (dset.shape[0])))
        fw.close()
Пример #17
0
def train_net_new(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Set up data loader
    pnum = 2048
    crop_point_num = 512
    workers = 1
    batchSize = 16

    class_name = "Pistol"

    train_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='train')
    train_data_loader = torch.utils.data.DataLoader(train_dataset_loader,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=int(workers))

    test_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='test')
    val_data_loader = torch.utils.data.DataLoader(test_dataset_loader,
                                                  batch_size=batchSize,
                                                  shuffle=True,
                                                  num_workers=int(workers))

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    # Create the networks
    grnet = GRNet(cfg, seg_class_no)
    grnet.apply(utils.helpers.init_weights)
    logging.debug('Parameters in GRNet: %d.' %
                  utils.helpers.count_parameters(grnet))

    # Move the network to GPU if possible
    grnet = grnet.to(device)

    # Create the optimizers
    grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              grnet.parameters()),
                                       lr=cfg.TRAIN.LEARNING_RATE,
                                       weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                       betas=cfg.TRAIN.BETAS)
    grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        grnet_optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(  # lgtm [py/unused-local-variable]
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)
    seg_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Load pretrained model if exists
    init_epoch = 0
    best_metrics = None
    if 'WEIGHTS' in cfg.CONST:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    train_seg_on_sparse = False
    train_seg_on_dense = False

    miou = 0

    # Training/Testing the network
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['SparseLoss', 'DenseLoss'])

        grnet.train()

        if epoch_idx == 5:
            train_seg_on_sparse = True

        if epoch_idx == 7:
            train_seg_on_dense = True

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (
                data,
                seg,
                model_ids,
        ) in enumerate(train_data_loader):
            data_time.update(time() - batch_end_time)

            input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(data)

            if batch_idx == 10:
                pass  #break

            data = data.to(device)
            seg = seg.to(device)

            input_cropped1 = input_cropped1.to(device)

            # remove points to make input incomplete
            choice = [
                torch.Tensor([1, 0, 0]),
                torch.Tensor([0, 0, 1]),
                torch.Tensor([1, 0, 1]),
                torch.Tensor([-1, 0, 0]),
                torch.Tensor([-1, 1, 0])
            ]
            for m in range(data.size()[0]):
                index = random.sample(choice, 1)
                p_center = index[0].to(device)
                distances = torch.sum((data[m] - p_center)**2, dim=1)
                order = torch.argsort(distances)

                zero_point = torch.FloatTensor([0, 0, 0]).to(device)
                input_cropped1.data[m, order[:crop_point_num]] = zero_point

            if save_crop_mode:
                np.save(class_name + "_orig", data[0].detach().cpu().numpy())
                np.save(class_name + "_cropped",
                        input_cropped1[0].detach().cpu().numpy())
                sys.exit()

            sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet(
                input_cropped1)

            data_seg = get_data_seg(data, full_seg)
            seg_loss = seg_criterion(torch.transpose(data_seg, 1, 2), seg)
            if train_seg_on_sparse and train_seg:
                gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                seg_loss += seg_criterion(torch.transpose(sparse_seg, 1, 2),
                                          gt_seg)
                seg_loss /= 2

            if train_seg_on_dense and train_seg:
                gt_seg = get_seg_gts(seg, data, dense_ptcloud)
                dense_seg_loss = seg_criterion(
                    torch.transpose(dense_seg, 1, 2), gt_seg)
                print(dense_seg_loss.item())

            if draw_mode:
                plot_ptcloud(data[0], seg[0], "orig")
                plot_ptcloud(input_cropped1[0], seg[0], "cropped")
                plot_ptcloud(sparse_ptcloud[0],
                             torch.argmax(sparse_seg[0], dim=1), "sparse_pred")
                if not train_seg_on_sparse:
                    gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                #plot_ptcloud(sparse_ptcloud[0], gt_seg[0], "sparse_gt")
                #if not train_seg_on_dense:
                #gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                print(dense_seg.size())
                plot_ptcloud(dense_ptcloud[0], torch.argmax(dense_seg[0],
                                                            dim=1),
                             "dense_pred")
                sys.exit()

            print(seg_loss.item())

            lamb = 0.8
            sparse_loss = chamfer_dist(sparse_ptcloud, data).to(device)
            dense_loss = chamfer_dist(dense_ptcloud, data).to(device)
            grid_loss = gridding_loss(sparse_ptcloud, data).to(device)
            if train_seg:
                _loss = lamb * (sparse_loss + dense_loss +
                                grid_loss) + (1 - lamb) * seg_loss
            else:
                _loss = (sparse_loss + dense_loss + grid_loss)
            if train_seg_on_dense and train_seg:
                _loss += (1 - lamb) * dense_seg_loss
            _loss.to(device)
            losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])

            grnet.zero_grad()
            _loss.backward()
            grnet_optimizer.step()

            n_itr = (epoch_idx - 1) * n_batches + batch_idx
            train_writer.add_scalar('Loss/Batch/Sparse',
                                    sparse_loss.item() * 1000, n_itr)
            train_writer.add_scalar('Loss/Batch/Dense',
                                    dense_loss.item() * 1000, n_itr)

            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            logging.info(
                '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))

        # Validate the current model
        if train_seg:
            miou_new = test_net_new(cfg, epoch_idx, val_data_loader,
                                    val_writer, grnet)
        else:
            miou_new = 0

        grnet_lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx)
        train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx)
        logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                     (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                      epoch_start_time, ['%.4f' % l for l in losses.avg()]))

        if not train_seg or miou_new > miou:
            file_name = class_name + 'noseg-ckpt-epoch.pth'
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save({
                'epoch_index': epoch_idx,
                'grnet': grnet.state_dict()
            }, output_path)  # yapf: disable

            logging.info('Saved checkpoint to %s ...' % output_path)
            miou = miou_new

    train_writer.close()
    val_writer.close()
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--drop',type=float,default=0.2)
parser.add_argument('--num_scales',type=int,default=3,help='number of scales')
parser.add_argument('--point_scales_list',type=list,default=[2048,1024,512],help='number of points in each scales')
parser.add_argument('--each_scales_size',type=int,default=1,help='each scales size')
parser.add_argument('--wtl2',type=float,default=0.9,help='0 means do not use else use with this weight')
parser.add_argument('--cropmethod', default = 'random_center', help = 'random|center|random_center')
opt = parser.parse_args()
print(opt)

def distance_squre1(p1,p2):
    return (p1[0]-p2[0])**2+(p1[1]-p2[1])**2+(p1[2]-p2[2])**2 


test_dset = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',classification=True, class_choice='Airplane', npoints=opt.pnum, split='test')
test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize,
                                         shuffle=False,num_workers = int(opt.workers))
length = len(test_dataloader)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
point_netG = _netG(opt.num_scales,opt.each_scales_size,opt.point_scales_list,opt.crop_point_num) 
point_netG = torch.nn.DataParallel(point_netG)
point_netG.to(device)
point_netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])   
point_netG.eval()

criterion_PointLoss = PointLoss_test().to(device)

input_cropped1 = torch.FloatTensor(opt.batchSize, 1, opt.pnum, 3)
errG_min = 100
n = 0
def main():
    blue = lambda x: '\033[94m' + x + '\033[0m'
    cat_no = {
        'Airplane': 0,
        'Bag': 1,
        'Cap': 2,
        'Car': 3,
        'Chair': 4,
        'Earphone': 5,
        'Guitar': 6,
        'Knife': 7,
        'Lamp': 8,
        'Laptop': 9,
        'Motorbike': 10,
        'Mug': 11,
        'Pistol': 12,
        'Rocket': 13,
        'Skateboard': 14,
        'Table': 15
    }

    #generate part label one-hot correspondence from the catagory:
    dataset_main_path = os.path.abspath(os.path.join(BASE_DIR,
                                                     '../../dataset'))
    oid2cpid_file_name = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/shapenet_part_overallid_to_catid_partid.json'
    )
    oid2cpid = json.load(open(oid2cpid_file_name, 'r'))
    object2setofoid = {}
    for idx in range(len(oid2cpid)):
        objid, pid = oid2cpid[idx]
        if not objid in object2setofoid.keys():
            object2setofoid[objid] = []
        object2setofoid[objid].append(idx)

    all_obj_cat_file = os.path.join(
        dataset_main_path, opt.dataset,
        'shapenetcore_partanno_segmentation_benchmark_v0/synsetoffset2category.txt'
    )
    fin = open(all_obj_cat_file, 'r')
    lines = [line.rstrip() for line in fin.readlines()]
    objcats = [line.split()[1] for line in lines]
    #    objnames = [line.split()[0] for line in lines]
    #    on2oid = {objcats[i]:i for i in range(len(objcats))}
    fin.close()

    colors = plt.cm.tab10((np.arange(10)).astype(int))
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # load the model for point cpas auto encoder
    capsule_net = PointCapsNet(opt.prim_caps_size, opt.prim_vec_size,
                               opt.latent_caps_size, opt.latent_vec_size,
                               opt.num_points)
    if opt.model != '':
        capsule_net.load_state_dict(torch.load(opt.model))
    if USE_CUDA:
        capsule_net = torch.nn.DataParallel(capsule_net).cuda()
    capsule_net = capsule_net.eval()

    # load the model for only decoding
    capsule_net_decoder = PointCapsNetDecoder(opt.prim_caps_size,
                                              opt.prim_vec_size,
                                              opt.latent_caps_size,
                                              opt.latent_vec_size,
                                              opt.num_points)
    if opt.model != '':
        capsule_net_decoder.load_state_dict(torch.load(opt.model),
                                            strict=False)
    if USE_CUDA:
        capsule_net_decoder = capsule_net_decoder.cuda()
    capsule_net_decoder = capsule_net_decoder.eval()

    # load the model for capsule wised part segmentation
    caps_seg_net = CapsSegNet(latent_caps_size=opt.latent_caps_size,
                              latent_vec_size=opt.latent_vec_size,
                              num_classes=opt.n_classes)
    if opt.part_model != '':
        caps_seg_net.load_state_dict(torch.load(opt.part_model))
    if USE_CUDA:
        caps_seg_net = caps_seg_net.cuda()
    caps_seg_net = caps_seg_net.eval()

    train_dataset = shapenet_part_loader.PartDataset(
        classification=False,
        class_choice=opt.class_choice,
        npoints=opt.num_points,
        split='test')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)

    # init the list for point cloud to save the source and target model
    pcd_list_source = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_list_source.append(pcd)
    pcd_list_target = []
    for i in range(opt.latent_caps_size):
        pcd = PointCloud()
        pcd_list_target.append(pcd)

    # init the list for point cloud to save the interpolated model
    inter_models_number = 5  # the number of interpolated models
    part_inter_no = 1  # the part that is interpolated
    pcd_list_inter = []
    for i in range(inter_models_number):
        pcd_list = []
        for j in range(opt.latent_caps_size):
            pcd = PointCloud()
            pcd_list.append(pcd)
        pcd_list_inter.append(pcd_list)


# apply a transformation in order to get a better view point
##airplane
    rotation_angle = np.pi / 2
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    flip_transforms = [[1, 0, 0, -2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transforms_r = [[1, 0, 0, 2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]
    flip_transformt = [[1, 0, 0, 2], [0, cosval, -sinval, 1.5],
                       [0, sinval, cosval, 0], [0, 0, 0, 1]]
    flip_transformt_r = [[1, 0, 0, -2], [0, 1, 0, -1.5], [0, 0, 1, 0],
                         [0, 0, 0, 1]]
    transform_range = np.arange(-4, 4.5, 2)

    # init open3d visualizer
    #    vis = Visualizer()
    #    vis.create_window()
    colors = plt.cm.tab20((np.arange(20)).astype(int))

    for batch_id, data in enumerate(train_dataloader):

        points, part_label, cls_label = data
        if not (opt.class_choice == None):
            cls_label[:] = cat_no[opt.class_choice]

        if (points.size(0) < opt.batch_size):
            break

        target = cls_label
        points_ = Variable(points)
        points_ = points_.transpose(2, 1)
        if USE_CUDA:
            points_ = points_.cuda()
        latent_caps, reconstructions = capsule_net(points_)
        reconstructions = reconstructions.transpose(1, 2).data.cpu()

        # create one hot label for the current class
        cur_label_one_hot = np.zeros((2, 16), dtype=np.float32)
        for i in range(2):
            cur_label_one_hot[i, cls_label[i]] = 1
        cur_label_one_hot = torch.from_numpy(cur_label_one_hot).float()
        expand = cur_label_one_hot.unsqueeze(2).expand(
            2, 16, opt.latent_caps_size).transpose(1, 2)

        # recontstruct from the original presaved capsule
        latent_caps, target, expand = Variable(latent_caps), Variable(
            target), Variable(expand)
        latent_caps, target, expand = latent_caps.cuda(), target.cuda(
        ), expand.cuda()

        # predidt the part label of each capsule
        latent_caps_with_one_hot = torch.cat((latent_caps, expand), 2)
        latent_caps_with_one_hot, expand = Variable(
            latent_caps_with_one_hot), Variable(expand)
        latent_caps_with_one_hot, expand = latent_caps_with_one_hot.cuda(
        ), expand.cuda()
        latent_caps_with_one_hot = latent_caps_with_one_hot.transpose(2, 1)
        output_digit = caps_seg_net(latent_caps_with_one_hot)
        for i in range(2):
            iou_oids = object2setofoid[objcats[cls_label[i]]]
            non_cat_labels = list(set(np.arange(50)).difference(set(iou_oids)))
            mini = torch.min(output_digit[i, :, :])
            output_digit[i, :, non_cat_labels] = mini - 1000
        pred_choice = output_digit.data.cpu().max(2)[1]

        # Append the capsules which is used to reconstruct the same part in two shapes
        part_no = iou_oids[
            part_inter_no]  # translate part label form [0,4] to [0,50]
        part_viz = []
        for caps_no in range(opt.latent_caps_size):
            if (pred_choice[0, caps_no] == part_no
                    and pred_choice[1, caps_no] == part_no):
                part_viz.append(caps_no)

        # add source reconstruction to open3d point cloud container to viz
        for j in range(opt.latent_caps_size):
            current_patch = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch[m, ] = reconstructions[1][opt.latent_caps_size *
                                                        m + j, ]
            pcd_list_source[j].points = Vector3dVector(current_patch)
            if (j in part_viz):
                pcd_list_source[j].paint_uniform_color(
                    [colors[6, 0], colors[6, 1], colors[6, 2]])
            else:
                pcd_list_source[j].paint_uniform_color([0.8, 0.8, 0.8])

        # add target reconstruction to open3d point cloud container to viz
        for j in range(opt.latent_caps_size):
            current_patch = torch.zeros(
                int(opt.num_points / opt.latent_caps_size), 3)
            for m in range(int(opt.num_points / opt.latent_caps_size)):
                current_patch[m, ] = reconstructions[0][opt.latent_caps_size *
                                                        m + j, ]
            pcd_list_target[j].points = Vector3dVector(current_patch)
            if (j in part_viz):
                pcd_list_target[j].paint_uniform_color(
                    [colors[5, 0], colors[5, 1], colors[5, 2]])
            else:
                pcd_list_target[j].paint_uniform_color([0.8, 0.8, 0.8])

        # interpolate the latent capsules between two shape
        latent_caps_inter = torch.zeros(inter_models_number,
                                        opt.latent_caps_size,
                                        opt.latent_vec_size)
        latent_caps_inter = Variable(latent_caps_inter)
        latent_caps_inter = latent_caps_inter.cuda()

        latent_caps_st_diff = torch.zeros(len(part_viz), opt.latent_caps_size)
        latent_caps_st_diff = latent_caps_st_diff.cuda()
        for j in range(len(part_viz)):
            latent_caps_st_diff[j, ] = latent_caps[
                0, part_viz[j], ] - latent_caps[1, part_viz[j], ]

        for i in range(inter_models_number):
            latent_caps_inter[i, ] = latent_caps[1, ]
            for j in range(len(part_viz)):
                latent_caps_inter[i, part_viz[j], ] = latent_caps[
                    1, part_viz[j], ] + latent_caps_st_diff[j, ] * i / (
                        inter_models_number - 1)

        # decode the interpolated latent capsules
        reconstructions_inter = capsule_net_decoder(latent_caps_inter)
        reconstructions_inter = reconstructions_inter.transpose(1,
                                                                2).data.cpu()

        # add interpolated latent capsule reconstruction to open3d point cloud container to viz
        for i in range(inter_models_number):
            for j in range(opt.latent_caps_size):
                current_patch = torch.zeros(
                    int(opt.num_points / opt.latent_caps_size), 3)
                for m in range(int(opt.num_points / opt.latent_caps_size)):
                    current_patch[m, ] = reconstructions_inter[i][
                        opt.latent_caps_size * m + j, ]
                pcd_list_inter[i][j].points = Vector3dVector(current_patch)
                part_no = iou_oids[part_inter_no]
                if (j in part_viz):
                    pcd_list_inter[i][j].paint_uniform_color(
                        [colors[6, 0], colors[6, 1], colors[6, 2]])
                else:
                    pcd_list_inter[i][j].paint_uniform_color([0.8, 0.8, 0.8])

        # show all interpolation
        all_point = PointCloud()
        for j in range(opt.latent_caps_size):
            pcd_list_source[j].transform(flip_transforms)
            pcd_list_target[j].transform(flip_transformt)
            all_point += pcd_list_source[j]
            all_point += pcd_list_target[j]

        for r in range(inter_models_number):
            flip_transform_inter = [[1, 0, 0, transform_range[r]],
                                    [0, cosval, -sinval, -2],
                                    [0, sinval, cosval, 0], [0, 0, 0, 1]]
            for k in range(opt.latent_caps_size):
                pcd_list_inter[r][k].transform(flip_transform_inter)
                all_point += pcd_list_inter[r][k]
        draw_geometries([all_point])