def main(): """ Training. """ global start_epoch, epoch, checkpoint # Initialize model or load checkpoint if checkpoint is None: model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) # Initialize the optimizer optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 model = checkpoint['model'] optimizer = checkpoint['optimizer'] # Move to default device model = model.to(device) criterion = nn.MSELoss().to(device) # Custom dataloaders train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # note that we're passing the collate function here # Total number of epochs to train for epochs = int(iterations // len(train_loader) + 1) # Epochs for epoch in range(start_epoch, epochs): # One epoch's training train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch) # Save checkpoint torch.save({ 'epoch': epoch, 'model': model, 'optimizer': optimizer }, 'checkpoint_srresnet.pth.tar')
generator = nn.DataParallel(generator, device_ids=list(range(ngpu))) generator.eval() model = generator # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device) # srgan_generator.eval() # model = srgan_generator for test_data_name in test_data_names: print("\n数据集 %s:\n" % test_data_name) # 定制化数据加载器 test_dataset = SRDataset(data_folder, split='test', crop_size=0, scaling_factor=4, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]', test_data_name=test_data_name) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) # 记录每个样本 PSNR 和 SSIM值 PSNRs = AverageMeter() SSIMs = AverageMeter() # 记录测试时间 start = time.time()
def main(): """ Training. """ global start_epoch, epoch, checkpoint, srresnet_checkpoint # Initialize model or load checkpoint if checkpoint is None: # Generator generator = Generator(large_kernel_size=large_kernel_size_g, small_kernel_size=small_kernel_size_g, n_channels=n_channels_g, n_blocks=n_blocks_g, scaling_factor=scaling_factor) # Initialize generator network with pretrained SRResNet generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint) # Initialize generator's optimizer optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), lr=lr) # Discriminator discriminator = Discriminator(kernel_size=kernel_size_d, n_channels=n_channels_d, n_blocks=n_blocks_d, fc_size=fc_size_d) # Initialize discriminator's optimizer optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 generator = checkpoint['generator'] discriminator = checkpoint['discriminator'] optimizer_g = checkpoint['optimizer_g'] optimizer_d = checkpoint['optimizer_d'] print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1)) # Truncated VGG19 network to be used in the loss calculation truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j) truncated_vgg19.eval() # Loss functions content_loss_criterion = nn.MSELoss() adversarial_loss_criterion = nn.BCEWithLogitsLoss() # Move to default device generator = generator.to(device) discriminator = discriminator.to(device) truncated_vgg19 = truncated_vgg19.to(device) content_loss_criterion = content_loss_criterion.to(device) adversarial_loss_criterion = adversarial_loss_criterion.to(device) # Custom dataloaders train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='[-1, 1]', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # Total number of epochs to train for epochs = int(iterations // len(train_loader) + 1) # Epochs for epoch in range(start_epoch, epochs): # At the halfway point, reduce learning rate to a tenth if epoch == int((iterations / 2) // len(train_loader) + 1): adjust_learning_rate(optimizer_g, 0.1) adjust_learning_rate(optimizer_d, 0.1) # One epoch's training train(train_loader=train_loader, generator=generator, discriminator=discriminator, truncated_vgg19=truncated_vgg19, content_loss_criterion=content_loss_criterion, adversarial_loss_criterion=adversarial_loss_criterion, optimizer_g=optimizer_g, optimizer_d=optimizer_d, epoch=epoch) if ((epoch % 20 == 0) and (epoch != 0)) or (epoch in {0, 1, 2}): now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") ckp_file = storage + '/ckp/{}_{}.pth.tar'.format(str(epoch).zfill(4), now) print('save ckp') torch.save({'epoch': epoch, 'generator': generator, 'discriminator': discriminator, 'optimizer_g': optimizer_g, 'optimizer_d': optimizer_d}, ckp_file )
def main(): """ 训练. """ global checkpoint,start_epoch,writer # 模型初始化 generator = Generator(large_kernel_size=large_kernel_size_g, small_kernel_size=small_kernel_size_g, n_channels=n_channels_g, n_blocks=n_blocks_g, scaling_factor=scaling_factor) discriminator = Discriminator(kernel_size=kernel_size_d, n_channels=n_channels_d, n_blocks=n_blocks_d, fc_size=fc_size_d) # 初始化优化器 optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,generator.parameters()),lr=lr) optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad,discriminator.parameters()),lr=lr) # 截断的VGG19网络用于计算损失函数 truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j) truncated_vgg19.eval() # 损失函数 content_loss_criterion = nn.MSELoss() adversarial_loss_criterion = nn.BCEWithLogitsLoss() # 将数据移至默认设备 generator = generator.to(device) discriminator = discriminator.to(device) truncated_vgg19 = truncated_vgg19.to(device) content_loss_criterion = content_loss_criterion.to(device) adversarial_loss_criterion = adversarial_loss_criterion.to(device) # 加载预训练模型 srresnetcheckpoint = torch.load(srresnet_checkpoint) generator.net.load_state_dict(srresnetcheckpoint['model']) if checkpoint is not None: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 generator.load_state_dict(checkpoint['generator']) discriminator.load_state_dict(checkpoint['discriminator']) optimizer_g.load_state_dict(checkpoint['optimizer_g']) optimizer_d.load_state_dict(checkpoint['optimizer_d']) # 单机多GPU训练 if torch.cuda.is_available() and ngpu > 1: generator = nn.DataParallel(generator, device_ids=list(range(ngpu))) discriminator = nn.DataParallel(discriminator, device_ids=list(range(ngpu))) # 定制化的dataloaders train_dataset = SRDataset(data_folder,split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='imagenet-norm') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # 开始逐轮训练 for epoch in range(start_epoch, epochs+1): if epoch == int(epochs / 2): # 执行到一半时降低学习率 adjust_learning_rate(optimizer_g, 0.1) adjust_learning_rate(optimizer_d, 0.1) generator.train() # 开启训练模式:允许使用批样本归一化 discriminator.train() losses_c = AverageMeter() # 内容损失 losses_a = AverageMeter() # 生成损失 losses_d = AverageMeter() # 判别损失 n_iter = len(train_loader) # 按批处理 for i, (lr_imgs, hr_imgs) in enumerate(train_loader): # 数据移至默认设备进行训练 lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 格式 hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), imagenet-normed 格式 #-----------------------1. 生成器更新---------------------------- # 生成 sr_imgs = generator(lr_imgs) # (N, 3, 96, 96), 范围在 [-1, 1] sr_imgs = convert_image( sr_imgs, source='[-1, 1]', target='imagenet-norm') # (N, 3, 96, 96), imagenet-normed # 计算 VGG 特征图 sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs) # batchsize X 512 X 6 X 6 hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach() # batchsize X 512 X 6 X 6 # 计算内容损失 content_loss = content_loss_criterion(sr_imgs_in_vgg_space,hr_imgs_in_vgg_space) # 计算生成损失 sr_discriminated = discriminator(sr_imgs) # (batch X 1) adversarial_loss = adversarial_loss_criterion( sr_discriminated, torch.ones_like(sr_discriminated)) # 生成器希望生成的图像能够完全迷惑判别器,因此它的预期所有图片真值为1 # 计算总的感知损失 perceptual_loss = content_loss + beta * adversarial_loss # 后向传播. optimizer_g.zero_grad() perceptual_loss.backward() # 更新生成器参数 optimizer_g.step() #记录损失值 losses_c.update(content_loss.item(), lr_imgs.size(0)) losses_a.update(adversarial_loss.item(), lr_imgs.size(0)) #-----------------------2. 判别器更新---------------------------- # 判别器判断 hr_discriminated = discriminator(hr_imgs) sr_discriminated = discriminator(sr_imgs.detach()) # 二值交叉熵损失 adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \ adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated)) # 判别器希望能够准确的判断真假,因此凡是生成器生成的都设置为0,原始图像均设置为1 # 后向传播 optimizer_d.zero_grad() adversarial_loss.backward() # 更新判别器 optimizer_d.step() # 记录损失 losses_d.update(adversarial_loss.item(), hr_imgs.size(0)) print(str(content_loss.item())+" "+str(adversarial_loss.item())) # 手动释放内存 del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated # 手工清除掉缓存 # 保存预训练模型 torch.save({ 'epoch': epoch, 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer_g': optimizer_g.state_dict(), 'optimizer_g': optimizer_g.state_dict(), }, 'results/checkpoint_srgan.pth')
def main(): """ Training. """ global start_epoch, epoch, checkpoint, srresnet_checkpoint, vgg_loss_enable print(vgg_loss_enable, args.vggloss) save_model = False # Initialize model or load checkpoint if checkpoint is None: # Generator min_p_loss = 1e10 generator = Generator(large_kernel_size=large_kernel_size_g, small_kernel_size=small_kernel_size_g, n_channels=n_channels_g, n_blocks=n_blocks_g, scaling_factor=scaling_factor) best_generator = generator # Initialize generator network with pretrained SRResNet # generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint) # Initialize generator's optimizer optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), lr=lr) # Discriminator discriminator = Discriminator(kernel_size=kernel_size_d, n_channels=n_channels_d, n_blocks=n_blocks_d, fc_size=fc_size_d) best_discriminator = discriminator # Initialize discriminator's optimizer optimizer_d = torch.optim.Adam(params=filter( lambda p: p.requires_grad, discriminator.parameters()), lr=lr) else: checkpoint = os.path.join(args.root, args.checkpoint) checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 generator = checkpoint['generator'] best_generator = generator discriminator = checkpoint['discriminator'] best_discriminator = discriminator optimizer_g = checkpoint['optimizer_g'] optimizer_d = checkpoint['optimizer_d'] min_p_loss = checkpoint['min_p_loss'] print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1)) if args.olr != None: overwrite_learning_rate(optimizer_g, args.olr) overwrite_learning_rate(optimizer_d, args.olr) if vgg_loss_enable: # Truncated VGG19 network to be used in the loss calculation print('vggloss enable') truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j) truncated_vgg19.eval() else: truncated_vgg19 = None # Loss functions content_loss_criterion = nn.MSELoss() adversarial_loss_criterion = nn.BCEWithLogitsLoss() # Move to default device generator = generator.to(device) discriminator = discriminator.to(device) if vgg_loss_enable: truncated_vgg19 = truncated_vgg19.to(device) content_loss_criterion = content_loss_criterion.to(device) adversarial_loss_criterion = adversarial_loss_criterion.to(device) # Custom dataloaders train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='imagenet-norm') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # Total number of epochs to train for print('iterations: {}'.format(iterations)) if args.epochs == 0: epochs = int(iterations // len(train_loader) + 1) else: epochs = args.epochs print('length of train_loader: {}'.format(len(train_loader))) print('epochs = {}'.format(epochs)) # Epochs print('epochs: {}'.format(epochs)) for epoch in range(start_epoch, epochs): # At the halfway point, reduce learning rate to a tenth if epoch == int((iterations / 2) // len(train_loader) + 1): adjust_learning_rate(optimizer_g, 0.1) adjust_learning_rate(optimizer_d, 0.1) # adjust_rate = lr_table(epoch) # if not adjust_rate == 1: # print('adjust learning rate by {}'.format(adjust_rate)) # adjust_learning_rate(optimizer_g, adjust_rate) # adjust_learning_rate(optimizer_d, adjust_rate) # One epoch's training p_loss = train(train_loader=train_loader, generator=generator, discriminator=discriminator, truncated_vgg19=truncated_vgg19, content_loss_criterion=content_loss_criterion, adversarial_loss_criterion=adversarial_loss_criterion, optimizer_g=optimizer_g, optimizer_d=optimizer_d, epoch=epoch) # Save checkpoint if p_loss < min_p_loss: best_generator = generator best_discriminator = discriminator min_p_loss = p_loss save_model = True if save_model == True: print('save model epoch {} min_p_loss: {}'.format( epoch, min_p_loss)) torch.save( { 'epoch': epoch, 'generator': best_generator, 'discriminator': best_discriminator, 'optimizer_g': optimizer_g, 'optimizer_d': optimizer_d, 'min_p_loss': min_p_loss }, os.path.join(checkpoint_path, 'checkpoint_srgan_{}.pth.tar'.format(epoch))) save_model = False
def main(): """ 训练. """ global checkpoint,start_epoch # 初始化 model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) # 初始化优化器 optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr) # 迁移至默认设备进行训练 model = model.to(device) criterion = nn.MSELoss().to(device) # 加载预训练模型 if checkpoint is not None: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if torch.cuda.is_available() and ngpu > 1: model = nn.DataParallel(model, device_ids=list(range(ngpu))) # 定制化的dataloaders train_dataset = SRDataset(data_folder,split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # 开始逐轮训练 for epoch in range(start_epoch, epochs+1): print("epoch:",epoch) model.train() # 训练模式:允许使用批样本归一化 loss_epoch = AverageMeter() # 统计损失函数 n_iter = len(train_loader) loss_data=0 tag = 0 # 按批处理 for i, (lr_imgs, hr_imgs) in enumerate(train_loader): # 数据移至默认设备进行训练 lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 格式 hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), [-1, 1]格式 # 前向传播 sr_imgs = model(lr_imgs) # 计算损失 loss = criterion(sr_imgs, hr_imgs) # 后向传播 optimizer.zero_grad() loss.backward() # 更新模型 optimizer.step() # 记录损失值 loss_epoch.update(loss.item(), lr_imgs.size(0)) loss_data = loss_data + loss.item() tag = tag + 1 #print("%.4f",loss.item()) ''' import pdb pdb.set_trace() ''' # 打印结果 #print("第 "+str(i)+ " 个batch训练结束") loss_data = loss_data/tag print("loss:",loss_data) # 手动释放内存 del lr_imgs, hr_imgs, sr_imgs # 保存训练模型 torch.save({ 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, 'results/checkpoint_srresnet.pth')