def get_result_for_pfNet(incomplete, netpath=""): point_netG = _netG(3, 1, [1024, 512, 256], 256) point_netG = torch.nn.DataParallel(point_netG) point_netG.to(device) if netpath == "": netpath = pfnet_path else: netpath = netpath point_netG.load_state_dict(torch.load( netpath, map_location=lambda storage, location: storage)['state_dict'], strict=False) point_netG.eval() ############################ # (1) data prepare ########################### input_cropped1 = torch.squeeze(incomplete, 1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1, 512, RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=False) input_cropped2 = Variable(input_cropped2, requires_grad=False) input_cropped3 = Variable(input_cropped3, requires_grad=False) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [input_cropped1, input_cropped2, input_cropped3] fake_center1, fake_center2, fake = point_netG(input_cropped) return fake.cuda().data.cpu().squeeze(0).numpy()
crop_num_list = [] for num_ in range(opt.pnum-opt.crop_point_num): crop_num_list.append(distance_order[num_+opt.crop_point_num][0]) indices = torch.LongTensor(crop_num_list) input_cropped_partial[0,0]=torch.index_select(real_point[0,0],0,indices) input_cropped_partial = torch.squeeze(input_cropped_partial,1) input_cropped_partial = input_cropped_partial.to(device) real_center = torch.squeeze(real_center,1) # real_center_key_idx = utils.farthest_point_sample(real_center,64,train = False) # real_center_key = utils.index_points(real_center,real_center_key_idx) # input_cropped1 = input_cropped1.to(device) input_cropped1 = torch.squeeze(input_cropped1,1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[1],RAN = True) input_cropped2 = utils.index_points(input_cropped1,input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[2],RAN = False) input_cropped3 = utils.index_points(input_cropped1,input_cropped3_idx) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [input_cropped1,input_cropped2,input_cropped3] # fake,fake_part = point_netG(input_cropped) fake_center1,fake_center2,fake=point_netG(input_cropped) fake_whole = torch.cat((input_cropped_partial,fake),1) fake_whole = fake_whole.to(device) real_point = real_point.to(device) real_center = real_center.to(device) dist_all, dist1, dist2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))#+0.1*criterion_PointLoss(torch.squeeze(fake_part,1),torch.squeeze(real_center,1)) dist_all=dist_all.cpu().detach().numpy()
sp] = real_point[m, 0, distance_order[sp][0]] # print(real_center.size(),input_cropped1.size()) #label.data.resize_([batch_size,1]).fill_(real_label) real_point = real_point.to(device) real_center = real_center.to(device) input_cropped1 = input_cropped1.to(device) ############################ # (1) data prepare ########################### real_center = Variable(real_center, requires_grad=True) real_center = torch.squeeze(real_center, 1) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1, 1024, RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1, 512, RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=True) input_cropped2 = Variable(input_cropped2, requires_grad=True) input_cropped3 = Variable(input_cropped3, requires_grad=True) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [input_cropped1, input_cropped2, input_cropped3] point_netG = point_netG.train() point_netG.zero_grad() fake = point_netG(input_cropped)
gt = gt.to(device) image = image.to(device) incomplete = Variable(incomplete, requires_grad=True).cuda() image = Variable(image.float(), requires_grad=True).cuda() image = torch.squeeze(image, 1) label.resize_([batch_size, 1]).fill_(real_label) label = label.to(device) ############################ # (1) data prepare ########################### real_center = Variable(gt, requires_grad=True) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=True) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=True) input_cropped1 = torch.squeeze(incomplete, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") point_netG = _netG(opt.num_scales,opt.each_scales_size,opt.point_scales_list,opt.crop_point_num) point_netG = torch.nn.DataParallel(point_netG) point_netG.to(device) point_netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict']) point_netG.eval() input_cropped1 = np.loadtxt(opt.infile,delimiter=',') input_cropped1 = torch.FloatTensor(input_cropped1) input_cropped1 = torch.unsqueeze(input_cropped1, 0) Zeros = torch.zeros(1,512,3) input_cropped1 = torch.cat((input_cropped1,Zeros),1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[1],RAN = True) input_cropped2 = utils.index_points(input_cropped1,input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1,opt.point_scales_list[2],RAN = False) input_cropped3 = utils.index_points(input_cropped1,input_cropped3_idx) input_cropped4_idx = utils.farthest_point_sample(input_cropped1,256,RAN = True) input_cropped4 = utils.index_points(input_cropped1,input_cropped4_idx) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [input_cropped1,input_cropped2,input_cropped3] fake_center1,fake_center2,fake=point_netG(input_cropped) fake = fake.cuda() fake_center1 = fake_center1.cuda() fake_center2 = fake_center2.cuda()
opt.crop_point_num) point_netG = torch.nn.DataParallel(point_netG) point_netG.to(device) point_netG.load_state_dict( torch.load(opt.netG, map_location=lambda storage, location: storage)['state_dict']) point_netG.eval() input_cropped1 = np.loadtxt(opt.infile, delimiter=',') input_cropped1 = torch.FloatTensor(input_cropped1) input_cropped1 = torch.unsqueeze(input_cropped1, 0) Zeros = torch.zeros(1, 512, 3) input_cropped1 = torch.cat((input_cropped1, Zeros), 1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1, 1024, RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1, 512, RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped4_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=True) input_cropped4 = utils.index_points(input_cropped1, input_cropped4_idx) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [input_cropped1, input_cropped2, input_cropped3] fake_center1, fake_center2, fake = point_netG(input_cropped) fake = fake.cuda() fake_center1 = fake_center1.cuda()
gt = gt.to(device) image = image.to(device) complete_gt = torch.cat([gt, incomplete], dim=1) complete_gt=complete_gt.to(device) incomplete = incomplete.to(device) image = image.to(device) incomplete = Variable(incomplete, requires_grad=False) image = Variable(image.float(), requires_grad=False) ############################ # (1) data prepare ########################### real_center = Variable(gt, requires_grad=False) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=False) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=False) input_cropped1 = torch.squeeze(incomplete, 1) input_cropped2_idx = utils.farthest_point_sample(input_cropped1, 512, RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample(input_cropped1, 256, RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=False) input_cropped2 = Variable(input_cropped2, requires_grad=False) input_cropped3 = Variable(input_cropped3, requires_grad=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") PCN = Autoencoder(opt.num_inputs, opt.num_coarses, opt.num_fines, opt.grid_size) PCN.load_state_dict( torch.load(opt.model, map_location=lambda storage, location: storage)['state_dict']) print("Let's use", torch.cuda.device_count(), "GPUs!") PCN.to(device) PCN = torch.nn.DataParallel(PCN) PCN.eval() input_cropped1 = np.loadtxt(opt.infile, delimiter=',') input_cropped1 = torch.FloatTensor(input_cropped1) input_cropped1 = torch.unsqueeze(input_cropped1, 0) input_cropped1 = input_cropped1.to(device) input_key_cropped_index = farthest_point_sample(input_cropped1, opt.num_inputs, RAN=False) input_key_cropped = index_points(input_cropped1, input_key_cropped_index) #BX1024X3 coarses, fine = PCN(input_key_cropped) fine = fine.cpu() np_fake = fine[0].detach().numpy() input_cropped1 = input_cropped1.cpu() np_crop = input_cropped1[0].numpy() np.savetxt('test_from_prn_to_PCN/crop_PCN' + '.csv', np_crop, fmt="%f,%f,%f") np.savetxt('test_from_prn_to_PCN/fake_PCN' + '.csv', np_fake, fmt="%f,%f,%f")
transforms = transforms.Compose( [ d_utils.PointcloudToTensor(), # d_utils.PointcloudRotate(axis=np.array([1, 0, 0])), # d_utils.PointcloudScale(), # d_utils.PointcloudTranslate(), # d_utils.PointcloudJitter(), ] ) dset = ModelNet40Cls(1024, train=True, transforms=transforms) print(dset[0][1]) print(len(dset)) dloader = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=True) for i, Data in enumerate(dloader, 0): real_point, target = Data print('1') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") real_point = real_point.to(device) real_point2_idx = utils.farthest_point_sample(real_point,512) real_point2 = utils.index_points(real_point,real_point2_idx) real_point3_idx = utils.farthest_point_sample(real_point,256) real_point3 = utils.index_points(real_point,real_point3_idx) # model1 = real_point[0].numpy() # model2 = real_point2[0].numpy() # model3 = real_point3[0].numpy() # # np.savetxt('test-examples/model'+'.txt', model1, fmt = "%f %f %f") # np.savetxt('test-examples/model2'+'.txt', model2, fmt = "%f %f %f") # np.savetxt('test-examples/model3'+'.txt', model3, fmt = "%f %f %f")
distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted(enumerate(distance_list), key=lambda x: x[1]) crop_num_list = [] for num in range(opt.num_fines - opt.crop_point_num): crop_num_list.append(distance_order[num + opt.crop_point_num][0]) indices = torch.LongTensor(crop_num_list) input_cropped[m, 0] = torch.index_select(real_point[m, 0], 0, indices) real_point = torch.squeeze(real_point, 1) #BX2048X3 input_cropped = torch.squeeze(input_cropped, 1) #BX(2048-512)X3 real_key_point_index = farthest_point_sample(real_point, opt.num_coarses, False) real_key_point = index_points(real_point, real_key_point_index) #BX1024X3 input_key_cropped_index = farthest_point_sample( input_cropped, opt.num_inputs, False) input_key_cropped = index_points(input_cropped, input_key_cropped_index) #BX1024X3 real_point = Variable(real_point, requires_grad=True) real_key_point = Variable(real_key_point, requires_grad=True) input_key_cropped = Variable(input_key_cropped, requires_grad=True) real_point = real_point.to(device) real_key_point = real_key_point.to(device) input_key_cropped = input_key_cropped.to(device)
def run(): print(opt) blue = lambda x: '\033[94m' + x + '\033[0m' BASE_DIR = os.path.dirname(os.path.abspath(__file__)) USE_CUDA = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") point_netG = _netG(opt.num_scales, opt.each_scales_size, opt.point_scales_list, opt.crop_point_num) point_netD = _netlocalD(opt.crop_point_num) cudnn.benchmark = True resume_epoch = 0 def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("Conv1d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm1d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) if USE_CUDA: print("Let's use", torch.cuda.device_count(), "GPUs!") point_netG = torch.nn.DataParallel(point_netG) point_netD = torch.nn.DataParallel(point_netD) point_netG.to(device) point_netG.apply(weights_init_normal) point_netD.to(device) point_netD.apply(weights_init_normal) if opt.netG != '': point_netG.load_state_dict( torch.load( opt.netG, map_location=lambda storage, location: storage)['state_dict']) resume_epoch = torch.load(opt.netG)['epoch'] if opt.netD != '': point_netD.load_state_dict( torch.load( opt.netD, map_location=lambda storage, location: storage)['state_dict']) resume_epoch = torch.load(opt.netD)['epoch'] if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) if opt.cuda: torch.cuda.manual_seed_all(opt.manualSeed) # def run(): # transforms = transforms.Compose([d_utils.PointcloudToTensor(),]) dset = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=True, class_choice=None, npoints=opt.pnum, split='train') assert dset dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) test_dset = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=True, class_choice=None, npoints=opt.pnum, split='test') test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False) #assert dset #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize, # shuffle=True,num_workers = int(opt.workers)) # # #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False) #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize, # shuffle=True,num_workers = int(opt.workers)) #pointcls_net.apply(weights_init) print(point_netG) print(point_netD) criterion = torch.nn.BCEWithLogitsLoss().to(device) criterion_PointLoss = PointLoss().to(device) # setup optimizer optimizerD = torch.optim.Adam(point_netD.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-05, weight_decay=opt.weight_decay) optimizerG = torch.optim.Adam(point_netG.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-05, weight_decay=opt.weight_decay) schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=40, gamma=0.2) schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=40, gamma=0.2) real_label = 1 fake_label = 0 crop_point_num = int(opt.crop_point_num) input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3) label = torch.FloatTensor(opt.batchSize) num_batch = len(dset) / opt.batchSize ########################### # G-NET and T-NET ##########################djel if opt.D_choose == 1: for epoch in range(resume_epoch, opt.niter): if epoch < 30: alpha1 = 0.01 alpha2 = 0.02 elif epoch < 80: alpha1 = 0.05 alpha2 = 0.1 else: alpha1 = 0.1 alpha2 = 0.2 if __name__ == '__main__': # On Windows calling this function is necessary. freeze_support() for i, data in enumerate(dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f") p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': #Set viewpoints choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample( choice, 1) #Random choose one of the viewpoint distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted(enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp][0]] = torch.FloatTensor( [0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f") label.resize_([batch_size, 1]).fill_(real_label) real_point = real_point.to(device) real_center = real_center.to(device) input_cropped1 = input_cropped1.to(device) label = label.to(device) ############################ # (1) data prepare ########################### real_center = Variable(real_center, requires_grad=True) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=True) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=True) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=True) input_cropped2 = Variable(input_cropped2, requires_grad=True) input_cropped3 = Variable(input_cropped3, requires_grad=True) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG = point_netG.train() point_netD = point_netD.train() ############################ # (2) Update D network ########################### point_netD.zero_grad() real_center = torch.unsqueeze(real_center, 1) output = point_netD(real_center) errD_real = criterion(output, label) errD_real.backward() fake_center1, fake_center2, fake = point_netG(input_cropped) fake = torch.unsqueeze(fake, 1) label.data.fill_(fake_label) output = point_netD(fake.detach()) errD_fake = criterion(output, label) # on(output, label) errD_fake.backward() errD = errD_real + errD_fake optimizerD.step() ############################ # (3) Update G network: maximize log(D(G(z))) ########################### point_netG.zero_grad() label.data.fill_(real_label) output = point_netD(fake) errG_D = criterion(output, label) errG_l2 = 0 # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy() np.savetxt('./post_set' + '.txt', torch.cat([ input_cropped1[0, :, :], torch.squeeze(fake, 1)[0, :, :] ], dim=0).cpu().detach().numpy(), fmt="%f,%f,%f") CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\ +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\ +alpha2*criterion_PointLoss(fake_center2,real_center_key2) errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2 errG.backward() optimizerG.step() np.savetxt('./fake' + '.txt', torch.squeeze(fake, 1)[0, :, :].cpu().detach().numpy(), fmt="%f,%f,%f") # input_A np.savetxt('./real_center' + '.txt', torch.squeeze(real_center, 1)[0, :, :].cpu().detach().numpy(), fmt="%f,%f,%f") # input_B print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f' % (epoch, opt.niter, i, len(dataloader), errD.data, errG_D.data, errG_l2, errG, CD_LOSS)) f = open('loss_PFNet.txt', 'a') f.write( '\n' + '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f' % (epoch, opt.niter, i, len(dataloader), errD.data, errG_D.data, errG_l2, errG, CD_LOSS)) if i % 10 == 0: print('After, ', i, '-th batch') f.write('\n' + 'After, ' + str(i) + '-th batch') for i, data in enumerate(test_dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor( batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor( batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample(choice, 1) distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted( enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp] [0]] = torch.FloatTensor([0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] real_center = real_center.to(device) real_center = torch.squeeze(real_center, 1) input_cropped1 = input_cropped1.to(device) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points( input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points( input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=False) input_cropped2 = Variable(input_cropped2, requires_grad=False) input_cropped3 = Variable(input_cropped3, requires_grad=False) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG.eval() fake_center1, fake_center2, fake = point_netG( input_cropped) CD_loss = criterion_PointLoss( torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) print('test result:', CD_loss) f.write('\n' + 'test result: %.4f' % (CD_loss)) break f.close() schedulerD.step() schedulerG.step() if epoch % 10 == 0: torch.save( { 'epoch': epoch + 1, 'state_dict': point_netG.state_dict() }, 'Trained_Model/point_netG' + str(epoch) + '.pth') torch.save( { 'epoch': epoch + 1, 'state_dict': point_netD.state_dict() }, 'Trained_Model/point_netD' + str(epoch) + '.pth') # ############################# ## ONLY G-NET ############################ else: for epoch in range(resume_epoch, opt.niter): if epoch < 30: alpha1 = 0.01 alpha2 = 0.02 elif epoch < 80: alpha1 = 0.05 alpha2 = 0.1 else: alpha1 = 0.1 alpha2 = 0.2 for i, data in enumerate(dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample(choice, 1) distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted(enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp][0]] = torch.FloatTensor( [0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] real_point = real_point.to(device) real_center = real_center.to(device) input_cropped1 = input_cropped1.to(device) ############################ # (1) data prepare ########################### real_center = Variable(real_center, requires_grad=True) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=True) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=True) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=True) input_cropped2 = Variable(input_cropped2, requires_grad=True) input_cropped3 = Variable(input_cropped3, requires_grad=True) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG = point_netG.train() point_netG.zero_grad() fake_center1, fake_center2, fake = point_netG(input_cropped) fake = torch.unsqueeze(fake, 1) ############################ # (3) Update G network: maximize log(D(G(z))) ########################### CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\ +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\ +alpha2*criterion_PointLoss(fake_center2,real_center_key2) errG_l2.backward() optimizerG.step() print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' % (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS)) f = open('loss_PFNet.txt', 'a') f.write( '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' % (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS)) f.close() schedulerD.step() schedulerG.step() if epoch % 10 == 0: torch.save( { 'epoch': epoch + 1, 'state_dict': point_netG.state_dict() }, 'Checkpoint/point_netG' + str(epoch) + '.pth')