def create_networks(opt, checkpoint=None): generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() if checkpoint: # Restore the network state generator.load_state_dict(checkpoint['G']) discriminator.load_state_dict(checkpoint['D']) # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() return generator, discriminator, perceptualnet
def LSGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network if opt.cudnn_benchmark == True: cudnn.benchmark = True else: cudnn.benchmark = False # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() #FeatureMatchingLoss = FML1Loss(opt.fm_param) # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(generator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save( net.module, 'deepfillNet_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save( net, 'deepfillNet_epoch%d_batchsize%d.pth' % (epoch, opt.batch_size)) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.epochs): for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() # LSGAN vectors valid = Tensor(np.ones((img.shape[0], 1, 8, 8))) fake = Tensor(np.zeros((img.shape[0], 1, 8, 8))) ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [-1, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [-1, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Overall Loss and optimize loss_fake = MSELoss(fake_scalar, fake) loss_true = MSELoss(true_scalar, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = MSELoss(fake_scalar, valid) # Get the deep semantic feature maps, and compute Perceptual Loss img = (img + 1) / 2 # in range [0, 1] img = utils.normalize_ImageNet_stats(img) # in range of ImageNet img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg = (second_out_wholeimg + 1) / 2 # in range [0, 1] second_out_wholeimg = utils.normalize_ImageNet_stats( second_out_wholeimg) second_out_wholeimg_featuremaps = perceptualnet( second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) # Compute losses loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, (epoch + 1), opt)
def Continue_train_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'WGAN_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'WGAN_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.RAW2RGBDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target, true_sal) in enumerate(dataloader): # To device true_input = true_input.cuda() true_target = true_target.cuda() true_sal = true_sal.cuda() true_sal3 = torch.cat((true_sal, true_sal, true_sal), 1).cuda() # Train Discriminator for j in range(opt.additional_training_d): optimizer_D.zero_grad() # Generator output fake_target, fake_sal = generator(true_input) # Fake samples fake_block1, fake_block2, fake_block3, fake_scalar = discriminator( true_input, fake_target.detach()) # True samples true_block1, true_block2, true_block3, true_scalar = discriminator( true_input, true_target) ''' # Feature Matching Loss FM_Loss = criterion_L1(fake_block1, true_block1) + criterion_L1(fake_block2, true_block2) + criterion_L1(fake_block3, true_block3) ''' # Overall Loss and optimize loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar) loss_D.backward() # Train Generator optimizer_G.zero_grad() fake_target, fake_sal = generator(true_input) # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, true_target) # Attention Loss true_Attn_target = true_target.mul(true_sal3) fake_sal3 = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_Attn_target = fake_target.mul(fake_sal3) Attention_Loss = criterion_L1(fake_Attn_target, true_Attn_target) # GAN Loss fake_block1, fake_block2, fake_block3, fake_scalar = discriminator( true_input, fake_target) GAN_Loss = -torch.mean(fake_scalar) # Perceptual Loss fake_target = fake_target * 0.5 + 0.5 true_target = true_target * 0.5 + 0.5 fake_target = utils.normalize_ImageNet_stats(fake_target) true_target = utils.normalize_ImageNet_stats(true_target) fake_percep_feature = perceptualnet(fake_target) true_percep_feature = perceptualnet(true_target) Perceptual_Loss = criterion_L1(fake_percep_feature, true_percep_feature) # Overall Loss and optimize loss = Pixellevel_L1_Loss + opt.lambda_attn * Attention_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_percep * Perceptual_Loss loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f]" % ((epoch + 1), opt.epochs, i, len(dataloader), Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item())) print( "\r[Attention Loss: %.4f] [Perceptual Loss: %.4f] Time_left: %s" % (Attention_Loss.item(), Perceptual_Loss.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D)
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss()#reduce=False, size_average=False) RELU = nn.ReLU() # Optimizers optimizer_g1 = torch.optim.Adam(generator.coarse.parameters(), lr=opt.lr_g) optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt, batch=0, is_D=False): """Save the model at "checkpoint_interval" and its multiple""" if is_D==True: model_name = 'discriminator_WGAN_epoch%d_batch%d.pth' % (epoch + 1, batch) else: model_name = 'deepfillv2_WGAN_epoch%d_batch%d.pth' % (epoch+1, batch) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Training loop for epoch in range(opt.epochs): print("Start epoch ", epoch+1, "!") for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * (1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * (1 - mask) + second_out * mask # in range [0, 1] for wk in range(1): optimizer_d.zero_grad() fake_scalar = discriminator(second_out_wholeimg.detach(), mask) true_scalar = discriminator(img, mask) #W_Loss = -torch.mean(true_scalar) + torch.mean(fake_scalar)#+ gradient_penalty(discriminator, img, second_out_wholeimg, mask) hinge_loss = torch.mean(RELU(1-true_scalar)) + torch.mean(RELU(fake_scalar+1)) loss_D = hinge_loss loss_D.backward(retain_graph=True) optimizer_d.step() ### Train Generator # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = - torch.mean(fake_scalar) optimizer_g1.zero_grad() first_MaskL1Loss.backward(retain_graph=True) optimizer_g1.step() optimizer_g.zero_grad() # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg_featuremaps = perceptualnet(second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) loss = 0.5*opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + GAN_Loss + second_PerceptualLoss * opt.lambda_perceptual loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print("\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, (batch_idx+1), len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print("\r[D Loss: %.5f] [Perceptual Loss: %.5f] [G Loss: %.5f] time_left: %s" % (loss_D.item(), second_PerceptualLoss.item(), GAN_Loss.item(), time_left)) if (batch_idx + 1) % 100 ==0: # Generate Visualization image masked_img = img * (1 - mask) + mask img_save = torch.cat((img, masked_img, first_out, second_out, first_out_wholeimg, second_out_wholeimg),3) # Recover normalization: * 255 because last layer is sigmoid activated img_save = F.interpolate(img_save, scale_factor=0.5) img_save = img_save * 255 # Process img_copy and do not destroy the data of img img_copy = img_save.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy() #img_copy = np.clip(img_copy, 0, 255) img_copy = img_copy.astype(np.uint8) save_img_name = 'sample_batch' + str(batch_idx+1) + '.png' save_img_path = os.path.join(sample_folder, save_img_name) img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR) cv2.imwrite(save_img_path, img_copy) if (batch_idx + 1) % 5000 == 0: save_model(generator, epoch, opt, batch_idx+1) save_model(discriminator, epoch, opt, batch_idx+1, is_D=True) #Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, epoch, opt) save_model(discriminator, epoch , opt, is_D=True)
def Continue_train_WGAN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task_name) sample_folder = os.path.join(opt.sample_path, opt.task_name) utils.check_path(save_folder) utils.check_path(sample_folder) # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) global_discriminator = utils.create_discriminator(opt) patch_discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() global_discriminator = nn.DataParallel(global_discriminator) global_discriminator = global_discriminator.cuda() patch_discriminator = nn.DataParallel(patch_discriminator) patch_discriminator = patch_discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() global_discriminator = global_discriminator.cuda() patch_discriminator = patch_discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D1 = torch.optim.Adam(global_discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D2 = torch.optim.Adam(patch_discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, optimizer): target_epoch = opt.epochs - opt.lr_decrease_epoch remain_epoch = opt.epochs - epoch if epoch >= opt.lr_decrease_epoch: lr = opt.lr * remain_epoch / target_epoch lr = max(lr, 1e-7) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" # Define the name of trained model if opt.save_mode == 'epoch': model_name = 'DeblurGANv2_%s_WGAN_epoch%d_bs%d.pth' % ( opt.network_type, epoch, opt.train_batch_size) if opt.save_mode == 'iter': model_name = 'DeblurGANv2_%s_WGAN_iter%d_bs%d.pth' % ( opt.network_type, iteration, opt.train_batch_size) save_model_path = os.path.join(opt.save_path, opt.task_name, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.train_batch_size *= gpu_num #opt.val_batch_size *= gpu_num opt.num_workers *= gpu_num # Define the dataset trainset = dataset.DeblurDataset(opt, 'train') valset = dataset.DeblurDataset(opt, 'val') print('The overall number of training images:', len(trainset)) print('The overall number of validation images:', len(valset)) # Define the dataloader train_loader = DataLoader(trainset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_loader = DataLoader(valset, batch_size=opt.val_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # For loop training for epoch in range(opt.epochs): for i, (true_input, true_target) in enumerate(train_loader): # To device true_input = true_input.cuda() true_target = true_target.cuda() # Train Discriminator optimizer_D1.zero_grad() optimizer_D2.zero_grad() # Generator output fake_target = generator(true_input) # Extract patch for patch discriminator rand_h = random.randint(0, opt.crop_size - opt.patch_size) rand_w = random.randint(0, opt.crop_size - opt.patch_size) fake_target_patch = fake_target[:, :, rand_h:rand_h + opt.patch_size, rand_w:rand_w + opt.patch_size] true_target_patch = true_target[:, :, rand_h:rand_h + opt.patch_size, rand_w:rand_w + opt.patch_size] # Global discriminator fake_scalar_d = global_discriminator(fake_target.detach()) true_scalar_d = global_discriminator(true_target) # Overall Loss and optimize loss_D1 = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d) loss_D1.backward() # Patch discriminator fake_scalar_d = patch_discriminator(fake_target_patch.detach()) true_scalar_d = patch_discriminator(true_target_patch) # Overall Loss and optimize loss_D2 = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d) loss_D2.backward() # Train Generator optimizer_G.zero_grad() fake_target = generator(true_input) # Extract patch for patch discriminator fake_target_patch = fake_target[:, :, rand_h:rand_h + opt.patch_size, rand_w:rand_w + opt.patch_size] # L1 Loss Pixellevel_L1_Loss = criterion_L1(fake_target, true_target) # Perceptual Loss fake_target_feature = perceptualnet(fake_target) true_target_feature = perceptualnet(true_target) Perceptual_Loss = criterion_L1(fake_target_feature, true_target_feature) # GAN Loss fake_scalar_global = global_discriminator(fake_target) Global_GAN_Loss = -torch.mean(fake_scalar_global) fake_scalar_patch = patch_discriminator(fake_target_patch) Patch_GAN_Loss = -torch.mean(fake_scalar_patch) # Overall Loss and optimize loss = opt.lambda_l1 * Pixellevel_L1_Loss + opt.lambda_percep * Perceptual_Loss + opt.lambda_gan * ( Global_GAN_Loss + Patch_GAN_Loss) loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(train_loader) + i iters_left = opt.epochs * len(train_loader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [Perceptual Loss: %.4f]" % ((epoch + 1), opt.epochs, i, len(train_loader), Pixellevel_L1_Loss.item(), Perceptual_Loss.item())) print( "\r[Global G Loss: %.4f, D Loss: %.4f] [Patch G Loss: %.4f, D Loss: %.4f] Time_left: %s" % (Global_GAN_Loss.item(), loss_D1.item(), Patch_GAN_Loss.item(), loss_D2.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), optimizer_D1) adjust_learning_rate(opt, (epoch + 1), optimizer_D2) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [true_input, fake_target, true_target] name_list = ['in', 'pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='train_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) ### Validation val_PSNR = 0 num_of_val_image = 0 for j, (true_input, true_target) in enumerate(val_loader): # To device # A is for input image, B is for target image true_input = true_input.cuda() true_target = true_target.cuda() # Forward propagation with torch.no_grad(): fake_target = generator(true_input) # Accumulate num of image and val_PSNR num_of_val_image += true_input.shape[0] val_PSNR += utils.psnr(fake_target, true_target, 1) * true_input.shape[0] val_PSNR = val_PSNR / num_of_val_image ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [true_input, fake_target, true_target] name_list = ['in', 'pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='val_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Record average PSNR print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))
def trainer_WGANGP(opt): # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(generator) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.save_mode == 'epoch': model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch, opt.batch_size) if opt.save_mode == 'iter': model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration, opt.batch_size) save_name = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.ColorizationDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Calculate the gradient penalty loss for WGAN-GP def compute_gradient_penalty(D, input_samples, real_samples, fake_samples): # Random weight term for interpolation between real and fake samples alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(input_samples, interpolates) # For PatchGAN fake = Variable(Tensor(real_samples.shape[0], 1, 30, 30).fill_(1.0), requires_grad=False) # Get gradient w.r.t. interpolates gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # For loop training for epoch in range(opt.epochs): for i, (true_L, true_RGB, true_sal) in enumerate(dataloader): # To device true_L = true_L.cuda() true_RGB = true_RGB.cuda() true_sal = true_sal.cuda() true_sal = torch.cat((true_sal, true_sal, true_sal), 1) true_attn = true_RGB.mul(true_sal) ### Train Discriminator optimizer_D.zero_grad() # Generator output fake_RGB, fake_sal = generator(true_L) # Fake colorizations fake_scalar_d = discriminator(true_L, fake_RGB.detach()) # True colorizations true_scalar_d = discriminator(true_L, true_RGB) # Gradient penalty gradient_penalty = compute_gradient_penalty( discriminator, true_L.data, true_RGB.data, fake_RGB.data) # Overall Loss and optimize loss_D = -torch.mean(true_scalar_d) + torch.mean( fake_scalar_d) + opt.lambda_gp * gradient_penalty loss_D.backward() optimizer_D.step() ### Train Generator optimizer_G.zero_grad() fake_RGB, fake_sal = generator(true_L) # Pixel-level L1 Loss loss_L1 = criterion_L1(fake_RGB, true_RGB) # Attention Loss fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_attn = fake_RGB.mul(fake_sal) loss_attn = criterion_L1(fake_attn, true_attn) # Perceptual Loss feature_fake_RGB = perceptualnet(fake_RGB) feature_true_RGB = perceptualnet(true_RGB) loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB) # GAN Loss fake_scalar = discriminator(true_L, fake_RGB) loss_GAN = -torch.mean(fake_scalar) # Overall Loss and optimize loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn loss_G.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(), loss_attn.item(), loss_percep.item(), loss_D.item(), loss_GAN.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_RGB, true_RGB] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=opt.sample_path, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list)
def trainer_LSGAN(opt): # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize Generator generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(generator) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" if opt.save_mode == 'epoch': model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch, opt.batch_size) if opt.save_mode == 'iter': model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration, opt.batch_size) save_name = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_name) print('The trained model is saved as %s' % (model_name)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.ColorizationDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # For loop training for epoch in range(opt.epochs): for i, (true_L, true_RGB, true_sal) in enumerate(dataloader): # To device true_L = true_L.cuda() true_RGB = true_RGB.cuda() true_sal = true_sal.cuda() true_sal = torch.cat((true_sal, true_sal, true_sal), 1) true_attn = true_RGB.mul(true_sal) # Adversarial ground truth valid = Tensor(np.ones((true_L.shape[0], 1, 30, 30))) fake = Tensor(np.zeros((true_L.shape[0], 1, 30, 30))) ### Train Discriminator optimizer_D.zero_grad() # Generator output fake_RGB, fake_sal = generator(true_L) # Fake colorizations fake_scalar_d = discriminator(true_L, fake_RGB.detach()) loss_fake = criterion_MSE(fake_scalar_d, fake) # True colorizations true_scalar_d = discriminator(true_L, true_RGB) loss_true = criterion_MSE(true_scalar_d, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_D.step() ### Train Generator optimizer_G.zero_grad() fake_RGB, fake_sal = generator(true_L) # Pixel-level L1 Loss loss_L1 = criterion_L1(fake_RGB, true_RGB) # Attention Loss fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1) fake_attn = fake_RGB.mul(fake_sal) loss_attn = criterion_L1(fake_attn, true_attn) # Perceptual Loss feature_fake_RGB = perceptualnet(fake_RGB) feature_true_RGB = perceptualnet(true_RGB) loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB) # GAN Loss fake_scalar = discriminator(true_L, fake_RGB) loss_GAN = criterion_MSE(fake_scalar, valid) # Overall Loss and optimize loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn loss_G.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(), loss_attn.item(), loss_percep.item(), loss_D.item(), loss_GAN.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [fake_RGB, true_RGB] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=opt.sample_path, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list)
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_LSGAN_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Training loop for epoch in range(opt.epochs): for batch_idx, (img, mask) in enumerate(dataloader): # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda img = img.cuda() mask = mask.cuda() ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Overall Loss and optimize loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # Mask L1 Loss first_MaskL1Loss = L1Loss(first_out_wholeimg, img) second_MaskL1Loss = L1Loss(second_out_wholeimg, img) # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = -torch.mean(fake_scalar) # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_wholeimg_featuremaps = perceptualnet( second_out_wholeimg) second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps) # Compute losses loss = opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + \ opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_MaskL1Loss.item(), second_MaskL1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model(generator, (epoch + 1), opt) ### Sample data every epoch masked_img = img * (1 - mask) + mask mask = torch.cat((mask, mask, mask), 1) if (epoch + 1) % 1 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out'] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255)
def Trainer_GAN(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.batch_size *= gpu_num opt.num_workers *= gpu_num print("Batch size is changed to %d" % opt.batch_size) print("Number of workers is changed to %d" % opt.num_workers) # Build path folder utils.check_path(opt.save_path) utils.check_path(opt.sample_path) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(optimizer, epoch, opt, init_lr): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = init_lr * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'GrayInpainting_GAN_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_path = os.path.join(opt.save_path, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_path) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_path) print('The trained model is successfully saved at epoch %d' % (epoch)) # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training and Testing # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.epochs): for batch_idx, (grayscale, mask) in enumerate(dataloader): # Load and put to cuda grayscale = grayscale.cuda() # out: [B, 1, 256, 256] mask = mask.cuda() # out: [B, 1, 256, 256] # LSGAN vectors valid = Tensor(np.ones((grayscale.shape[0], 1, 8, 8))) fake = Tensor(np.zeros((grayscale.shape[0], 1, 8, 8))) # ---------------------------------------- # Train Discriminator # ---------------------------------------- optimizer_d.zero_grad() # forward propagation out = generator(grayscale, mask) # out: [B, 1, 256, 256] out_wholeimg = grayscale * (1 - mask) + out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(grayscale, mask) # Overall Loss and optimize loss_fake = MSELoss(fake_scalar, fake) loss_true = MSELoss(true_scalar, valid) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() # ---------------------------------------- # Train Generator # ---------------------------------------- optimizer_g.zero_grad() # forward propagation out = generator(grayscale, mask) # out: [B, 1, 256, 256] out_wholeimg = grayscale * (1 - mask) + out * mask # in range [0, 1] # Mask L1 Loss MaskL1Loss = L1Loss(out_wholeimg, grayscale) # GAN Loss fake_scalar = discriminator(out_wholeimg, mask) MaskGAN_Loss = MSELoss(fake_scalar, valid) # Get the deep semantic feature maps, and compute Perceptual Loss out_3c = torch.cat((out_wholeimg, out_wholeimg, out_wholeimg), 1) grayscale_3c = torch.cat((grayscale, grayscale, grayscale), 1) out_featuremaps = perceptualnet(out_3c) gt_featuremaps = perceptualnet(grayscale_3c) PerceptualLoss = L1Loss(out_featuremaps, gt_featuremaps) # Compute losses loss = opt.lambda_l1 * MaskL1Loss + opt.lambda_perceptual * PerceptualLoss + opt.lambda_gan * MaskGAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [Perceptual Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), MaskL1Loss.item(), PerceptualLoss.item(), loss_D.item(), MaskGAN_Loss.item(), time_left)) # Learning rate decrease adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g) adjust_learning_rate(optimizer_d, (epoch + 1), opt, opt.lr_d) # Save the model save_model(generator, (epoch + 1), opt) utils.sample(grayscale, mask, out_wholeimg, opt.sample_path, (epoch + 1))
def ESRGAN_Trainer(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task_name) sample_folder = os.path.join(opt.sample_path, opt.task_name) if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_MSE = torch.nn.MSELoss().cuda() # Initialize networks generator = utils.create_ESRGAN_generator(opt) discriminator = utils.create_ESRGAN_discriminator(opt) perceptualnet = utils.create_perceptualnet(opt) # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2)) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): """Save the model at "checkpoint_interval" and its multiple""" # Define the name of trained model if opt.save_mode == 'epoch': model_name = 'Wavelet_epoch%d_bs%d.pth' % (epoch, opt.batch_size) if opt.save_mode == 'iter': model_name = 'Wavelet_iter%d_bs%d.pth' % (iteration, opt.batch_size) save_model_path = os.path.join(opt.save_path, opt.task_name, model_name) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataset trainset = dataset.LRHRDataset(opt) print('The overall number of images:', len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Tensor type Tensor = torch.cuda.FloatTensor # Count start time prev_time = time.time() # Tensorboard writer = SummaryWriter() # For loop training for epoch in range(opt.epochs): # Record learning rate for param_group in optimizer_G.param_groups: writer.add_scalar('data/lr', param_group['lr'], epoch) print('learning rate = ', param_group['lr']) if epoch == 0: iters_done = 0 for i, (img_LR, img_HR) in enumerate(dataloader): # To device # A is for downsample clean image, B is for noisy image #assert img_LR.shape == img_HR.shape img_LR = img_LR.cuda() img_HR = img_HR.cuda() # Adversarial ground truth valid = Tensor( np.ones((img_LR.shape[0], 1, img_LR.shape[2] // 8, img_LR.shape[3] // 8))) fake = Tensor( np.zeros((img_LR.shape[0], 1, img_LR.shape[2] // 8, img_LR.shape[3] // 8))) z = np.random.randn(opt.batch_size, 1, 128, 128).astype(np.float32) z = np.repeat(z, 64, axis=1) # z = np.repeat(z, 32, axis=3) z = Variable(torch.from_numpy(z)).cuda() # print(z.size()) ### Train Generator # Forward pred = generator(img_LR, z) # L1 loss loss_L1 = criterion_L1(pred, img_HR) # gan part fake_scalar = discriminator(pred) loss_gan = criterion_MSE(fake_scalar, valid) # Perceptual loss part fea_true = perceptualnet(img_HR) fea_pred = perceptualnet(pred) # print(fea_pred.size()) loss_percep = criterion_MSE(fea_true, fea_pred) # Overall Loss and optimize optimizer_G.zero_grad() loss = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_gan + opt.lambda_percep * loss_percep loss.backward() optimizer_G.step() ### Train Discriminator # Forward pred = generator(img_LR, z) # GAN loss fake_scalar = discriminator(pred.detach()) loss_fake = criterion_MSE(fake_scalar, fake) true_scalar = discriminator(img_HR) loss_true = criterion_MSE(true_scalar, valid) # Overall Loss and optimize optimizer_D.zero_grad() loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_D.step() # Record losses writer.add_scalar('data/loss_L1', loss_L1.item(), iters_done) writer.add_scalar('data/loss_percep', loss_percep.item(), iters_done) writer.add_scalar('data/loss_G', loss.item(), iters_done) writer.add_scalar('data/loss_D', loss_D.item(), iters_done) # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [L1 Loss: %.4f] [G Loss: %.4f] [G percep Loss: %.4f] [D Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(), loss_gan.item(), loss_percep.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_D) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [pred, img_HR] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) writer.close()
def Trainer_GAN(opt): # ---------------------------------------- # Initialization # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task_name) sample_folder = os.path.join(opt.sample_path, opt.task_name) if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Initialize networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() discriminator = nn.DataParallel(discriminator) discriminator = discriminator.cuda() perceptualnet = nn.DataParallel(perceptualnet) perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # ---------------------------------------- # Network dataset # ---------------------------------------- # Handle multiple GPUs gpu_num = torch.cuda.device_count() print("There are %d GPUs used" % gpu_num) opt.train_batch_size *= gpu_num #opt.val_batch_size *= gpu_num opt.num_workers *= gpu_num # Define the dataset train_imglist = utils.get_jpgs(os.path.join(opt.in_path_train)) val_imglist = utils.get_jpgs(os.path.join(opt.in_path_val)) train_dataset = dataset.Qbayer2RGB_dataset(opt, 'train', train_imglist) val_dataset = dataset.Qbayer2RGB_dataset(opt, 'val', val_imglist) print('The overall number of training images:', len(train_imglist)) print('The overall number of validation images:', len(val_imglist)) # Define the dataloader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.val_batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Network training parameters # ---------------------------------------- # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() class ColorLoss(nn.Module): def __init__(self): super(ColorLoss, self).__init__() self.L1loss = nn.L1Loss() def RGB2YUV(self, RGB): YUV = RGB.clone() YUV[:, 0, :, :] = 0.299 * RGB[:, 0, :, :] + 0.587 * RGB[:, 1, :, :] + 0.114 * RGB[:, 2, :, :] YUV[:, 1, :, :] = -0.14713 * RGB[:, 0, :, :] - 0.28886 * RGB[:, 1, :, :] + 0.436 * RGB[:, 2, :, :] YUV[:, 2, :, :] = 0.615 * RGB[:, 0, :, :] - 0.51499 * RGB[:, 1, :, :] - 0.10001 * RGB[:, 2, :, :] return YUV def forward(self, x, y): yuv_x = self.RGB2YUV(x) yuv_y = self.RGB2YUV(y) return self.L1loss(yuv_x, yuv_y) yuv_loss = ColorLoss() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_D = torch.optim.Adam(generator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer, lr_gd): # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = lr_gd * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = lr_gd * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator): # Define the name of trained model if opt.save_mode == 'epoch': model_name = '%s_gan_noise%.3f_epoch%d_bs%d.pth' % ( opt.net_mode, opt.noise_level, epoch, opt.train_batch_size) if opt.save_mode == 'iter': model_name = '%s_gan_noise%.3f_iter%d_bs%d.pth' % ( opt.net_mode, opt.noise_level, iteration, opt.train_batch_size) save_model_path = os.path.join(opt.save_path, opt.task_name, model_name) # Save model if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.module.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: torch.save(generator.state_dict(), save_model_path) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() # Tensorboard writer = SummaryWriter() # For loop training for epoch in range(opt.epochs): # Record learning rate for param_group in optimizer_G.param_groups: writer.add_scalar('data/lr', param_group['lr'], epoch) print('learning rate = ', param_group['lr']) if epoch == 0: iters_done = 0 ### Training for i, (in_img, RGBout_img) in enumerate(train_loader): # To device # A is for input image, B is for target image in_img = in_img.cuda() RGBout_img = RGBout_img.cuda() ## Train Discriminator # Forward propagation out = generator(in_img) optimizer_D.zero_grad() # Fake samples fake_scalar_d = discriminator(in_img, out.detach()) true_scalar_d = discriminator(in_img, RGBout_img) # Overall Loss and optimize loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d) loss_D.backward() #torch.nn.utils.clip_grad_norm(discriminator.parameters(), opt.grad_clip_norm) optimizer_D.step() ## Train Generator # Forward propagation out = generator(in_img) # GAN loss fake_scalar = discriminator(in_img, out) L_gan = -torch.mean(fake_scalar) * opt.lambda_gan # Perceptual loss features fake_B_fea = perceptualnet(utils.normalize_ImageNet_stats(out)) true_B_fea = perceptualnet( utils.normalize_ImageNet_stats(RGBout_img)) L_percep = opt.lambda_percep * criterion_L1(fake_B_fea, true_B_fea) # Pixel loss L_pixel = opt.lambda_pixel * criterion_L1(out, RGBout_img) # Color loss L_color = opt.lambda_color * yuv_loss(out, RGBout_img) # Sum up to total loss loss = L_pixel + L_percep + L_gan + L_color # Record losses writer.add_scalar('data/L_pixel', L_pixel.item(), iters_done) writer.add_scalar('data/L_percep', L_percep.item(), iters_done) writer.add_scalar('data/L_color', L_color.item(), iters_done) writer.add_scalar('data/L_gan', L_gan.item(), iters_done) writer.add_scalar('data/L_total', loss.item(), iters_done) writer.add_scalar('data/loss_D', loss_D.item(), iters_done) # Backpropagate gradients optimizer_G.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm(generator.parameters(), opt.grad_clip_norm) optimizer_G.step() # Determine approximate time left iters_done = epoch * len(train_loader) + i + 1 iters_left = opt.epochs * len(train_loader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] [L_pixel: %.4f]" % ((epoch + 1), opt.epochs, i, len(train_loader), loss.item(), L_pixel.item())) print( "\r[L_percep: %.4f] [L_color: %.4f] [L_gan: %.4f] [loss_D: %.4f] Time_left: %s" % (L_percep.item(), L_color.item(), L_gan.item(), loss_D.item(), time_left)) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), iters_done, len(train_loader), generator) # Learning rate decrease at certain epochs adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_G, opt.lr_g) adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_D, opt.lr_d) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [out, RGBout_img] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='train_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) ### Validation val_PSNR = 0 num_of_val_image = 0 for j, (in_img, RGBout_img) in enumerate(val_loader): # To device # A is for input image, B is for target image in_img = in_img.cuda() RGBout_img = RGBout_img.cuda() # Forward propagation with torch.no_grad(): out = generator(in_img) # Accumulate num of image and val_PSNR num_of_val_image += in_img.shape[0] val_PSNR += utils.psnr(out, RGBout_img, 1) * in_img.shape[0] val_PSNR = val_PSNR / num_of_val_image ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [out, RGBout_img] name_list = ['pred', 'gt'] utils.save_sample_png(sample_folder=sample_folder, sample_name='val_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Record average PSNR writer.add_scalar('data/val_PSNR', val_PSNR, epoch) print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR)) writer.close()
def WGAN_trainer(opt): # ---------------------------------------- # Initialize training parameters # ---------------------------------------- # cudnn benchmark accelerates the network cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = opt.save_path sample_folder = opt.sample_path if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Build networks generator = utils.create_generator(opt) discriminator = utils.create_discriminator(opt) perceptualnet = utils.create_perceptualnet() # Loss functions L1Loss = nn.L1Loss() MSELoss = nn.MSELoss() # Optimizers optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr_d, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(lr_in, optimizer, epoch, opt): """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs""" lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr # Save the two-stage generator model def save_model_generator(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # Save the dicriminator model def save_model_discriminator(net, epoch, opt): """Save the model at "checkpoint_interval" and its multiple""" model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) if opt.multi_gpu == True: if epoch % opt.checkpoint_interval == 0: torch.save(net.module.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) else: if epoch % opt.checkpoint_interval == 0: torch.save(net.state_dict(), model_name) print('The trained model is successfully saved at epoch %d' % (epoch)) # load the model def load_model(net, epoch, opt, type='G'): """Save the model at "checkpoint_interval" and its multiple""" if type == 'G': model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) else: model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % ( epoch, opt.batch_size) model_name = os.path.join(save_folder, model_name) pretrained_dict = torch.load(model_name) net.load_state_dict(pretrained_dict) if opt.resume: load_model(generator, opt.resume_epoch, opt, type='G') load_model(discriminator, opt.resume_epoch, opt, type='D') print( '--------------------Pretrained Models are Loaded--------------------' ) # To device if opt.multi_gpu == True: generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) perceptualnet = nn.DataParallel(perceptualnet) generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() else: generator = generator.cuda() discriminator = discriminator.cuda() perceptualnet = perceptualnet.cuda() # ---------------------------------------- # Initialize training dataset # ---------------------------------------- # Define the dataset trainset = train_dataset.InpaintDataset(opt) print('The overall number of images equals to %d' % len(trainset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True, drop_last=True) # ---------------------------------------- # Training # ---------------------------------------- # Initialize start time prev_time = time.time() # Tensor type Tensor = torch.cuda.FloatTensor # Training loop for epoch in range(opt.resume_epoch, opt.epochs): for batch_idx, (img, height, width) in enumerate(dataloader): img = img.cuda() # set the same free form masks for each batch mask = torch.empty(img.shape[0], 1, img.shape[2], img.shape[3]).cuda() for i in range(opt.batch_size): mask[i] = torch.from_numpy( train_dataset.InpaintDataset.random_ff_mask( shape=(height[0], width[0])).astype(np.float32)).cuda() # LSGAN vectors valid = Tensor( np.ones((img.shape[0], 1, height[0] // 32, width[0] // 32))) fake = Tensor( np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32))) zero = Tensor( np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32))) ### Train Discriminator optimizer_d.zero_grad() # Generator output first_out, second_out = generator(img, mask) # forward propagation first_out_wholeimg = img * ( 1 - mask) + first_out * mask # in range [0, 1] second_out_wholeimg = img * ( 1 - mask) + second_out * mask # in range [0, 1] # Fake samples fake_scalar = discriminator(second_out_wholeimg.detach(), mask) # True samples true_scalar = discriminator(img, mask) # Loss and optimize loss_fake = -torch.mean(torch.min(zero, -valid - fake_scalar)) loss_true = -torch.mean(torch.min(zero, -valid + true_scalar)) # Overall Loss and optimize loss_D = 0.5 * (loss_fake + loss_true) loss_D.backward() optimizer_d.step() ### Train Generator optimizer_g.zero_grad() # L1 Loss first_L1Loss = (first_out - img).abs().mean() second_L1Loss = (second_out - img).abs().mean() # GAN Loss fake_scalar = discriminator(second_out_wholeimg, mask) GAN_Loss = -torch.mean(fake_scalar) # Get the deep semantic feature maps, and compute Perceptual Loss img_featuremaps = perceptualnet(img) # feature maps second_out_featuremaps = perceptualnet(second_out) second_PerceptualLoss = L1Loss(second_out_featuremaps, img_featuremaps) # Compute losses loss = opt.lambda_l1 * first_L1Loss + opt.lambda_l1 * second_L1Loss + \ opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss loss.backward() optimizer_g.step() # Determine approximate time left batches_done = epoch * len(dataloader) + batch_idx batches_left = opt.epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" % ((epoch + 1), opt.epochs, batch_idx, len(dataloader), first_L1Loss.item(), second_L1Loss.item())) print( "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s" % (loss_D.item(), GAN_Loss.item(), second_PerceptualLoss.item(), time_left)) masked_img = img * (1 - mask) + mask mask = torch.cat((mask, mask, mask), 1) if (batch_idx + 1) % 40 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = [ 'gt', 'mask', 'masked_img', 'first_out', 'second_out' ] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Learning rate decrease adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt) adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt) # Save the model save_model_generator(generator, (epoch + 1), opt) save_model_discriminator(discriminator, (epoch + 1), opt) ### Sample data every epoch if (epoch + 1) % 1 == 0: img_list = [img, mask, masked_img, first_out, second_out] name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out'] utils.save_sample_png(sample_folder=sample_folder, sample_name='epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255)