コード例 #1
0
ファイル: test_fig.py プロジェクト: DreamBlack/APCNet
def get_result_for_pfNet(incomplete, netpath=""):
    point_netG = _netG(3, 1, [1024, 512, 256], 256)
    point_netG = torch.nn.DataParallel(point_netG)
    point_netG.to(device)
    if netpath == "":
        netpath = pfnet_path
    else:
        netpath = netpath
    point_netG.load_state_dict(torch.load(
        netpath, map_location=lambda storage, location: storage)['state_dict'],
                               strict=False)
    point_netG.eval()

    ############################
    # (1) data prepare
    ###########################
    input_cropped1 = torch.squeeze(incomplete, 1)
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                     512,
                                                     RAN=True)
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
    input_cropped3_idx = utils.farthest_point_sample(input_cropped1,
                                                     256,
                                                     RAN=False)
    input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
    input_cropped1 = Variable(input_cropped1, requires_grad=False)
    input_cropped2 = Variable(input_cropped2, requires_grad=False)
    input_cropped3 = Variable(input_cropped3, requires_grad=False)
    input_cropped2 = input_cropped2.to(device)
    input_cropped3 = input_cropped3.to(device)
    input_cropped = [input_cropped1, input_cropped2, input_cropped3]
    fake_center1, fake_center2, fake = point_netG(input_cropped)
    return fake.cuda().data.cpu().squeeze(0).numpy()
コード例 #2
0
        
    crop_num_list = []
    for num_ in range(opt.pnum-opt.crop_point_num):
        crop_num_list.append(distance_order[num_+opt.crop_point_num][0])
    indices = torch.LongTensor(crop_num_list)
    input_cropped_partial[0,0]=torch.index_select(real_point[0,0],0,indices)
    input_cropped_partial = torch.squeeze(input_cropped_partial,1)
    input_cropped_partial = input_cropped_partial.to(device)
     
    real_center = torch.squeeze(real_center,1)
#    real_center_key_idx = utils.farthest_point_sample(real_center,64,train = False)
#    real_center_key = utils.index_points(real_center,real_center_key_idx)
#    input_cropped1 = input_cropped1.to(device)
    
    input_cropped1 = torch.squeeze(input_cropped1,1)
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[1],RAN = True)
    input_cropped2     = utils.index_points(input_cropped1,input_cropped2_idx)
    input_cropped3_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[2],RAN = False)
    input_cropped3     = utils.index_points(input_cropped1,input_cropped3_idx)
    input_cropped2 = input_cropped2.to(device)
    input_cropped3 = input_cropped3.to(device)      
    input_cropped  = [input_cropped1,input_cropped2,input_cropped3]
    
#    fake,fake_part = point_netG(input_cropped)
    fake_center1,fake_center2,fake=point_netG(input_cropped)
    fake_whole = torch.cat((input_cropped_partial,fake),1)
    fake_whole = fake_whole.to(device)
    real_point = real_point.to(device)
    real_center = real_center.to(device)
    dist_all, dist1, dist2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))#+0.1*criterion_PointLoss(torch.squeeze(fake_part,1),torch.squeeze(real_center,1))
    dist_all=dist_all.cpu().detach().numpy()
コード例 #3
0
                                     sp] = real_point[m, 0,
                                                      distance_order[sp][0]]
#                print(real_center.size(),input_cropped1.size())
#label.data.resize_([batch_size,1]).fill_(real_label)
        real_point = real_point.to(device)
        real_center = real_center.to(device)
        input_cropped1 = input_cropped1.to(device)
        ############################
        # (1) data prepare
        ###########################
        real_center = Variable(real_center, requires_grad=True)
        real_center = torch.squeeze(real_center, 1)

        input_cropped1 = torch.squeeze(input_cropped1, 1)
        input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                         1024,
                                                         RAN=True)
        input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
        input_cropped3_idx = utils.farthest_point_sample(input_cropped1,
                                                         512,
                                                         RAN=False)
        input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
        input_cropped1 = Variable(input_cropped1, requires_grad=True)
        input_cropped2 = Variable(input_cropped2, requires_grad=True)
        input_cropped3 = Variable(input_cropped3, requires_grad=True)
        input_cropped2 = input_cropped2.to(device)
        input_cropped3 = input_cropped3.to(device)
        input_cropped = [input_cropped1, input_cropped2, input_cropped3]
        point_netG = point_netG.train()
        point_netG.zero_grad()
        fake = point_netG(input_cropped)
コード例 #4
0
            gt = gt.to(device)
            image = image.to(device)

            incomplete = Variable(incomplete, requires_grad=True).cuda()
            image = Variable(image.float(), requires_grad=True).cuda()
            image = torch.squeeze(image, 1)
            label.resize_([batch_size, 1]).fill_(real_label)
            label = label.to(device)

            ############################
            # (1) data prepare
            ###########################
            real_center = Variable(gt, requires_grad=True)
            real_center = torch.squeeze(real_center, 1)
            real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                               64,
                                                               RAN=False)
            real_center_key1 = utils.index_points(real_center,
                                                  real_center_key1_idx)
            real_center_key1 = Variable(real_center_key1, requires_grad=True)

            real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                               128,
                                                               RAN=True)
            real_center_key2 = utils.index_points(real_center,
                                                  real_center_key2_idx)
            real_center_key2 = Variable(real_center_key2, requires_grad=True)

            input_cropped1 = torch.squeeze(incomplete, 1)
            input_cropped2_idx = utils.farthest_point_sample(
                input_cropped1, opt.point_scales_list[1], RAN=True)
コード例 #5
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
point_netG = _netG(opt.num_scales,opt.each_scales_size,opt.point_scales_list,opt.crop_point_num) 
point_netG = torch.nn.DataParallel(point_netG)
point_netG.to(device)
point_netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])   
point_netG.eval()

input_cropped1 = np.loadtxt(opt.infile,delimiter=',')
input_cropped1 = torch.FloatTensor(input_cropped1)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)
Zeros = torch.zeros(1,512,3)
input_cropped1 = torch.cat((input_cropped1,Zeros),1)


input_cropped2_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[1],RAN = True)
input_cropped2     = utils.index_points(input_cropped1,input_cropped2_idx)
input_cropped3_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[2],RAN = False)
input_cropped3     = utils.index_points(input_cropped1,input_cropped3_idx)
input_cropped4_idx = utils.farthest_point_sample(input_cropped1,256,RAN = True)
input_cropped4     = utils.index_points(input_cropped1,input_cropped4_idx)
input_cropped2 = input_cropped2.to(device)
input_cropped3 = input_cropped3.to(device)      
input_cropped  = [input_cropped1,input_cropped2,input_cropped3]

fake_center1,fake_center2,fake=point_netG(input_cropped)
fake = fake.cuda()
fake_center1 = fake_center1.cuda()
fake_center2 = fake_center2.cuda()

コード例 #6
0
                   opt.crop_point_num)
point_netG = torch.nn.DataParallel(point_netG)
point_netG.to(device)
point_netG.load_state_dict(
    torch.load(opt.netG,
               map_location=lambda storage, location: storage)['state_dict'])
point_netG.eval()

input_cropped1 = np.loadtxt(opt.infile, delimiter=',')
input_cropped1 = torch.FloatTensor(input_cropped1)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)
Zeros = torch.zeros(1, 512, 3)
input_cropped1 = torch.cat((input_cropped1, Zeros), 1)

input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                 1024,
                                                 RAN=True)
input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
input_cropped3_idx = utils.farthest_point_sample(input_cropped1,
                                                 512,
                                                 RAN=False)
input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
input_cropped4_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=True)
input_cropped4 = utils.index_points(input_cropped1, input_cropped4_idx)
input_cropped2 = input_cropped2.to(device)
input_cropped3 = input_cropped3.to(device)
input_cropped = [input_cropped1, input_cropped2, input_cropped3]

fake_center1, fake_center2, fake = point_netG(input_cropped)
fake = fake.cuda()
fake_center1 = fake_center1.cuda()
コード例 #7
0
    gt = gt.to(device)
    image = image.to(device)
    complete_gt = torch.cat([gt, incomplete], dim=1)
    complete_gt=complete_gt.to(device)
    incomplete = incomplete.to(device)
    image = image.to(device)

    incomplete = Variable(incomplete, requires_grad=False)
    image = Variable(image.float(), requires_grad=False)

    ############################
    # (1) data prepare
    ###########################
    real_center = Variable(gt, requires_grad=False)
    real_center = torch.squeeze(real_center, 1)
    real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False)
    real_center_key1 = utils.index_points(real_center, real_center_key1_idx)
    real_center_key1 = Variable(real_center_key1, requires_grad=False)

    real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True)
    real_center_key2 = utils.index_points(real_center, real_center_key2_idx)
    real_center_key2 = Variable(real_center_key2, requires_grad=False)

    input_cropped1 = torch.squeeze(incomplete, 1)
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1, 512, RAN=True)
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
    input_cropped3_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=False)
    input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
    input_cropped1 = Variable(input_cropped1, requires_grad=False)
    input_cropped2 = Variable(input_cropped2, requires_grad=False)
    input_cropped3 = Variable(input_cropped3, requires_grad=False)
コード例 #8
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PCN = Autoencoder(opt.num_inputs, opt.num_coarses, opt.num_fines,
                  opt.grid_size)
PCN.load_state_dict(
    torch.load(opt.model,
               map_location=lambda storage, location: storage)['state_dict'])
print("Let's use", torch.cuda.device_count(), "GPUs!")
PCN.to(device)
PCN = torch.nn.DataParallel(PCN)
PCN.eval()

input_cropped1 = np.loadtxt(opt.infile, delimiter=',')
input_cropped1 = torch.FloatTensor(input_cropped1)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)
input_cropped1 = input_cropped1.to(device)
input_key_cropped_index = farthest_point_sample(input_cropped1,
                                                opt.num_inputs,
                                                RAN=False)
input_key_cropped = index_points(input_cropped1,
                                 input_key_cropped_index)  #BX1024X3

coarses, fine = PCN(input_key_cropped)
fine = fine.cpu()
np_fake = fine[0].detach().numpy()
input_cropped1 = input_cropped1.cpu()
np_crop = input_cropped1[0].numpy()

np.savetxt('test_from_prn_to_PCN/crop_PCN' + '.csv', np_crop, fmt="%f,%f,%f")
np.savetxt('test_from_prn_to_PCN/fake_PCN' + '.csv', np_fake, fmt="%f,%f,%f")
コード例 #9
0
ファイル: ModelNet40Loader.py プロジェクト: Wangzs111/PF-NET
    transforms = transforms.Compose(
        [
            d_utils.PointcloudToTensor(),
#            d_utils.PointcloudRotate(axis=np.array([1, 0, 0])),
#            d_utils.PointcloudScale(),
#            d_utils.PointcloudTranslate(),
#            d_utils.PointcloudJitter(),
        ]
    )
    dset = ModelNet40Cls(1024, train=True, transforms=transforms)   
    print(dset[0][1])
    print(len(dset))
    dloader = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=True)

    for i, Data in enumerate(dloader, 0):
        real_point, target = Data
        print('1')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        real_point = real_point.to(device)
        real_point2_idx = utils.farthest_point_sample(real_point,512)
        real_point2 = utils.index_points(real_point,real_point2_idx)
        real_point3_idx = utils.farthest_point_sample(real_point,256)
        real_point3 = utils.index_points(real_point,real_point3_idx)
        
#        model1 = real_point[0].numpy()
#        model2 = real_point2[0].numpy()
#        model3 = real_point3[0].numpy()
#        
#        np.savetxt('test-examples/model'+'.txt', model1, fmt = "%f %f %f")
#        np.savetxt('test-examples/model2'+'.txt', model2, fmt = "%f %f %f")
#        np.savetxt('test-examples/model3'+'.txt', model3, fmt = "%f %f %f")
コード例 #10
0
                    distance_squre(real_point[m, 0, n], p_center))
            distance_order = sorted(enumerate(distance_list),
                                    key=lambda x: x[1])

            crop_num_list = []
            for num in range(opt.num_fines - opt.crop_point_num):
                crop_num_list.append(distance_order[num +
                                                    opt.crop_point_num][0])
            indices = torch.LongTensor(crop_num_list)
            input_cropped[m, 0] = torch.index_select(real_point[m, 0], 0,
                                                     indices)

        real_point = torch.squeeze(real_point, 1)  #BX2048X3
        input_cropped = torch.squeeze(input_cropped, 1)  #BX(2048-512)X3

        real_key_point_index = farthest_point_sample(real_point,
                                                     opt.num_coarses, False)
        real_key_point = index_points(real_point,
                                      real_key_point_index)  #BX1024X3

        input_key_cropped_index = farthest_point_sample(
            input_cropped, opt.num_inputs, False)
        input_key_cropped = index_points(input_cropped,
                                         input_key_cropped_index)  #BX1024X3

        real_point = Variable(real_point, requires_grad=True)
        real_key_point = Variable(real_key_point, requires_grad=True)
        input_key_cropped = Variable(input_key_cropped, requires_grad=True)

        real_point = real_point.to(device)
        real_key_point = real_key_point.to(device)
        input_key_cropped = input_key_cropped.to(device)
コード例 #11
0
ファイル: Train_PFNet.py プロジェクト: HeunSeungLim/PF_HL
def run():
    print(opt)

    blue = lambda x: '\033[94m' + x + '\033[0m'
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    point_netG = _netG(opt.num_scales, opt.each_scales_size,
                       opt.point_scales_list, opt.crop_point_num)
    point_netD = _netlocalD(opt.crop_point_num)
    cudnn.benchmark = True
    resume_epoch = 0

    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("Conv1d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm1d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        point_netG = torch.nn.DataParallel(point_netG)
        point_netD = torch.nn.DataParallel(point_netD)
        point_netG.to(device)
        point_netG.apply(weights_init_normal)
        point_netD.to(device)
        point_netD.apply(weights_init_normal)
    if opt.netG != '':
        point_netG.load_state_dict(
            torch.load(
                opt.netG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netG)['epoch']
    if opt.netD != '':
        point_netD.load_state_dict(
            torch.load(
                opt.netD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netD)['epoch']

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)


# def run():

# transforms = transforms.Compose([d_utils.PointcloudToTensor(),])
    dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='train')
    assert dset

    dataloader = torch.utils.data.DataLoader(dset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    test_dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='test')
    test_dataloader = torch.utils.data.DataLoader(test_dset,
                                                  batch_size=opt.batchSize,
                                                  shuffle=True,
                                                  num_workers=int(opt.workers))

    #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False)
    #assert dset
    #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))
    #
    #
    #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False)
    #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))

    #pointcls_net.apply(weights_init)
    print(point_netG)
    print(point_netD)

    criterion = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_PointLoss = PointLoss().to(device)

    # setup optimizer
    optimizerD = torch.optim.Adam(point_netD.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    optimizerG = torch.optim.Adam(point_netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD,
                                                 step_size=40,
                                                 gamma=0.2)
    schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG,
                                                 step_size=40,
                                                 gamma=0.2)

    real_label = 1
    fake_label = 0

    crop_point_num = int(opt.crop_point_num)
    input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3)
    label = torch.FloatTensor(opt.batchSize)

    num_batch = len(dset) / opt.batchSize

    ###########################
    #  G-NET and T-NET
    ##########################djel
    if opt.D_choose == 1:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            if __name__ == '__main__':
                # On Windows calling this function is necessary.
                freeze_support()

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    #Set viewpoints
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(
                            choice, 1)  #Random choose one of the viewpoint
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]

                # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                label.resize_([batch_size, 1]).fill_(real_label)
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                label = label.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netD = point_netD.train()
                ############################
                # (2) Update D network
                ###########################
                point_netD.zero_grad()
                real_center = torch.unsqueeze(real_center, 1)
                output = point_netD(real_center)
                errD_real = criterion(output, label)
                errD_real.backward()

                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                label.data.fill_(fake_label)
                output = point_netD(fake.detach())
                errD_fake = criterion(output, label)
                # on(output, label)
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################
                point_netG.zero_grad()
                label.data.fill_(real_label)
                output = point_netD(fake)
                errG_D = criterion(output, label)
                errG_l2 = 0

                # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy()
                np.savetxt('./post_set' + '.txt',
                           torch.cat([
                               input_cropped1[0, :, :],
                               torch.squeeze(fake, 1)[0, :, :]
                           ],
                                     dim=0).cpu().detach().numpy(),
                           fmt="%f,%f,%f")

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2
                errG.backward()
                optimizerG.step()

                np.savetxt('./fake' + '.txt',
                           torch.squeeze(fake,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_A
                np.savetxt('./real_center' + '.txt',
                           torch.squeeze(real_center,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_B

                print(
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' +
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))

                if i % 10 == 0:
                    print('After, ', i, '-th batch')
                    f.write('\n' + 'After, ' + str(i) + '-th batch')
                    for i, data in enumerate(test_dataloader, 0):
                        real_point, target = data

                        batch_size = real_point.size()[0]
                        real_center = torch.FloatTensor(
                            batch_size, 1, opt.crop_point_num, 3)
                        input_cropped1 = torch.FloatTensor(
                            batch_size, opt.pnum, 3)
                        input_cropped1 = input_cropped1.data.copy_(real_point)
                        real_point = torch.unsqueeze(real_point, 1)
                        input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                        p_origin = [0, 0, 0]

                        if opt.cropmethod == 'random_center':
                            choice = [
                                torch.Tensor([1, 0, 0]),
                                torch.Tensor([0, 0, 1]),
                                torch.Tensor([1, 0, 1]),
                                torch.Tensor([-1, 0, 0]),
                                torch.Tensor([-1, 1, 0])
                            ]

                            for m in range(batch_size):
                                index = random.sample(choice, 1)
                                distance_list = []
                                p_center = index[0]
                                for n in range(opt.pnum):
                                    distance_list.append(
                                        distance_squre(real_point[m, 0, n],
                                                       p_center))
                                distance_order = sorted(
                                    enumerate(distance_list),
                                    key=lambda x: x[1])
                                for sp in range(opt.crop_point_num):
                                    input_cropped1.data[
                                        m, 0, distance_order[sp]
                                        [0]] = torch.FloatTensor([0, 0, 0])
                                    real_center.data[m, 0, sp] = real_point[
                                        m, 0, distance_order[sp][0]]
                        real_center = real_center.to(device)
                        real_center = torch.squeeze(real_center, 1)
                        input_cropped1 = input_cropped1.to(device)
                        input_cropped1 = torch.squeeze(input_cropped1, 1)
                        input_cropped2_idx = utils.farthest_point_sample(
                            input_cropped1, opt.point_scales_list[1], RAN=True)
                        input_cropped2 = utils.index_points(
                            input_cropped1, input_cropped2_idx)
                        input_cropped3_idx = utils.farthest_point_sample(
                            input_cropped1,
                            opt.point_scales_list[2],
                            RAN=False)
                        input_cropped3 = utils.index_points(
                            input_cropped1, input_cropped3_idx)
                        input_cropped1 = Variable(input_cropped1,
                                                  requires_grad=False)
                        input_cropped2 = Variable(input_cropped2,
                                                  requires_grad=False)
                        input_cropped3 = Variable(input_cropped3,
                                                  requires_grad=False)
                        input_cropped2 = input_cropped2.to(device)
                        input_cropped3 = input_cropped3.to(device)
                        input_cropped = [
                            input_cropped1, input_cropped2, input_cropped3
                        ]
                        point_netG.eval()
                        fake_center1, fake_center2, fake = point_netG(
                            input_cropped)
                        CD_loss = criterion_PointLoss(
                            torch.squeeze(fake, 1),
                            torch.squeeze(real_center, 1))
                        print('test result:', CD_loss)
                        f.write('\n' + 'test result:  %.4f' % (CD_loss))
                        break
                f.close()
            schedulerD.step()
            schedulerG.step()
            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Trained_Model/point_netG' + str(epoch) + '.pth')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netD.state_dict()
                    }, 'Trained_Model/point_netD' + str(epoch) + '.pth')

    #
    #############################
    ## ONLY G-NET
    ############################
    else:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)
                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(choice, 1)
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netG.zero_grad()
                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG_l2.backward()
                optimizerG.step()
                print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                      (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                    (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f.close()
            schedulerD.step()
            schedulerG.step()

            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Checkpoint/point_netG' + str(epoch) + '.pth')