def __init__(self, input_nc=3, output_nc=3, gpu_id=None): self.device = torch.device( f"cuda:{gpu_id}" if gpu_id is not None else 'cpu') print(f"Using device {self.device}") # Hyperparameters self.lambda_idt = 0.5 self.lambda_A = 10.0 self.lambda_B = 10.0 # Define generator networks self.netG_A = networks.define_netG(input_nc, output_nc, ngf=64, n_blocks=9, device=self.device) self.netG_B = networks.define_netG(output_nc, input_nc, ngf=64, n_blocks=9, device=self.device) # Define discriminator networks self.netD_A = networks.define_netD(output_nc, ndf=64, n_layers=3, device=self.device) self.netD_B = networks.define_netD(input_nc, ndf=64, n_layers=3, device=self.device) # Define image pools self.fake_A_pool = utils.ImagePool(pool_size=50) self.fake_B_pool = utils.ImagePool(pool_size=50) # Define loss functions self.criterionGAN = networks.GANLoss().to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # Define optimizers netG_params = itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()) netD_params = itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()) self.optimizer_G = torch.optim.Adam(netG_params, lr=0.0002, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(netD_params, lr=0.0002, betas=(0.5, 0.999)) # Learning rate schedulers self.scheduler_G = utils.get_lr_scheduler(self.optimizer_G) self.scheduler_D = utils.get_lr_scheduler(self.optimizer_D)
# Test data # test_data = DatasetFromFolder(data_dir, subfolder='img_test', transform=transform, resize_scale=params.input_size, # crop_size=params.crop_size, fliplr=params.fliplr, yuv=True) test_data = DatasetFromFolder2(data_dir, subfolder='lfw_test.txt', transform=transform, resize_scale=params.input_size) test_data_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=False) # image pool num_pool = 50 # cover_pool = utils.ImagePool(num_pool) stego_pool = utils.ImagePool(num_pool) secret_pool = utils.ImagePool(num_pool) # optimizers encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=params.lr, betas=(params.beta1, params.beta2), weight_decay=params.weight_decay) decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=params.lr, betas=(params.beta1, params.beta2), weight_decay=params.weight_decay) # encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=params.lr, weight_decay=1e-8) # decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=params.lr, weight_decay=1e-8) discriminator_optimizer = torch.optim.SGD(discriminator.parameters(), lr=params.lr / 3,
def __init__(self, args, dataloaders): self.dataloaders = dataloaders self.net_D1 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=2).to(device) self.net_D2 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=2).to(device) self.net_D3 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=3).to(device) self.net_G = cycnet.define_G(input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='none').to(device) # M.Amintoosi norm='instance' # self.net_G = cycnet.define_G( # input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='instance').to(device) # Learning rate and Beta1 for Adam optimizers self.lr = args.lr # define optimizers self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D1 = optim.Adam(self.net_D1.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D2 = optim.Adam(self.net_D2.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D3 = optim.Adam(self.net_D3.parameters(), lr=self.lr, betas=(0.5, 0.999)) # define lr schedulers self.exp_lr_scheduler_G = lr_scheduler.StepLR( self.optimizer_G, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D1 = lr_scheduler.StepLR( self.optimizer_D1, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D2 = lr_scheduler.StepLR( self.optimizer_D2, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D3 = lr_scheduler.StepLR( self.optimizer_D3, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) # coefficient to balance loss functions self.lambda_L1 = args.lambda_L1 self.lambda_adv = args.lambda_adv # based on which metric to update the "best" ckpt self.metric = args.metric # define some other vars to record the training states self.running_acc = [] self.epoch_acc = 0 if 'mse' in self.metric: self.best_val_acc = 1e9 # for mse, rmse, a lower score is better else: self.best_val_acc = 0.0 # for others (ssim, psnr), a higher score is better self.best_epoch_id = 0 self.epoch_to_start = 0 self.max_num_epochs = args.max_num_epochs self.G_pred1 = None self.G_pred2 = None self.batch = None self.G_loss = None self.D_loss = None self.is_training = False self.batch_id = 0 self.epoch_id = 0 self.checkpoint_dir = args.checkpoint_dir self.vis_dir = args.vis_dir self.D1_fake_pool = utils.ImagePool(pool_size=50) self.D2_fake_pool = utils.ImagePool(pool_size=50) self.D3_fake_pool = utils.ImagePool(pool_size=50) # define the loss functions if args.pixel_loss == 'minimum_pixel_loss': self._pxl_loss = loss.MinimumPixelLoss( opt=1) # 1 for L1 and 2 for L2 elif args.pixel_loss == 'pixel_loss': self._pxl_loss = loss.PixelLoss(opt=1) # 1 for L1 and 2 for L2 else: raise NotImplementedError( 'pixel loss function [%s] is not implemented', args.pixel_loss) self._gan_loss = loss.GANLoss(gan_mode='vanilla').to(device) self._exclusion_loss = loss.ExclusionLoss() self._kurtosis_loss = loss.KurtosisLoss() # enable some losses? self.with_d1d2 = args.enable_d1d2 self.with_d3 = args.enable_d3 self.with_exclusion_loss = args.enable_exclusion_loss self.with_kurtosis_loss = args.enable_kurtosis_loss # m-th epoch to activate adversarial training self.m_epoch_activate_adv = int(self.max_num_epochs / 20) + 1 # output auto-enhancement? self.output_auto_enhance = args.output_auto_enhance # use synfake to train D? self.synfake = args.enable_synfake # check and create model dir if os.path.exists(self.checkpoint_dir) is False: os.mkdir(self.checkpoint_dir) if os.path.exists(self.vis_dir) is False: os.mkdir(self.vis_dir) # visualize model if args.print_models: self._visualize_models()
def train(network_gen: nn.Module, network_dis: nn.Module, dataloader, checkpoint_path, weight_gan=0.01, weight_l1=1.0, weight_fm=10.0, weight_vgg=10.0, device=torch.device('cuda:0'), n_critic=3, n_gen=1, fake_pool_size=256, lr_gen=1e-4, lr_dis=5e-4, updates_per_epoch=10000, record_freq=1000, total_updates=1000000, gradient_accumulate=4, enable_fp16=False, resume=False): if enable_fp16: print(' -- FP16 AMP enabled') print(' -- Initializing losses') loss_gan = ganloss.GANLossSoftLS(device) loss_vgg = vgg_loss.VGG19LossWithStyle().to(device) opt_gen = optim.Adam(network_gen.parameters(), lr=lr_gen, betas=(0.5, 0.99), weight_decay=1e-6) opt_dis = optim.Adam(network_dis.parameters(), lr=lr_dis, betas=(0.5, 0.99), weight_decay=1e-6) sch_gen = optim.lr_scheduler.ReduceLROnPlateau(opt_gen, 'min', factor=0.5, patience=4, verbose=True, min_lr=1e-6) sch_dis = optim.lr_scheduler.ReduceLROnPlateau(opt_dis, 'min', factor=0.5, patience=4, verbose=True, min_lr=1e-6) sch_meter = utils.AvgMeter() scaler_gen = amp.GradScaler(enabled=enable_fp16) scaler_dis = amp.GradScaler(enabled=enable_fp16) loss_dis_real_meter = utils.AvgMeter() loss_dis_fake_meter = utils.AvgMeter() loss_dis_meter = utils.AvgMeter() loss_gen_l1_meter = utils.AvgMeter() loss_gen_l1_coarse_meter = utils.AvgMeter() loss_gen_vgg_meter = utils.AvgMeter() loss_gen_vgg_coarse_meter = utils.AvgMeter() loss_gen_gan_meter = utils.AvgMeter() loss_gen_fm_meter = utils.AvgMeter() loss_gen_meter = utils.AvgMeter() writer = SummaryWriter(os.path.join(checkpoint_path, 'tb_summary')) os.makedirs(os.path.join(checkpoint_path, 'checkpoints'), exist_ok=True) fakepool = utils.ImagePool(fake_pool_size, device) counter_start = 0 if resume: chekcpoints = os.listdir(os.path.join(checkpoint_path, 'checkpoints')) last_chekcpoints = sorted( chekcpoints, key=lambda item: (len(item), item) )[-1] if 'latest.ckpt' not in chekcpoints else 'latest.ckpt' print(f' -- Loading checkpoint {last_chekcpoints}') ckpt = torch.load( os.path.join(checkpoint_path, 'checkpoints', last_chekcpoints)) network_gen.load_state_dict(ckpt['gen']) network_dis.load_state_dict(ckpt['dis']) opt_gen.load_state_dict(ckpt['gen_opt']) opt_dis.load_state_dict(ckpt['dis_opt']) counter_start = ckpt['counter'] + 1 print(f' -- Resume training from update {counter_start}') else: print(f' -- Start training from scratch') dataloader = iter(dataloader) print(' -- Training start') try: for counter in tqdm(range(counter_start, total_updates)): dataloader_meter = utils.AvgMeter() # train discrimiantor for critic in range(n_critic): opt_dis.zero_grad() for _ in range(gradient_accumulate): start_time = time.time() real_img, mask = next(dataloader) end_time = time.time() dataloader_meter(end_time - start_time) real_img, mask = real_img.to(device), mask.to(device) real_img_masked = mask_image(real_img, mask) if np.random.randint(0, 2) == 0 or not fakepool.available(): with torch.no_grad(), amp.autocast( enabled=enable_fp16): fake_img = network_gen(real_img_masked, mask) fakepool.put(fake_img) else: fake_img = fakepool.sample() with amp.autocast(enabled=enable_fp16): real_logits, _ = network_dis(real_img) fake_logits, _ = network_dis(fake_img) loss_dis_real = loss_gan(real_logits, 'real', None) mask_inv = 1 - F.interpolate( mask, size=(real_logits.shape[2], real_logits.shape[3]), mode='bicubic', align_corners=False) loss_dis_fake = loss_gan(fake_logits, 'fake', mask_inv) loss_dis = 0.5 * (loss_dis_real + loss_dis_fake) if torch.isnan(loss_dis) or torch.isinf(loss_dis): raise Exception scaler_dis.scale(loss_dis / float(gradient_accumulate)).backward() loss_dis_real_meter(loss_dis_real.item()) loss_dis_fake_meter(loss_dis_fake.item()) loss_dis_meter(loss_dis.item()) scaler_dis.unscale_(opt_dis) #torch.nn.utils.clip_grad_norm_(network_dis.parameters(), 0.1) scaler_dis.step(opt_dis) scaler_dis.update() # train generator for gen in range(n_gen): opt_gen.zero_grad() for _ in range(gradient_accumulate): start_time = time.time() real_img, mask = next(dataloader) end_time = time.time() dataloader_meter(end_time - start_time) real_img, mask = real_img.to(device), mask.to(device) real_img_masked = mask_image(real_img, mask) with amp.autocast(enabled=enable_fp16): inpainted_result = network_gen(real_img_masked, mask) #inpainted_result_coarse, inpainted_result = network_gen(real_img_masked, mask) loss_gen_l1 = F.l1_loss(inpainted_result, real_img) loss_vgg_combined = loss_vgg(inpainted_result, real_img) generator_logits, dis_features_fake = network_dis( inpainted_result) with torch.no_grad(): _, dis_features_real = network_dis(real_img) loss_fm = 0 for (fm_fake, fm_real) in zip(dis_features_fake, dis_features_real): loss_fm += F.mse_loss(fm_fake, fm_real) loss_fm = loss_fm / len(dis_features_real) loss_gen_gan = loss_gan(generator_logits, 'generator', None) loss_gen = weight_l1 * ( loss_gen_l1) + weight_fm * loss_fm + weight_vgg * ( loss_vgg_combined) + weight_gan * loss_gen_gan if torch.isnan(loss_gen) or torch.isinf(loss_dis): raise Exception scaler_gen.scale(loss_gen / float(gradient_accumulate)).backward() loss_gen_meter(loss_gen.item()) loss_gen_l1_meter(loss_gen_l1.item()) sch_meter(loss_gen_l1.item() ) # use L1 loss as lr scheduler metric loss_gen_vgg_meter(loss_vgg_combined.item()) loss_gen_gan_meter(loss_gen_gan.item()) loss_gen_fm_meter(loss_fm.item()) scaler_gen.unscale_(opt_gen) #torch.nn.utils.clip_grad_norm_(network_gen.parameters(), 0.1) scaler_gen.step(opt_gen) scaler_gen.update() if counter % record_freq == 0: tqdm.write(f' -- Record at update {counter}') writer.add_scalar('discriminator/all', loss_dis_meter(reset=True), counter) writer.add_scalar('discriminator/real', loss_dis_real_meter(reset=True), counter) writer.add_scalar('discriminator/fake', loss_dis_fake_meter(reset=True), counter) writer.add_scalar('generator/all', loss_gen_meter(reset=True), counter) writer.add_scalar('generator/l1', loss_gen_l1_meter(reset=True), counter) writer.add_scalar('generator/vgg', loss_gen_vgg_meter(reset=True), counter) writer.add_scalar('generator/gan', loss_gen_gan_meter(reset=True), counter) writer.add_scalar('generator/fm', loss_gen_fm_meter(reset=True), counter) writer.add_image('original/image', img_unscale(real_img), counter, dataformats='NCHW') writer.add_image('original/mask', mask, counter, dataformats='NCHW') writer.add_image('original/masked', img_unscale(real_img_masked), counter, dataformats='NCHW') writer.add_image('inpainted/refined', img_unscale(inpainted_result), counter, dataformats='NCHW') torch.save( { 'dis': network_dis.state_dict(), 'gen': network_gen.state_dict(), 'dis_opt': opt_dis.state_dict(), 'gen_opt': opt_gen.state_dict(), 'counter': counter }, os.path.join(checkpoint_path, 'checkpoints', f'update_{counter}.ckpt')) if counter > 0 and counter % updates_per_epoch == 0: tqdm.write(f' -- Epoch finished at update {counter}') # epoch finished loss_epoch = sch_meter(reset=True) sch_gen.step(loss_epoch) sch_dis.step(loss_epoch) #tqdm.write(f'Dataloader overhead avg {int(dataloader_meter(reset = True) * 1000)}ms') except KeyboardInterrupt: print(' -- Training interrupted, saving latest model ..') torch.save( { 'dis': network_dis.state_dict(), 'gen': network_gen.state_dict(), 'dis_opt': opt_dis.state_dict(), 'gen_opt': opt_gen.state_dict(), 'counter': counter }, os.path.join(checkpoint_path, 'checkpoints', f'latest.ckpt'))
def _build_net(self): # tfph: TensorFlow PlaceHolder self.x_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='x_test_tfph') self.y_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='y_test_tfph') self.fake_x_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='fake_x_tfph') self.fake_y_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='fake_y_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._G_gen_train_ops) self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops, use_sigmoid=self.use_sigmoid) self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._F_gen_train_ops) self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops, use_sigmoid=self.use_sigmoid) data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size, is_train=self.flags.is_train) # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed( ) self.fake_x_pool_obj = utils.ImagePool(pool_size=50) self.fake_y_pool_obj = utils.ImagePool(pool_size=50) # cycle consistency loss cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs) # X -> Y self.fake_y_imgs = self.G_gen(self.x_imgs) self.G_gen_loss = self.generator_loss(self.Dy_dis, self.fake_y_imgs, use_lsgan=self.use_lsgan) self.G_loss = self.G_gen_loss + cycle_loss self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis, self.y_imgs, self.fake_y_tfph, use_lsgan=self.use_lsgan) # Y -> X self.fake_x_imgs = self.F_gen(self.y_imgs) self.F_gen_loss = self.generator_loss(self.Dx_dis, self.fake_x_imgs, use_lsgan=self.use_lsgan) self.F_loss = self.F_gen_loss + cycle_loss self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis, self.x_imgs, self.fake_x_tfph, use_lsgan=self.use_lsgan) G_optim = self.optimizer(loss=self.G_loss, variables=self.G_gen.variables, name='Adam_G') Dy_optim = self.optimizer(loss=self.Dy_dis_loss, variables=self.Dy_dis.variables, name='Adam_Dy') F_optim = self.optimizer(loss=self.F_loss, variables=self.F_gen.variables, name='Adam_F') Dx_optim = self.optimizer(loss=self.Dx_dis_loss, variables=self.Dx_dis.variables, name='Adam_Dx') self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim]) # with tf.control_dependencies([G_optim, Dy_optim, F_optim, Dx_optim]): # self.optims = tf.no_op(name='optimizers') # for sampling function self.fake_y_sample = self.G_gen(self.x_test_tfph) self.fake_x_sample = self.F_gen(self.y_test_tfph)
betas=(params.beta1, params.beta2)) D_B_optimizer = torch.optim.Adam(D_B.parameters(), lr=params.lrD, betas=(params.beta1, params.beta2)) # Training GAN D_A_avg_losses = [] D_B_avg_losses = [] G_A_avg_losses = [] G_B_avg_losses = [] cycle_A_avg_losses = [] cycle_B_avg_losses = [] # Generated image pool num_pool = 50 fake_A_pool = utils.ImagePool(num_pool) fake_B_pool = utils.ImagePool(num_pool) step = 0 for epoch in range(params.num_epochs): D_A_losses = [] D_B_losses = [] G_A_losses = [] G_B_losses = [] cycle_A_losses = [] cycle_B_losses = [] # learning rate decay if (epoch + 1) > params.decay_epoch: D_A_optimizer.param_groups[0]['lr'] -= params.lrD / ( params.num_epochs - params.decay_epoch)
action='store_true', help='use pre-trained model') source_prediction_max_result = [] target_prediction_max_result = [] best_prec_result = torch.tensor(0, dtype=torch.float32) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cuda = True if torch.cuda.is_available() else False FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor fake_A_buffer = utils.ImagePool(max_size=args.max_buffer) fake_B_buffer = utils.ImagePool(max_size=args.max_buffer) def main(): global args, best_prec_result utils.default_model_dir = args.dir start_time = time.time() Source_train_loader, Source_test_loader = dataset_selector(args.sd) Target_train_loader, Target_test_loader = dataset_selector(args.td) state_info = utils.model_optim_state_info() state_info.model_init() state_info.model_cuda_init()
def train(self,args): self.dataset_A = loadPickleFile("cache_check/coded_sps_A_norm.pickle") self.dataset_B = loadPickleFile("cache_check/coded_sps_B_norm.pickle") n_samples = len(self.dataset_A) dataset = trainingDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=128) train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=4) a_fake_sample = utils.ImagePool(50) b_fake_sample = utils.ImagePool(50) for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) for i, (a_real, b_real) in enumerate(train_loader): # step a_real = a_real.float() b_real = b_real.float() step = epoch *len(train_loader) + i + 1 print(step) # Generator Computations ################################################## set_grad([self.Da, self.Db], False) self.g_optimizer.zero_grad() # Forward pass through generators ################################################## a_fake = self.g_AB(a_real.cuda()) b_fake = self.g_BA(b_real.cuda()) a_recon = self.g_AB(b_fake) b_recon = self.g_BA(a_fake) a_idt = self.g_AB(a_real.cuda()) b_idt = self.g_BA(b_real.cuda()) # Identity losses ################################################### a_idt_loss = self.L1(a_idt, a_real.cuda()) * args.lamda * args.idt_coef b_idt_loss = self.L1(b_idt, b_real.cuda()) * args.lamda * args.idt_coef if self.loss_type=='lsgan': # Adversarial losses ################################################### a_fake_dis = self.Da(a_fake) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size()))) a_gen_loss = self.MSE(a_fake_dis, real_label) b_gen_loss = self.MSE(b_fake_dis, real_label) elif self.loss_type=='wgan': # Wasserstein-GAN loss # G_A(A) a_gen_loss = self.criterionGAN(b_fake, generator_loss=True) # G_B(B) b_gen_loss = self.criterionGAN(a_fake, generator_loss=True) # Cycle consistency losses ################################################### a_cycle_loss = self.L1(a_recon, a_real.cuda()) * args.lamda b_cycle_loss = self.L1(b_recon, b_real.cuda()) * args.lamda # Total generators losses ################################################### gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss # Update generators ################################################### gen_loss.backward(retain_graph=True) self.g_optimizer.step() # Discriminator Computations ################################################# set_grad([self.Da, self.Db], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# a_fake = a_fake_sample.query(a_fake) b_fake = b_fake_sample.query(b_fake) a_fake, b_fake = utils.cuda([a_fake, b_fake]) print ("A_R_Size",a_fake.size()) print ("B_R_Size",b_fake.size()) if self.loss_type=='lsgan': # Forward pass through discriminators ################################################# a_real_dis = self.Da(a_real.cuda()) a_fake_dis = self.Da(a_fake) b_real_dis = self.Db(b_real.cuda()) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones(a_real_dis.size()))) fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size()))) # Discriminator losses ################################################## a_dis_real_loss = self.MSE(a_real_dis, real_label) a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) b_dis_real_loss = self.MSE(b_real_dis, real_label) b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) # Total discriminators losses a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5 elif self.loss_type=='wgan': for i_critic in range(self.wgan_n_critic): # Clip the parameters for k-Lipschitz continuity for p in self.Da.parameters(): p.data.clamp_(self.wgan_clamp_lower, self.wgan_clamp_upper) for p in self.Db.parameters(): p.data.clamp_(self.wgan_clamp_lower, self.wgan_clamp_upper) #D_A a_dis_loss = self.backward_D_wasserstein(self.Da, a_real.cuda(), a_fake) # D_B b_dis_loss = self.backward_D_wasserstein(self.Db, b_real.cuda(), b_fake) # Update discriminators ################################################## a_dis_loss.backward(retain_graph=True) b_dis_loss.backward(retain_graph=True) self.d_optimizer.step() writer.add_scalar('DisA loss', a_dis_loss, epoch * len(train_loader) + i) writer.add_scalar('DisB loss', b_dis_loss, epoch * len(train_loader) + i) writer.add_scalar('Generator loss', gen_loss / 1000, epoch * len(train_loader) + i) print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %(epoch, i + 1, len(train_loader), gen_loss,a_dis_loss+b_dis_loss)) # Override the latest checkpoint ####################################################### utils.save_checkpoint({'epoch': epoch + 1, 'Da': self.Da.state_dict(), 'Db': self.Db.state_dict(), 'Gab': self.g_AB.state_dict(), 'Gba': self.g_BA.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict()}, '%s/w_gan_2.ckpt' % (args.checkpoint_dir)) # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step()
def main(): # Get training options opt = get_opt() # Define the networks # netG_A: used to transfer image from domain A to domain B # netG_B: used to transfer image from domain B to domain A netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf) # netD_A: used to test whether an image is from domain B # netD_B: used to test whether an image is from domain A netD_A = networks.Discriminator(opt.input_nc, opt.ndf) netD_B = networks.Discriminator(opt.output_nc, opt.ndf) # Initialize the networks if opt.cuda: netG_A.cuda() netG_B.cuda() netD_A.cuda() netD_B.cuda() utils.init_weight(netG_A) utils.init_weight(netG_B) utils.init_weight(netD_A) utils.init_weight(netD_B) if opt.pretrained: netG_A.load_state_dict(torch.load('pretrained/netG_A.pth')) netG_B.load_state_dict(torch.load('pretrained/netG_B.pth')) netD_A.load_state_dict(torch.load('pretrained/netD_A.pth')) netD_B.load_state_dict(torch.load('pretrained/netD_B.pth')) # Define the loss functions criterion_GAN = utils.GANLoss() if opt.cuda: criterion_GAN.cuda() criterion_cycle = torch.nn.L1Loss() # Alternatively, can try MSE cycle consistency loss #criterion_cycle = torch.nn.MSELoss() criterion_identity = torch.nn.L1Loss() # Define the optimizers optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(), netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # Create learning rate schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew) input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew) # Define two image pools to store generated images fake_A_pool = utils.ImagePool() fake_B_pool = utils.ImagePool() # Define the transform, and load the data transform = transforms.Compose([ transforms.Resize((opt.sizeh, opt.sizew)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) dataloader = DataLoader(ImageDataset(opt.rootdir, transform=transform, mode='train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # numpy arrays to store the loss of epoch loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) # Training for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay): start = time.strftime("%H:%M:%S") print("current epoch :", epoch, " start time :", start) # Empty list to store the loss of each mini-batch loss_G_list = [] loss_D_A_list = [] loss_D_B_list = [] for i, batch in enumerate(dataloader): if i % 50 == 1: print("current step: ", i) current = time.strftime("%H:%M:%S") print("current time :", current) print("last loss G:", loss_G_list[-1], "last loss D_A", loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1]) real_A = input_A.copy_(batch['A']) real_B = input_B.copy_(batch['B']) # Train the generator optimizer_G.zero_grad() # Compute fake images and reconstructed images fake_B = netG_A(real_A) fake_A = netG_B(real_B) if opt.identity_loss != 0: same_B = netG_A(real_B) same_A = netG_B(real_A) # discriminators require no gradients when optimizing generators utils.set_requires_grad([netD_A, netD_B], False) # Identity loss if opt.identity_loss != 0: loss_identity_A = criterion_identity( same_A, real_A) * opt.identity_loss loss_identity_B = criterion_identity( same_B, real_B) * opt.identity_loss # GAN loss prediction_fake_B = netD_B(fake_B) loss_gan_B = criterion_GAN(prediction_fake_B, True) prediction_fake_A = netD_A(fake_A) loss_gan_A = criterion_GAN(prediction_fake_A, True) # Cycle consistent loss recA = netG_B(fake_B) recB = netG_A(fake_A) loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss # total loss without the identity loss loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B if opt.identity_loss != 0: loss_G += loss_identity_A + loss_identity_B loss_G_list.append(loss_G.item()) loss_G.backward() optimizer_G.step() # Train the discriminator utils.set_requires_grad([netD_A, netD_B], True) # Train the discriminator D_A optimizer_D_A.zero_grad() # real images pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_A = fake_A_pool.query(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, False) #total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A_list.append(loss_D_A.item()) loss_D_A.backward() optimizer_D_A.step() # Train the discriminator D_B optimizer_D_B.zero_grad() # real images pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_B = fake_B_pool.query(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, False) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B_list.append(loss_D_B.item()) loss_D_B.backward() optimizer_D_B.step() # Update the learning rate lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A.state_dict(), 'model/netG_A.pth') torch.save(netG_B.state_dict(), 'model/netG_B.pth') torch.save(netD_A.state_dict(), 'model/netD_A.pth') torch.save(netD_B.state_dict(), 'model/netD_B.pth') # Save other checkpoint information checkpoint = { 'epoch': epoch, 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(), 'lr_scheduler_D_B': lr_scheduler_D_B.state_dict() } torch.save(checkpoint, 'model/checkpoint.pth') # Update the numpy arrays that record the loss loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list) loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list) loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list) np.savetxt('model/loss_G.txt', loss_G_array) np.savetxt('model/loss_D_A.txt', loss_D_A_array) np.savetxt('model/loss_D_b.txt', loss_D_B_array) if epoch % 10 == 9: torch.save(netG_A.state_dict(), 'model/netG_A' + str(epoch) + '.pth') torch.save(netG_B.state_dict(), 'model/netG_B' + str(epoch) + '.pth') torch.save(netD_A.state_dict(), 'model/netD_A' + str(epoch) + '.pth') torch.save(netD_B.state_dict(), 'model/netD_B' + str(epoch) + '.pth') end = time.strftime("%H:%M:%S") print("current epoch :", epoch, " end time :", end) print("G loss :", loss_G_array[epoch], "D_A loss :", loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
MSE_loss = nn.MSELoss().cuda() L1_loss = nn.L1Loss().cuda() # Adam optimizer G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=opt.lrG, betas=(opt.beta1, opt.beta2)) D_A_optimizer = optim.Adam(D_A.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2)) D_B_optimizer = optim.Adam(D_B.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2)) # image store fakeA_store = utils.ImagePool(50) fakeB_store = utils.ImagePool(50) train_hist = utils.train_histogram_initialize() print('**************************start training!**************************') start_time = time.time() for epoch in range(opt.train_epoch): D_A_losses = [] D_B_losses = [] G_A_losses = [] G_B_losses = [] A_cycle_losses = [] B_cycle_losses = [] epoch_start_time = time.time() num_iter = 0
def _build_net(self): # tfph: TensorFlow PlaceHolder self.x_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='x_test_tfph') self.y_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='y_test_tfph') self.xy_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2], name='xy_fake_pairs_tfph') self.yx_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2], name='yx_fake_pairs_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._G_gen_train_ops) self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops, is_lsgan=self.is_lsgan) self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._F_gen_train_ops) self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops, is_lsgan=self.is_lsgan) self.vggModel = VGG16(name='VGG16_Pretrained') data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size, is_train=self.flags.is_train) # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed() self.fake_xy_pool_obj = utils.ImagePool(pool_size=50) self.fake_yx_pool_obj = utils.ImagePool(pool_size=50) # cycle consistency loss self.cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs) # concatenation self.fake_y_imgs = self.G_gen(self.x_imgs) self.xy_real_pairs = tf.concat([self.x_imgs, self.y_imgs], axis=3) self.xy_fake_pairs = tf.concat([self.x_imgs, self.fake_y_imgs], axis=3) self.fake_x_imgs = self.F_gen(self.y_imgs) self.yx_real_pairs = tf.concat([self.y_imgs, self.x_imgs], axis=3) self.yx_fake_pairs = tf.concat([self.y_imgs, self.fake_x_imgs], axis=3) # X -> Y self.G_gen_loss = self.generator_loss(self.Dy_dis, self.xy_fake_pairs, is_lsgan=self.is_lsgan) self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_gdl_loss = self.gradient_difference_loss(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_perceptual_loss = self.perceptual_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_loss = self.G_gen_loss + self.G_cond_loss + self.cycle_loss + self.G_gdl_loss + self.G_perceptual_loss self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis, self.xy_real_pairs, self.xy_fake_pairs_tfph, is_lsgan=self.is_lsgan) # Y -> X self.F_gen_loss = self.generator_loss(self.Dx_dis, self.yx_fake_pairs, is_lsgan=self.is_lsgan) self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_gdl_loss = self.gradient_difference_loss(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_perceputal_loss = self.perceptual_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_loss = self.F_gen_loss + self.F_cond_loss + self.cycle_loss + self.F_gdl_loss + self.F_perceputal_loss self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis, self.yx_real_pairs, self.yx_fake_pairs_tfph, is_lsgan=self.is_lsgan) G_optim = self.optimizer(loss=self.G_loss, variables=self.G_gen.variables, name='Adam_G') Dy_optim = self.optimizer(loss=self.Dy_dis_loss, variables=self.Dy_dis.variables, name='Adam_Dy') F_optim = self.optimizer(loss=self.F_loss, variables=self.F_gen.variables, name='Adam_F') Dx_optim = self.optimizer(loss=self.Dx_dis_loss, variables=self.Dx_dis.variables, name='Adam_Dx') self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim]) # for sampling function self.fake_y_sample = self.G_gen(self.x_test_tfph) self.fake_x_sample = self.F_gen(self.y_test_tfph)
def _build_net(self): # tfph: TensorFlow PlaceHolder self.x_test_tfph = placeholder(tf.float32, shape=[None, *self.img_size], name='x_test_tfph') self.y_test_tfph = placeholder(tf.float32, shape=[None, *self.img_size], name='y_test_tfph') # Supervised learning placeholders for Image Pool Tech. self.xy_fake_pairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2], name='xy_fake_pairs_tfph') self.yx_fake_pairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2], name='yx_fake_pairs_tfph') # Unsupervised learning placeholders for Image Pool Tech. self.xy_fake_unpairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 1], name='xy_fake_unpairs_tfph') self.yx_fake_unpairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 1], name='yx_fake_unpairs_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._G_gen_train_ops) self.Dy_dis_sup = Discriminator( name='Dy_sup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=False, _ops=self._Dy_dis_train_ops) self.Dy_dis_unsup = Discriminator( name='Dy_unsup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=True, _ops=self._Dy_dis_train_ops) self.F_gen = Generator( name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._F_gen_train_ops) self.Dx_dis_sup = Discriminator( name='Dx_sup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=False, _ops=self._Dx_dis_train_ops) self.Dx_dis_unsup = Discriminator( name='Dx_unsup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=True, _ops=self._Dx_dis_train_ops) self.vggModel = VGG16(name='VGG16_Pretrained') data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size, is_train=self.flags.is_train) # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed() self.fake_xy_pool_obj_sup = utils.ImagePool(pool_size=50) self.fake_yx_pool_obj_sup = utils.ImagePool(pool_size=50) self.fake_xy_pool_obj_unsup = utils.ImagePool(pool_size=50) self.fake_yx_pool_obj_unsup = utils.ImagePool(pool_size=50) # cycle consistency loss self.cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs) # concatenation self.fake_y_imgs = self.G_gen(self.x_imgs) self.xy_real_pairs = tf.concat([self.x_imgs, self.y_imgs], axis=3) self.xy_fake_pairs = tf.concat([self.x_imgs, self.fake_y_imgs], axis=3) self.fake_x_imgs = self.F_gen(self.y_imgs) self.yx_real_pairs = tf.concat([self.y_imgs, self.x_imgs], axis=3) self.yx_fake_pairs = tf.concat([self.y_imgs, self.fake_x_imgs], axis=3) # X -> Y # Supervised learning self.G_gen_loss_sup = self.generator_loss(self.Dy_dis_sup, self.xy_fake_pairs) self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_gdl_loss = self.gradient_difference_loss(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_perceptual_loss = self.perceptual_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_ssim_loss = self.ssim_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs) self.G_loss_sup = self.G_gen_loss_sup + self.G_cond_loss + self.cycle_loss + self.G_gdl_loss + \ self.G_perceptual_loss + self.G_ssim_loss self.Dy_dis_loss_sup = self.discriminator_loss( self.Dy_dis_sup, self.xy_real_pairs, self.xy_fake_pairs_tfph, is_lsgan=self.is_lsgan) # Unsupervised learning self.G_gen_loss_unsup = self.generator_loss(self.Dy_dis_unsup, self.fake_y_imgs) self.G_loss_unsup = self.G_gen_loss_unsup + self.cycle_loss self.Dy_dis_loss_unsup = self.discriminator_loss( self.Dy_dis_unsup, self.y_imgs, self.xy_fake_unpairs_tfph, is_lsgan=False) # Integrated optimization self.G_gen_loss_integrated = self.G_loss_sup + self.G_loss_unsup self.Dy_dis_loss_integrated = self.Dy_dis_loss_sup + self.Dy_dis_loss_unsup # Y -> X # Supervised learning self.F_gen_loss_sup = self.generator_loss(self.Dx_dis_sup, self.yx_fake_pairs) self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_gdl_loss = self.gradient_difference_loss(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_perceputal_loss = self.perceptual_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_ssim_loss = self.ssim_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs) self.F_loss_sup = self.F_gen_loss_sup + self.F_cond_loss + self.cycle_loss + self.F_gdl_loss + \ self.F_perceputal_loss + self.F_ssim_loss self.Dx_dis_loss_sup = self.discriminator_loss( self.Dx_dis_sup, self.yx_real_pairs, self.yx_fake_pairs_tfph, is_lsgan=self.is_lsgan) # Unsupervised Learning self.F_gen_loss_unsup = self.generator_loss(self.Dx_dis_unsup, self.fake_x_imgs) self.F_loss_unsup = self.F_gen_loss_unsup + self.cycle_loss self.Dx_dis_loss_unsup = self.discriminator_loss( self.Dx_dis_unsup, self.x_imgs, self.yx_fake_unpairs_tfph, is_lsgan=False) # Integrated optimization self.F_gen_loss_integrated = self.F_loss_sup + self.F_loss_unsup self.Dx_dis_loss_integrated = self.Dx_dis_loss_sup + self.Dx_dis_loss_unsup # Supervised learning G_optim_sup = self.optimizer( loss=self.G_loss_sup, variables=self.G_gen.variables, name='Adam_G_sup') Dy_optim_sup = self.optimizer( loss=self.Dy_dis_loss_sup, variables=self.Dy_dis_sup.variables, name='Adam_Dy_sup') F_optim_sup = self.optimizer( loss=self.F_loss_sup, variables=self.F_gen.variables, name='Adam_F_sup') Dx_optim_sup = self.optimizer( loss=self.Dx_dis_loss_sup, variables=self.Dx_dis_sup.variables, name='Adam_Dx_sup') self.optims_sup = tf.group([G_optim_sup, Dy_optim_sup, F_optim_sup, Dx_optim_sup]) # Unsupervised learning G_optim_unsup = self.optimizer( loss=self.G_loss_unsup, variables=self.G_gen.variables, name='Adam_G_unsup') Dy_optim_unsup = self.optimizer( loss=self.Dy_dis_loss_unsup, variables=self.Dy_dis_unsup.variables, name='Adam_Dy_unsup') F_optim_unsup = self.optimizer( loss=self.F_loss_unsup, variables=self.F_gen.variables, name='Adam_F_unsup') Dx_optim_unsup = self.optimizer( loss=self.Dx_dis_loss_unsup, variables=self.Dx_dis_unsup.variables, name='Adam_Dx_unsup') self.optims_unsup = tf.group([G_optim_unsup, Dy_optim_unsup, F_optim_unsup, Dx_optim_unsup]) # Integrated optimization G_optim_integrated = self.optimizer( loss=self.G_gen_loss_integrated, variables=self.G_gen.variables, name='Adam_G_integrated') Dy_optim_integrated = self.optimizer( loss=self.Dy_dis_loss_integrated, variables=[self.Dy_dis_sup.variables, self.Dy_dis_unsup.variables], name='Adam_Dy_integrated') F_optim_integrated = self.optimizer( loss=self.F_gen_loss_integrated, variables=self.F_gen.variables, name='Adam_F_integrated') Dx_optim_integrated = self.optimizer( loss=self.Dx_dis_loss_integrated, variables=[self.Dx_dis_sup.variables, self.Dx_dis_unsup.variables], name='Adam_Dx_integrated') self.optims_integrated = tf.group( [G_optim_integrated, Dy_optim_integrated, F_optim_integrated, Dx_optim_integrated]) # for sampling function self.fake_y_sample = self.G_gen(self.x_test_tfph) self.fake_x_sample = self.F_gen(self.y_test_tfph) self.print_network_vars(is_print=True)
def train(epochs): gan_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() l1_loss = gluon.loss.L1Loss() trainer_G = gluon.Trainer(netG.collect_params(), 'adam', optimizer_params={ 'learning_rate': 0.0002, 'beta1': 0.5, 'beta2': 0.999 }) trainer_D = gluon.Trainer(netD.collect_params(), 'adam', optimizer_params={ 'learning_rate': 0.0002, 'beta1': 0.5, 'beta2': 0.999 }) ## config the log file logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) fh = logging.FileHandler(os.path.join(log_dir, 'train.log')) logger.addHandler(fh) sw = SummaryWriter(logdir=os.path.join(log_dir, 'train_sw')) batch_len = train_iter.num_data // train_iter.batch_size image_pool = utils.ImagePool(50) global_step = 0 for epch in range(epochs): train_iter.reset() epch_time = time.time() batch_time = time.time() for iter_step, databatch in enumerate(train_iter): data = databatch.data[0].as_in_context(ctx) label = databatch.label[0].as_in_context(ctx) ## train netD pred = netG(data) # fake_data =nd.concat(data, pred, dim=1) # fake_data = image_pool.fetch_img(fake_data) fake_data = image_pool.fetch_img(nd.concat(data, pred, dim=1)) with autograd.record(): # fake pred_fake = netD(fake_data) fake_label = nd.zeros_like(pred_fake) loss_fake = gan_loss(pred_fake, fake_label).sum() # real real_data = nd.concat(data, label, dim=1) pred_real = netD(real_data) real_label = nd.ones_like(pred_real) loss_real = gan_loss(pred_real, real_label).sum() loss_D = (loss_real + loss_fake) * 0.5 loss_D.backward() trainer_D.step(data.shape[0]) sw.add_scalar('lossD', loss_D.asscalar(), global_step) ## train netG with autograd.record(): pred = netG(data) in_data = nd.concat(data, pred, dim=1) pred_real = netD(in_data) pred_label = nd.ones_like(pred_real) ganloss_g = gan_loss(pred_real, pred_label) l1loss_g = l1_loss(pred, label) loss_G = ganloss_g + l1loss_g * l1_lambda loss_G = loss_G.sum() loss_G.backward() trainer_G.step(data.shape[0]) sw.add_scalar('lossG', loss_G.asscalar(), global_step) ## do the checkpoints during intra epoch if (iter_step + 1) % log_iter_intervals == 0: logger.info( '[Epoch {}][Iter {}] Done., Speed: {:.4f} sample / s'. format(str(epch), str(iter_step), data.shape[0] / (time.time() - batch_time))) batch_time = time.time() global_step += 1 ## do the evaluation after every epoch fake_img = pred[0] img_arr = (fake_img - mx.nd.min(fake_img)) / (mx.nd.max(fake_img) - mx.nd.min(fake_img)) # img_arr = img_arr[::-1, :, :] sw.add_image('generated image', img_arr) eval(epch) ## do the checkpoints inter epochs netG.save_parameters(ckpt_fmt.format('netG', str(epch))) netD.save_parameters(ckpt_fmt.format('netD', str(epch))) logger.info('[Epoch {}] Done. Cost: {:.4f} s'.format( str(epch), time.time() - epch_time))
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') source_prediction_max_result = [] target_prediction_max_result = [] best_prec_result = torch.tensor(0, dtype=torch.float32) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cuda = True if torch.cuda.is_available() else False FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor fake_S_buffer = utils.ImagePool(max_size=args.max_buffer) fake_T_buffer = utils.ImagePool(max_size=args.max_buffer) # adversarial_loss = torch.nn.BCELoss() criterion_GAN = torch.nn.MSELoss() criterion_Cycle = torch.nn.L1Loss() criterion_Recov = torch.nn.MSELoss() criterion = nn.CrossEntropyLoss().cuda() def main(): global args, best_prec_result utils.default_model_dir = args.dir start_time = time.time() Source_train_loader, Source_test_loader = dataset_selector(args.sd)
def main(unused_argv): total_step = 0 checkpoints_dir = './models/real2cartoon' summary_dir = './summary' graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN(batch_size=FLAGS.batch_size, image_size=256, use_mse=FLAGS.use_mse, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, filters=FLAGS.filters, beta1=FLAGS.beta1, mse_label=FLAGS.mse_label, file_x=FLAGS.file_x, file_y=FLAGS.file_y) G_loss, F_loss, D_X_loss, D_Y_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, F_loss, D_X_loss, D_Y_loss) summarys = tf.summary.merge_all() train_writer = tf.summary.FileWriter(summary_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: ckpt = tf.train.get_checkpoint_state(checkpoints_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) total_step = int( next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) logger.info('load model success' + ckpt.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) logger.info('start new model') # img_x = utils.get_img(FLAGS.file_x, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size) # img_y = utils.get_img(FLAGS.file_y, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_X_pool = utils.ImagePool(FLAGS.pool_size) fake_Y_pool = utils.ImagePool(FLAGS.pool_size) while not coord.should_stop(): # img_x, img_y = read_file() fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summarys ], feed_dict={ cycle_gan.x: fake_X_pool.query(fake_x_val), cycle_gan.y: fake_Y_pool.query(fake_y_val) })) train_writer.add_summary(summary, total_step) train_writer.flush() logger.info('step: {}'.format(total_step)) if total_step > 1e5: sess.run(cycle_gan.learning_rate_decay_op()) if total_step % 100 == 0: logger.info('-----------Step %d:-------------' % total_step) logger.info(' G_loss : {}'.format(G_loss_val)) logger.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logger.info(' F_loss : {}'.format(F_loss_val)) logger.info(' D_X_loss : {}'.format(D_X_loss_val)) logger.info(' learning_rate : {}'.format( cycle_gan.learning_rate)) if total_step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=total_step) logger.info("Model saved in file: %s" % save_path) total_step += 1 except KeyboardInterrupt: logger.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=total_step) logger.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def _build_graph(self): self.x_test_tfph = tf.compat.v1.placeholder( tf.float32, shape=[None, *self.input_shape], name='x_test_tfph') self.fake_pair_tfph = tf.compat.v1.placeholder( tf.float32, shape=[None, self.input_shape[0], self.input_shape[1], 2], name='fake_pairs_tfph') self.rate_tfph = tf.compat.v1.placeholder(tf.float32, name='keep_prob_ph') # Initialize TFRecord reader train_reader = Reader(tfrecordsFile=self.data_path[0], decodeImgShape=self.decode_img_shape, imgShape=self.input_shape, batchSize=self.batch_size, name='train') # Initialize generator & discriminator self.gen_obj = Generator(name='G', gen_c=self.gen_c, norm='instance', logger=self.logger, _ops=None) self.dis_obj = Discriminator(name='D', dis_c=self.dis_c, norm='instance', logger=self.logger, _ops=None) # Random batch for training self.img_train, self.seg_img_train = train_reader.shuffle_batch() self.img_pool_obj = utils.ImagePool(pool_size=150) # Transform img_train and seg_img_train trans_seg_img_train = self.transform_seg(self.seg_img_train) trans_img_train = self.transform_img(self.img_train) # Concatenation self.g_sample = self.gen_obj(trans_seg_img_train, self.rate_tfph) self.real_pair = tf.concat([trans_seg_img_train, trans_img_train], axis=3) self.fake_pair = tf.concat([trans_seg_img_train, self.g_sample], axis=3) # Define generator loss self.gen_adv_loss = self.generator_loss(self.dis_obj, self.fake_pair) self.cond_loss = self.conditional_loss(pred=self.g_sample, gt=trans_img_train) self.gen_loss = self.gen_adv_loss + self.cond_loss # Define discriminator loss self.dis_loss = self.discriminator_loss(self.dis_obj, self.real_pair, self.fake_pair_tfph) # Optimizers self.gen_optim = self.init_optimizer(loss=self.gen_loss, variables=self.gen_obj.variables, name='Adam_gen') self.dis_optim = self.init_optimizer(loss=self.dis_loss, variables=self.dis_obj.variables, name='Adam_dis')
def _build_net(self): self.mae_record_placeholder = tf.placeholder( tf.float32, name='mae_record_placeholder') self.mae_record = tf.Variable(256., trainable=False, dtype=tf.float32, name='mae_record') self.mae_record_assign_op = self.mae_record.assign( self.mae_record_placeholder) # tfph: TensorFlow PlaceHolder self.x_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='x_test_tfph') self.y_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='y_test_tfph') self.fake_x_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='fake_x_tfph') self.fake_y_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='fake_y_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.image_size, _ops=self._G_gen_train_ops) self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops, use_sigmoid=self.use_sigmoid) self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.image_size, _ops=self._F_gen_train_ops) self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops, use_sigmoid=self.use_sigmoid) x_reader = Reader(self.x_path, name='X', image_size=self.image_size, batch_size=self.flags.batch_size) y_reader = Reader(self.y_path, name='Y', image_size=self.image_size, batch_size=self.flags.batch_size) self.x_imgs = x_reader.feed() self.y_imgs = y_reader.feed() self.fake_x_pool_obj = utils.ImagePool(pool_size=50) self.fake_y_pool_obj = utils.ImagePool(pool_size=50) # cycle consistency loss cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs) # X -> Y self.fake_y_imgs = self.G_gen(self.x_imgs) self.G_gen_loss = self.generator_loss(self.Dy_dis, self.fake_y_imgs, use_lsgan=self.use_lsgan) self.G_loss = self.G_gen_loss + cycle_loss self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis, self.y_imgs, self.fake_y_tfph, use_lsgan=self.use_lsgan) # Y -> X self.fake_x_imgs = self.F_gen(self.y_imgs) self.F_gen_loss = self.generator_loss(self.Dx_dis, self.fake_x_imgs, use_lsgan=self.use_lsgan) self.F_loss = self.F_gen_loss + cycle_loss self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis, self.x_imgs, self.fake_x_tfph, use_lsgan=self.use_lsgan) G_op = self.optimizer(loss=self.G_loss, variables=self.G_gen.variables, name='Adam_G') G_ops = [G_op] + self._G_gen_train_ops G_optim = tf.group(*G_ops) Dy_op = self.optimizer(loss=self.Dy_dis_loss, variables=self.Dy_dis.variables, name='Adam_Dy') Dy_ops = [Dy_op] + self._Dy_dis_train_ops Dy_optim = tf.group(*Dy_ops) F_op = self.optimizer(loss=self.F_loss, variables=self.F_gen.variables, name='Adam_F') F_ops = [F_op] + self._F_gen_train_ops F_optim = tf.group(*F_ops) Dx_op = self.optimizer(loss=self.Dx_dis_loss, variables=self.Dx_dis.variables, name='Adam_Dx') Dx_ops = [Dx_op] + self._Dx_dis_train_ops Dx_optim = tf.group(*Dx_ops) self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim]) # with tf.control_dependencies([G_optim, Dy_optim, F_optim, Dx_optim]): # self.optims = tf.no_op(name='optimizers') # for sampling function self.fake_y_sample = self.G_gen(self.x_test_tfph) self.fake_x_sample = self.F_gen(self.y_test_tfph) self.recon_x_sample = self.F_gen(self.G_gen(self.x_test_tfph))
def _build_net(self): # tfph: TensorFlow PlaceHolder self.x_test_tfph = tf.placeholder( tf.float32, shape=[None, self.img_size[0], self.img_size[1], self.mm_dim], name='x_test_tfph') self.y_test_tfph = tf.placeholder( tf.float32, shape=[None, self.img_size[0], self.img_size[1], self.fp_dim], name='y_test_tfph') self.xy_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[ None, self.img_size[0], self.img_size[1], self.mm_dim + self.fp_dim ], name='xy_fake_pairs_tfph') self.yx_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[ None, self.img_size[0], self.img_size[1], self.mm_dim + self.fp_dim ], name='yx_fake_pairs_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size, output_dim=self.fp_dim, _ops=self._G_gen_train_ops) self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops) self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size, output_dim=self.mm_dim, _ops=self._F_gen_train_ops) self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops) data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size, is_train=self.flags.is_train) # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed( ) # slicing minutiae and fingerprint images to 2d and 1d self.x_imgs_2d = tf.slice(self.x_imgs, begin=[0, 0, 0, 0], size=[-1, -1, -1, 2], name='x_slice') self.y_imgs_1d = tf.slice(self.y_imgs, begin=[0, 0, 0, 0], size=[-1, -1, -1, 1], name='y_slice') self.fake_xy_pool_obj = utils.ImagePool(pool_size=50) self.fake_yx_pool_obj = utils.ImagePool(pool_size=50) # cycle consistency loss self.cycle_loss = self.cycle_consistency_loss(self.x_imgs_2d, self.y_imgs_1d) # concatenation self.fake_y_imgs = self.G_gen(self.x_imgs_2d) self.xy_real_pairs = tf.concat([self.x_imgs_2d, self.y_imgs_1d], axis=3) self.xy_fake_pairs = tf.concat([self.x_imgs_2d, self.fake_y_imgs], axis=3) self.fake_x_imgs = self.F_gen(self.y_imgs_1d) self.yx_real_pairs = tf.concat([self.y_imgs_1d, self.x_imgs_2d], axis=3) self.yx_fake_pairs = tf.concat([self.y_imgs_1d, self.fake_x_imgs], axis=3) # X -> Y self.G_gen_loss = self.generator_loss(self.Dy_dis, self.xy_fake_pairs) self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs, gts=self.y_imgs_1d, weight=self.L1_lambda) self.G_loss = self.G_gen_loss + self.G_cond_loss + self.cycle_loss self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis, self.xy_real_pairs, self.xy_fake_pairs_tfph) # Y -> X self.F_gen_loss = self.generator_loss(self.Dx_dis, self.yx_fake_pairs) self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs, gts=self.x_imgs_2d, weight=0.) self.F_loss = self.F_gen_loss + self.F_cond_loss + self.cycle_loss self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis, self.yx_real_pairs, self.yx_fake_pairs_tfph) G_optim = self.optimizer(loss=self.G_loss, variables=self.G_gen.variables, name='Adam_G') Dy_optim = self.optimizer(loss=self.Dy_dis_loss, variables=self.Dy_dis.variables, name='Adam_Dy') F_optim = self.optimizer(loss=self.F_loss, variables=self.F_gen.variables, name='Adam_F') Dx_optim = self.optimizer(loss=self.Dx_dis_loss, variables=self.Dx_dis.variables, name='Adam_Dx') self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim]) # for sampling function self.fake_y_sample = self.G_gen(self.x_test_tfph) self.fake_x_sample = self.F_gen(self.y_test_tfph)
def _build_net(self): self.mae_record_placeholder = tf.placeholder( tf.float32, name='mae_record_placeholder') self.mae_record = tf.Variable(256., trainable=False, dtype=tf.float32, name='mae_record') self.mae_record_assign_op = self.mae_record.assign( self.mae_record_placeholder) # tfph: TensorFlow PlaceHolder self.x_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='x_test_tfph') self.y_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='y_test_tfph') self.fake_x_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='fake_x_tfph') self.fake_y_tfph = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='fake_y_tfph') self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.image_size, _ops=self._G_gen_train_ops) self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops, use_sigmoid=self.use_sigmoid) self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.image_size, _ops=self._F_gen_train_ops) self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops, use_sigmoid=self.use_sigmoid) x_reader = Reader(self.x_path, name='X', image_size=self.image_size, batch_size=self.flags.batch_size) y_reader = Reader(self.y_path, name='Y', image_size=self.image_size, batch_size=self.flags.batch_size) self.x_imgs = x_reader.feed() self.y_imgs = y_reader.feed() self.fake_x_pool_obj = utils.ImagePool(pool_size=50) self.fake_y_pool_obj = utils.ImagePool(pool_size=50) self._unpair_net() # idea from cyclegan self._pair_net() # idea from pix2pix # Optimizers # G generator for unpaired data G_op_unpair = self.optimizer(loss=self.G_loss_unpair, variables=self.G_gen.variables, name='Adam_G_unpair') G_ops_unpair = [G_op_unpair] + self._G_gen_train_ops G_optim_unpair = tf.group(*G_ops_unpair) # G generator for paired data G_op_pair = self.optimizer(loss=self.G_loss_pair, variables=self.G_gen.variables, name='Adam_G_pair') G_ops_pair = [G_op_pair] + self._G_gen_train_ops self.G_optim_pair = tf.group(*G_ops_pair) # Dy discriminator for unpaired data Dy_op_unpair = self.optimizer(loss=self.Dy_dis_loss_unpair, variables=[ self.Dy_dis.share_variables, self.Dy_dis.unpair_variables ], name='Adam_Dy_unpair') Dy_ops_unpair = [Dy_op_unpair] + self._Dy_dis_train_ops Dy_optim_unpair = tf.group(*Dy_ops_unpair) # Dy discriminator for paired data Dy_op_pair = self.optimizer(loss=self.Dy_dis_loss_pair, variables=[ self.Dy_dis.share_variables, self.Dy_dis.pair_variables ], name='Adam_Dy_pair') Dy_ops_pair = [Dy_op_pair] + self._Dy_dis_train_ops self.Dy_optim_pair = tf.group(*Dy_ops_pair) # F generator for unpaired data F_op_unpair = self.optimizer(loss=self.F_loss_unpair, variables=self.F_gen.variables, name='Adam_F_unpair') F_ops_unpair = [F_op_unpair] + self._F_gen_train_ops F_optim_unpair = tf.group(*F_ops_unpair) # F generator for paired data F_op_pair = self.optimizer(loss=self.F_loss_pair, variables=self.F_gen.variables, name='Adam_F_pair') F_ops_pair = [F_op_pair] + self._F_gen_train_ops self.F_optim_pair = tf.group(*F_ops_pair) # Dx discriminator for unpaired data Dx_op_unpair = self.optimizer(loss=self.Dx_dis_loss_unpair, variables=[ self.Dx_dis.share_variables, self.Dx_dis.unpair_variables ], name='Adam_Dx_unpair') Dx_ops_unpair = [Dx_op_unpair] + self._Dx_dis_train_ops Dx_optim_unpair = tf.group(*Dx_ops_unpair) # Dx discriminator for paired data Dx_op_pair = self.optimizer(loss=self.Dx_dis_loss_pair, variables=[ self.Dx_dis.share_variables, self.Dx_dis.pair_variables ], name='Adam_Dx_pair') Dx_ops_pair = [Dx_op_pair] + self._Dx_dis_train_ops self.Dx_optim_pair = tf.group(*Dx_ops_pair) self.optims_unpair = tf.group( [G_optim_unpair, Dy_optim_unpair, F_optim_unpair, Dx_optim_unpair]) self.optims_pair = tf.group([ self.G_optim_pair, self.Dy_optim_pair, self.F_optim_pair, self.Dx_optim_pair ]) self.loss_collections = [ self.G_loss_unpair, self.Dy_dis_loss_unpair, self.F_loss_unpair, self.Dx_dis_loss_unpair, self.G_loss_pair, self.Dy_dis_loss_pair, self.F_loss_pair, self.Dx_dis_loss_pair ]
def train(self, args): # Obtain dataloaders loader = self.get_dataloader(args) # Generated image pools imagepool_a = utils.ImagePool() imagepool_b = utils.ImagePool() lambda_coef = args.lamda lambda_idt = args.idt_coef # Initialize Weights utils.init_weights(self.G_BA) utils.init_weights(self.G_AB) utils.init_weights(self.D_A) utils.init_weights(self.D_B) step = 0 self.load_checkpoint(args) # Terrible hack self.gen_scheduler.last_epoch = self.curr_epoch - 1 self.dis_scheduler.last_epoch = self.curr_epoch - 1 self.G_BA.train() self.G_AB.train() for epoch in range(self.curr_epoch, args.epochs): for a_real, b_real in loader: # Send data to (ideally) GPU a_real = a_real.to(self.device) b_real = b_real.to(self.device) # batch size batch_size = a_real.shape[0] positive_labels = torch.ones(batch_size).to(self.device) negative_labels = torch.zeros(batch_size).to(self.device) # Generator forward passes a_fake = self.G_BA(b_real) b_fake = self.G_AB(a_real) a_reconstruct = self.G_BA(b_fake) b_reconstruct = self.G_AB(a_fake) a_identity = self.G_BA(a_real) b_identity = self.G_AB(b_real) # Identity Loss a_idt_loss = self.L1(a_identity, a_real) * lambda_coef * lambda_idt b_idt_loss = self.L1(b_identity, b_real) * lambda_coef * lambda_idt # GAN Loss a_fake_dis = self.D_A(a_fake) b_fake_dis = self.D_B(b_fake) positive_labels = torch.ones_like(a_fake_dis) a_gan_loss = self.MSE(a_fake_dis, positive_labels) b_gan_loss = self.MSE(b_fake_dis, positive_labels) # Cycle Loss a_cycle_loss = self.L1(a_reconstruct, a_real) * lambda_coef b_cycle_loss = self.L1(b_reconstruct, b_real) * lambda_coef # Total Loss total_gan_loss = a_idt_loss + b_idt_loss + a_gan_loss + b_gan_loss + a_cycle_loss + b_cycle_loss # Sample previously generated images for discriminator forward pass a_fake = torch.Tensor( imagepool_a(a_fake.detach().cpu().clone().numpy()) ) # a_fake first dim might be batch entry b_fake = torch.Tensor( imagepool_b(b_fake.detach().cpu().clone().numpy())) a_fake = a_fake.to(self.device) b_fake = b_fake.to(self.device) # Discriminator forward pass a_real_dis = self.D_A(a_real) a_fake_dis = self.D_B(a_fake) b_real_dis = self.D_B(b_real) b_fake_dis = self.D_B(b_fake) # Discriminator Losses positive_labels = torch.ones_like(a_fake_dis) negative_labels = torch.zeros_like(a_fake_dis) a_dis_real_loss = self.MSE(a_real_dis, positive_labels) a_dis_fake_loss = self.MSE(a_fake_dis, negative_labels) b_dis_real_loss = self.MSE(b_real_dis, positive_labels) b_dis_fake_loss = self.MSE(b_fake_dis, negative_labels) a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5 # Step self.gen_optimizer.zero_grad() total_gan_loss.backward() self.gen_optimizer.step() self.dis_optimizer.zero_grad() a_dis_loss.backward() b_dis_loss.backward() self.dis_optimizer.step() for group in self.dis_optimizer.param_groups: for p in group['params']: state = self.dis_optimizer.state[p] if state['step'] >= 962: state['step'] = 962 for group in self.gen_optimizer.param_groups: for p in group['params']: state = self.gen_optimizer.state[p] if state['step'] >= 962: state['step'] = 962 if (step + 1) % 5 == 0: print( "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % (epoch, step + 1, len(loader), total_gan_loss, a_dis_loss + b_dis_loss)) step += 1 self.save_checkpoint(epoch + 1, args) self.gen_scheduler.step() self.dis_scheduler.step() step = 0