def create_networks(opt, checkpoint=None): generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() if checkpoint: # Restore the network state generator.load_state_dict(checkpoint['G']) discriminator.load_state_dict(checkpoint['D']) # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() return generator, discriminator, perceptualnet
def trainer_WGANGP(opt): # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(generator) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.save_mode == 'epoch': model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch, opt.batch_size) if opt.save_mode == 'iter': model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration, opt.batch_size) save_name = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.ColorizationDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Calculate the gradient penalty loss for WGAN-GP def compute_gradient_penalty(D, input_samples, real_samples, fake_samples): # Random weight term for interpolation between real and fake samples alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(input_samples, interpolates) # For PatchGAN fake = Variable(Tensor(real_samples.shape[0], 1, 30, 30).fill_(1.0), requires_grad=False) # Get gradient w.r.t. interpolates gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # For loop training for epoch in range(opt.epochs): for i, (true_L, true_RGB, true_sal) in enumerate(dataloader): # To device true_L = true_L.cuda() true_RGB = true_RGB.cuda() true_sal = true_sal.cuda() true_sal = torch.cat((true_sal, true_sal, true_sal), 1) true_attn = true_RGB.mul(true_sal) ### Train Discriminator optimizer_D.zero_grad() # Generator output fake_RGB, fake_sal = generator(true_L) # Fake colorizations fake_scalar_d = discriminator(true_L, fake_RGB.detach()) # True colorizations true_scalar_d = discriminator(true_L, true_RGB) # Gradient penalty gradient_penalty = compute_gradient_penalty( discriminator, true_L.data, true_RGB.data, fake_RGB.data) # Overall Loss and optimize loss_D = -torch.mean(true_scalar_d) + torch.mean( fake_scalar_d) + opt.lambda_gp * gradient_penalty loss_D.backward() optimizer_D.step() ### Train Generator optimizer_G.zero_grad() fake_RGB, fake_sal = generator(true_L) # Pixel-level L1 Loss loss_L1 = criterion_L1(fake_RGB, true_RGB) # Attention Loss fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_attn = fake_RGB.mul(fake_sal) loss_attn = criterion_L1(fake_attn, true_attn) # Perceptual Loss feature_fake_RGB = perceptualnet(fake_RGB) feature_true_RGB = perceptualnet(true_RGB) loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB) # GAN Loss fake_scalar = discriminator(true_L, fake_RGB) loss_GAN = -torch.mean(fake_scalar) # Overall Loss and optimize loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn loss_G.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(), loss_attn.item(), loss_percep.item(), loss_D.item(), loss_GAN.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_RGB, true_RGB] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=opt.sample_path, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list)
def trainer_LSGAN(opt): # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(generator) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.save_mode == 'epoch': model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch, opt.batch_size) if opt.save_mode == 'iter': model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration, opt.batch_size) save_name = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.ColorizationDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # For loop training for epoch in range(opt.epochs): for i, (true_L, true_RGB, true_sal) in enumerate(dataloader): # To device true_L = true_L.cuda() true_RGB = true_RGB.cuda() true_sal = true_sal.cuda() true_sal = torch.cat((true_sal, true_sal, true_sal), 1) true_attn = true_RGB.mul(true_sal) # Adversarial ground truth valid = Tensor(np.ones((true_L.shape[0], 1, 30, 30))) fake = Tensor(np.zeros((true_L.shape[0], 1, 30, 30))) ### Train Discriminator optimizer_D.zero_grad() # Generator output fake_RGB, fake_sal = generator(true_L) # Fake colorizations fake_scalar_d = discriminator(true_L, fake_RGB.detach()) loss_fake = criterion_MSE(fake_scalar_d, fake) # True colorizations true_scalar_d = discriminator(true_L, true_RGB) loss_true = criterion_MSE(true_scalar_d, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_D.step() ### Train Generator optimizer_G.zero_grad() fake_RGB, fake_sal = generator(true_L) # Pixel-level L1 Loss loss_L1 = criterion_L1(fake_RGB, true_RGB) # Attention Loss fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_attn = fake_RGB.mul(fake_sal) loss_attn = criterion_L1(fake_attn, true_attn) # Perceptual Loss feature_fake_RGB = perceptualnet(fake_RGB) feature_true_RGB = perceptualnet(true_RGB) loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB) # GAN Loss fake_scalar = discriminator(true_L, fake_RGB) loss_GAN = criterion_MSE(fake_scalar, valid) # Overall Loss and optimize loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn loss_G.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(), loss_attn.item(), loss_percep.item(), loss_D.item(), loss_GAN.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_RGB, true_RGB] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=opt.sample_path, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list)
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_LSGAN_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Training loop for epoch in range(opt.epochs): for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Overall Loss and optimize loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = -torch.mean(fake_scalar) # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg_featuremaps = perceptualnet( second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) # Compute losses loss = opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + \ opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, (epoch + 1), opt) ### Sample data every epoch masked_img = img * (1 - mask) + mask mask = torch.cat((mask, mask, mask), 1) if (epoch + 1) % 1 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out'] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255)
def Trainer_GAN(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num print("Batch size is changed to %d" % opt.batch_size) print("Number of workers is changed to %d" % opt.num_workers) # Build path folder utils.check_path(opt.save_path) utils.check_path(opt.sample_path) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(optimizer, epoch, opt, init_lr): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = init_lr * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'GrayInpainting_GAN_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_path = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_path) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_path) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.epochs): for batch_idx, (grayscale, mask) in enumerate(dataloader): # Load and put to cuda grayscale = grayscale.cuda() # out: [B, 1, 256, 256] mask = mask.cuda() # out: [B, 1, 256, 256] # LSGAN vectors valid = Tensor(np.ones((grayscale.shape[0], 1, 8, 8))) fake = Tensor(np.zeros((grayscale.shape[0], 1, 8, 8))) # ---------------------------------------- # Train Discriminator # ---------------------------------------- optimizer_d.zero_grad() # forward propagation out = generator(grayscale, mask) # out: [B, 1, 256, 256] out_wholeimg = grayscale * (1 - mask) + out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(grayscale, mask) # Overall Loss and optimize loss_fake = MSELoss(fake_scalar, fake) loss_true = MSELoss(true_scalar, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() # ---------------------------------------- # Train Generator # ---------------------------------------- optimizer_g.zero_grad() # forward propagation out = generator(grayscale, mask) # out: [B, 1, 256, 256] out_wholeimg = grayscale * (1 - mask) + out * mask # in range [0, 1] # Mask L1 Loss MaskL1Loss = L1Loss(out_wholeimg, grayscale) # GAN Loss fake_scalar = discriminator(out_wholeimg, mask) MaskGAN_Loss = MSELoss(fake_scalar, valid) # Get the deep semantic feature maps, and compute Perceptual Loss out_3c = torch.cat((out_wholeimg, out_wholeimg, out_wholeimg), 1) grayscale_3c = torch.cat((grayscale, grayscale, grayscale), 1) out_featuremaps = perceptualnet(out_3c) gt_featuremaps = perceptualnet(grayscale_3c) PerceptualLoss = L1Loss(out_featuremaps, gt_featuremaps) # Compute losses loss = opt.lambda_l1 * MaskL1Loss + opt.lambda_perceptual * PerceptualLoss + opt.lambda_gan * MaskGAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [Perceptual Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), MaskL1Loss.item(), PerceptualLoss.item(), loss_D.item(), MaskGAN_Loss.item(), time_left)) # Learning rate decrease adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g) adjust_learning_rate(optimizer_d, (epoch + 1), opt, opt.lr_d) # Save the model save_model(generator, (epoch + 1), opt) utils.sample(grayscale, mask, out_wholeimg, opt.sample_path, (epoch + 1))
def Continue_train_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.RAW2RGBDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target, true_sal) in enumerate(dataloader): # To device true_input = true_input.cuda() true_target = true_target.cuda() true_sal = true_sal.cuda() true_sal3 = torch.cat((true_sal, true_sal, true_sal), 1).cuda() # Train Discriminator for j in range(opt.additional_training_d): optimizer_D.zero_grad() # Generator output fake_target, fake_sal = generator(true_input) # Fake samples fake_block1, fake_block2, fake_block3, fake_scalar = discriminator( true_input, fake_target.detach()) # True samples true_block1, true_block2, true_block3, true_scalar = discriminator( true_input, true_target) ''' # Feature Matching Loss FM_Loss = criterion_L1(fake_block1, true_block1) + criterion_L1(fake_block2, true_block2) + criterion_L1(fake_block3, true_block3) ''' # Overall Loss and optimize loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar) loss_D.backward() # Train Generator optimizer_G.zero_grad() fake_target, fake_sal = generator(true_input) # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, true_target) # Attention Loss true_Attn_target = true_target.mul(true_sal3) fake_sal3 = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_Attn_target = fake_target.mul(fake_sal3) Attention_Loss = criterion_L1(fake_Attn_target, true_Attn_target) # GAN Loss fake_block1, fake_block2, fake_block3, fake_scalar = discriminator( true_input, fake_target) GAN_Loss = -torch.mean(fake_scalar) # Perceptual Loss fake_target = fake_target * 0.5 + 0.5 true_target = true_target * 0.5 + 0.5 fake_target = utils.normalize_ImageNet_stats(fake_target) true_target = utils.normalize_ImageNet_stats(true_target) fake_percep_feature = perceptualnet(fake_target) true_percep_feature = perceptualnet(true_target) Perceptual_Loss = criterion_L1(fake_percep_feature, true_percep_feature) # Overall Loss and optimize loss = Pixellevel_L1_Loss + opt.lambda_attn * Attention_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_percep * Perceptual_Loss loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f]" % ((epoch + 1), opt.epochs, i, len(dataloader), Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item())) print( "\r[Attention Loss: %.4f] [Perceptual Loss: %.4f] Time_left: %s" % (Attention_Loss.item(), Perceptual_Loss.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)
def Continue_train_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.NormalRGBDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target) in enumerate(dataloader): # To device true_input = true_input.cuda() true_target = true_target.cuda() # Sample noise and get data noise1 = utils.get_noise(true_input.shape[0], opt.z_dim, opt.random_type) noise1 = noise1.cuda() # out: batch * z_dim noise2 = utils.get_noise(true_input.shape[0], opt.z_dim, opt.random_type) noise2 = noise2.cuda() # out: batch * z_dim concat_noise = torch.cat((noise1, noise2), 0) # out: 2batch * z_dim concat_input = torch.cat((true_input, true_input), 0) # out: 2batch * 1 * 256 * 256 concat_target = torch.cat((true_target, true_target), 0) # out: 2batch * 3 * 256 * 256 # Train Generator optimizer_G.zero_grad() fake_target = generator( concat_input, concat_noise) # out: 2batch * 3 * 256 * 256 # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, concat_target) # MSGAN Loss fake_target1, fake_target2 = fake_target.split( true_input.shape[0], 0) ms_value = torch.mean( torch.abs(fake_target2 - fake_target1)) / torch.mean( torch.abs(noise2 - noise1)) eps = 1e-5 ModeSeeking_Loss = 1 / (ms_value + eps) # GAN Loss fake_scalar = discriminator(concat_input, fake_target) GAN_Loss = -torch.mean(fake_scalar) # Overall Loss and optimize loss = opt.lambda_l1 * Pixellevel_L1_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_ms * ModeSeeking_Loss loss.backward() optimizer_G.step() # Train Discriminator for j in range(opt.additional_training_d): optimizer_D.zero_grad() # Generator output fake_target = generator(concat_input, concat_noise) fake_target1, fake_target2 = fake_target.split( concat_noise.shape[0], 0) # Fake samples fake_scalar_d1 = discriminator(true_input, fake_target1.detach()) fake_scalar_d2 = discriminator(true_input, fake_target2.detach()) # True samples true_scalar_d = discriminator(true_input, true_target) # Overall Loss and optimize loss_D1 = -torch.mean(true_scalar_d) + torch.mean( fake_scalar_d1) loss_D2 = -torch.mean(true_scalar_d) + torch.mean( fake_scalar_d2) loss_D = loss_D1 + loss_D2 loss_D.backward() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)
def CycleGAN_LSGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_BCE = torch.nn.BCEWithLogitsLoss().cuda() # Initialize networks G = utils.create_generator(opt) D = utils.create_discriminator(opt) # To device if opt.multi_gpu: G = nn.DataParallel(G) G = G.cuda() D = nn.DataParallel(D) D = D.cuda() else: G = G.cuda() D = D.cuda() # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, G, D): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save( G.module, 'AttnGAN_parent_G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) torch.save( D.module, 'AttnGAN_parent_D_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save( G.module, 'AttnGAN_parent_G_iter%d_bs%d.pth' % (iteration, opt.batch_size)) torch.save( D.module, 'AttnGAN_parent_D_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save( G, 'AttnGAN_parent_G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) torch.save( D, 'AttnGAN_parent_D_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save( G, 'AttnGAN_parent_G_iter%d_bs%d.pth' % (iteration, opt.batch_size)) torch.save( D, 'AttnGAN_parent_D_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.CFP_dataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (img, imglabel) in enumerate(dataloader): # To device img = img.cuda() idx = torch.randperm(len(imglabel)) imglabel_fake = imglabel[idx].contiguous() imglabel = imglabel.cuda() imglabel_fake = imglabel_fake.cuda() # ------------------------------- Train Generator ------------------------------- optimizer_G.zero_grad() # Forward img_recon, img_fake = G(img, imglabel, imglabel_fake) out_adv, out_class = D(img_fake) # Recon Loss loss_recon = criterion_L1(img_recon, img) # WGAN loss loss_gan = -torch.mean(out_adv) # Classification Loss loss_class = criterion_BCE(out_class, imglabel_fake) # Overall Loss and optimize loss = opt.lambda_recon * loss_recon + opt.lambda_gan * loss_gan + opt.lambda_class * loss_class loss.backward() optimizer_G.step() # ------------------------------- Train Discriminator ------------------------------- optimizer_D.zero_grad() # Forward img_recon, img_fake = G(img, imglabel, imglabel_fake) out_adv_fake, out_class_fake = D(img_fake.detach()) out_adv_true, out_class_true = D(img.detach()) # WGAN loss loss_gan = torch.mean(out_adv_fake) - torch.mean(out_adv_true) # Classification Loss loss_class = criterion_BCE(out_class_true, imglabel) # Overall Loss and optimize loss = loss_gan + loss_class loss.backward() optimizer_D.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] [GAN Loss: %.4f] [Class Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_recon.item(), loss_gan.item(), loss_class.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G, D) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)
def CycleGAN_LSGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator # A is for grayscale image # B is for color RGB image G_AB = utils.create_generator(opt) G_BA = utils.create_generator(opt) D_A = utils.create_discriminator(opt) D_B = utils.create_discriminator(opt) # To device if opt.multi_gpu: G_AB = nn.DataParallel(G_AB) G_AB = G_AB.cuda() G_BA = nn.DataParallel(G_BA) G_BA = G_BA.cuda() D_A = nn.DataParallel(D_A) D_A = D_A.cuda() D_B = nn.DataParallel(D_B) D_B = D_B.cuda() else: G_AB = G_AB.cuda() G_BA = G_BA.cuda() D_A = D_A.cuda() D_B = D_B.cuda() # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, G_AB, G_BA): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( G_AB.module, 'G_AB_LSGAN_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) torch.save( G_BA.module, 'G_BA_LSGAN_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( G_AB.module, 'G_AB_LSGAN_iter%d_bs%d.pth' % (iteration, opt.batch_size)) torch.save( G_BA.module, 'G_BA_LSGAN_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( G_AB, 'G_AB_LSGAN_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) torch.save( G_BA, 'G_BA_LSGAN_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( G_AB, 'G_AB_LSGAN_iter%d_bs%d.pth' % (iteration, opt.batch_size)) torch.save( G_BA, 'G_BA_LSGAN_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # Tensor type Tensor = torch.cuda.FloatTensor # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.DomainTransferDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_A, true_B) in enumerate(dataloader): # To device # A is for grayscale image # B is for color RGB image true_A = true_A.cuda() true_B = true_B.cuda() # Adversarial ground truth valid = Tensor(np.ones((true_A.shape[0], 1, 16, 16))) fake = Tensor(np.zeros((true_A.shape[0], 1, 16, 16))) # Train Generator optimizer_G.zero_grad() # Indentity Loss loss_indentity_A = criterion_L1(G_BA(true_A), true_A) loss_indentity_B = criterion_L1(G_AB(true_B), true_B) loss_indentity = (loss_indentity_A + loss_indentity_B) / 2 # GAN Loss fake_B = G_AB(true_A) loss_GAN_AB = criterion_MSE(D_B(fake_B), valid) fake_A = G_BA(true_B) loss_GAN_BA = criterion_MSE(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle-consistency Loss recon_A = G_BA(fake_B) loss_cycle_A = criterion_L1(recon_A, true_A) recon_B = G_AB(fake_A) loss_cycle_B = criterion_L1(recon_B, true_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Overall Loss and optimize loss = loss_GAN + opt.lambda_cycle * loss_cycle + opt.lambda_identity * loss_indentity loss.backward() optimizer_G.step() # Train Discriminator A optimizer_D_A.zero_grad() # Fake samples fake_scalar_d = D_A(fake_A.detach()) loss_fake = criterion_MSE(fake_scalar_d, fake) # True samples true_scalar_d = D_A(true_A) loss_true = criterion_MSE(true_scalar_d, valid) # Overall Loss and optimize loss_D_A = 0.5 * (loss_fake + loss_true) loss_D_A.backward() # Train Discriminator B optimizer_D_B.zero_grad() # Fake samples fake_scalar_d = D_B(fake_B.detach()) loss_fake = criterion_MSE(fake_scalar_d, fake) # True samples true_scalar_d = D_B(true_B) loss_true = criterion_MSE(true_scalar_d, valid) # Overall Loss and optimize loss_D_B = 0.5 * (loss_fake + loss_true) loss_D_B.backward() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [D_A Loss: %.4f] [D_B Loss: %.4f] [G GAN Loss: %.4f] [G Cycle Loss: %.4f] [G Indentity Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_D_A.item(), loss_D_B.item(), loss_GAN.item(), loss_cycle.item(), loss_indentity.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G_AB, G_BA) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D_A) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D_B)
def Pre_train(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_L2 = torch.nn.MSELoss().cuda() # Initialize Generator G = utils.create_generator(opt) D_cVAE = utils.create_discriminator(opt) D_cLR = utils.create_discriminator(opt) E = utils.create_encoder(opt) # To device if opt.multi_gpu: G = nn.DataParallel(G) G = G.cuda() D_cVAE = nn.DataParallel(D_cVAE) D_cVAE = D_cVAE.cuda() D_cLR = nn.DataParallel(D_cLR) D_cLR = discriminator_cLR.cuda() E = nn.DataParallel(E) E = E.cuda() else: G = G.cuda() D_cVAE = D_cVAE.cuda() D_cLR = D_cLR.cuda() E = E.cuda() # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D_cVAE = torch.optim.Adam(D_cVAE.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D_cLR = torch.optim.Adam(D_cLR.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_E = torch.optim.Adam(E.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(opt, epoch, optimizer): decay_rate = 1.0 - (max(0, epoch - opt.start_decrease_epoch) // opt.lr_decrease_divide) # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs lr = opt.lr_g * decay_rate for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.DomainTransferDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target) in enumerate(dataloader): # To device, and seperate data for cVAE_GAN and cLR_GAN true_input = true_input.cuda() true_target = true_target.cuda() cVAE_data = { 'img': true_input[[0], :, :, :], 'ground_truth': true_target[[0], :, :, :] } cLR_data = { 'img': true_input[[1], :, :, :], 'ground_truth': true_target[[1], :, :, :] } ''' ----------------------------- 1. Train D ----------------------------- ''' ############# Step 1. D loss in cVAE-GAN ############# # Encoded latent vector mu, log_variance = E(cVAE_data['ground_truth']) std = torch.exp(log_variance / 2) random_z = torch.randn(1, opt.z_dim).cuda() encoded_z = (random_z * std) + mu # Generate fake image fake_img_cVAE = G(cVAE_data['img'], encoded_z) # Get scores and loss real_d_cVAE_1, real_d_cVAE_2 = D_cVAE(cVAE_data['ground_truth']) fake_d_cVAE_1, fake_d_cVAE_2 = D_cVAE(fake_img_cVAE.detach()) # mse_loss for LSGAN D_loss_cVAE_1 = criterion_L2(real_d_cVAE_1, 1) + criterion_L2( fake_d_cVAE_1, 0) D_loss_cVAE_2 = criterion_L2(real_d_cVAE_2, 1) + criterion_L2( fake_d_cVAE_2, 0) ############# Step 2. D loss in cLR-GAN ############# # Random latent vector random_z = torch.randn(1, opt.z_dim).cuda() # Generate fake image fake_img_cLR = G(cLR_data['img'], random_z) # Get scores and loss real_d_cLR_1, real_d_cLR_2 = D_cLR(cLR_data['ground_truth']) fake_d_cLR_1, fake_d_cLR_2 = D_cLR(fake_img_cLR.detach()) D_loss_cLR_1 = criterion_L2(real_d_cLR_1, 1) + criterion_L2( fake_d_cLR_1, 0) D_loss_cLR_2 = criterion_L2(real_d_cLR_2, 1) + criterion_L2( fake_d_cLR_2, 0) D_loss = D_loss_cVAE_1 + D_loss_cLR_1 + D_loss_cVAE_2 + D_loss_cLR_2 # Update optimizer_D_cVAE.zero_grad() optimizer_D_cLR.zero_grad() D_loss.backward() optimizer_D_cVAE.step() optimizer_D_cLR.step() ''' ----------------------------- 2. Train G & E ----------------------------- ''' ############# Step 1. GAN loss to fool discriminator (cVAE_GAN and cLR_GAN) ############# # Encoded latent vector mu, log_variance = E(cVAE_data['ground_truth']) std = torch.exp(log_variance / 2) random_z = torch.randn(1, opt.z_dim).cuda() encoded_z = (random_z * std) + mu # Generate fake image and get adversarial loss fake_img_cVAE = G(cVAE_data['img'], encoded_z) fake_d_cVAE_1, fake_d_cVAE_2 = D_cVAE(fake_img_cVAE) GAN_loss_cVAE_1 = criterion_L2(fake_d_cVAE_1, 1) GAN_loss_cVAE_2 = criterion_L2(fake_d_cVAE_2, 1) # Random latent vector random_z = torch.randn(1, opt.z_dim).cuda() # Generate fake image and get adversarial loss fake_img_cLR = G(cLR_data['img'], random_z) fake_d_cLR_1, fake_d_cLR_2 = D_cLR(fake_img_cLR) GAN_loss_cLR_1 = criterion_L2(fake_d_cLR_1, 1) GAN_loss_cLR_2 = criterion_L2(fake_d_cLR_2, 1) G_GAN_loss = GAN_loss_cVAE_1 + GAN_loss_cVAE_2 + GAN_loss_cLR_1 + GAN_loss_cLR_2 G_GAN_loss = opt.lambda_gan * G_GAN_loss ############# Step 2. KL-divergence with N(0, 1) (cVAE-GAN) ############# KL_div_loss = opt.lambda_kl * torch.sum( 0.5 * (mu**2 + torch.exp(log_variance) - log_variance - 1)) ############# Step 3. Reconstruction of ground truth image (|G(A, z) - B|) (cVAE-GAN) ############# img_recon_loss = opt.lambda_recon * criterion_L1( fake_img_cVAE, cVAE_data['ground_truth']) EG_loss = G_GAN_loss + KL_div_loss + img_recon_loss optimizer_G.zero_grad() optimizer_E.zero_grad() EG_loss.backward(retain_graph=True) optimizer_G.step() optimizer_E.step() ''' ----------------------------- 3. Train ONLY G ----------------------------- ''' ############ Step 1. Reconstrution of random latent code (|E(G(A, z)) - z|) (cLR-GAN) ############ # This step should update ONLY G. mu_, log_variance_ = E(fake_img_cLR) G_alone_loss = opt.lambda_z * criterion_L1(mu_, random_z) optimizer_G.zero_grad() G_alone_loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [D Loss: %.4f] [GAN Loss: %.4f] [Recon Loss: %.4f] [KL Loss: %.4f] [z Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), D_loss.item(), D_loss.item(), G_GAN_loss.item(), img_recon_loss.item(), KL_div_loss.item(), G_alone_loss.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), optimizer_D_cVAE) adjust_learning_rate(opt, (epoch + 1), optimizer_D_cLR) adjust_learning_rate(opt, (epoch + 1), optimizer_E)
def Train_GAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator generatorNet = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) flownet = utils.create_pwcnet(opt) # To device if opt.multi_gpu: generatorNet = nn.DataParallel(generatorNet) generatorNet = generatorNet.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() flownet = nn.DataParallel(flownet) flownet = flownet.cuda() else: discriminator = discriminator.cuda() generatorNet = generatorNet.cuda() flownet = flownet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generatorNet.parameters(), lr = opt.lr_g, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d, betas = (opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor ** (iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save(generator.module, 'Pre_%s_epoch%d_bs%d_Gan%d_os%d_ol%d.pth' % (opt.task, epoch, opt.batch_size, opt.lambda_gan, opt.lambda_flow, opt.lambda_flow_long)) print('The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save(generator.module, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print('The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save(generator, 'Pre_%s_epoch%d_bs%d_GAN%d_os%d_ol%d.pth' % (opt.task, epoch, opt.batch_size, opt.lambda_gan, opt.lambda_flow, opt.lambda_flow_long)) print('The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save(generator, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print('The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the class list imglist = utils.text_readlines('videocolor_linux.txt') classlist = utils.get_dirs(opt.baseroot) ''' imgnumber = len(imglist) - (len(imglist) % opt.batch_size) imglist = imglist[:imgnumber] ''' # Define the dataset trainset = dataset.MultiFramesDataset(opt, imglist, classlist) print('The overall number of classes:', len(trainset)) # Define the dataloader dataloader = utils.create_dataloader(trainset, opt) # ---------------------------------------- # Training # ---------------------------------------- # Tensor type Tensor = torch.cuda.FloatTensor # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for iteration, (in_part, out_part) in enumerate(dataloader): # Train Generator optimizer_G.zero_grad() optimizer_D.zero_grad() lstm_state = None loss_flow = 0 loss_flow_long = 0 loss_L1 = 0 loss_D = 0 loss_G = 0 x_0 = in_part[0].cuda() p_t_0 = in_part[0].cuda() # Adversarial ground truth valid = Tensor(np.ones((in_part[0].shape[0], 1, 30, 30))) fake = Tensor(np.zeros((in_part[0].shape[0], 1, 30, 30))) for iter_frame in range(opt.iter_frames): # Read data x_t = in_part[iter_frame].cuda() y_t = out_part[iter_frame].cuda() # Initialize the second input and compute flow loss if iter_frame == 0: p_t_last = torch.zeros(opt.batch_size, opt.out_channels, opt.resize_h, opt.resize_w).cuda() elif iter_frame == 1: x_t_last = in_part[iter_frame - 1].cuda() p_t_last = p_t.detach() p_t_0 = p_t.detach() p_t_last.requires_grad = False p_t_0.requires_grad = False # o_t_last_2_t range is [-20, +20] o_t_last_2_t = pwcnet.PWCEstimate(flownet, x_t, x_t_last) x_t_warp = pwcnet.PWCNetBackward((x_t_last + 1) / 2, o_t_last_2_t) # y_t_warp range is [0, 1] p_t_warp = pwcnet.PWCNetBackward((p_t_last + 1) / 2, o_t_last_2_t) else: x_t_last = in_part[iter_frame - 1].cuda() p_t_last = p_t.detach() p_t_last.requires_grad = False # o_t_last_2_t o_t_first_2_t range is [-20, +20] o_t_last_2_t = pwcnet.PWCEstimate(flownet, x_t, x_t_last) o_t_first_2_t = pwcnet.PWCEstimate(flownet,x_t, x_0) # y_t_warp, y_t_warp_long range is [0, 1] x_t_warp = pwcnet.PWCNetBackward((x_t_last + 1) / 2, o_t_last_2_t) p_t_warp = pwcnet.PWCNetBackward((p_t_last + 1) / 2, o_t_last_2_t) x_t_warp_long = pwcnet.PWCNetBackward((x_0 + 1) / 2, o_t_first_2_t) p_t_warp_long = pwcnet.PWCNetBackward((p_t_0 + 1) / 2, o_t_first_2_t) # Generator output p_t, lstm_state = generatorNet(x_t, p_t_last, lstm_state) lstm_state = utils.repackage_hidden(lstm_state) if iter_frame == 1: mask_flow = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp, dim=1).pow(2) ).unsqueeze(1) loss_flow += criterion_L1(mask_flow * (p_t + 1) / 2, mask_flow * p_t_warp) elif iter_frame > 1: mask_flow = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp, dim=1).pow(2) ).unsqueeze(1) loss_flow += criterion_L1(mask_flow * (p_t + 1) / 2, mask_flow * p_t_warp) mask_flow_long = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp_long, dim=1).pow(2) ).unsqueeze(1) loss_flow_long += criterion_L1(mask_flow_long * (p_t + 1) / 2, mask_flow_long * p_t_warp_long) # Fake samples fake_scalar = discriminator(x_t, p_t.detach()) loss_fake = criterion_MSE(fake_scalar, fake) # True samples true_scalar = discriminator(x_t, y_t) loss_true = criterion_MSE(true_scalar, valid) # Train Discriminator loss_D += 0.5 * (loss_fake + loss_true) # Train Generator # GAN Loss fake_scalar = discriminator(x_t, p_t) loss_G += criterion_MSE(fake_scalar, valid) # Pixel-level loss loss_L1 += criterion_L1(p_t, y_t) # Overall Loss and optimize loss = loss_L1 + opt.lambda_flow * loss_flow + opt.lambda_flow_long * loss_flow_long + opt.lambda_gan * loss_G loss.backward() loss_D.backward() optimizer_G.step() optimizer_D.step() # Determine approximate time left iters_done = epoch * len(dataloader) + iteration iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print("\r[Epoch %d/%d] [Batch %d/%d] [L1 Loss: %.4f] [Flow Loss Short: %.8f] [Flow Loss Long: %.8f] [G Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, iteration, len(dataloader), loss_L1.item(), loss_flow.item(), loss_flow_long.item(), loss_G.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generatorNet) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)
def Train_single(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator generatorNet = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) # To device if opt.multi_gpu: generatorNet = nn.DataParallel(generatorNet) generatorNet = generatorNet.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() else: discriminator = discriminator.cuda() generatorNet = generatorNet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generatorNet.parameters(), lr = opt.lr_g, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d, betas = (opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor ** (iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save(generator.module, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save(generator.module, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print('The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save(generator, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save(generator, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print('The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the class list imglist = utils.text_readlines('ILSVRC2012_train_sal_name.txt')[:1272480] # Define the dataset trainset = dataset.ColorizationDataset(opt, imglist) print('The overall number of classes:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True) # ---------------------------------------- # Training # ---------------------------------------- # Tensor type Tensor = torch.cuda.FloatTensor # Count start time prev_time = time.time() # For loop training # For loop training for epoch in range(opt.epochs): for iteration, (x_t, y_t) in enumerate(dataloader): # Train Generator optimizer_G.zero_grad() optimizer_D.zero_grad() lstm_state = None x_t = x_t.cuda() y_t = y_t.cuda() valid = Tensor(np.ones((x_t.shape[0], 1, 30, 30))) fake = Tensor(np.zeros((x_t.shape[0], 1, 30, 30))) p_t_last = torch.zeros(opt.batch_size, opt.out_channels, opt.resize_h, opt.resize_w).cuda() # Train Discriminator # Generator output p_t, lstm_state = generatorNet(x_t, p_t_last, lstm_state) # Fake samples fake_scalar = discriminator(x_t, p_t.detach()) loss_fake = criterion_MSE(fake_scalar, fake) # True samples true_scalar = discriminator(x_t, y_t) loss_true = criterion_MSE(true_scalar, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) # Train Generator # GAN Loss fake_scalar = discriminator(x_t, p_t) loss_G = criterion_MSE(fake_scalar, valid) # Pixel-level loss loss_L1 = criterion_L1(p_t, y_t) # Overall Loss and optimize loss = loss_L1 + opt.lambda_gan * loss_G loss.backward() loss_D.backward() optimizer_G.step() optimizer_D.step() # Determine approximate time left iters_done = epoch * len(dataloader) + iteration iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print("\r[Epoch %d/%d] [Batch %d/%d] [L1 Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, iteration, len(dataloader), loss_L1.item(), loss_G.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generatorNet) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the two-stage generator model def save_model_generator(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # Save the dicriminator model def save_model_discriminator(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # load the model def load_model(net, epoch, opt, type='G'): """Save the model at "checkpoint_interval" and its multiple""" if type == 'G': model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) else: model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) pretrained_dict = torch.load(model_name) net.load_state_dict(pretrained_dict) if opt.resume: load_model(generator, opt.resume_epoch, opt, type='G') load_model(discriminator, opt.resume_epoch, opt, type='D') print( '--------------------Pretrained Models are Loaded--------------------' ) # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = train_dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True, drop_last=True) # ---------------------------------------- # Training # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.resume_epoch, opt.epochs): for batch_idx, (img, height, width) in enumerate(dataloader): img = img.cuda() # set the same free form masks for each batch mask = torch.empty(img.shape[0], 1, img.shape[2], img.shape[3]).cuda() for i in range(opt.batch_size): mask[i] = torch.from_numpy( train_dataset.InpaintDataset.random_ff_mask( shape=(height[0], width[0])).astype(np.float32)).cuda() # LSGAN vectors valid = Tensor( np.ones((img.shape[0], 1, height[0] // 32, width[0] // 32))) fake = Tensor( np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32))) zero = Tensor( np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32))) ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Loss and optimize loss_fake = -torch.mean(torch.min(zero, -valid - fake_scalar)) loss_true = -torch.mean(torch.min(zero, -valid + true_scalar)) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # L1 Loss first_L1Loss = (first_out - img).abs().mean() second_L1Loss = (second_out - img).abs().mean() # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = -torch.mean(fake_scalar) # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_featuremaps = perceptualnet(second_out) second_PerceptualLoss = L1Loss(second_out_featuremaps, img_featuremaps) # Compute losses loss = opt.lambda_l1 * first_L1Loss + opt.lambda_l1 * second_L1Loss + \ opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_L1Loss.item(), second_L1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) masked_img = img * (1 - mask) + mask mask = torch.cat((mask, mask, mask), 1) if (batch_idx + 1) % 40 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = [ 'gt', 'mask', 'masked_img', 'first_out', 'second_out' ] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model_generator(generator, (epoch + 1), opt) save_model_discriminator(discriminator, (epoch + 1), opt) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out'] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255)
def Continue_train_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task_name) sample_folder = os.path.join(opt.sample_path, opt.task_name) utils.check_path(save_folder) utils.check_path(sample_folder) # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, optimizer): target_epoch = opt.epochs - opt.lr_decrease_epoch remain_epoch = opt.epochs - epoch if epoch >= opt.lr_decrease_epoch: lr = opt.lr_g * remain_epoch / target_epoch for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save( generator.module.state_dict(), 'DeblurGANv1_wgan_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save( generator.module.state_dict(), 'DeblurGANv1_wgan_iter%d_bs%d.pth' % (iteration, opt.train_batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save( generator.state_dict(), 'DeblurGANv1_wgan_epoch%d_bs%d.pth' % (epoch, opt.train_batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save( generator.state_dict(), 'DeblurGANv1_wgan_iter%d_bs%d.pth' % (iteration, opt.train_batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.train_batch_size *= gpu_num #opt.val_batch_size *= gpu_num opt.num_workers *= gpu_num # Define the dataset trainset = dataset.DeblurDataset(opt, 'train') valset = dataset.DeblurDataset(opt, 'val') print('The overall number of training images:', len(trainset)) print('The overall number of validation images:', len(valset)) # Define the dataloader train_loader = DataLoader(trainset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_loader = DataLoader(valset, batch_size=opt.val_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target) in enumerate(train_loader): # To device true_input = true_input.cuda() true_target = true_target.cuda() # Train Discriminator for j in range(opt.additional_training_d): optimizer_D.zero_grad() # Generator output fake_target = generator(true_input) # Fake samples fake_scalar_d = discriminator(true_input, fake_target.detach()) true_scalar_d = discriminator(true_input, true_target) # Overall Loss and optimize loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d) loss_D.backward() # Train Generator optimizer_G.zero_grad() fake_target = generator(true_input) # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, true_target) # GAN Loss fake_scalar = discriminator(true_input, fake_target) GAN_Loss = -torch.mean(fake_scalar) # Overall Loss and optimize loss = opt.lambda_l1 * Pixellevel_L1_Loss + GAN_Loss loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(train_loader) + i iters_left = opt.epochs * len(train_loader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(train_loader), Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_target, true_target] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='train_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) ### Validation val_PSNR = 0 num_of_val_image = 0 for j, (true_input, true_target) in enumerate(val_loader): # To device # A is for input image, B is for target image true_input = true_input.cuda() true_target = true_target.cuda() # Forward propagation with torch.no_grad(): fake_target = generator(true_input) # Accumulate num of image and val_PSNR num_of_val_image += true_input.shape[0] val_PSNR += utils.psnr(fake_target, true_target, 1) * true_input.shape[0] val_PSNR = val_PSNR / num_of_val_image ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_target, true_target] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='val_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Record average PSNR print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))
def Trainer_GAN(opt): # ---------------------------------------- # Initialization # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task_name) sample_folder = os.path.join(opt.sample_path, opt.task_name) if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Initialize networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # ---------------------------------------- # Network dataset # ---------------------------------------- # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.train_batch_size *= gpu_num #opt.val_batch_size *= gpu_num opt.num_workers *= gpu_num # Define the dataset train_imglist = utils.get_jpgs(os.path.join(opt.in_path_train)) val_imglist = utils.get_jpgs(os.path.join(opt.in_path_val)) train_dataset = dataset.Qbayer2RGB_dataset(opt, 'train', train_imglist) val_dataset = dataset.Qbayer2RGB_dataset(opt, 'val', val_imglist) print('The overall number of training images:', len(train_imglist)) print('The overall number of validation images:', len(val_imglist)) # Define the dataloader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.val_batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Network training parameters # ---------------------------------------- # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() class ColorLoss(nn.Module): def __init__(self): super(ColorLoss, self).__init__() self.L1loss = nn.L1Loss() def RGB2YUV(self, RGB): YUV = RGB.clone() YUV[:, 0, :, :] = 0.299 * RGB[:, 0, :, :] + 0.587 * RGB[:, 1, :, :] + 0.114 * RGB[:, 2, :, :] YUV[:, 1, :, :] = -0.14713 * RGB[:, 0, :, :] - 0.28886 * RGB[:, 1, :, :] + 0.436 * RGB[:, 2, :, :] YUV[:, 2, :, :] = 0.615 * RGB[:, 0, :, :] - 0.51499 * RGB[:, 1, :, :] - 0.10001 * RGB[:, 2, :, :] return YUV def forward(self, x, y): yuv_x = self.RGB2YUV(x) yuv_y = self.RGB2YUV(y) return self.L1loss(yuv_x, yuv_y) yuv_loss = ColorLoss() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(generator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer, lr_gd): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = lr_gd * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = lr_gd * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): # Define the name of trained model if opt.save_mode == 'epoch': model_name = '%s_gan_noise%.3f_epoch%d_bs%d.pth' % ( opt.net_mode, opt.noise_level, epoch, opt.train_batch_size) if opt.save_mode == 'iter': model_name = '%s_gan_noise%.3f_iter%d_bs%d.pth' % ( opt.net_mode, opt.noise_level, iteration, opt.train_batch_size) save_model_path = os.path.join(opt.save_path, opt.task_name, model_name) # Save model if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensorboard writer = SummaryWriter() # For loop training for epoch in range(opt.epochs): # Record learning rate for param_group in optimizer_G.param_groups: writer.add_scalar('data/lr', param_group['lr'], epoch) print('learning rate = ', param_group['lr']) if epoch == 0: iters_done = 0 ### Training for i, (in_img, RGBout_img) in enumerate(train_loader): # To device # A is for input image, B is for target image in_img = in_img.cuda() RGBout_img = RGBout_img.cuda() ## Train Discriminator # Forward propagation out = generator(in_img) optimizer_D.zero_grad() # Fake samples fake_scalar_d = discriminator(in_img, out.detach()) true_scalar_d = discriminator(in_img, RGBout_img) # Overall Loss and optimize loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d) loss_D.backward() #torch.nn.utils.clip_grad_norm(discriminator.parameters(), opt.grad_clip_norm) optimizer_D.step() ## Train Generator # Forward propagation out = generator(in_img) # GAN loss fake_scalar = discriminator(in_img, out) L_gan = -torch.mean(fake_scalar) * opt.lambda_gan # Perceptual loss features fake_B_fea = perceptualnet(utils.normalize_ImageNet_stats(out)) true_B_fea = perceptualnet( utils.normalize_ImageNet_stats(RGBout_img)) L_percep = opt.lambda_percep * criterion_L1(fake_B_fea, true_B_fea) # Pixel loss L_pixel = opt.lambda_pixel * criterion_L1(out, RGBout_img) # Color loss L_color = opt.lambda_color * yuv_loss(out, RGBout_img) # Sum up to total loss loss = L_pixel + L_percep + L_gan + L_color # Record losses writer.add_scalar('data/L_pixel', L_pixel.item(), iters_done) writer.add_scalar('data/L_percep', L_percep.item(), iters_done) writer.add_scalar('data/L_color', L_color.item(), iters_done) writer.add_scalar('data/L_gan', L_gan.item(), iters_done) writer.add_scalar('data/L_total', loss.item(), iters_done) writer.add_scalar('data/loss_D', loss_D.item(), iters_done) # Backpropagate gradients optimizer_G.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm(generator.parameters(), opt.grad_clip_norm) optimizer_G.step() # Determine approximate time left iters_done = epoch * len(train_loader) + i + 1 iters_left = opt.epochs * len(train_loader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] [L_pixel: %.4f]" % ((epoch + 1), opt.epochs, i, len(train_loader), loss.item(), L_pixel.item())) print( "\r[L_percep: %.4f] [L_color: %.4f] [L_gan: %.4f] [loss_D: %.4f] Time_left: %s" % (L_percep.item(), L_color.item(), L_gan.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), iters_done, len(train_loader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_G, opt.lr_g) adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_D, opt.lr_d) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [out, RGBout_img] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='train_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) ### Validation val_PSNR = 0 num_of_val_image = 0 for j, (in_img, RGBout_img) in enumerate(val_loader): # To device # A is for input image, B is for target image in_img = in_img.cuda() RGBout_img = RGBout_img.cuda() # Forward propagation with torch.no_grad(): out = generator(in_img) # Accumulate num of image and val_PSNR num_of_val_image += in_img.shape[0] val_PSNR += utils.psnr(out, RGBout_img, 1) * in_img.shape[0] val_PSNR = val_PSNR / num_of_val_image ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [out, RGBout_img] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='val_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Record average PSNR writer.add_scalar('data/val_PSNR', val_PSNR, epoch) print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR)) writer.close()
def Continue_train_LSGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'LSGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'LSGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'LSGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'LSGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # Tensor type Tensor = torch.cuda.FloatTensor # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.NormalRGBDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target) in enumerate(dataloader): # To device true_input = true_input.cuda() true_target = true_target.cuda() # Adversarial ground truth valid = Tensor(np.ones((true_input.shape[0], 1, 30, 30))) fake = Tensor(np.zeros((true_input.shape[0], 1, 30, 30))) # Train Discriminator for j in range(opt.additional_training_d): optimizer_D.zero_grad() # Generator output fake_target = generator(true_input) # Fake samples fake_scalar_d = discriminator(true_input, fake_target.detach()) loss_fake = criterion_MSE(fake_scalar_d, fake) # True samples true_scalar_d = discriminator(true_input, true_target) loss_true = criterion_MSE(true_scalar_d, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_D.step() # Train Generator optimizer_G.zero_grad() fake_target = generator(true_input) # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, true_target) # GAN Loss fake_scalar = discriminator(true_input, fake_target) GAN_Loss = criterion_MSE(fake_scalar, valid) # Overall Loss and optimize loss = Pixellevel_L1_Loss + opt.lambda_gan * GAN_Loss loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)
def Trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network if opt.cudnn_benchmark == True: cudnn.benchmark = True else: cudnn.benchmark = False # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) generator = generator.cuda() discriminator = discriminator.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() # Loss functions L1Loss = nn.L1Loss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(generator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate_g(optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr def adjust_learning_rate_d(optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = opt.lr_d * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save( net.module, 'ContextureEncoder_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save( net, 'ContextureEncoder_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Training loop for epoch in range(opt.epochs): for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() ### Train Discriminator optimizer_d.zero_grad() # Generator output masked_img = img * (1 - mask) fake = generator(masked_img) # Fake samples fake_scalar = discriminator(fake.detach()) # True samples true_scalar = discriminator(img) # Overall Loss and optimize loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar) loss_D.backward() ### Train Generator optimizer_g.zero_grad() # forward propagation fusion_fake = img * (1 - mask) + fake * mask # in range [-1, 1] # Mask L1 Loss MaskL1Loss = L1Loss(fusion_fake, img) # GAN Loss fake_scalar = discriminator(fusion_fake) GAN_Loss = -torch.mean(fake_scalar) # Compute losses loss = MaskL1Loss + opt.gan_param * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), MaskL1Loss.item(), loss_D.item(), GAN_Loss.item(), time_left)) # Learning rate decrease adjust_learning_rate_g(optimizer_g, (epoch + 1), opt) adjust_learning_rate_d(optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, (epoch + 1), opt)
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss()#reduce=False, size_average=False) RELU = nn.ReLU() # Optimizers optimizer_g1 = torch.optim.Adam(generator.coarse.parameters(), lr=opt.lr_g) optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt, batch=0, is_D=False): """Save the model at "checkpoint_interval" and its multiple""" if is_D==True: model_name = 'discriminator_WGAN_epoch%d_batch%d.pth' % (epoch + 1, batch) else: model_name = 'deepfillv2_WGAN_epoch%d_batch%d.pth' % (epoch+1, batch) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Training loop for epoch in range(opt.epochs): print("Start epoch ", epoch+1, "!") for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * (1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * (1 - mask) + second_out * mask # in range [0, 1] for wk in range(1): optimizer_d.zero_grad() fake_scalar = discriminator(second_out_wholeimg.detach(), mask) true_scalar = discriminator(img, mask) #W_Loss = -torch.mean(true_scalar) + torch.mean(fake_scalar)#+ gradient_penalty(discriminator, img, second_out_wholeimg, mask) hinge_loss = torch.mean(RELU(1-true_scalar)) + torch.mean(RELU(fake_scalar+1)) loss_D = hinge_loss loss_D.backward(retain_graph=True) optimizer_d.step() ### Train Generator # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = - torch.mean(fake_scalar) optimizer_g1.zero_grad() first_MaskL1Loss.backward(retain_graph=True) optimizer_g1.step() optimizer_g.zero_grad() # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg_featuremaps = perceptualnet(second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) loss = 0.5*opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + GAN_Loss + second_PerceptualLoss * opt.lambda_perceptual loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print("\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, (batch_idx+1), len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print("\r[D Loss: %.5f] [Perceptual Loss: %.5f] [G Loss: %.5f] time_left: %s" % (loss_D.item(), second_PerceptualLoss.item(), GAN_Loss.item(), time_left)) if (batch_idx + 1) % 100 ==0: # Generate Visualization image masked_img = img * (1 - mask) + mask img_save = torch.cat((img, masked_img, first_out, second_out, first_out_wholeimg, second_out_wholeimg),3) # Recover normalization: * 255 because last layer is sigmoid activated img_save = F.interpolate(img_save, scale_factor=0.5) img_save = img_save * 255 # Process img_copy and do not destroy the data of img img_copy = img_save.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy() #img_copy = np.clip(img_copy, 0, 255) img_copy = img_copy.astype(np.uint8) save_img_name = 'sample_batch' + str(batch_idx+1) + '.png' save_img_path = os.path.join(sample_folder, save_img_name) img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR) cv2.imwrite(save_img_path, img_copy) if (batch_idx + 1) % 5000 == 0: save_model(generator, epoch, opt, batch_idx+1) save_model(discriminator, epoch, opt, batch_idx+1, is_D=True) #Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, epoch, opt) save_model(discriminator, epoch , opt, is_D=True)
def Trainer_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator_a, generator_b = utils.create_generator(opt) discriminator_a, discriminator_b = utils.create_discriminator(opt) # To device if opt.multi_gpu: generator_a = nn.DataParallel(generator_a) generator_a = generator_a.cuda() generator_b = nn.DataParallel(generator_b) generator_b = generator_b.cuda() discriminator_a = nn.DataParallel(discriminator_a) discriminator_a = discriminator_a.cuda() discriminator_b = nn.DataParallel(discriminator_b) discriminator_b = discriminator_b.cuda() else: generator_a = generator_a.cuda() generator_b = generator_b.cuda() discriminator_a = discriminator_a.cuda() discriminator_b = discriminator_b.cuda() # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(generator_a.parameters(), generator_b.parameters()), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D_a = torch.optim.Adam(discriminator_a.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) optimizer_D_b = torch.optim.Adam(discriminator_b.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator_a, generator_b): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator_a.module, 'WGAN_DRIT_epoch%d_bs%d_a.pth' % (epoch, opt.batch_size)) torch.save( generator_b.module, 'WGAN_DRIT_epoch%d_bs%d_b.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator_a.module, 'WGAN_DRIT_iter%d_bs%d_a.pth' % (iteration, opt.batch_size)) torch.save( generator_b.module, 'WGAN_DRIT_iter%d_bs%d_b.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator_a, 'WGAN_DRIT_epoch%d_bs%d_a.pth' % (epoch, opt.batch_size)) torch.save( generator_b, 'WGAN_DRIT_epoch%d_bs%d_b.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator_a, 'WGAN_DRIT_iter%d_bs%d_a.pth' % (iteration, opt.batch_size)) torch.save( generator_b, 'WGAN_DRIT_iter%d_bs%d_b.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- dataloader = utils.create_dataloader(opt) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (img_a, img_b) in enumerate(dataloader): # To device img_a = img_a.cuda() img_b = img_b.cuda() # Sampled style codes (prior) prior_s_a = torch.randn(img_a.shape[0], opt.style_dim).cuda() prior_s_b = torch.randn(img_a.shape[0], opt.style_dim).cuda() # ---------------------------------------- # Train Generator # ---------------------------------------- # Note that: # input / output image dimension: [B, 3, 256, 256] # content_code dimension: [B, 256, 64, 64] # style_code dimension: [B, 8] # generator_a is related to domain a / style a # generator_b is related to domain b / style b optimizer_G.zero_grad() # Get shared latent representation c_a, s_a = generator_a.encode(img_a) c_b, s_b = generator_b.encode(img_b) # Reconstruct images img_aa_recon = generator_a.decode(c_a, s_a) img_bb_recon = generator_b.decode(c_b, s_b) # Translate images img_ba = generator_a.decode(c_b, prior_s_a) img_ab = generator_b.decode(c_a, prior_s_b) # Cycle code translation c_b_recon, s_a_recon = generator_a.encode(img_ba) c_a_recon, s_b_recon = generator_b.encode(img_ab) # Cycle image translation img_aa_recon_cycle = generator_a.decode( c_a_recon, s_a) if opt.lambda_cycle > 0 else 0 img_bb_recon_cycle = generator_b.decode( c_b_recon, s_b) if opt.lambda_cycle > 0 else 0 # Losses loss_id_1 = opt.lambda_id * criterion_L1(img_aa_recon, img_a) loss_id_2 = opt.lambda_id * criterion_L1(img_bb_recon, img_b) loss_s_1 = opt.lambda_style * criterion_L1(s_a_recon, prior_s_a) loss_s_2 = opt.lambda_style * criterion_L1(s_b_recon, prior_s_b) loss_c_1 = opt.lambda_content * criterion_L1( c_a_recon, c_a.detach()) loss_c_2 = opt.lambda_content * criterion_L1( c_b_recon, c_b.detach()) loss_cycle_1 = opt.lambda_cycle * criterion_L1( img_aa_recon_cycle, img_a) if opt.lambda_cycle > 0 else 0 loss_cycle_2 = opt.lambda_cycle * criterion_L1( img_bb_recon_cycle, img_b) if opt.lambda_cycle > 0 else 0 # GAN Loss fake_scalar_a = discriminator_a(img_ba) fake_scalar_b = discriminator_b(img_ab) loss_gan1 = -opt.lambda_gan * torch.mean(fake_scalar_a) loss_gan2 = -opt.lambda_gan * torch.mean(fake_scalar_b) # Overall Losses and optimization loss_G = loss_id_1 + loss_id_2 + loss_s_1 + loss_s_2 + loss_c_1 + loss_c_2 + loss_cycle_1 + loss_cycle_2 + loss_gan1 + loss_gan2 loss_G.backward() optimizer_G.step() # ---------------------------------------- # Train Discriminator # ---------------------------------------- optimizer_D_a.zero_grad() optimizer_D_b.zero_grad() # D_a fake_scalar_a = discriminator_a(img_ba.detach()) true_scalar_a = discriminator_a(img_a) loss_D_a = torch.mean(fake_scalar_a) - torch.mean(true_scalar_a) loss_D_a.backward() optimizer_D_a.step() # D_b fake_scalar_b = discriminator_b(img_ab.detach()) true_scalar_b = discriminator_b(img_b) loss_D_b = torch.mean(fake_scalar_b) - torch.mean(true_scalar_b) loss_D_b.backward() optimizer_D_b.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] [Style Loss: %.4f] [Content Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), (loss_id_1 + loss_id_2).item(), (loss_s_1 + loss_s_2).item(), (loss_c_1 + loss_c_2).item(), (loss_gan1 + loss_gan2).item(), (loss_D_a + loss_D_b).item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator_a, generator_b) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D_a) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D_b)
def LSGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network if opt.cudnn_benchmark == True: cudnn.benchmark = True else: cudnn.benchmark = False # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() #FeatureMatchingLoss = FML1Loss(opt.fm_param) # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(generator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save( net.module, 'deepfillNet_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save( net, 'deepfillNet_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.epochs): for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() # LSGAN vectors valid = Tensor(np.ones((img.shape[0], 1, 8, 8))) fake = Tensor(np.zeros((img.shape[0], 1, 8, 8))) ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [-1, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [-1, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Overall Loss and optimize loss_fake = MSELoss(fake_scalar, fake) loss_true = MSELoss(true_scalar, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = MSELoss(fake_scalar, valid) # Get the deep semantic feature maps, and compute Perceptual Loss img = (img + 1) / 2 # in range [0, 1] img = utils.normalize_ImageNet_stats(img) # in range of ImageNet img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg = (second_out_wholeimg + 1) / 2 # in range [0, 1] second_out_wholeimg = utils.normalize_ImageNet_stats( second_out_wholeimg) second_out_wholeimg_featuremaps = perceptualnet( second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) # Compute losses loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, (epoch + 1), opt)
def Inpainting(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator G_I = utils.create_generator(opt) D_I = utils.create_discriminator(opt) # To device if opt.multi_gpu: G_I = nn.DataParallel(G_I) G_I = G_I.cuda() D_I = nn.DataParallel(D_I) D_I = D_I.cuda() else: G_I = G_I.cuda() G_I = G_I.cuda() D_I = D_I.cuda() # Optimizers optimizer_G = torch.optim.RMSprop(G_I.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.RMSprop(D_I.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, G_I): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( G_I.module, 'G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( G_I.module, 'G_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( G_I, 'G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( G_I, 'G_iter%d_bs%d.pth' % (iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # Tensor type Tensor = torch.cuda.FloatTensor # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.DomainTransferDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, img, valid in enumerate(dataloader): # To device img = img.cuda() valid = valid.cuda() # Adversarial ground truth valid = Tensor(np.ones((img.shape[0], 1, 16, 16))) fake = Tensor(np.zeros((img.shape[0], 1, 16, 16))) # Train Generator optimizer_G.zero_grad() # GAN Loss fake = G_I(img) loss_GAN = criterion_MSE(D_I(fake), valid) # wGAN Loss fake = G_I(img) loss_wGAN = -torch.mean(D_I(valid)) + torch.mean(D_I(fake)) # Reconstruction Loss(L2) loss_rec = criterion_MSE(fake, valid) # Overall Loss and optimize loss = opt.lambda_rec * loss_rec + opt.lambda_adv * loss_wGAN loss.backward() optimizer_G.step() for p in D_I.parameters(): p.data.clamp_(-0.01, 0.01) # Train Discriminator optimizer_D.zero_grad() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [G wGAN Loss: %.4f] [G rec Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_wGAN.item(), loss_rec.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G_I) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)