type=float, default=0.95, help='0 means do not use else use with this weight') parser.add_argument('--cropmethod', default='random_center', help='random|center|random_center') opt = parser.parse_args() print(opt) blue = lambda x: '\033[94m' + x + '\033[0m' BASE_DIR = os.path.dirname(os.path.abspath(__file__)) USE_CUDA = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") point_netG = _netG(opt.num_scales, opt.each_scales_size, opt.point_scales_list, opt.crop_point_num) point_netD = _netlocalD(opt.crop_point_num) cudnn.benchmark = True resume_epoch = 0 def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("Conv1d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm1d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
def run(): print(opt) blue = lambda x: '\033[94m' + x + '\033[0m' BASE_DIR = os.path.dirname(os.path.abspath(__file__)) USE_CUDA = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") point_netG = _netG(opt.num_scales, opt.each_scales_size, opt.point_scales_list, opt.crop_point_num) point_netD = _netlocalD(opt.crop_point_num) cudnn.benchmark = True resume_epoch = 0 def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("Conv1d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm1d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) if USE_CUDA: print("Let's use", torch.cuda.device_count(), "GPUs!") point_netG = torch.nn.DataParallel(point_netG) point_netD = torch.nn.DataParallel(point_netD) point_netG.to(device) point_netG.apply(weights_init_normal) point_netD.to(device) point_netD.apply(weights_init_normal) if opt.netG != '': point_netG.load_state_dict( torch.load( opt.netG, map_location=lambda storage, location: storage)['state_dict']) resume_epoch = torch.load(opt.netG)['epoch'] if opt.netD != '': point_netD.load_state_dict( torch.load( opt.netD, map_location=lambda storage, location: storage)['state_dict']) resume_epoch = torch.load(opt.netD)['epoch'] if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) if opt.cuda: torch.cuda.manual_seed_all(opt.manualSeed) # def run(): # transforms = transforms.Compose([d_utils.PointcloudToTensor(),]) dset = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=True, class_choice=None, npoints=opt.pnum, split='train') assert dset dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) test_dset = shapenet_part_loader.PartDataset( root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/', classification=True, class_choice=None, npoints=opt.pnum, split='test') test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) #dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=True, transforms=transforms, download = False) #assert dset #dataloader = torch.utils.data.DataLoader(dset, batch_size=opt.batchSize, # shuffle=True,num_workers = int(opt.workers)) # # #test_dset = ModelNet40Loader.ModelNet40Cls(opt.pnum, train=False, transforms=transforms, download = False) #test_dataloader = torch.utils.data.DataLoader(test_dset, batch_size=opt.batchSize, # shuffle=True,num_workers = int(opt.workers)) #pointcls_net.apply(weights_init) print(point_netG) print(point_netD) criterion = torch.nn.BCEWithLogitsLoss().to(device) criterion_PointLoss = PointLoss().to(device) # setup optimizer optimizerD = torch.optim.Adam(point_netD.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-05, weight_decay=opt.weight_decay) optimizerG = torch.optim.Adam(point_netG.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-05, weight_decay=opt.weight_decay) schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=40, gamma=0.2) schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=40, gamma=0.2) real_label = 1 fake_label = 0 crop_point_num = int(opt.crop_point_num) input_cropped1 = torch.FloatTensor(opt.batchSize, opt.pnum, 3) label = torch.FloatTensor(opt.batchSize) num_batch = len(dset) / opt.batchSize ########################### # G-NET and T-NET ##########################djel if opt.D_choose == 1: for epoch in range(resume_epoch, opt.niter): if epoch < 30: alpha1 = 0.01 alpha2 = 0.02 elif epoch < 80: alpha1 = 0.05 alpha2 = 0.1 else: alpha1 = 0.1 alpha2 = 0.2 if __name__ == '__main__': # On Windows calling this function is necessary. freeze_support() for i, data in enumerate(dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) # np.savetxt('./pre_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f") p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': #Set viewpoints choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample( choice, 1) #Random choose one of the viewpoint distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted(enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp][0]] = torch.FloatTensor( [0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] # np.savetxt('./post_set' + '.txt', input_cropped1[0, 0, :, :], fmt="%f,%f,%f") label.resize_([batch_size, 1]).fill_(real_label) real_point = real_point.to(device) real_center = real_center.to(device) input_cropped1 = input_cropped1.to(device) label = label.to(device) ############################ # (1) data prepare ########################### real_center = Variable(real_center, requires_grad=True) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=True) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=True) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=True) input_cropped2 = Variable(input_cropped2, requires_grad=True) input_cropped3 = Variable(input_cropped3, requires_grad=True) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG = point_netG.train() point_netD = point_netD.train() ############################ # (2) Update D network ########################### point_netD.zero_grad() real_center = torch.unsqueeze(real_center, 1) output = point_netD(real_center) errD_real = criterion(output, label) errD_real.backward() fake_center1, fake_center2, fake = point_netG(input_cropped) fake = torch.unsqueeze(fake, 1) label.data.fill_(fake_label) output = point_netD(fake.detach()) errD_fake = criterion(output, label) # on(output, label) errD_fake.backward() errD = errD_real + errD_fake optimizerD.step() ############################ # (3) Update G network: maximize log(D(G(z))) ########################### point_netG.zero_grad() label.data.fill_(real_label) output = point_netD(fake) errG_D = criterion(output, label) errG_l2 = 0 # temp = torch.cat([input_cropped1[0,:,:], torch.squeeze(fake, 1)[0,:,:]], dim=0).cpu().detach().numpy() np.savetxt('./post_set' + '.txt', torch.cat([ input_cropped1[0, :, :], torch.squeeze(fake, 1)[0, :, :] ], dim=0).cpu().detach().numpy(), fmt="%f,%f,%f") CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\ +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\ +alpha2*criterion_PointLoss(fake_center2,real_center_key2) errG = (1 - opt.wtl2) * errG_D + opt.wtl2 * errG_l2 errG.backward() optimizerG.step() np.savetxt('./fake' + '.txt', torch.squeeze(fake, 1)[0, :, :].cpu().detach().numpy(), fmt="%f,%f,%f") # input_A np.savetxt('./real_center' + '.txt', torch.squeeze(real_center, 1)[0, :, :].cpu().detach().numpy(), fmt="%f,%f,%f") # input_B print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f/ %.4f' % (epoch, opt.niter, i, len(dataloader), errD.data, errG_D.data, errG_l2, errG, CD_LOSS)) f = open('loss_PFNet.txt', 'a') f.write( '\n' + '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f / %.4f /%.4f' % (epoch, opt.niter, i, len(dataloader), errD.data, errG_D.data, errG_l2, errG, CD_LOSS)) if i % 10 == 0: print('After, ', i, '-th batch') f.write('\n' + 'After, ' + str(i) + '-th batch') for i, data in enumerate(test_dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor( batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor( batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample(choice, 1) distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted( enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp] [0]] = torch.FloatTensor([0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] real_center = real_center.to(device) real_center = torch.squeeze(real_center, 1) input_cropped1 = input_cropped1.to(device) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points( input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points( input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=False) input_cropped2 = Variable(input_cropped2, requires_grad=False) input_cropped3 = Variable(input_cropped3, requires_grad=False) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG.eval() fake_center1, fake_center2, fake = point_netG( input_cropped) CD_loss = criterion_PointLoss( torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) print('test result:', CD_loss) f.write('\n' + 'test result: %.4f' % (CD_loss)) break f.close() schedulerD.step() schedulerG.step() if epoch % 10 == 0: torch.save( { 'epoch': epoch + 1, 'state_dict': point_netG.state_dict() }, 'Trained_Model/point_netG' + str(epoch) + '.pth') torch.save( { 'epoch': epoch + 1, 'state_dict': point_netD.state_dict() }, 'Trained_Model/point_netD' + str(epoch) + '.pth') # ############################# ## ONLY G-NET ############################ else: for epoch in range(resume_epoch, opt.niter): if epoch < 30: alpha1 = 0.01 alpha2 = 0.02 elif epoch < 80: alpha1 = 0.05 alpha2 = 0.1 else: alpha1 = 0.1 alpha2 = 0.2 for i, data in enumerate(dataloader, 0): real_point, target = data batch_size = real_point.size()[0] real_center = torch.FloatTensor(batch_size, 1, opt.crop_point_num, 3) input_cropped1 = torch.FloatTensor(batch_size, opt.pnum, 3) input_cropped1 = input_cropped1.data.copy_(real_point) real_point = torch.unsqueeze(real_point, 1) input_cropped1 = torch.unsqueeze(input_cropped1, 1) p_origin = [0, 0, 0] if opt.cropmethod == 'random_center': choice = [ torch.Tensor([1, 0, 0]), torch.Tensor([0, 0, 1]), torch.Tensor([1, 0, 1]), torch.Tensor([-1, 0, 0]), torch.Tensor([-1, 1, 0]) ] for m in range(batch_size): index = random.sample(choice, 1) distance_list = [] p_center = index[0] for n in range(opt.pnum): distance_list.append( distance_squre(real_point[m, 0, n], p_center)) distance_order = sorted(enumerate(distance_list), key=lambda x: x[1]) for sp in range(opt.crop_point_num): input_cropped1.data[ m, 0, distance_order[sp][0]] = torch.FloatTensor( [0, 0, 0]) real_center.data[m, 0, sp] = real_point[ m, 0, distance_order[sp][0]] real_point = real_point.to(device) real_center = real_center.to(device) input_cropped1 = input_cropped1.to(device) ############################ # (1) data prepare ########################### real_center = Variable(real_center, requires_grad=True) real_center = torch.squeeze(real_center, 1) real_center_key1_idx = utils.farthest_point_sample(real_center, 64, RAN=False) real_center_key1 = utils.index_points(real_center, real_center_key1_idx) real_center_key1 = Variable(real_center_key1, requires_grad=True) real_center_key2_idx = utils.farthest_point_sample(real_center, 128, RAN=True) real_center_key2 = utils.index_points(real_center, real_center_key2_idx) real_center_key2 = Variable(real_center_key2, requires_grad=True) input_cropped1 = torch.squeeze(input_cropped1, 1) input_cropped2_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[1], RAN=True) input_cropped2 = utils.index_points(input_cropped1, input_cropped2_idx) input_cropped3_idx = utils.farthest_point_sample( input_cropped1, opt.point_scales_list[2], RAN=False) input_cropped3 = utils.index_points(input_cropped1, input_cropped3_idx) input_cropped1 = Variable(input_cropped1, requires_grad=True) input_cropped2 = Variable(input_cropped2, requires_grad=True) input_cropped3 = Variable(input_cropped3, requires_grad=True) input_cropped2 = input_cropped2.to(device) input_cropped3 = input_cropped3.to(device) input_cropped = [ input_cropped1, input_cropped2, input_cropped3 ] point_netG = point_netG.train() point_netG.zero_grad() fake_center1, fake_center2, fake = point_netG(input_cropped) fake = torch.unsqueeze(fake, 1) ############################ # (3) Update G network: maximize log(D(G(z))) ########################### CD_LOSS = criterion_PointLoss(torch.squeeze(fake, 1), torch.squeeze(real_center, 1)) errG_l2 = criterion_PointLoss(torch.squeeze(fake,1),torch.squeeze(real_center,1))\ +alpha1*criterion_PointLoss(fake_center1,real_center_key1)\ +alpha2*criterion_PointLoss(fake_center2,real_center_key2) errG_l2.backward() optimizerG.step() print('[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' % (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS)) f = open('loss_PFNet.txt', 'a') f.write( '\n' + '[%d/%d][%d/%d] Loss_G: %.4f / %.4f ' % (epoch, opt.niter, i, len(dataloader), errG_l2, CD_LOSS)) f.close() schedulerD.step() schedulerG.step() if epoch % 10 == 0: torch.save( { 'epoch': epoch + 1, 'state_dict': point_netG.state_dict() }, 'Checkpoint/point_netG' + str(epoch) + '.pth')