def train(args): start_t = time.time() params = get_train_options() params["exp_name"] = args.exp_name params["patch_num_point"] = 1024 params["batch_size"] = args.batch_size params['use_gan'] = args.use_gan if args.debug: params["nepoch"] = 2 params["model_save_interval"] = 3 params['model_vis_interval'] = 3 log_dir = os.path.join(params["model_save_dir"], args.exp_name) if os.path.exists(log_dir) == False: os.makedirs(log_dir) tb_logger = Logger(log_dir) trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"], split_dir=params['train_split']) #print(params["dataset_dir"]) num_workers = 4 train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') G_model = Generator(params) G_model.apply(xavier_init) G_model = torch.nn.DataParallel(G_model).to(device) D_model = Discriminator(params, in_channels=3) D_model.apply(xavier_init) D_model = torch.nn.DataParallel(D_model).to(device) G_model.train() D_model.train() optimizer_D = Adam(D_model.parameters(), lr=params["lr_D"], betas=(0.9, 0.999)) optimizer_G = Adam(G_model.parameters(), lr=params["lr_G"], betas=(0.9, 0.999)) D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2) G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2) Loss_fn = Loss() print("preparation time is %fs" % (time.time() - start_t)) iter = 0 for e in range(params["nepoch"]): D_scheduler.step() G_scheduler.step() for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): optimizer_G.zero_grad() optimizer_D.zero_grad() input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() start_t_batch = time.time() output_point_cloud = G_model(input_data) repulsion_loss = Loss_fn.get_repulsion_loss( output_point_cloud.permute(0, 2, 1)) uniform_loss = Loss_fn.get_uniform_loss( output_point_cloud.permute(0, 2, 1)) #print(output_point_cloud.shape,gt_data.shape) emd_loss = Loss_fn.get_emd_loss( output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1)) if params['use_gan'] == True: fake_pred = D_model(output_point_cloud.detach()) d_loss_fake = Loss_fn.get_discriminator_loss_single( fake_pred, label=False) d_loss_fake.backward() optimizer_D.step() real_pred = D_model(gt_data.detach()) d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred, label=True) d_loss_real.backward() optimizer_D.step() d_loss = d_loss_real + d_loss_fake fake_pred = D_model(output_point_cloud) g_loss = Loss_fn.get_generator_loss(fake_pred) #print(repulsion_loss,uniform_loss,emd_loss) total_G_loss=params['uniform_w']*uniform_loss+params['emd_w']*emd_loss+ \ repulsion_loss*params['repulsion_w']+ g_loss*params['gan_w'] else: #total_G_loss = params['uniform_w'] * uniform_loss + params['emd_w'] * emd_loss + \ # repulsion_loss * params['repulsion_w'] total_G_loss=params['emd_w'] * emd_loss + \ repulsion_loss * params['repulsion_w'] #total_G_loss=emd_loss total_G_loss.backward() optimizer_G.step() current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr'] current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr'] tb_logger.scalar_summary('repulsion_loss', repulsion_loss.item(), iter) tb_logger.scalar_summary('uniform_loss', uniform_loss.item(), iter) tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) if params['use_gan'] == True: tb_logger.scalar_summary('d_loss', d_loss.item(), iter) tb_logger.scalar_summary('g_loss', g_loss.item(), iter) tb_logger.scalar_summary('lr_D', current_lr_D, iter) tb_logger.scalar_summary('lr_G', current_lr_G, iter) msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( str(datetime.timedelta(seconds=round(time.time() - start_t))), "epoch", e, batch_id + 1, len(train_data_loader), "total_G_loss", total_G_loss.item(), "iter time", (time.time() - start_t_batch)) print(msg) iter += 1 if (e + 1) % params['model_save_interval'] == 0 and e > 0: model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) if os.path.exists(model_save_dir) == False: os.makedirs(model_save_dir) D_ckpt_model_filename = "D_iter_%d.pth" % (e) G_ckpt_model_filename = "G_iter_%d.pth" % (e) D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) torch.save(D_model.module.state_dict(), D_model_save_path) torch.save(G_model.module.state_dict(), G_model_save_path)
def train(args): start_t = time.time() params = get_train_options() params["exp_name"] = args.exp_name params["patch_num_point"] = 256 params["batch_size"] = args.batch_size if args.debug: params["nepoch"] = 2 params["model_save_interval"] = 3 params['model_vis_interval'] = 3 log_dir = os.path.join(params["model_save_dir"], args.exp_name) if os.path.exists(log_dir) == False: os.makedirs(log_dir) tb_logger = Logger(log_dir) #trainloader=PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split']) trainloader = PUGAN_Dataset(h5_file_path=params["dataset_dir"], npoint=256) num_workers = 4 train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ########################################## # Initialize generator and discriminator # ########################################## G_AB = Generator(params) G_AB.apply(xavier_init) G_AB = torch.nn.DataParallel(G_AB).to(device) G_BA = Downsampler(params) G_BA.apply(xavier_init) G_BA = torch.nn.DataParallel(G_BA).to(device) D_A = Discriminator(params, in_channels=3) D_A.apply(xavier_init) D_A = torch.nn.DataParallel(D_A).to(device) D_B = Discriminator(params, in_channels=3) D_B.apply(xavier_init) D_B = torch.nn.DataParallel(D_B).to(device) ######################################## #Optimizers and Learning Rate scheduler# ######################################## optimizer_D_A = Adam(D_A.parameters(), lr=params["lr_D_A"], betas=(0.9, 0.999)) optimizer_D_B = Adam(D_B.parameters(), lr=params["lr_D_B"], betas=(0.9, 0.999)) optimizer_G_AB = Adam(G_AB.parameters(), lr=params["lr_G_AB"], betas=(0.9, 0.999)) optimizer_G_BA = Adam(G_BA.parameters(), lr=params["lr_G_BA"], betas=(0.9, 0.999)) D_A_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2) G_AB_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2) D_B_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2) G_BA_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2) Loss_fn = Loss() print("preparation time is %fs" % (time.time() - start_t)) iter = 0 for e in range(params["nepoch"]): for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): G_AB.train() G_BA.train() D_A.train() D_B.train() optimizer_G_AB.zero_grad() optimizer_D_A.zero_grad() optimizer_G_BA.zero_grad() optimizer_D_B.zero_grad() input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() start_t_batch = time.time() output_point_cloud_high = G_AB(input_data) output_point_cloud_low = G_BA(gt_data) ##################################### # Loss # ##################################### repulsion_loss_AB = Loss_fn.get_repulsion_loss( output_point_cloud_high.permute(0, 2, 1)) uniform_loss_AB = Loss_fn.get_uniform_loss( output_point_cloud_high.permute(0, 2, 1)) repulsion_loss_BA = Loss_fn.get_repulsion_loss( output_point_cloud_low.permute(0, 2, 1)) uniform_loss_BA = Loss_fn.get_uniform_loss( output_point_cloud_low.permute(0, 2, 1)) emd_loss_AB = Loss_fn.get_emd_loss( output_point_cloud_high.permute(0, 2, 1), gt_data.permute(0, 2, 1)) #emd_loss_BA = Loss_fn.get_emd_loss(output_point_cloud_low.permute(0, 2, 1), input_data.permute(0, 2, 1)) #Cycle Loss recov_A = G_BA(output_point_cloud_high) ABA_repul_loss = Loss_fn.get_repulsion_loss( recov_A.permute(0, 2, 1)) ABA_uniform_loss = Loss_fn.get_uniform_loss( recov_A.permute(0, 2, 1)) recov_B = G_AB(output_point_cloud_low) BAB_repul_loss = Loss_fn.get_repulsion_loss( recov_B.permute(0, 2, 1)) BAB_uniform_loss = Loss_fn.get_uniform_loss( recov_B.permute(0, 2, 1)) BAB_emd_loss = Loss_fn.get_emd_loss(recov_B.permute(0, 2, 1), gt_data.permute(0, 2, 1)) #G_AB loss fake_pred_B = D_A(output_point_cloud_high.detach()) g_AB_loss = Loss_fn.get_generator_loss(fake_pred_B) total_G_AB_loss=g_AB_loss*params['gan_w_AB']+ BAB_repul_loss*params['repulsion_w_AB']+ \ BAB_uniform_loss*params['uniform_w_AB']+ BAB_emd_loss*params['emd_w_AB']+ \ params['uniform_w_AB']*uniform_loss_AB+params['emd_w_AB']*emd_loss_AB+ \ repulsion_loss_AB*params['repulsion_w_AB'] total_G_AB_loss.backward() optimizer_G_AB.step() #G_BA loss fake_pred_A = D_B(output_point_cloud_low.detach()) g_BA_loss = Loss_fn.get_generator_loss(fake_pred_A) total_G_BA_loss=g_BA_loss*params['gan_w_BA']+ ABA_repul_loss*params['repulsion_w_BA']+ \ repulsion_loss_BA*params['repulsion_w_BA'] # ABA_uniform_loss*params['uniform_w_BA']+ \ # params['uniform_w_BA']*uniform_loss_BA+ \ total_G_BA_loss.backward() optimizer_G_BA.step() #Discriminator A loss fake_B_ = fake_A_buffer.push_and_pop(output_point_cloud_high) fake_pred_B = D_A(fake_B_.detach()) d_A_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_B, label=False) real_pred_B = D_A(gt_data.detach()) d_A_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_B, label=True) d_A_loss = d_A_loss_real + d_A_loss_fake d_A_loss.backward() optimizer_D_A.step() #Discriminator B loss fake_A_ = fake_B_buffer.push_and_pop(output_point_cloud_low) fake_pred_A = D_B(fake_A_.detach()) d_B_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_A, label=False) real_pred_A = D_B(input_data.detach()) d_B_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_A, label=True) d_B_loss = d_B_loss_real + d_B_loss_fake d_B_loss.backward() optimizer_D_B.step() #Learning rate scheduler# current_lr_D_A = optimizer_D_A.state_dict( )['param_groups'][0]['lr'] current_lr_G_AB = optimizer_G_AB.state_dict( )['param_groups'][0]['lr'] current_lr_D_B = optimizer_D_B.state_dict( )['param_groups'][0]['lr'] current_lr_G_BA = optimizer_G_BA.state_dict( )['param_groups'][0]['lr'] # tb_logger.scalar_summary('repulsion_loss_AB', repulsion_loss_AB.item(), iter) # tb_logger.scalar_summary('uniform_loss_AB', uniform_loss_AB.item(), iter) # tb_logger.scalar_summary('repulsion_loss_BA', repulsion_loss_BA.item(), iter) # tb_logger.scalar_summary('uniform_loss_BA', uniform_loss_BA.item(), iter) # tb_logger.scalar_summary('emd_loss_AB', emd_loss_AB.item(), iter) tb_logger.scalar_summary('d_A_loss', d_A_loss.item(), iter) tb_logger.scalar_summary('g_AB_loss', g_AB_loss.item(), iter) tb_logger.scalar_summary('Total_G_AB_loss', total_G_AB_loss.item(), iter) tb_logger.scalar_summary('lr_D_A', current_lr_D_A, iter) tb_logger.scalar_summary('lr_G_AB', current_lr_G_AB, iter) tb_logger.scalar_summary('d_B_loss', d_B_loss.item(), iter) tb_logger.scalar_summary('g_BA_loss', g_BA_loss.item(), iter) tb_logger.scalar_summary('Total_G_BA_loss', total_G_BA_loss.item(), iter) tb_logger.scalar_summary('lr_D_B', current_lr_D_B, iter) tb_logger.scalar_summary('lr_G_BA', current_lr_G_BA, iter) msg = "{:0>8},{}:{}, [{}/{}], {}: {}, {}: {}, {}:{}, {}: {},{}: {}".format( str(datetime.timedelta(seconds=round(time.time() - start_t))), "epoch", e + 1, batch_id + 1, len(train_data_loader), "total_G_AB_loss", total_G_AB_loss.item(), "total_G_BA_loss", total_G_BA_loss.item(), "iter time", (time.time() - start_t_batch), "d_A_loss", d_A_loss.item(), "d_B_loss", d_B_loss.item()) print(msg) iter += 1 D_A_scheduler.step() G_AB_scheduler.step() D_B_scheduler.step() G_BA_scheduler.step() if (e + 1) % params['model_save_interval'] == 0 and e > 0: model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) if os.path.exists(model_save_dir) == False: os.makedirs(model_save_dir) D_A_ckpt_model_filename = "D_A_iter_%d.pth" % (e + 1) G_AB_ckpt_model_filename = "G_AB_iter_%d.pth" % (e + 1) D_A_model_save_path = os.path.join(model_save_dir, D_A_ckpt_model_filename) G_AB_model_save_path = os.path.join(model_save_dir, G_AB_ckpt_model_filename) D_B_ckpt_model_filename = "D_B_iter_%d.pth" % (e + 1) G_BA_ckpt_model_filename = "G_BA_iter_%d.pth" % (e + 1) model_ckpt_model_filename = "Cyclegan_iter_%d.pth" % (e + 1) D_B_model_save_path = os.path.join(model_save_dir, D_B_ckpt_model_filename) G_BA_model_save_path = os.path.join(model_save_dir, G_BA_ckpt_model_filename) model_all_path = os.path.join(model_save_dir, model_ckpt_model_filename) torch.save( { 'G_AB_state_dict': G_AB.module.state_dict(), 'G_BA_state_dict': G_BA.module.state_dict(), 'D_A_state_dict': D_A.module.state_dict(), 'D_B_state_dict': D_B.module.state_dict(), 'optimizer_G_AB_state_dict': optimizer_G_AB.state_dict(), 'optimizer_G_BA_state_dict': optimizer_G_BA.state_dict(), 'optimizer_D_A_state_dict': optimizer_D_A.state_dict(), 'optimizer_D_B_state_dict': optimizer_D_B.state_dict() }, model_all_path) torch.save(D_A.module.state_dict(), D_A_model_save_path) torch.save(G_AB.module.state_dict(), G_AB_model_save_path) torch.save(D_B.module.state_dict(), D_B_model_save_path) torch.save(G_BA.module.state_dict(), G_BA_model_save_path)