예제 #1
0
파일: trainer.py 프로젝트: mwufi/deepfillv2
def create_networks(opt, checkpoint=None):
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    if checkpoint:
        # Restore the network state
        generator.load_state_dict(checkpoint['G'])
        discriminator.load_state_dict(checkpoint['D'])

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    return generator, discriminator, perceptualnet
예제 #2
0
def LSGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()
    #FeatureMatchingLoss = FML1Loss(opt.fm_param)

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net.module, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # LSGAN vectors
            valid = Tensor(np.ones((img.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((img.shape[0], 1, 8, 8)))

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [-1, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [-1, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img = (img + 1) / 2  # in range [0, 1]
            img = utils.normalize_ImageNet_stats(img)  # in range of ImageNet
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg = (second_out_wholeimg +
                                   1) / 2  # in range [0, 1]
            second_out_wholeimg = utils.normalize_ImageNet_stats(
                second_out_wholeimg)
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)
예제 #3
0
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.RAW2RGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target, true_sal) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()
            true_sal = true_sal.cuda()
            true_sal3 = torch.cat((true_sal, true_sal, true_sal), 1).cuda()

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target, fake_sal = generator(true_input)

                # Fake samples
                fake_block1, fake_block2, fake_block3, fake_scalar = discriminator(
                    true_input, fake_target.detach())
                # True samples
                true_block1, true_block2, true_block3, true_scalar = discriminator(
                    true_input, true_target)
                '''
                # Feature Matching Loss
                FM_Loss = criterion_L1(fake_block1, true_block1) + criterion_L1(fake_block2, true_block2) + criterion_L1(fake_block3, true_block3)
                '''
                # Overall Loss and optimize
                loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
                loss_D.backward()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target, fake_sal = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # Attention Loss
            true_Attn_target = true_target.mul(true_sal3)
            fake_sal3 = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_Attn_target = fake_target.mul(fake_sal3)
            Attention_Loss = criterion_L1(fake_Attn_target, true_Attn_target)

            # GAN Loss
            fake_block1, fake_block2, fake_block3, fake_scalar = discriminator(
                true_input, fake_target)
            GAN_Loss = -torch.mean(fake_scalar)

            # Perceptual Loss
            fake_target = fake_target * 0.5 + 0.5
            true_target = true_target * 0.5 + 0.5
            fake_target = utils.normalize_ImageNet_stats(fake_target)
            true_target = utils.normalize_ImageNet_stats(true_target)
            fake_percep_feature = perceptualnet(fake_target)
            true_percep_feature = perceptualnet(true_target)
            Perceptual_Loss = criterion_L1(fake_percep_feature,
                                           true_percep_feature)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + opt.lambda_attn * Attention_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_percep * Perceptual_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f]"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item()))
            print(
                "\r[Attention Loss: %.4f] [Perceptual Loss: %.4f] Time_left: %s"
                % (Attention_Loss.item(), Perceptual_Loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)
예제 #4
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark
    cv2.setNumThreads(0)
    cv2.ocl.setUseOpenCL(False)

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()#reduce=False, size_average=False)
    RELU = nn.ReLU()

    # Optimizers
    optimizer_g1 = torch.optim.Adam(generator.coarse.parameters(), lr=opt.lr_g)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Save the model if pre_train == True
    def save_model(net, epoch, opt, batch=0, is_D=False):
        """Save the model at "checkpoint_interval" and its multiple"""
        if is_D==True:
            model_name = 'discriminator_WGAN_epoch%d_batch%d.pth' % (epoch + 1, batch)
        else:
            model_name = 'deepfillv2_WGAN_epoch%d_batch%d.pth' % (epoch+1, batch)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
    
    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True)
    
    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        print("Start epoch ", epoch+1, "!")
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (1 - mask) + first_out * mask        # in range [0, 1]
            second_out_wholeimg = img * (1 - mask) + second_out * mask      # in range [0, 1]

            for wk in range(1):
                optimizer_d.zero_grad()
                fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
                true_scalar = discriminator(img, mask)
                #W_Loss = -torch.mean(true_scalar) + torch.mean(fake_scalar)#+ gradient_penalty(discriminator, img, second_out_wholeimg, mask)
                hinge_loss = torch.mean(RELU(1-true_scalar)) + torch.mean(RELU(fake_scalar+1))
                loss_D = hinge_loss
                loss_D.backward(retain_graph=True)
                optimizer_d.step()

            ### Train Generator
            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)
            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = - torch.mean(fake_scalar)

            optimizer_g1.zero_grad()
            first_MaskL1Loss.backward(retain_graph=True)
            optimizer_g1.step()

            optimizer_g.zero_grad()

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps)

            loss = 0.5*opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + GAN_Loss + second_PerceptualLoss * opt.lambda_perceptual
            loss.backward()

            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" %
                  ((epoch + 1), opt.epochs, (batch_idx+1), len(dataloader), first_MaskL1Loss.item(),
                   second_MaskL1Loss.item()))
            print("\r[D Loss: %.5f] [Perceptual Loss: %.5f] [G Loss: %.5f] time_left: %s" %
                  (loss_D.item(), second_PerceptualLoss.item(), GAN_Loss.item(), time_left))



            if (batch_idx + 1) % 100 ==0:

                # Generate Visualization image
                masked_img = img * (1 - mask) + mask
                img_save = torch.cat((img, masked_img, first_out, second_out, first_out_wholeimg, second_out_wholeimg),3)
                # Recover normalization: * 255 because last layer is sigmoid activated
                img_save = F.interpolate(img_save, scale_factor=0.5)
                img_save = img_save * 255
                # Process img_copy and do not destroy the data of img
                img_copy = img_save.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
                #img_copy = np.clip(img_copy, 0, 255)
                img_copy = img_copy.astype(np.uint8)
                save_img_name = 'sample_batch' + str(batch_idx+1) + '.png'
                save_img_path = os.path.join(sample_folder, save_img_name)
                img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
                cv2.imwrite(save_img_path, img_copy)
            if (batch_idx + 1) % 5000 == 0:
                save_model(generator, epoch, opt, batch_idx+1)
                save_model(discriminator, epoch, opt, batch_idx+1, is_D=True)


        #Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, epoch, opt)
        save_model(discriminator, epoch , opt, is_D=True)
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task_name)
    sample_folder = os.path.join(opt.sample_path, opt.task_name)
    utils.check_path(save_folder)
    utils.check_path(sample_folder)

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    global_discriminator = utils.create_discriminator(opt)
    patch_discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        global_discriminator = nn.DataParallel(global_discriminator)
        global_discriminator = global_discriminator.cuda()
        patch_discriminator = nn.DataParallel(patch_discriminator)
        patch_discriminator = patch_discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        global_discriminator = global_discriminator.cuda()
        patch_discriminator = patch_discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D1 = torch.optim.Adam(global_discriminator.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))
    optimizer_D2 = torch.optim.Adam(patch_discriminator.parameters(),
                                    lr=opt.lr,
                                    betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        target_epoch = opt.epochs - opt.lr_decrease_epoch
        remain_epoch = opt.epochs - epoch
        if epoch >= opt.lr_decrease_epoch:
            lr = opt.lr * remain_epoch / target_epoch
            lr = max(lr, 1e-7)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = 'DeblurGANv2_%s_WGAN_epoch%d_bs%d.pth' % (
                opt.network_type, epoch, opt.train_batch_size)
        if opt.save_mode == 'iter':
            model_name = 'DeblurGANv2_%s_WGAN_iter%d_bs%d.pth' % (
                opt.network_type, iteration, opt.train_batch_size)
        save_model_path = os.path.join(opt.save_path, opt.task_name,
                                       model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Define the dataset
    trainset = dataset.DeblurDataset(opt, 'train')
    valset = dataset.DeblurDataset(opt, 'val')
    print('The overall number of training images:', len(trainset))
    print('The overall number of validation images:', len(valset))

    # Define the dataloader
    train_loader = DataLoader(trainset,
                              batch_size=opt.train_batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=opt.val_batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(train_loader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Discriminator
            optimizer_D1.zero_grad()
            optimizer_D2.zero_grad()

            # Generator output
            fake_target = generator(true_input)

            # Extract patch for patch discriminator
            rand_h = random.randint(0, opt.crop_size - opt.patch_size)
            rand_w = random.randint(0, opt.crop_size - opt.patch_size)
            fake_target_patch = fake_target[:, :,
                                            rand_h:rand_h + opt.patch_size,
                                            rand_w:rand_w + opt.patch_size]
            true_target_patch = true_target[:, :,
                                            rand_h:rand_h + opt.patch_size,
                                            rand_w:rand_w + opt.patch_size]

            # Global discriminator
            fake_scalar_d = global_discriminator(fake_target.detach())
            true_scalar_d = global_discriminator(true_target)
            # Overall Loss and optimize
            loss_D1 = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
            loss_D1.backward()

            # Patch discriminator
            fake_scalar_d = patch_discriminator(fake_target_patch.detach())
            true_scalar_d = patch_discriminator(true_target_patch)
            # Overall Loss and optimize
            loss_D2 = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
            loss_D2.backward()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # Extract patch for patch discriminator
            fake_target_patch = fake_target[:, :,
                                            rand_h:rand_h + opt.patch_size,
                                            rand_w:rand_w + opt.patch_size]

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # Perceptual Loss
            fake_target_feature = perceptualnet(fake_target)
            true_target_feature = perceptualnet(true_target)
            Perceptual_Loss = criterion_L1(fake_target_feature,
                                           true_target_feature)

            # GAN Loss
            fake_scalar_global = global_discriminator(fake_target)
            Global_GAN_Loss = -torch.mean(fake_scalar_global)
            fake_scalar_patch = patch_discriminator(fake_target_patch)
            Patch_GAN_Loss = -torch.mean(fake_scalar_patch)

            # Overall Loss and optimize
            loss = opt.lambda_l1 * Pixellevel_L1_Loss + opt.lambda_percep * Perceptual_Loss + opt.lambda_gan * (
                Global_GAN_Loss + Patch_GAN_Loss)
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [Perceptual Loss: %.4f]"
                % ((epoch + 1), opt.epochs, i, len(train_loader),
                   Pixellevel_L1_Loss.item(), Perceptual_Loss.item()))
            print(
                "\r[Global G Loss: %.4f, D Loss: %.4f] [Patch G Loss: %.4f, D Loss: %.4f] Time_left: %s"
                % (Global_GAN_Loss.item(), loss_D1.item(),
                   Patch_GAN_Loss.item(), loss_D2.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), optimizer_D1)
            adjust_learning_rate(opt, (epoch + 1), optimizer_D2)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [true_input, fake_target, true_target]
            name_list = ['in', 'pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        ### Validation
        val_PSNR = 0
        num_of_val_image = 0

        for j, (true_input, true_target) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Forward propagation
            with torch.no_grad():
                fake_target = generator(true_input)

            # Accumulate num of image and val_PSNR
            num_of_val_image += true_input.shape[0]
            val_PSNR += utils.psnr(fake_target, true_target,
                                   1) * true_input.shape[0]
        val_PSNR = val_PSNR / num_of_val_image

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [true_input, fake_target, true_target]
            name_list = ['in', 'pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='val_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        # Record average PSNR
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))
예제 #6
0
def trainer_WGANGP(opt):

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    if not os.path.exists(opt.save_path):
        os.makedirs(opt.save_path)

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(generator)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.save_mode == 'epoch':
            model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch,
                                                        opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration,
                                                       opt.batch_size)
        save_name = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.ColorizationDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Calculate the gradient penalty loss for WGAN-GP
    def compute_gradient_penalty(D, input_samples, real_samples, fake_samples):
        # Random weight term for interpolation between real and fake samples
        alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples +
                        ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = D(input_samples, interpolates)
        # For PatchGAN
        fake = Variable(Tensor(real_samples.shape[0], 1, 30, 30).fill_(1.0),
                        requires_grad=False)
        # Get gradient w.r.t. interpolates
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
        return gradient_penalty

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_L, true_RGB, true_sal) in enumerate(dataloader):

            # To device
            true_L = true_L.cuda()
            true_RGB = true_RGB.cuda()
            true_sal = true_sal.cuda()
            true_sal = torch.cat((true_sal, true_sal, true_sal), 1)
            true_attn = true_RGB.mul(true_sal)

            ### Train Discriminator
            optimizer_D.zero_grad()

            # Generator output
            fake_RGB, fake_sal = generator(true_L)

            # Fake colorizations
            fake_scalar_d = discriminator(true_L, fake_RGB.detach())

            # True colorizations
            true_scalar_d = discriminator(true_L, true_RGB)

            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(
                discriminator, true_L.data, true_RGB.data, fake_RGB.data)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar_d) + torch.mean(
                fake_scalar_d) + opt.lambda_gp * gradient_penalty
            loss_D.backward()
            optimizer_D.step()

            ### Train Generator
            optimizer_G.zero_grad()

            fake_RGB, fake_sal = generator(true_L)

            # Pixel-level L1 Loss
            loss_L1 = criterion_L1(fake_RGB, true_RGB)

            # Attention Loss
            fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_attn = fake_RGB.mul(fake_sal)
            loss_attn = criterion_L1(fake_attn, true_attn)

            # Perceptual Loss
            feature_fake_RGB = perceptualnet(fake_RGB)
            feature_true_RGB = perceptualnet(true_RGB)
            loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB)

            # GAN Loss
            fake_scalar = discriminator(true_L, fake_RGB)
            loss_GAN = -torch.mean(fake_scalar)

            # Overall Loss and optimize
            loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn
            loss_G.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(),
                   loss_attn.item(), loss_percep.item(), loss_D.item(),
                   loss_GAN.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_RGB, true_RGB]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=opt.sample_path,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list)
예제 #7
0
def trainer_LSGAN(opt):

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    if not os.path.exists(opt.save_path):
        os.makedirs(opt.save_path)

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(generator)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.save_mode == 'epoch':
            model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch,
                                                        opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration,
                                                       opt.batch_size)
        save_name = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.ColorizationDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_L, true_RGB, true_sal) in enumerate(dataloader):

            # To device
            true_L = true_L.cuda()
            true_RGB = true_RGB.cuda()
            true_sal = true_sal.cuda()
            true_sal = torch.cat((true_sal, true_sal, true_sal), 1)
            true_attn = true_RGB.mul(true_sal)

            # Adversarial ground truth
            valid = Tensor(np.ones((true_L.shape[0], 1, 30, 30)))
            fake = Tensor(np.zeros((true_L.shape[0], 1, 30, 30)))

            ### Train Discriminator
            optimizer_D.zero_grad()

            # Generator output
            fake_RGB, fake_sal = generator(true_L)

            # Fake colorizations
            fake_scalar_d = discriminator(true_L, fake_RGB.detach())
            loss_fake = criterion_MSE(fake_scalar_d, fake)

            # True colorizations
            true_scalar_d = discriminator(true_L, true_RGB)
            loss_true = criterion_MSE(true_scalar_d, valid)

            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_D.step()

            ### Train Generator
            optimizer_G.zero_grad()

            fake_RGB, fake_sal = generator(true_L)

            # Pixel-level L1 Loss
            loss_L1 = criterion_L1(fake_RGB, true_RGB)

            # Attention Loss
            fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_attn = fake_RGB.mul(fake_sal)
            loss_attn = criterion_L1(fake_attn, true_attn)

            # Perceptual Loss
            feature_fake_RGB = perceptualnet(fake_RGB)
            feature_true_RGB = perceptualnet(true_RGB)
            loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB)

            # GAN Loss
            fake_scalar = discriminator(true_L, fake_RGB)
            loss_GAN = criterion_MSE(fake_scalar, valid)

            # Overall Loss and optimize
            loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn
            loss_G.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(),
                   loss_attn.item(), loss_percep.item(), loss_D.item(),
                   loss_GAN.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_RGB, true_RGB]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=opt.sample_path,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list)
예제 #8
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'deepfillv2_LSGAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [0, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = -torch.mean(fake_scalar)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + \
                opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)

        ### Sample data every epoch
        masked_img = img * (1 - mask) + mask
        mask = torch.cat((mask, mask, mask), 1)
        if (epoch + 1) % 1 == 0:
            img_list = [img, mask, masked_img, first_out, second_out]
            name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)
예제 #9
0
def Trainer_GAN(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num
    print("Batch size is changed to %d" % opt.batch_size)
    print("Number of workers is changed to %d" % opt.num_workers)

    # Build path folder
    utils.check_path(opt.save_path)
    utils.check_path(opt.sample_path)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(optimizer, epoch, opt, init_lr):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = init_lr * (opt.lr_decrease_factor
                        **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'GrayInpainting_GAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (grayscale, mask) in enumerate(dataloader):

            # Load and put to cuda
            grayscale = grayscale.cuda()  # out: [B, 1, 256, 256]
            mask = mask.cuda()  # out: [B, 1, 256, 256]

            # LSGAN vectors
            valid = Tensor(np.ones((grayscale.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((grayscale.shape[0], 1, 8, 8)))

            # ----------------------------------------
            #           Train Discriminator
            # ----------------------------------------
            optimizer_d.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(grayscale, mask)
            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()

            # ----------------------------------------
            #             Train Generator
            # ----------------------------------------
            optimizer_g.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(out_wholeimg, grayscale)

            # GAN Loss
            fake_scalar = discriminator(out_wholeimg, mask)
            MaskGAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            out_3c = torch.cat((out_wholeimg, out_wholeimg, out_wholeimg), 1)
            grayscale_3c = torch.cat((grayscale, grayscale, grayscale), 1)
            out_featuremaps = perceptualnet(out_3c)
            gt_featuremaps = perceptualnet(grayscale_3c)
            PerceptualLoss = L1Loss(out_featuremaps, gt_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * MaskL1Loss + opt.lambda_perceptual * PerceptualLoss + opt.lambda_gan * MaskGAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [Perceptual Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   MaskL1Loss.item(), PerceptualLoss.item(), loss_D.item(),
                   MaskGAN_Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g)
        adjust_learning_rate(optimizer_d, (epoch + 1), opt, opt.lr_d)

        # Save the model
        save_model(generator, (epoch + 1), opt)
        utils.sample(grayscale, mask, out_wholeimg, opt.sample_path,
                     (epoch + 1))
예제 #10
0
파일: trainer.py 프로젝트: liujikun/ESRGAN
def ESRGAN_Trainer(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task_name)
    sample_folder = os.path.join(opt.sample_path, opt.task_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize networks
    generator = utils.create_ESRGAN_generator(opt)
    discriminator = utils.create_ESRGAN_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()

    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = 'Wavelet_epoch%d_bs%d.pth' % (epoch, opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'Wavelet_iter%d_bs%d.pth' % (iteration,
                                                      opt.batch_size)
        save_model_path = os.path.join(opt.save_path, opt.task_name,
                                       model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.LRHRDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Count start time
    prev_time = time.time()

    # Tensorboard
    writer = SummaryWriter()

    # For loop training
    for epoch in range(opt.epochs):

        # Record learning rate
        for param_group in optimizer_G.param_groups:
            writer.add_scalar('data/lr', param_group['lr'], epoch)
            print('learning rate = ', param_group['lr'])

        if epoch == 0:
            iters_done = 0

        for i, (img_LR, img_HR) in enumerate(dataloader):

            # To device
            # A is for downsample clean image, B is for noisy image
            #assert img_LR.shape == img_HR.shape
            img_LR = img_LR.cuda()
            img_HR = img_HR.cuda()

            # Adversarial ground truth
            valid = Tensor(
                np.ones((img_LR.shape[0], 1, img_LR.shape[2] // 8,
                         img_LR.shape[3] // 8)))
            fake = Tensor(
                np.zeros((img_LR.shape[0], 1, img_LR.shape[2] // 8,
                          img_LR.shape[3] // 8)))
            z = np.random.randn(opt.batch_size, 1, 128, 128).astype(np.float32)
            z = np.repeat(z, 64, axis=1)
            # z = np.repeat(z, 32, axis=3)
            z = Variable(torch.from_numpy(z)).cuda()
            # print(z.size())
            ### Train Generator
            # Forward
            pred = generator(img_LR, z)

            # L1 loss
            loss_L1 = criterion_L1(pred, img_HR)

            # gan part
            fake_scalar = discriminator(pred)
            loss_gan = criterion_MSE(fake_scalar, valid)

            # Perceptual loss part
            fea_true = perceptualnet(img_HR)
            fea_pred = perceptualnet(pred)
            # print(fea_pred.size())
            loss_percep = criterion_MSE(fea_true, fea_pred)

            # Overall Loss and optimize
            optimizer_G.zero_grad()
            loss = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_gan + opt.lambda_percep * loss_percep
            loss.backward()
            optimizer_G.step()

            ### Train Discriminator
            # Forward
            pred = generator(img_LR, z)

            # GAN loss
            fake_scalar = discriminator(pred.detach())
            loss_fake = criterion_MSE(fake_scalar, fake)
            true_scalar = discriminator(img_HR)
            loss_true = criterion_MSE(true_scalar, valid)

            # Overall Loss and optimize
            optimizer_D.zero_grad()
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_D.step()

            # Record losses
            writer.add_scalar('data/loss_L1', loss_L1.item(), iters_done)
            writer.add_scalar('data/loss_percep', loss_percep.item(),
                              iters_done)
            writer.add_scalar('data/loss_G', loss.item(), iters_done)
            writer.add_scalar('data/loss_D', loss_D.item(), iters_done)

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [L1 Loss: %.4f] [G Loss: %.4f] [G percep Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(),
                   loss_gan.item(), loss_percep.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)

            ### Sample data every epoch
            if (epoch + 1) % 1 == 0:
                img_list = [pred, img_HR]
                name_list = ['pred', 'gt']
                utils.save_sample_png(sample_folder=sample_folder,
                                      sample_name='epoch%d' % (epoch + 1),
                                      img_list=img_list,
                                      name_list=name_list,
                                      pixel_max_cnt=255)

    writer.close()
예제 #11
0
def Trainer_GAN(opt):
    # ----------------------------------------
    #              Initialization
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task_name)
    sample_folder = os.path.join(opt.sample_path, opt.task_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Initialize networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Define the dataset
    train_imglist = utils.get_jpgs(os.path.join(opt.in_path_train))
    val_imglist = utils.get_jpgs(os.path.join(opt.in_path_val))
    train_dataset = dataset.Qbayer2RGB_dataset(opt, 'train', train_imglist)
    val_dataset = dataset.Qbayer2RGB_dataset(opt, 'val', val_imglist)
    print('The overall number of training images:', len(train_imglist))
    print('The overall number of validation images:', len(val_imglist))

    # Define the dataloader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.train_batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.val_batch_size,
                                             shuffle=False,
                                             num_workers=opt.num_workers,
                                             pin_memory=True)

    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    class ColorLoss(nn.Module):
        def __init__(self):
            super(ColorLoss, self).__init__()
            self.L1loss = nn.L1Loss()

        def RGB2YUV(self, RGB):
            YUV = RGB.clone()
            YUV[:,
                0, :, :] = 0.299 * RGB[:,
                                       0, :, :] + 0.587 * RGB[:,
                                                              1, :, :] + 0.114 * RGB[:,
                                                                                     2, :, :]
            YUV[:,
                1, :, :] = -0.14713 * RGB[:,
                                          0, :, :] - 0.28886 * RGB[:,
                                                                   1, :, :] + 0.436 * RGB[:,
                                                                                          2, :, :]
            YUV[:,
                2, :, :] = 0.615 * RGB[:,
                                       0, :, :] - 0.51499 * RGB[:,
                                                                1, :, :] - 0.10001 * RGB[:,
                                                                                         2, :, :]
            return YUV

        def forward(self, x, y):
            yuv_x = self.RGB2YUV(x)
            yuv_y = self.RGB2YUV(y)
            return self.L1loss(yuv_x, yuv_y)

    yuv_loss = ColorLoss()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer, lr_gd):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = lr_gd * (opt.lr_decrease_factor
                          **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = lr_gd * (opt.lr_decrease_factor
                          **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = '%s_gan_noise%.3f_epoch%d_bs%d.pth' % (
                opt.net_mode, opt.noise_level, epoch, opt.train_batch_size)
        if opt.save_mode == 'iter':
            model_name = '%s_gan_noise%.3f_iter%d_bs%d.pth' % (
                opt.net_mode, opt.noise_level, iteration, opt.train_batch_size)
        save_model_path = os.path.join(opt.save_path, opt.task_name,
                                       model_name)
        # Save model
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensorboard
    writer = SummaryWriter()

    # For loop training
    for epoch in range(opt.epochs):

        # Record learning rate
        for param_group in optimizer_G.param_groups:
            writer.add_scalar('data/lr', param_group['lr'], epoch)
            print('learning rate = ', param_group['lr'])

        if epoch == 0:
            iters_done = 0

        ### Training
        for i, (in_img, RGBout_img) in enumerate(train_loader):

            # To device
            # A is for input image, B is for target image
            in_img = in_img.cuda()
            RGBout_img = RGBout_img.cuda()

            ## Train Discriminator
            # Forward propagation
            out = generator(in_img)

            optimizer_D.zero_grad()
            # Fake samples
            fake_scalar_d = discriminator(in_img, out.detach())
            true_scalar_d = discriminator(in_img, RGBout_img)
            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
            loss_D.backward()
            #torch.nn.utils.clip_grad_norm(discriminator.parameters(), opt.grad_clip_norm)
            optimizer_D.step()

            ## Train Generator
            # Forward propagation
            out = generator(in_img)

            # GAN loss
            fake_scalar = discriminator(in_img, out)
            L_gan = -torch.mean(fake_scalar) * opt.lambda_gan

            # Perceptual loss features
            fake_B_fea = perceptualnet(utils.normalize_ImageNet_stats(out))
            true_B_fea = perceptualnet(
                utils.normalize_ImageNet_stats(RGBout_img))
            L_percep = opt.lambda_percep * criterion_L1(fake_B_fea, true_B_fea)

            # Pixel loss
            L_pixel = opt.lambda_pixel * criterion_L1(out, RGBout_img)

            # Color loss
            L_color = opt.lambda_color * yuv_loss(out, RGBout_img)

            # Sum up to total loss
            loss = L_pixel + L_percep + L_gan + L_color

            # Record losses
            writer.add_scalar('data/L_pixel', L_pixel.item(), iters_done)
            writer.add_scalar('data/L_percep', L_percep.item(), iters_done)
            writer.add_scalar('data/L_color', L_color.item(), iters_done)
            writer.add_scalar('data/L_gan', L_gan.item(), iters_done)
            writer.add_scalar('data/L_total', loss.item(), iters_done)
            writer.add_scalar('data/loss_D', loss_D.item(), iters_done)

            # Backpropagate gradients
            optimizer_G.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm(generator.parameters(), opt.grad_clip_norm)
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i + 1
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] [L_pixel: %.4f]"
                % ((epoch + 1), opt.epochs, i, len(train_loader), loss.item(),
                   L_pixel.item()))
            print(
                "\r[L_percep: %.4f] [L_color: %.4f] [L_gan: %.4f] [loss_D: %.4f] Time_left: %s"
                % (L_percep.item(), L_color.item(), L_gan.item(),
                   loss_D.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), iters_done, len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_G,
                                 opt.lr_g)
            adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_D,
                                 opt.lr_d)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [out, RGBout_img]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        ### Validation
        val_PSNR = 0
        num_of_val_image = 0

        for j, (in_img, RGBout_img) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            in_img = in_img.cuda()
            RGBout_img = RGBout_img.cuda()

            # Forward propagation
            with torch.no_grad():
                out = generator(in_img)

            # Accumulate num of image and val_PSNR
            num_of_val_image += in_img.shape[0]
            val_PSNR += utils.psnr(out, RGBout_img, 1) * in_img.shape[0]
        val_PSNR = val_PSNR / num_of_val_image

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [out, RGBout_img]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='val_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        # Record average PSNR
        writer.add_scalar('data/val_PSNR', val_PSNR, epoch)
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))

    writer.close()
예제 #12
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the two-stage generator model
    def save_model_generator(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # Save the dicriminator model
    def save_model_discriminator(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # load the model
    def load_model(net, epoch, opt, type='G'):
        """Save the model at "checkpoint_interval" and its multiple"""
        if type == 'G':
            model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % (
                epoch, opt.batch_size)
        else:
            model_name = 'deepfillv2_WGAN_D_epoch%d_batchsize%d.pth' % (
                epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        pretrained_dict = torch.load(model_name)
        net.load_state_dict(pretrained_dict)

    if opt.resume:
        load_model(generator, opt.resume_epoch, opt, type='G')
        load_model(discriminator, opt.resume_epoch, opt, type='D')
        print(
            '--------------------Pretrained Models are Loaded--------------------'
        )

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = train_dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.num_workers,
                            pin_memory=True,
                            drop_last=True)

    # ----------------------------------------
    #            Training
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.resume_epoch, opt.epochs):
        for batch_idx, (img, height, width) in enumerate(dataloader):

            img = img.cuda()
            # set the same free form masks for each batch
            mask = torch.empty(img.shape[0], 1, img.shape[2],
                               img.shape[3]).cuda()
            for i in range(opt.batch_size):
                mask[i] = torch.from_numpy(
                    train_dataset.InpaintDataset.random_ff_mask(
                        shape=(height[0],
                               width[0])).astype(np.float32)).cuda()

            # LSGAN vectors
            valid = Tensor(
                np.ones((img.shape[0], 1, height[0] // 32, width[0] // 32)))
            fake = Tensor(
                np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32)))
            zero = Tensor(
                np.zeros((img.shape[0], 1, height[0] // 32, width[0] // 32)))

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [0, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Loss and optimize
            loss_fake = -torch.mean(torch.min(zero, -valid - fake_scalar))
            loss_true = -torch.mean(torch.min(zero, -valid + true_scalar))
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # L1 Loss
            first_L1Loss = (first_out - img).abs().mean()
            second_L1Loss = (second_out - img).abs().mean()

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = -torch.mean(fake_scalar)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_featuremaps = perceptualnet(second_out)
            second_PerceptualLoss = L1Loss(second_out_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * first_L1Loss + opt.lambda_l1 * second_L1Loss + \
                   opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_L1Loss.item(), second_L1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

            masked_img = img * (1 - mask) + mask
            mask = torch.cat((mask, mask, mask), 1)
            if (batch_idx + 1) % 40 == 0:
                img_list = [img, mask, masked_img, first_out, second_out]
                name_list = [
                    'gt', 'mask', 'masked_img', 'first_out', 'second_out'
                ]
                utils.save_sample_png(sample_folder=sample_folder,
                                      sample_name='epoch%d' % (epoch + 1),
                                      img_list=img_list,
                                      name_list=name_list,
                                      pixel_max_cnt=255)

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model_generator(generator, (epoch + 1), opt)
        save_model_discriminator(discriminator, (epoch + 1), opt)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [img, mask, masked_img, first_out, second_out]
            name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)