Esempio n. 1
0
    p_origin = [0, 0, 0]

    choice = [
        torch.Tensor([1, 0, 0]),
        torch.Tensor([0, 0, 1]),
        torch.Tensor([1, 0, 1]),
        torch.Tensor([-1, 0, 0]),
        torch.Tensor([-1, 1, 0])
    ]
    index = random.sample(choice, 1)
    distance_list = []
    p_center = index[0]
    center_id = 9
    # p_center = choice[center_id]
    for num in range(opt.pnum):
        distance_list.append(distance_squre(real_point[0, 0, num], p_center))
    distance_order = sorted(enumerate(distance_list), key=lambda x: x[1])

    for sp in range(opt.crop_point_num):
        input_cropped1.data[0, 0, distance_order[sp][0]] = torch.FloatTensor(
            [0, 0, 0])
        real_center.data[0, 0, sp] = real_point[0, 0, distance_order[sp][0]]

    real_center.to(device)
    real_center = torch.squeeze(real_center, 1)

    input_cropped1 = torch.squeeze(input_cropped1, 1)
    input_cropped2_idx = utils.farthest_point_sample(input_cropped1,
                                                     opt.point_scales_list[1],
                                                     RAN=True)
    input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx)
Esempio n. 2
0
        torch.Tensor([0, 0, 1]),
        torch.Tensor([1, 0, 1]),
        torch.Tensor([-1, 0, 0]),
        torch.Tensor([-1, 1, 0])
    ]
    #    points = [x for x in range(0,opt.pnum-1)]
    #    choice =random.sample(points,5)
    index = choice[IDX - 1]  #random.sample(choice,1)
    IDX = IDX + 1
    if IDX % 5 == 0:
        IDX = 0
    distance_list = []
    #    p_center  = real_point[0,0,index]
    p_center = index
    for num in range(opt.num_points):
        distance_list.append(distance_squre(real_point[0, 0, num], p_center))
    distance_order = sorted(enumerate(distance_list), key=lambda x: x[1])

    for sp in range(opt.crop_point_num):
        input_cropped_ours.data[0, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
        real_center.data[0, 0, sp] = real_point[0, 0, distance_order[sp][0]]
    real_center = torch.squeeze(real_center, 1)
    crop_num_list = []
    for num in range(opt.num_points - opt.crop_point_num):
        crop_num_list.append(distance_order[num + opt.crop_point_num][0])
    indices = torch.LongTensor(crop_num_list)
    input_cropped[0, 0] = torch.index_select(real_point[0, 0], 0, indices)

    real_point = torch.squeeze(real_point, 1)
                #Set viewpoints
                choice = [
                    torch.Tensor([1, 0, 0]),
                    torch.Tensor([0, 0, 1]),
                    torch.Tensor([1, 0, 1]),
                    torch.Tensor([-1, 0, 0]),
                    torch.Tensor([-1, 1, 0])
                ]
                for m in range(batch_size):
                    index = random.sample(
                        choice, 1)  #Random choose one of the viewpoint
                    distance_list = []
                    p_center = index[0]
                    for n in range(opt.pnum):
                        distance_list.append(
                            distance_squre(real_point[m, 0, n], p_center))
                    distance_order = sorted(enumerate(distance_list),
                                            key=lambda x: x[1])

                    for sp in range(opt.crop_point_num):
                        input_cropped1.data[
                            m, 0, distance_order[sp][0]] = torch.FloatTensor(
                                [0, 0, 0])
                        real_center.data[m, 0, sp] = real_point[
                            m, 0, distance_order[sp][0]]
            label.resize_([batch_size, 1]).fill_(real_label)
            real_point = real_point.to(device)
            real_center = real_center.to(device)
            input_cropped1 = input_cropped1.to(device)
            label = label.to(device)
            ############################
Esempio n. 4
0
    for data in train_loader():
        points, label = data
        batch_size = points.shape[0]

        real_point = points.numpy()
        real_center = np.zeros(
            (batch_size, opt.crop_point_num, 3)).astype('float32')
        cropped_point = copy.deepcopy(real_point)

        for m in range(batch_size):
            index = random.sample(crop_choice, 1)
            distance_list = []
            p_center = index[0]

            for n in range(opt.pnum):
                distance_list.append(distance_squre(real_point[m, n],
                                                    p_center))
            distance_order = sorted(enumerate(distance_list),
                                    key=lambda x: x[1])

            for sp in range(opt.crop_point_num):
                cropped_point[m, distance_order[sp][0]] = np.array([0, 0, 0])
                real_center[m, sp] = real_point[m, distance_order[sp][0]]

        cropped_point1_idx = utils.farthest_point_sample_numpy(
            cropped_point, opt.point_scales_list[1], RAN=True)
        cropped_point1 = utils.index_points_numpy(cropped_point,
                                                  cropped_point1_idx)

        cropped_point2_idx = utils.farthest_point_sample_numpy(
            cropped_point, opt.point_scales_list[2], RAN=False)
        cropped_point2 = utils.index_points_numpy(cropped_point,
Esempio n. 5
0
def run():
    print(opt)

    blue = lambda x: '\033[94m' + x + '\033[0m'
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    USE_CUDA = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    point_netG = _netG(opt.num_scales, opt.each_scales_size,
                       opt.point_scales_list, opt.crop_point_num)
    point_netD = _netlocalD(opt.crop_point_num)
    cudnn.benchmark = True
    resume_epoch = 0

    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("Conv1d") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm1d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    if USE_CUDA:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        point_netG = torch.nn.DataParallel(point_netG)
        point_netD = torch.nn.DataParallel(point_netD)
        point_netG.to(device)
        point_netG.apply(weights_init_normal)
        point_netD.to(device)
        point_netD.apply(weights_init_normal)
    if opt.netG != '':
        point_netG.load_state_dict(
            torch.load(
                opt.netG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netG)['epoch']
    if opt.netD != '':
        point_netD.load_state_dict(
            torch.load(
                opt.netD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netD)['epoch']

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)


# def run():

# transforms = transforms.Compose([d_utils.PointcloudToTensor(),])
    dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='train')
    assert dset

    dataloader = torch.utils.data.DataLoader(dset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    test_dset = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=True,
        class_choice=None,
        npoints=opt.pnum,
        split='test')
    test_dataloader = torch.utils.data.DataLoader(test_dset,
                                                  batch_size=opt.batchSize,
                                                  shuffle=True,
                                                  num_workers=int(opt.workers))

    #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False)
    #assert dset
    #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))
    #
    #
    #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False)
    #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize,
    #                                         shuffle=True,num_workers = int(opt.workers))

    #pointcls_net.apply(weights_init)
    print(point_netG)
    print(point_netD)

    criterion = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_PointLoss = PointLoss().to(device)

    # setup optimizer
    optimizerD = torch.optim.Adam(point_netD.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    optimizerG = torch.optim.Adam(point_netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.9, 0.999),
                                  eps=1e-05,
                                  weight_decay=opt.weight_decay)
    schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD,
                                                 step_size=40,
                                                 gamma=0.2)
    schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG,
                                                 step_size=40,
                                                 gamma=0.2)

    real_label = 1
    fake_label = 0

    crop_point_num = int(opt.crop_point_num)
    input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3)
    label = torch.FloatTensor(opt.batchSize)

    num_batch = len(dset) / opt.batchSize

    ###########################
    #  G-NET and T-NET
    ##########################djel
    if opt.D_choose == 1:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            if __name__ == '__main__':
                # On Windows calling this function is necessary.
                freeze_support()

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    #Set viewpoints
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(
                            choice, 1)  #Random choose one of the viewpoint
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]

                # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f")

                label.resize_([batch_size, 1]).fill_(real_label)
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                label = label.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netD = point_netD.train()
                ############################
                # (2) Update D network
                ###########################
                point_netD.zero_grad()
                real_center = torch.unsqueeze(real_center, 1)
                output = point_netD(real_center)
                errD_real = criterion(output, label)
                errD_real.backward()

                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                label.data.fill_(fake_label)
                output = point_netD(fake.detach())
                errD_fake = criterion(output, label)
                # on(output, label)
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################
                point_netG.zero_grad()
                label.data.fill_(real_label)
                output = point_netD(fake)
                errG_D = criterion(output, label)
                errG_l2 = 0

                # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy()
                np.savetxt('./post_set' + '.txt',
                           torch.cat([
                               input_cropped1[0, :, :],
                               torch.squeeze(fake, 1)[0, :, :]
                           ],
                                     dim=0).cpu().detach().numpy(),
                           fmt="%f,%f,%f")

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2
                errG.backward()
                optimizerG.step()

                np.savetxt('./fake' + '.txt',
                           torch.squeeze(fake,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_A
                np.savetxt('./real_center' + '.txt',
                           torch.squeeze(real_center,
                                         1)[0, :, :].cpu().detach().numpy(),
                           fmt="%f,%f,%f")  # input_B

                print(
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' +
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f'
                    % (epoch, opt.niter, i, len(dataloader), errD.data,
                       errG_D.data, errG_l2, errG, CD_LOSS))

                if i % 10 == 0:
                    print('After, ', i, '-th batch')
                    f.write('\n' + 'After, ' + str(i) + '-th batch')
                    for i, data in enumerate(test_dataloader, 0):
                        real_point, target = data

                        batch_size = real_point.size()[0]
                        real_center = torch.FloatTensor(
                            batch_size, 1, opt.crop_point_num, 3)
                        input_cropped1 = torch.FloatTensor(
                            batch_size, opt.pnum, 3)
                        input_cropped1 = input_cropped1.data.copy_(real_point)
                        real_point = torch.unsqueeze(real_point, 1)
                        input_cropped1 = torch.unsqueeze(input_cropped1, 1)

                        p_origin = [0, 0, 0]

                        if opt.cropmethod == 'random_center':
                            choice = [
                                torch.Tensor([1, 0, 0]),
                                torch.Tensor([0, 0, 1]),
                                torch.Tensor([1, 0, 1]),
                                torch.Tensor([-1, 0, 0]),
                                torch.Tensor([-1, 1, 0])
                            ]

                            for m in range(batch_size):
                                index = random.sample(choice, 1)
                                distance_list = []
                                p_center = index[0]
                                for n in range(opt.pnum):
                                    distance_list.append(
                                        distance_squre(real_point[m, 0, n],
                                                       p_center))
                                distance_order = sorted(
                                    enumerate(distance_list),
                                    key=lambda x: x[1])
                                for sp in range(opt.crop_point_num):
                                    input_cropped1.data[
                                        m, 0, distance_order[sp]
                                        [0]] = torch.FloatTensor([0, 0, 0])
                                    real_center.data[m, 0, sp] = real_point[
                                        m, 0, distance_order[sp][0]]
                        real_center = real_center.to(device)
                        real_center = torch.squeeze(real_center, 1)
                        input_cropped1 = input_cropped1.to(device)
                        input_cropped1 = torch.squeeze(input_cropped1, 1)
                        input_cropped2_idx = utils.farthest_point_sample(
                            input_cropped1, opt.point_scales_list[1], RAN=True)
                        input_cropped2 = utils.index_points(
                            input_cropped1, input_cropped2_idx)
                        input_cropped3_idx = utils.farthest_point_sample(
                            input_cropped1,
                            opt.point_scales_list[2],
                            RAN=False)
                        input_cropped3 = utils.index_points(
                            input_cropped1, input_cropped3_idx)
                        input_cropped1 = Variable(input_cropped1,
                                                  requires_grad=False)
                        input_cropped2 = Variable(input_cropped2,
                                                  requires_grad=False)
                        input_cropped3 = Variable(input_cropped3,
                                                  requires_grad=False)
                        input_cropped2 = input_cropped2.to(device)
                        input_cropped3 = input_cropped3.to(device)
                        input_cropped = [
                            input_cropped1, input_cropped2, input_cropped3
                        ]
                        point_netG.eval()
                        fake_center1, fake_center2, fake = point_netG(
                            input_cropped)
                        CD_loss = criterion_PointLoss(
                            torch.squeeze(fake, 1),
                            torch.squeeze(real_center, 1))
                        print('test result:', CD_loss)
                        f.write('\n' + 'test result:  %.4f' % (CD_loss))
                        break
                f.close()
            schedulerD.step()
            schedulerG.step()
            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Trained_Model/point_netG' + str(epoch) + '.pth')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netD.state_dict()
                    }, 'Trained_Model/point_netD' + str(epoch) + '.pth')

    #
    #############################
    ## ONLY G-NET
    ############################
    else:
        for epoch in range(resume_epoch, opt.niter):
            if epoch < 30:
                alpha1 = 0.01
                alpha2 = 0.02
            elif epoch < 80:
                alpha1 = 0.05
                alpha2 = 0.1
            else:
                alpha1 = 0.1
                alpha2 = 0.2

            for i, data in enumerate(dataloader, 0):

                real_point, target = data

                batch_size = real_point.size()[0]
                real_center = torch.FloatTensor(batch_size, 1,
                                                opt.crop_point_num, 3)
                input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3)
                input_cropped1 = input_cropped1.data.copy_(real_point)
                real_point = torch.unsqueeze(real_point, 1)
                input_cropped1 = torch.unsqueeze(input_cropped1, 1)
                p_origin = [0, 0, 0]
                if opt.cropmethod == 'random_center':
                    choice = [
                        torch.Tensor([1, 0, 0]),
                        torch.Tensor([0, 0, 1]),
                        torch.Tensor([1, 0, 1]),
                        torch.Tensor([-1, 0, 0]),
                        torch.Tensor([-1, 1, 0])
                    ]
                    for m in range(batch_size):
                        index = random.sample(choice, 1)
                        distance_list = []
                        p_center = index[0]
                        for n in range(opt.pnum):
                            distance_list.append(
                                distance_squre(real_point[m, 0, n], p_center))
                        distance_order = sorted(enumerate(distance_list),
                                                key=lambda x: x[1])

                        for sp in range(opt.crop_point_num):
                            input_cropped1.data[
                                m, 0,
                                distance_order[sp][0]] = torch.FloatTensor(
                                    [0, 0, 0])
                            real_center.data[m, 0, sp] = real_point[
                                m, 0, distance_order[sp][0]]
                real_point = real_point.to(device)
                real_center = real_center.to(device)
                input_cropped1 = input_cropped1.to(device)
                ############################
                # (1) data prepare
                ###########################
                real_center = Variable(real_center, requires_grad=True)
                real_center = torch.squeeze(real_center, 1)
                real_center_key1_idx = utils.farthest_point_sample(real_center,
                                                                   64,
                                                                   RAN=False)
                real_center_key1 = utils.index_points(real_center,
                                                      real_center_key1_idx)
                real_center_key1 = Variable(real_center_key1,
                                            requires_grad=True)

                real_center_key2_idx = utils.farthest_point_sample(real_center,
                                                                   128,
                                                                   RAN=True)
                real_center_key2 = utils.index_points(real_center,
                                                      real_center_key2_idx)
                real_center_key2 = Variable(real_center_key2,
                                            requires_grad=True)

                input_cropped1 = torch.squeeze(input_cropped1, 1)
                input_cropped2_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[1], RAN=True)
                input_cropped2 = utils.index_points(input_cropped1,
                                                    input_cropped2_idx)
                input_cropped3_idx = utils.farthest_point_sample(
                    input_cropped1, opt.point_scales_list[2], RAN=False)
                input_cropped3 = utils.index_points(input_cropped1,
                                                    input_cropped3_idx)
                input_cropped1 = Variable(input_cropped1, requires_grad=True)
                input_cropped2 = Variable(input_cropped2, requires_grad=True)
                input_cropped3 = Variable(input_cropped3, requires_grad=True)
                input_cropped2 = input_cropped2.to(device)
                input_cropped3 = input_cropped3.to(device)
                input_cropped = [
                    input_cropped1, input_cropped2, input_cropped3
                ]
                point_netG = point_netG.train()
                point_netG.zero_grad()
                fake_center1, fake_center2, fake = point_netG(input_cropped)
                fake = torch.unsqueeze(fake, 1)
                ############################
                # (3) Update G network: maximize log(D(G(z)))
                ###########################

                CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1),
                                              torch.squeeze(real_center, 1))

                errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\
                +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\
                +alpha2*criterion_PointLoss(fake_center2,real_center_key2)

                errG_l2.backward()
                optimizerG.step()
                print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                      (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f = open('loss_PFNet.txt', 'a')
                f.write(
                    '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' %
                    (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS))
                f.close()
            schedulerD.step()
            schedulerG.step()

            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': point_netG.state_dict()
                    }, 'Checkpoint/point_netG' + str(epoch) + '.pth')
Esempio n. 6
0
    real_center = torch.FloatTensor(opt.batch_size, 1, opt.crop_point_num, 3)
    fake_center = torch.FloatTensor(opt.batch_size, 1, opt.crop_point_num, 3)
    batch_size = real_point.size()[0]
    p_origin = [0,0,0]
    choice =[torch.Tensor([1,0,0]),torch.Tensor([0,0,1]),torch.Tensor([1,0,1]),torch.Tensor([-1,0,0]),torch.Tensor([-1,1,0])]   
#    points = [x for x in range(0,opt.pnum-1)]
#    choice =random.sample(points,5)
    index = choice[IDX-1]#random.sample(choice,1)
    IDX  = IDX+1
    if IDX%5 == 0:
        IDX = 0
    distance_list = []
#    p_center  = real_point[0,0,index]
    p_center = index
    for num in range(opt.num_fines):
        distance_list.append(distance_squre(real_point[0,0,num],p_center))
    distance_order = sorted(enumerate(distance_list), key = lambda x:x[1])
    
    for sp in range(opt.crop_point_num):
        real_center.data[0,0,sp] = real_point[0,0,distance_order[sp][0]]
    real_center = torch.squeeze(real_center,1) 
    real_center = real_center.to(device)
    
    crop_num_list = []
    for num in range(opt.num_fines-opt.crop_point_num):
        crop_num_list.append(distance_order[num+opt.crop_point_num][0])
    indices = torch.LongTensor(crop_num_list)
    input_cropped[0,0]=torch.index_select(real_point[0,0],0,indices)
    
    
    real_point = torch.squeeze(real_point,1)