def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size,
                         small_kernel_size=small_kernel_size,
                         n_channels=n_channels,
                         n_blocks=n_blocks,
                         scaling_factor=scaling_factor)
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                     lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True)  # note that we're passing the collate function here

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch)

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer
        }, 'checkpoint_srresnet.pth.tar')
Exemple #2
0
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))

    generator.eval()
    model = generator
    # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
    # srgan_generator.eval()
    # model = srgan_generator

    for test_data_name in test_data_names:
        print("\n数据集 %s:\n" % test_data_name)

        # 定制化数据加载器
        test_dataset = SRDataset(data_folder,
                                 split='test',
                                 crop_size=0,
                                 scaling_factor=4,
                                 lr_img_type='imagenet-norm',
                                 hr_img_type='[-1, 1]',
                                 test_data_name=test_data_name)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)

        # 记录每个样本 PSNR 和 SSIM值
        PSNRs = AverageMeter()
        SSIMs = AverageMeter()

        # 记录测试时间
        start = time.time()
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint, srresnet_checkpoint
    # Initialize model or load checkpoint
    if checkpoint is None:
        # Generator
        generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)

        # Initialize generator network with pretrained SRResNet
        generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)

        # Initialize generator's optimizer
        optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()),
                                       lr=lr)

        # Discriminator
        discriminator = Discriminator(kernel_size=kernel_size_d,
                                      n_channels=n_channels_d,
                                      n_blocks=n_blocks_d,
                                      fc_size=fc_size_d)

        # Initialize discriminator's optimizer
        optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
                                       lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator = checkpoint['generator']
        discriminator = checkpoint['discriminator']
        optimizer_g = checkpoint['optimizer_g']
        optimizer_d = checkpoint['optimizer_d']
        print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))

    # Truncated VGG19 network to be used in the loss calculation
    truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
    truncated_vgg19.eval()

    # Loss functions
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # Move to default device
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='[-1, 1]',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers,
                                               pin_memory=True)

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # At the halfway point, reduce learning rate to a tenth
        if epoch == int((iterations / 2) // len(train_loader) + 1):
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        # One epoch's training
        train(train_loader=train_loader,
              generator=generator,
              discriminator=discriminator,
              truncated_vgg19=truncated_vgg19,
              content_loss_criterion=content_loss_criterion,
              adversarial_loss_criterion=adversarial_loss_criterion,
              optimizer_g=optimizer_g,
              optimizer_d=optimizer_d,
              epoch=epoch)

        if ((epoch % 20 == 0) and (epoch != 0)) or (epoch in {0, 1, 2}):
            now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
            ckp_file = storage + '/ckp/{}_{}.pth.tar'.format(str(epoch).zfill(4), now)
            print('save ckp')
            torch.save({'epoch': epoch,
                        'generator': generator,
                        'discriminator': discriminator,
                        'optimizer_g': optimizer_g,
                        'optimizer_d': optimizer_d},
                       ckp_file
                       )
Exemple #4
0
def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 模型初始化
    generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)

    discriminator = Discriminator(kernel_size=kernel_size_d,
                                    n_channels=n_channels_d,
                                    n_blocks=n_blocks_d,
                                    fc_size=fc_size_d)

    # 初始化优化器
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,generator.parameters()),lr=lr)
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad,discriminator.parameters()),lr=lr)

    # 截断的VGG19网络用于计算损失函数
    truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
    truncated_vgg19.eval()

    # 损失函数
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # 将数据移至默认设备
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)
    
    # 加载预训练模型
    srresnetcheckpoint = torch.load(srresnet_checkpoint)
    generator.net.load_state_dict(srresnetcheckpoint['model'])

    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator.load_state_dict(checkpoint['generator'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer_g.load_state_dict(checkpoint['optimizer_g'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
    
    # 单机多GPU训练
    if torch.cuda.is_available() and ngpu > 1:
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))
        discriminator = nn.DataParallel(discriminator, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='imagenet-norm')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):
        
        if epoch == int(epochs / 2):  # 执行到一半时降低学习率
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        generator.train()   # 开启训练模式:允许使用批样本归一化
        discriminator.train()

        losses_c = AverageMeter()  # 内容损失
        losses_a = AverageMeter()  # 生成损失
        losses_d = AverageMeter()  # 判别损失

        n_iter = len(train_loader)

        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24),  imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  imagenet-normed 格式

            #-----------------------1. 生成器更新----------------------------
            # 生成
            sr_imgs = generator(lr_imgs)  # (N, 3, 96, 96), 范围在 [-1, 1]
            sr_imgs = convert_image(
                sr_imgs, source='[-1, 1]',
                target='imagenet-norm')  # (N, 3, 96, 96), imagenet-normed

            # 计算 VGG 特征图
            sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs)              # batchsize X 512 X 6 X 6
            hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach()     # batchsize X 512 X 6 X 6

            # 计算内容损失
            content_loss = content_loss_criterion(sr_imgs_in_vgg_space,hr_imgs_in_vgg_space)

            # 计算生成损失
            sr_discriminated = discriminator(sr_imgs)  # (batch X 1)   
            adversarial_loss = adversarial_loss_criterion(
                sr_discriminated, torch.ones_like(sr_discriminated)) # 生成器希望生成的图像能够完全迷惑判别器,因此它的预期所有图片真值为1

            # 计算总的感知损失
            perceptual_loss = content_loss + beta * adversarial_loss

            # 后向传播.
            optimizer_g.zero_grad()
            perceptual_loss.backward()

            # 更新生成器参数
            optimizer_g.step()

            #记录损失值
            losses_c.update(content_loss.item(), lr_imgs.size(0))
            losses_a.update(adversarial_loss.item(), lr_imgs.size(0))


            #-----------------------2. 判别器更新----------------------------
            # 判别器判断
            hr_discriminated = discriminator(hr_imgs)
            sr_discriminated = discriminator(sr_imgs.detach())

            # 二值交叉熵损失
            adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
                            adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))  # 判别器希望能够准确的判断真假,因此凡是生成器生成的都设置为0,原始图像均设置为1

            # 后向传播
            optimizer_d.zero_grad()
            adversarial_loss.backward()

            # 更新判别器
            optimizer_d.step()

            # 记录损失
            losses_d.update(adversarial_loss.item(), hr_imgs.size(0))


            print(str(content_loss.item())+"    "+str(adversarial_loss.item()))
 
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated  # 手工清除掉缓存

        # 保存预训练模型
        torch.save({
            'epoch': epoch,
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
        }, 'results/checkpoint_srgan.pth')
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint, srresnet_checkpoint, vgg_loss_enable
    print(vgg_loss_enable, args.vggloss)

    save_model = False

    # Initialize model or load checkpoint
    if checkpoint is None:
        # Generator
        min_p_loss = 1e10

        generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)
        best_generator = generator

        # Initialize generator network with pretrained SRResNet
        # generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)

        # Initialize generator's optimizer
        optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                     generator.parameters()),
                                       lr=lr)

        # Discriminator
        discriminator = Discriminator(kernel_size=kernel_size_d,
                                      n_channels=n_channels_d,
                                      n_blocks=n_blocks_d,
                                      fc_size=fc_size_d)
        best_discriminator = discriminator

        # Initialize discriminator's optimizer
        optimizer_d = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, discriminator.parameters()),
                                       lr=lr)

    else:
        checkpoint = os.path.join(args.root, args.checkpoint)
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator = checkpoint['generator']
        best_generator = generator
        discriminator = checkpoint['discriminator']
        best_discriminator = discriminator
        optimizer_g = checkpoint['optimizer_g']
        optimizer_d = checkpoint['optimizer_d']
        min_p_loss = checkpoint['min_p_loss']
        print("\nLoaded checkpoint from epoch %d.\n" %
              (checkpoint['epoch'] + 1))

    if args.olr != None:
        overwrite_learning_rate(optimizer_g, args.olr)
        overwrite_learning_rate(optimizer_d, args.olr)

    if vgg_loss_enable:
        # Truncated VGG19 network to be used in the loss calculation
        print('vggloss enable')
        truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
        truncated_vgg19.eval()
    else:
        truncated_vgg19 = None

    # Loss functions
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # Move to default device
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    if vgg_loss_enable:
        truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='imagenet-norm')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    # Total number of epochs to train for
    print('iterations: {}'.format(iterations))
    if args.epochs == 0:
        epochs = int(iterations // len(train_loader) + 1)
    else:
        epochs = args.epochs
    print('length of train_loader: {}'.format(len(train_loader)))
    print('epochs = {}'.format(epochs))

    # Epochs
    print('epochs: {}'.format(epochs))
    for epoch in range(start_epoch, epochs):

        # At the halfway point, reduce learning rate to a tenth
        if epoch == int((iterations / 2) // len(train_loader) + 1):
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        # adjust_rate = lr_table(epoch)
        # if not adjust_rate == 1:
        #     print('adjust learning rate by {}'.format(adjust_rate))
        #     adjust_learning_rate(optimizer_g, adjust_rate)
        #     adjust_learning_rate(optimizer_d, adjust_rate)

        # One epoch's training
        p_loss = train(train_loader=train_loader,
                       generator=generator,
                       discriminator=discriminator,
                       truncated_vgg19=truncated_vgg19,
                       content_loss_criterion=content_loss_criterion,
                       adversarial_loss_criterion=adversarial_loss_criterion,
                       optimizer_g=optimizer_g,
                       optimizer_d=optimizer_d,
                       epoch=epoch)

        # Save checkpoint
        if p_loss < min_p_loss:
            best_generator = generator
            best_discriminator = discriminator
            min_p_loss = p_loss
            save_model = True

        if save_model == True:
            print('save model epoch {} min_p_loss: {}'.format(
                epoch, min_p_loss))
            torch.save(
                {
                    'epoch': epoch,
                    'generator': best_generator,
                    'discriminator': best_discriminator,
                    'optimizer_g': optimizer_g,
                    'optimizer_d': optimizer_d,
                    'min_p_loss': min_p_loss
                },
                os.path.join(checkpoint_path,
                             'checkpoint_srgan_{}.pth.tar'.format(epoch)))
            save_model = False
Exemple #6
0
def main():
    """
    训练.
    """
    global checkpoint,start_epoch

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化优化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

            
    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        print("epoch:",epoch)
        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)
        loss_data=0
        tag = 0 
        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向传播
            sr_imgs = model(lr_imgs)

            # 计算损失
            loss = criterion(sr_imgs, hr_imgs)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))
            loss_data = loss_data + loss.item()
            tag = tag + 1
            #print("%.4f",loss.item())
            '''
            import pdb
            pdb.set_trace()
            '''
            # 打印结果
            #print("第 "+str(i)+ " 个batch训练结束")
 
        loss_data = loss_data/tag
        print("loss:",loss_data)
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs

        # 保存训练模型
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_srresnet.pth')