real_center.data[m, 0, sp] = real_point[
                            m, 0, distance_order[sp][0]]
            label.resize_([batch_size, 1]).fill_(real_label)
            real_point = real_point.to(device)
            real_center = real_center.to(device)
            input_cropped1 = input_cropped1.to(device)
            label = label.to(device)
            ############################
            # (1) data prepare
            ###########################
            real_center = Variable(real_center, requires_grad=True)
            real_center = torch.squeeze(real_center, 1)
            real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                               64,
                                                               RAN=False)
            real_center_key1 = utils.index_points(real_center,
                                                  real_center_key1_idx)
            real_center_key1 = Variable(real_center_key1, requires_grad=True)

            real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                               128,
                                                               RAN=True)
            real_center_key2 = utils.index_points(real_center,
                                                  real_center_key2_idx)
            real_center_key2 = Variable(real_center_key2, requires_grad=True)

            input_cropped1 = torch.squeeze(input_cropped1, 1)
            input_cropped2_idx = utils.farthest_point_sample(
                input_cropped1, opt.point_scales_list[1], RAN=True)
            input_cropped2 = utils.index_points(input_cropped1,
                                                input_cropped2_idx)
            input_cropped3_idx = utils.farthest_point_sample(
예제 #2
0
        distance_list.append(distance_squre(real_point[0, 0, num], p_center))
    distance_order = sorted(enumerate(distance_list), key=lambda x: x[1])

    for sp in range(opt.crop_point_num):
        input_cropped1.data[0, 0, distance_order[sp][0]] = torch.FloatTensor(
            [0, 0, 0])
        real_center.data[0, 0, sp] = real_point[0, 0, distance_order[sp][0]]

    real_center.to(device)
    real_center = torch.squeeze(real_center, 1)

    input_cropped1 = torch.squeeze(input_cropped1, 1)
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                     opt.point_scales_list[1],
                                                     RAN=True)
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
    input_cropped3_idx = utils.farthest_point_sample(input_cropped1,
                                                     opt.point_scales_list[2],
                                                     RAN=False)
    input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
    input_cropped1 = input_cropped1.to(device)
    input_cropped2 = input_cropped2.to(device)
    input_cropped3 = input_cropped3.to(device)
    input_cropped = [input_cropped1, input_cropped2, input_cropped3]

    fake_center1, fake_center2, fake = point_netG(input_cropped)
    fake = fake.cuda()
    real_center = real_center.cuda()
    real_center = real_center.cuda()
    errG = criterion_PointLoss(
        torch.squeeze(fake, 1), torch.squeeze(real_center, 1)
예제 #3
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PCN = Autoencoder(opt.num_inputs, opt.num_coarses, opt.num_fines,
                  opt.grid_size)
PCN.load_state_dict(
    torch.load(opt.model,
               map_location=lambda storage, location: storage)['state_dict'])
print("Let's use", torch.cuda.device_count(), "GPUs!")
PCN.to(device)
PCN = torch.nn.DataParallel(PCN)
PCN.eval()

input_cropped1 = np.loadtxt(opt.infile, delimiter=',')
input_cropped1 = torch.FloatTensor(input_cropped1)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)
input_cropped1 = input_cropped1.to(device)
input_key_cropped_index = farthest_point_sample(input_cropped1,
                                                opt.num_inputs,
                                                RAN=False)
input_key_cropped = index_points(input_cropped1,
                                 input_key_cropped_index)  #BX1024X3

coarses, fine = PCN(input_key_cropped)
fine = fine.cpu()
np_fake = fine[0].detach().numpy()
input_cropped1 = input_cropped1.cpu()
np_crop = input_cropped1[0].numpy()

np.savetxt('test_from_prn_to_PCN/crop_PCN' + '.csv', np_crop, fmt="%f,%f,%f")
np.savetxt('test_from_prn_to_PCN/fake_PCN' + '.csv', np_fake, fmt="%f,%f,%f")
예제 #4
0
    transforms = transforms.Compose(
        [
            d_utils.PointcloudToTensor(),
#            d_utils.PointcloudRotate(axis=np.array([1, 0, 0])),
#            d_utils.PointcloudScale(),
#            d_utils.PointcloudTranslate(),
#            d_utils.PointcloudJitter(),
        ]
    )
    dset = ModelNet40Cls(1024, train=True, transforms=transforms)   
    print(dset[0][1])
    print(len(dset))
    dloader = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=True)

    for i, Data in enumerate(dloader, 0):
        real_point, target = Data
        print('1')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        real_point = real_point.to(device)
        real_point2_idx = utils.farthest_point_sample(real_point,512)
        real_point2 = utils.index_points(real_point,real_point2_idx)
        real_point3_idx = utils.farthest_point_sample(real_point,256)
        real_point3 = utils.index_points(real_point,real_point3_idx)
        
#        model1 = real_point[0].numpy()
#        model2 = real_point2[0].numpy()
#        model3 = real_point3[0].numpy()
#        
#        np.savetxt('test-examples/model'+'.txt', model1, fmt = "%f %f %f")
#        np.savetxt('test-examples/model2'+'.txt', model2, fmt = "%f %f %f")
#        np.savetxt('test-examples/model3'+'.txt', model3, fmt = "%f %f %f")
예제 #5
0
        input_cropped1 = torch.unsqueeze(
            input_cropped1, 1)  # input_cropped1.shape = [24, 1, 2024, 3]
        p_origin = [0, 0, 0]
        label.resize_([batch_size, 1]).fill_(real_label)
        if real_point.size()[0] < opt.batchSize: continue
        real_point = real_point.to(
            device)  # real_point.shape = [24, 1, 2048, 3]
        real_center = real_center.to(
            device)  # real_center.shape = [24, 1, 512, 3]
        input_cropped1 = input_cropped1.to(
            device)  # input_cropped1.shape = [24, 1, 2048, 3]
        label = label.to(device)  # real label construction done

        # obtain data for the two channels
        real_center = torch.squeeze(real_center, 1)  # [24, 512, 3]
        real_center_key1 = utils.index_points(real_center,
                                              real_center_key1_idx)

        input_cropped1 = torch.squeeze(input_cropped1, 1)

        input_cropped = [input_cropped1, input_cropped2
                         ]  # make sure if inputs are 2048 and 512
        gen_net = gen_net.train()
        dis_net = dis_net.train()

        # update discriminator
        dis_net.zero_grad()
        real_center = torch.unsqueeze(real_center, 1)
        print('real center shape', real_center.shape)
        real_out = dis_net(real_center)
        #print('real label shape', label.shape)
        dis_err_real = criterion(real_out, label)
예제 #6
0
                                    key=lambda x: x[1])

            crop_num_list = []
            for num in range(opt.num_fines - opt.crop_point_num):
                crop_num_list.append(distance_order[num +
                                                    opt.crop_point_num][0])
            indices = torch.LongTensor(crop_num_list)
            input_cropped[m, 0] = torch.index_select(real_point[m, 0], 0,
                                                     indices)

        real_point = torch.squeeze(real_point, 1)  #BX2048X3
        input_cropped = torch.squeeze(input_cropped, 1)  #BX(2048-512)X3

        real_key_point_index = farthest_point_sample(real_point,
                                                     opt.num_coarses, False)
        real_key_point = index_points(real_point,
                                      real_key_point_index)  #BX1024X3

        input_key_cropped_index = farthest_point_sample(
            input_cropped, opt.num_inputs, False)
        input_key_cropped = index_points(input_cropped,
                                         input_key_cropped_index)  #BX1024X3

        real_point = Variable(real_point, requires_grad=True)
        real_key_point = Variable(real_key_point, requires_grad=True)
        input_key_cropped = Variable(input_key_cropped, requires_grad=True)

        real_point = real_point.to(device)
        real_key_point = real_key_point.to(device)
        input_key_cropped = input_key_cropped.to(device)
        optimizer.zero_grad()
예제 #7
0
            real_point = real_point.to(device)  # [B, 2048, 3]
            real_center = real_center.to(device)  # [B, 512, 3]
            input_cropped1 = input_cropped1.to(
                device
            )  # [B, 2048, 3]: PFNet model needs 2048 points to work..
            label = label.to(device)  # [32, 1]

            ###########################
            # (1) data prepare
            ###########################
            real_center = Variable(real_center, requires_grad=True)
            # real_center = torch.squeeze(real_center, 1)  # [32, 512, 3] - Alli: done before
            real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                               64,
                                                               RAN=False)
            real_center_key1 = utils.index_points(
                real_center, real_center_key1_idx)  # [32, 64, 3]
            real_center_key1 = Variable(real_center_key1, requires_grad=True)

            real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                               128,
                                                               RAN=True)
            real_center_key2 = utils.index_points(real_center,
                                                  real_center_key2_idx)
            real_center_key2 = Variable(real_center_key2, requires_grad=True)

            # input_cropped1 = torch.squeeze(input_cropped1, 1)  # [32, 2048, 3] - Alli: done before
            input_cropped2_idx = utils.farthest_point_sample(
                input_cropped1, opt.point_scales_list[1], RAN=True)
            input_cropped2 = utils.index_points(
                input_cropped1, input_cropped2_idx)  # [32, 1024, 3]
            input_cropped3_idx = utils.farthest_point_sample(
point_netG.to(device)
point_netG.load_state_dict(
    torch.load(opt.netG,
               map_location=lambda storage, location: storage)['state_dict'])
point_netG.eval()

input_cropped1 = np.loadtxt(opt.infile, delimiter=',')
input_cropped1 = torch.FloatTensor(input_cropped1)
input_cropped1 = torch.unsqueeze(input_cropped1, 0)
Zeros = torch.zeros(1, 512, 3)
input_cropped1 = torch.cat((input_cropped1, Zeros), 1)

input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                 1024,
                                                 RAN=True)
input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
input_cropped3_idx = utils.farthest_point_sample(input_cropped1,
                                                 512,
                                                 RAN=False)
input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx)
input_cropped4_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=True)
input_cropped4 = utils.index_points(input_cropped1, input_cropped4_idx)
input_cropped2 = input_cropped2.to(device)
input_cropped3 = input_cropped3.to(device)
input_cropped = [input_cropped1, input_cropped2, input_cropped3]

fake_center1, fake_center2, fake = point_netG(input_cropped)
fake = fake.cuda()
fake_center1 = fake_center1.cuda()
fake_center2 = fake_center2.cuda()
예제 #9
0
print(opt)

# collate all car files into car_csv
#car_ids = [filename.split('.')[0] for filename in glob.glob('test_files/*.pcd')]
car_ids = ['test_files/car']
print(car_ids)
total_points = 0
for i, car_id in enumerate(car_ids):
    #test_npy = np.load(os.path.join('test_files/', car_id), allow_pickle = True)
    print(car_ids)
    #input_cropped1 = np.loadtxt(os.path.join('car_csv/', car_id.split('.')[0].split('/')[1]+'.csv'),delimiter=',')
    input_cropped1 = partial
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                     opt.point_scales_list[1],
                                                     RAN=True)
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)

    fake_center1, fake_fine = gen_net(input_cropped1)

    fake_fine = fake_fine.cuda()
    fake_center1 = fake_center1.cuda()

    input_cropped2 = input_cropped2.cpu()

    real = torch.unsqueeze(real, 0)
    real2_idx = utils.farthest_point_sample(real, 128, RAN=False)
    real2 = utils.index_points(real, real2_idx)

    real2 = real2.cpu()

    np_real2 = real2[0].detach().numpy()
예제 #10
0
def run():
    print(opt)

    blue = lambda x: '\033[94m' + x + '\033[0m'
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    point_netG = _netG(opt.num_scales, opt.each_scales_size,
                       opt.point_scales_list, opt.crop_point_num)
    point_netD = _netlocalD(opt.crop_point_num)
    cudnn.benchmark = True
    resume_epoch = 0

    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("Conv1d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm1d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        point_netG = torch.nn.DataParallel(point_netG)
        point_netD = torch.nn.DataParallel(point_netD)
        point_netG.to(device)
        point_netG.apply(weights_init_normal)
        point_netD.to(device)
        point_netD.apply(weights_init_normal)
    if opt.netG != '':
        point_netG.load_state_dict(
            torch.load(
                opt.netG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netG)['epoch']
    if opt.netD != '':
        point_netD.load_state_dict(
            torch.load(
                opt.netD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netD)['epoch']

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)


# def run():

# transforms = transforms.Compose([d_utils.PointcloudToTensor(),])
    dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='train')
    assert dset

    dataloader = torch.utils.data.DataLoader(dset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    test_dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='test')
    test_dataloader = torch.utils.data.DataLoader(test_dset,
                                                  batch_size=opt.batchSize,
                                                  shuffle=True,
                                                  num_workers=int(opt.workers))

    #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False)
    #assert dset
    #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))
    #
    #
    #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False)
    #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))

    #pointcls_net.apply(weights_init)
    print(point_netG)
    print(point_netD)

    criterion = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_PointLoss = PointLoss().to(device)

    # setup optimizer
    optimizerD = torch.optim.Adam(point_netD.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    optimizerG = torch.optim.Adam(point_netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD,
                                                 step_size=40,
                                                 gamma=0.2)
    schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG,
                                                 step_size=40,
                                                 gamma=0.2)

    real_label = 1
    fake_label = 0

    crop_point_num = int(opt.crop_point_num)
    input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3)
    label = torch.FloatTensor(opt.batchSize)

    num_batch = len(dset) / opt.batchSize

    ###########################
    #  G-NET and T-NET
    ##########################djel
    if opt.D_choose == 1:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            if __name__ == '__main__':
                # On Windows calling this function is necessary.
                freeze_support()

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    #Set viewpoints
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(
                            choice, 1)  #Random choose one of the viewpoint
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]

                # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                label.resize_([batch_size, 1]).fill_(real_label)
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                label = label.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netD = point_netD.train()
                ############################
                # (2) Update D network
                ###########################
                point_netD.zero_grad()
                real_center = torch.unsqueeze(real_center, 1)
                output = point_netD(real_center)
                errD_real = criterion(output, label)
                errD_real.backward()

                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                label.data.fill_(fake_label)
                output = point_netD(fake.detach())
                errD_fake = criterion(output, label)
                # on(output, label)
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################
                point_netG.zero_grad()
                label.data.fill_(real_label)
                output = point_netD(fake)
                errG_D = criterion(output, label)
                errG_l2 = 0

                # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy()
                np.savetxt('./post_set' + '.txt',
                           torch.cat([
                               input_cropped1[0, :, :],
                               torch.squeeze(fake, 1)[0, :, :]
                           ],
                                     dim=0).cpu().detach().numpy(),
                           fmt="%f,%f,%f")

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2
                errG.backward()
                optimizerG.step()

                np.savetxt('./fake' + '.txt',
                           torch.squeeze(fake,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_A
                np.savetxt('./real_center' + '.txt',
                           torch.squeeze(real_center,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_B

                print(
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' +
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))

                if i % 10 == 0:
                    print('After, ', i, '-th batch')
                    f.write('\n' + 'After, ' + str(i) + '-th batch')
                    for i, data in enumerate(test_dataloader, 0):
                        real_point, target = data

                        batch_size = real_point.size()[0]
                        real_center = torch.FloatTensor(
                            batch_size, 1, opt.crop_point_num, 3)
                        input_cropped1 = torch.FloatTensor(
                            batch_size, opt.pnum, 3)
                        input_cropped1 = input_cropped1.data.copy_(real_point)
                        real_point = torch.unsqueeze(real_point, 1)
                        input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                        p_origin = [0, 0, 0]

                        if opt.cropmethod == 'random_center':
                            choice = [
                                torch.Tensor([1, 0, 0]),
                                torch.Tensor([0, 0, 1]),
                                torch.Tensor([1, 0, 1]),
                                torch.Tensor([-1, 0, 0]),
                                torch.Tensor([-1, 1, 0])
                            ]

                            for m in range(batch_size):
                                index = random.sample(choice, 1)
                                distance_list = []
                                p_center = index[0]
                                for n in range(opt.pnum):
                                    distance_list.append(
                                        distance_squre(real_point[m, 0, n],
                                                       p_center))
                                distance_order = sorted(
                                    enumerate(distance_list),
                                    key=lambda x: x[1])
                                for sp in range(opt.crop_point_num):
                                    input_cropped1.data[
                                        m, 0, distance_order[sp]
                                        [0]] = torch.FloatTensor([0, 0, 0])
                                    real_center.data[m, 0, sp] = real_point[
                                        m, 0, distance_order[sp][0]]
                        real_center = real_center.to(device)
                        real_center = torch.squeeze(real_center, 1)
                        input_cropped1 = input_cropped1.to(device)
                        input_cropped1 = torch.squeeze(input_cropped1, 1)
                        input_cropped2_idx = utils.farthest_point_sample(
                            input_cropped1, opt.point_scales_list[1], RAN=True)
                        input_cropped2 = utils.index_points(
                            input_cropped1, input_cropped2_idx)
                        input_cropped3_idx = utils.farthest_point_sample(
                            input_cropped1,
                            opt.point_scales_list[2],
                            RAN=False)
                        input_cropped3 = utils.index_points(
                            input_cropped1, input_cropped3_idx)
                        input_cropped1 = Variable(input_cropped1,
                                                  requires_grad=False)
                        input_cropped2 = Variable(input_cropped2,
                                                  requires_grad=False)
                        input_cropped3 = Variable(input_cropped3,
                                                  requires_grad=False)
                        input_cropped2 = input_cropped2.to(device)
                        input_cropped3 = input_cropped3.to(device)
                        input_cropped = [
                            input_cropped1, input_cropped2, input_cropped3
                        ]
                        point_netG.eval()
                        fake_center1, fake_center2, fake = point_netG(
                            input_cropped)
                        CD_loss = criterion_PointLoss(
                            torch.squeeze(fake, 1),
                            torch.squeeze(real_center, 1))
                        print('test result:', CD_loss)
                        f.write('\n' + 'test result:  %.4f' % (CD_loss))
                        break
                f.close()
            schedulerD.step()
            schedulerG.step()
            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Trained_Model/point_netG' + str(epoch) + '.pth')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netD.state_dict()
                    }, 'Trained_Model/point_netD' + str(epoch) + '.pth')

    #
    #############################
    ## ONLY G-NET
    ############################
    else:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)
                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(choice, 1)
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netG.zero_grad()
                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG_l2.backward()
                optimizerG.step()
                print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                      (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                    (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f.close()
            schedulerD.step()
            schedulerG.step()

            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Checkpoint/point_netG' + str(epoch) + '.pth')