Пример #1
0
def main():
    batch_size = 1
    n_epochs = 10
    TRAINING_SIZE = 192
    TEST_SIZE = 48

    train_generator, _ = create_generator("C:\\Users\\Francis\\Documents\\Repositories\\LungSegmentation\\data\\prepared", add_contours=True, n_augments=batch_size)

    test_generator, _ = create_generator("C:\\Users\\Francis\\Documents\\Repositories\\LungSegmentation\\data\\prepared", add_contours=True, is_train=False, n_augments=1)
    model_checkpoint = ModelCheckpoint('dcan.hdf5', monitor='loss',verbose=0, save_best_only=True)

    base_log_dir = os.path.join(os.getcwd(), "logs\\fit\\dcan")
    if not os.path.exists(base_log_dir):
        os.makedirs(base_log_dir)

    log_dir = base_log_dir + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    tensorboard_callback = TensorBoard(log_dir=log_dir)
    early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5)
    model = build_dcan()

    model.fit(train_generator, 
        steps_per_epoch=TRAINING_SIZE,
        epochs=n_epochs,
        verbose=1,
        validation_data=test_generator,
        validation_steps=TEST_SIZE,
        callbacks=[model_checkpoint, early_stopping_callback])
Пример #2
0
def WGAN_tester(opt):
    
    # Save the model if pre_train == True
    def load_model_generator(net, epoch, opt):
        model_name = 'deepfillv2_WGAN_G_epoch%d_batchsize%d.pth' % (epoch, 4)
        model_name = os.path.join('pretrained_model', model_name)
        pretrained_dict = torch.load(model_name)
        generator.load_state_dict(pretrained_dict)

    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # configurations
    if not os.path.exists(results_path):
        os.makedirs(results_path)

    # Build networks
    generator = utils.create_generator(opt).eval()
    print('-------------------------Loading Pretrained Model-------------------------')
    load_model_generator(generator, opt.epoch, opt)
    print('-------------------------Pretrained Model Loaded-------------------------')

    # To device
    generator = generator.cuda()
    
    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = test_dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = False, num_workers = opt.num_workers, pin_memory = True)
    
    # ----------------------------------------
    #            Testing
    # ----------------------------------------
    # Testing loop
    for batch_idx, (img, mask) in enumerate(dataloader):
        img = img.cuda()
        mask = mask.cuda()

        # Generator output
        with torch.no_grad():
            first_out, second_out = generator(img, mask)

        # forward propagation
        first_out_wholeimg = img * (1 - mask) + first_out * mask        # in range [0, 1]
        second_out_wholeimg = img * (1 - mask) + second_out * mask      # in range [0, 1]

        masked_img = img * (1 - mask) + mask
        mask = torch.cat((mask, mask, mask), 1)
        img_list = [second_out_wholeimg]
        name_list = ['second_out']
        utils.save_sample_png(sample_folder = results_path, sample_name = '%d' % (batch_idx + 1), img_list = img_list, name_list = name_list, pixel_max_cnt = 255)
        print('----------------------batch_idx%d' % (batch_idx + 1) + ' has been finished----------------------')
Пример #3
0
def create_networks(opt, checkpoint=None):
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    if checkpoint:
        # Restore the network state
        generator.load_state_dict(checkpoint['G'])
        discriminator.load_state_dict(checkpoint['D'])

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    return generator, discriminator, perceptualnet
Пример #4
0
def Pre_train(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    #torch.cuda.set_device(1)

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    utils.check_path(save_folder)
    utils.check_path(sample_folder)

    # Loss functions
    if opt.no_gpu == False:
        criterion_L1 = torch.nn.L1Loss().cuda()
        criterion_L2 = torch.nn.MSELoss().cuda()
        #criterion_rainypred = torch.nn.L1Loss().cuda()
        criterion_ssim = pytorch_ssim.SSIM().cuda()
    else:
        criterion_L1 = torch.nn.L1Loss()
        criterion_L2 = torch.nn.MSELoss()
        #criterion_rainypred = torch.nn.L1Loss().cuda()
        criterion_ssim = pytorch_ssim.SSIM()

    # Initialize Generator
    generator = utils.create_generator(opt)

    # To device
    if opt.no_gpu == False:
        if opt.multi_gpu:
            generator = nn.DataParallel(generator)
            generator = generator.cuda()
        else:
            generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                          generator.parameters()),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    #optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr_g, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)

    # pretrained model
    #encnet = encoding.models.get_model('Encnet_ResNet50s_PContext', pretrained=True).cuda()
    #encnet.eval()
    #resnet = (torch.nn.Sequential(*list(encnet.children())[:1]))[0]
    #resnet.eval()
    #encnet_feat = torch.nn.Sequential(*list(resnet.children())[:1])
    #encnet_feat.eval()

    #for param in encnet.parameters():
    #    param.requires_grad = False
    print("pretrained models loaded")

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        target_epoch = opt.epochs - opt.lr_decrease_epoch
        remain_epoch = opt.epochs - epoch
        if epoch >= opt.lr_decrease_epoch:
            lr = opt.lr_g * remain_epoch / target_epoch
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Define the name of trained model
        """
        if opt.save_mode == 'epoch':
            model_name = 'KPN_single_image_epoch%d_bs%d_mu%d_sigma%d.pth' % (epoch, opt.train_batch_size, opt.mu, opt.sigma)
        if opt.save_mode == 'iter':
            model_name = 'KPN_single_image_iter%d_bs%d_mu%d_sigma%d.pth' % (iteration, opt.train_batch_size, opt.mu, opt.sigma)
        """
        if opt.save_mode == 'epoch':
            model_name = 'KPN_rainy_image_epoch%d_bs%d.pth' % (
                epoch, opt.train_batch_size)
        if opt.save_mode == 'iter':
            model_name = 'KPN_rainy_image_iter%d_bs%d.pth' % (
                iteration, opt.train_batch_size)
        save_model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    #os.environ["CUDA_VISIBLE_DEVICES"] = ""
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    #if opt.no_gpu == False:
    #opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    #opt.num_workers *= gpu_num

    #print(opt.multi_gpu)
    '''
    print(opt.no_gpu == False)
    print(opt.no_gpu)
    print(gpu_num)
    print(opt.train_batch_size)
    '''

    # Define the dataset
    trainset = dataset.DenoisingDataset(opt)
    print('The overall number of training images:', len(trainset))

    # Define the dataloader
    train_loader = DataLoader(trainset,
                              batch_size=opt.train_batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(train_loader):

            #print("in epoch %d" % i)

            if opt.no_gpu == False:
                # To device
                true_input = true_input.cuda()
                true_target = true_target.cuda()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input, true_input)

            ssim_loss = -criterion_ssim(true_target, fake_target)
            '''
            #trans for enc_net
            enc_trans = transforms.Compose([transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
            fake_target_norm = torch.from_numpy(np.zeros(fake_target.size())).cuda()
            true_target_norm = torch.from_numpy(np.zeros(true_target.size())).cuda()
            for j in range(fake_target.size()[0]):
                fake_target_norm[j] = enc_trans(fake_target[j])
                true_target_norm[j] = enc_trans(true_target[j])
            '''

            #print(fake_target_norm.size())
            #enc_pred = encnet.evaluate(fake_target_norm.type(torch.FloatTensor).cuda())
            #enc_pred = encnet(fake_target_norm.type(torch.FloatTensor).cuda())[0]
            #enc_gt = encnet(true_target_norm.type(torch.FloatTensor).cuda())[0]
            '''
            enc_feat_pred = encnet_feat(fake_target_norm.type(torch.FloatTensor).cuda())[0]
            enc_feat_gt = encnet_feat(true_target_norm.type(torch.FloatTensor).cuda())[0]
            '''

            #rain_layer_gt = true_input - true_target
            #rain_layer_pred = true_input - fake_target
            #rainy_pred = true_input - (fake_target * rain_layer_pred)
            #print(type(true_input))
            #print(type(fake_target))

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)
            #enc_loss = criterion_L1(enc_pred, enc_gt)
            #enc_feat_loss = criterion_L1(enc_feat_pred, enc_feat_gt)
            #Pixellevel_L2_Loss = criterion_L2(fake_target, true_target)
            #Pixellevel_L2_Loss = criterion_L2(rain_layer_pred, rain_layer_gt)
            #Loss_rainypred = criterion_rainypred(rainy_pred, true_input)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + 0.2 * ssim_loss
            #loss = Pixellevel_L1_Loss
            #loss = Pixellevel_L1_Loss + Pixellevel_L2_Loss + Loss_rainypred
            loss.backward()
            optimizer_G.step()

            #check
            '''
            for j in encnet.named_parameters():
                print(j)
                break
            '''

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Loss: %.4f %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(train_loader),
                   Pixellevel_L1_Loss.item(), ssim_loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), optimizer_G)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [true_input, fake_target, true_target]
            name_list = ['in', 'pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)
        '''
Пример #5
0
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task_name)
    sample_folder = os.path.join(opt.sample_path, opt.task_name)
    utils.check_path(save_folder)
    utils.check_path(sample_folder)

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        target_epoch = opt.epochs - opt.lr_decrease_epoch
        remain_epoch = opt.epochs - epoch
        if epoch >= opt.lr_decrease_epoch:
            lr = opt.lr_g * remain_epoch / target_epoch
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(
                        generator.module.state_dict(),
                        'DeblurGANv1_wgan_epoch%d_bs%d.pth' %
                        (epoch, opt.batch_size))
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(
                        generator.module.state_dict(),
                        'DeblurGANv1_wgan_iter%d_bs%d.pth' %
                        (iteration, opt.train_batch_size))
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(
                        generator.state_dict(),
                        'DeblurGANv1_wgan_epoch%d_bs%d.pth' %
                        (epoch, opt.train_batch_size))
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(
                        generator.state_dict(),
                        'DeblurGANv1_wgan_iter%d_bs%d.pth' %
                        (iteration, opt.train_batch_size))
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Define the dataset
    trainset = dataset.DeblurDataset(opt, 'train')
    valset = dataset.DeblurDataset(opt, 'val')
    print('The overall number of training images:', len(trainset))
    print('The overall number of validation images:', len(valset))

    # Define the dataloader
    train_loader = DataLoader(trainset,
                              batch_size=opt.train_batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=opt.val_batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(train_loader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Discriminator
            for j in range(opt.additional_training_d):

                optimizer_D.zero_grad()

                # Generator output
                fake_target = generator(true_input)

                # Fake samples
                fake_scalar_d = discriminator(true_input, fake_target.detach())
                true_scalar_d = discriminator(true_input, true_target)
                # Overall Loss and optimize
                loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
                loss_D.backward()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # GAN Loss
            fake_scalar = discriminator(true_input, fake_target)
            GAN_Loss = -torch.mean(fake_scalar)

            # Overall Loss and optimize
            loss = opt.lambda_l1 * Pixellevel_L1_Loss + GAN_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(train_loader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), optimizer_D)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_target, true_target]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        ### Validation
        val_PSNR = 0
        num_of_val_image = 0

        for j, (true_input, true_target) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Forward propagation
            with torch.no_grad():
                fake_target = generator(true_input)

            # Accumulate num of image and val_PSNR
            num_of_val_image += true_input.shape[0]
            val_PSNR += utils.psnr(fake_target, true_target,
                                   1) * true_input.shape[0]
        val_PSNR = val_PSNR / num_of_val_image

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_target, true_target]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='val_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        # Record average PSNR
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))
def Trainer(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize SGN
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
    
    # Learning rate decrease
    def adjust_learning_rate(opt, iteration, optimizer):
        # Set the learning rate to the specific value
        if iteration >= opt.iter_decreased:
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt.lr_decreased

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, network):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
                    torch.save(network.module.state_dict(), 'SGN_epoch%d_bs%d_mu%d_sigma%d.pth' % (epoch, opt.batch_size, opt.mu, opt.sigma))
                    print('The trained model is successfully saved at epoch %d' % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(network.module.state_dict(), 'SGN_iter%d_bs%d_mu%d_sigma%d.pth' % (iteration, opt.batch_size, opt.mu, opt.sigma))
                    print('The trained model is successfully saved at iteration %d' % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
                    torch.save(network.state_dict(), 'SGN_epoch%d_bs%d_mu%d_sigma%d.pth' % (epoch, opt.batch_size, opt.mu, opt.sigma))
                    print('The trained model is successfully saved at epoch %d' % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(network.state_dict(), 'SGN_iter%d_bs%d_mu%d_sigma%d.pth' % (iteration, opt.batch_size, opt.mu, opt.sigma))
                    print('The trained model is successfully saved at iteration %d' % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.DenoisingDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (noisy_img, img) in enumerate(dataloader):

            # To device
            noisy_img = noisy_img.cuda()
            img = img.cuda()

            # Train Generator
            optimizer_G.zero_grad()

            # Forword propagation
            recon_img = generator(noisy_img)
            loss = criterion_L1(recon_img, img)

            # Overall Loss and optimize
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] Time_left: %s" %
                ((epoch + 1), opt.epochs, i, len(dataloader), loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (iters_done + 1), optimizer_G)
Пример #7
0
def CycleGAN_LSGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize Generator
    # A is for grayscale image
    # B is for color RGB image
    G_AB = utils.create_generator(opt)
    G_BA = utils.create_generator(opt)
    D_A = utils.create_discriminator(opt)
    D_B = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        G_AB = nn.DataParallel(G_AB)
        G_AB = G_AB.cuda()
        G_BA = nn.DataParallel(G_BA)
        G_BA = G_BA.cuda()
        D_A = nn.DataParallel(D_A)
        D_A = D_A.cuda()
        D_B = nn.DataParallel(D_B)
        D_B = D_B.cuda()
    else:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, G_AB, G_BA):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            G_AB.module, 'G_AB_LSGAN_epoch%d_bs%d.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            G_BA.module, 'G_BA_LSGAN_epoch%d_bs%d.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            G_AB.module, 'G_AB_LSGAN_iter%d_bs%d.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            G_BA.module, 'G_BA_LSGAN_iter%d_bs%d.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            G_AB, 'G_AB_LSGAN_epoch%d_bs%d.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            G_BA, 'G_BA_LSGAN_epoch%d_bs%d.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            G_AB, 'G_AB_LSGAN_iter%d_bs%d.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            G_BA, 'G_BA_LSGAN_iter%d_bs%d.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.DomainTransferDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_A, true_B) in enumerate(dataloader):

            # To device
            # A is for grayscale image
            # B is for color RGB image
            true_A = true_A.cuda()
            true_B = true_B.cuda()

            # Adversarial ground truth
            valid = Tensor(np.ones((true_A.shape[0], 1, 16, 16)))
            fake = Tensor(np.zeros((true_A.shape[0], 1, 16, 16)))

            # Train Generator
            optimizer_G.zero_grad()

            # Indentity Loss
            loss_indentity_A = criterion_L1(G_BA(true_A), true_A)
            loss_indentity_B = criterion_L1(G_AB(true_B), true_B)
            loss_indentity = (loss_indentity_A + loss_indentity_B) / 2

            # GAN Loss
            fake_B = G_AB(true_A)
            loss_GAN_AB = criterion_MSE(D_B(fake_B), valid)
            fake_A = G_BA(true_B)
            loss_GAN_BA = criterion_MSE(D_A(fake_A), valid)
            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle-consistency Loss
            recon_A = G_BA(fake_B)
            loss_cycle_A = criterion_L1(recon_A, true_A)
            recon_B = G_AB(fake_A)
            loss_cycle_B = criterion_L1(recon_B, true_B)
            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Overall Loss and optimize
            loss = loss_GAN + opt.lambda_cycle * loss_cycle + opt.lambda_identity * loss_indentity
            loss.backward()
            optimizer_G.step()

            # Train Discriminator A
            optimizer_D_A.zero_grad()

            # Fake samples
            fake_scalar_d = D_A(fake_A.detach())
            loss_fake = criterion_MSE(fake_scalar_d, fake)

            # True samples
            true_scalar_d = D_A(true_A)
            loss_true = criterion_MSE(true_scalar_d, valid)

            # Overall Loss and optimize
            loss_D_A = 0.5 * (loss_fake + loss_true)
            loss_D_A.backward()

            # Train Discriminator B
            optimizer_D_B.zero_grad()

            # Fake samples
            fake_scalar_d = D_B(fake_B.detach())
            loss_fake = criterion_MSE(fake_scalar_d, fake)

            # True samples
            true_scalar_d = D_B(true_B)
            loss_true = criterion_MSE(true_scalar_d, valid)

            # Overall Loss and optimize
            loss_D_B = 0.5 * (loss_fake + loss_true)
            loss_D_B.backward()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [D_A Loss: %.4f] [D_B Loss: %.4f] [G GAN Loss: %.4f] [G Cycle Loss: %.4f] [G Indentity Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   loss_D_A.item(), loss_D_B.item(), loss_GAN.item(),
                   loss_cycle.item(), loss_indentity.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       G_AB, G_BA)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_A)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_B)
def Pre_train(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_L2 = torch.nn.MSELoss().cuda()

    # Initialize Generator
    G = utils.create_generator(opt)
    D_cVAE = utils.create_discriminator(opt)
    D_cLR = utils.create_discriminator(opt)
    E = utils.create_encoder(opt)

    # To device
    if opt.multi_gpu:
        G = nn.DataParallel(G)
        G = G.cuda()
        D_cVAE = nn.DataParallel(D_cVAE)
        D_cVAE = D_cVAE.cuda()
        D_cLR = nn.DataParallel(D_cLR)
        D_cLR = discriminator_cLR.cuda()
        E = nn.DataParallel(E)
        E = E.cuda()
    else:
        G = G.cuda()
        D_cVAE = D_cVAE.cuda()
        D_cLR = D_cLR.cuda()
        E = E.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(G.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D_cVAE = torch.optim.Adam(D_cVAE.parameters(),
                                        lr=opt.lr,
                                        betas=(opt.b1, opt.b2),
                                        weight_decay=opt.weight_decay)
    optimizer_D_cLR = torch.optim.Adam(D_cLR.parameters(),
                                       lr=opt.lr,
                                       betas=(opt.b1, opt.b2),
                                       weight_decay=opt.weight_decay)
    optimizer_E = torch.optim.Adam(E.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        decay_rate = 1.0 - (max(0, epoch - opt.start_decrease_epoch) //
                            opt.lr_decrease_divide)
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        lr = opt.lr_g * decay_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'Pre_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'Pre_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'Pre_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'Pre_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.DomainTransferDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device, and seperate data for cVAE_GAN and cLR_GAN
            true_input = true_input.cuda()
            true_target = true_target.cuda()
            cVAE_data = {
                'img': true_input[[0], :, :, :],
                'ground_truth': true_target[[0], :, :, :]
            }
            cLR_data = {
                'img': true_input[[1], :, :, :],
                'ground_truth': true_target[[1], :, :, :]
            }
            ''' ----------------------------- 1. Train D ----------------------------- '''
            #############   Step 1. D loss in cVAE-GAN #############

            # Encoded latent vector
            mu, log_variance = E(cVAE_data['ground_truth'])
            std = torch.exp(log_variance / 2)
            random_z = torch.randn(1, opt.z_dim).cuda()
            encoded_z = (random_z * std) + mu

            # Generate fake image
            fake_img_cVAE = G(cVAE_data['img'], encoded_z)

            # Get scores and loss
            real_d_cVAE_1, real_d_cVAE_2 = D_cVAE(cVAE_data['ground_truth'])
            fake_d_cVAE_1, fake_d_cVAE_2 = D_cVAE(fake_img_cVAE.detach())

            # mse_loss for LSGAN
            D_loss_cVAE_1 = criterion_L2(real_d_cVAE_1, 1) + criterion_L2(
                fake_d_cVAE_1, 0)
            D_loss_cVAE_2 = criterion_L2(real_d_cVAE_2, 1) + criterion_L2(
                fake_d_cVAE_2, 0)

            #############   Step 2. D loss in cLR-GAN   #############

            # Random latent vector
            random_z = torch.randn(1, opt.z_dim).cuda()

            # Generate fake image
            fake_img_cLR = G(cLR_data['img'], random_z)

            # Get scores and loss
            real_d_cLR_1, real_d_cLR_2 = D_cLR(cLR_data['ground_truth'])
            fake_d_cLR_1, fake_d_cLR_2 = D_cLR(fake_img_cLR.detach())

            D_loss_cLR_1 = criterion_L2(real_d_cLR_1, 1) + criterion_L2(
                fake_d_cLR_1, 0)
            D_loss_cLR_2 = criterion_L2(real_d_cLR_2, 1) + criterion_L2(
                fake_d_cLR_2, 0)

            D_loss = D_loss_cVAE_1 + D_loss_cLR_1 + D_loss_cVAE_2 + D_loss_cLR_2

            # Update
            optimizer_D_cVAE.zero_grad()
            optimizer_D_cLR.zero_grad()
            D_loss.backward()
            optimizer_D_cVAE.step()
            optimizer_D_cLR.step()
            ''' ----------------------------- 2. Train G & E ----------------------------- '''
            ############# Step 1. GAN loss to fool discriminator (cVAE_GAN and cLR_GAN) #############

            # Encoded latent vector
            mu, log_variance = E(cVAE_data['ground_truth'])
            std = torch.exp(log_variance / 2)
            random_z = torch.randn(1, opt.z_dim).cuda()
            encoded_z = (random_z * std) + mu

            # Generate fake image and get adversarial loss
            fake_img_cVAE = G(cVAE_data['img'], encoded_z)
            fake_d_cVAE_1, fake_d_cVAE_2 = D_cVAE(fake_img_cVAE)

            GAN_loss_cVAE_1 = criterion_L2(fake_d_cVAE_1, 1)
            GAN_loss_cVAE_2 = criterion_L2(fake_d_cVAE_2, 1)

            # Random latent vector
            random_z = torch.randn(1, opt.z_dim).cuda()

            # Generate fake image and get adversarial loss
            fake_img_cLR = G(cLR_data['img'], random_z)
            fake_d_cLR_1, fake_d_cLR_2 = D_cLR(fake_img_cLR)

            GAN_loss_cLR_1 = criterion_L2(fake_d_cLR_1, 1)
            GAN_loss_cLR_2 = criterion_L2(fake_d_cLR_2, 1)

            G_GAN_loss = GAN_loss_cVAE_1 + GAN_loss_cVAE_2 + GAN_loss_cLR_1 + GAN_loss_cLR_2
            G_GAN_loss = opt.lambda_gan * G_GAN_loss

            ############# Step 2. KL-divergence with N(0, 1) (cVAE-GAN) #############

            KL_div_loss = opt.lambda_kl * torch.sum(
                0.5 * (mu**2 + torch.exp(log_variance) - log_variance - 1))

            ############# Step 3. Reconstruction of ground truth image (|G(A, z) - B|) (cVAE-GAN) #############
            img_recon_loss = opt.lambda_recon * criterion_L1(
                fake_img_cVAE, cVAE_data['ground_truth'])

            EG_loss = G_GAN_loss + KL_div_loss + img_recon_loss
            optimizer_G.zero_grad()
            optimizer_E.zero_grad()
            EG_loss.backward(retain_graph=True)
            optimizer_G.step()
            optimizer_E.step()
            ''' ----------------------------- 3. Train ONLY G ----------------------------- '''
            ############ Step 1. Reconstrution of random latent code (|E(G(A, z)) - z|) (cLR-GAN) ############

            # This step should update ONLY G.
            mu_, log_variance_ = E(fake_img_cLR)

            G_alone_loss = opt.lambda_z * criterion_L1(mu_, random_z)

            optimizer_G.zero_grad()
            G_alone_loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [D Loss: %.4f] [GAN Loss: %.4f] [Recon Loss: %.4f] [KL Loss: %.4f] [z Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), D_loss.item(),
                   D_loss.item(), G_GAN_loss.item(), img_recon_loss.item(),
                   KL_div_loss.item(), G_alone_loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), optimizer_D_cVAE)
            adjust_learning_rate(opt, (epoch + 1), optimizer_D_cLR)
            adjust_learning_rate(opt, (epoch + 1), optimizer_E)
Пример #9
0
def Trainer_GAN(opt):
    # ----------------------------------------
    #              Initialization
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task_name)
    sample_folder = os.path.join(opt.sample_path, opt.task_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Initialize networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Define the dataset
    train_imglist = utils.get_jpgs(os.path.join(opt.in_path_train))
    val_imglist = utils.get_jpgs(os.path.join(opt.in_path_val))
    train_dataset = dataset.Qbayer2RGB_dataset(opt, 'train', train_imglist)
    val_dataset = dataset.Qbayer2RGB_dataset(opt, 'val', val_imglist)
    print('The overall number of training images:', len(train_imglist))
    print('The overall number of validation images:', len(val_imglist))

    # Define the dataloader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.train_batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.val_batch_size,
                                             shuffle=False,
                                             num_workers=opt.num_workers,
                                             pin_memory=True)

    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    class ColorLoss(nn.Module):
        def __init__(self):
            super(ColorLoss, self).__init__()
            self.L1loss = nn.L1Loss()

        def RGB2YUV(self, RGB):
            YUV = RGB.clone()
            YUV[:,
                0, :, :] = 0.299 * RGB[:,
                                       0, :, :] + 0.587 * RGB[:,
                                                              1, :, :] + 0.114 * RGB[:,
                                                                                     2, :, :]
            YUV[:,
                1, :, :] = -0.14713 * RGB[:,
                                          0, :, :] - 0.28886 * RGB[:,
                                                                   1, :, :] + 0.436 * RGB[:,
                                                                                          2, :, :]
            YUV[:,
                2, :, :] = 0.615 * RGB[:,
                                       0, :, :] - 0.51499 * RGB[:,
                                                                1, :, :] - 0.10001 * RGB[:,
                                                                                         2, :, :]
            return YUV

        def forward(self, x, y):
            yuv_x = self.RGB2YUV(x)
            yuv_y = self.RGB2YUV(y)
            return self.L1loss(yuv_x, yuv_y)

    yuv_loss = ColorLoss()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer, lr_gd):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = lr_gd * (opt.lr_decrease_factor
                          **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = lr_gd * (opt.lr_decrease_factor
                          **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = '%s_gan_noise%.3f_epoch%d_bs%d.pth' % (
                opt.net_mode, opt.noise_level, epoch, opt.train_batch_size)
        if opt.save_mode == 'iter':
            model_name = '%s_gan_noise%.3f_iter%d_bs%d.pth' % (
                opt.net_mode, opt.noise_level, iteration, opt.train_batch_size)
        save_model_path = os.path.join(opt.save_path, opt.task_name,
                                       model_name)
        # Save model
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensorboard
    writer = SummaryWriter()

    # For loop training
    for epoch in range(opt.epochs):

        # Record learning rate
        for param_group in optimizer_G.param_groups:
            writer.add_scalar('data/lr', param_group['lr'], epoch)
            print('learning rate = ', param_group['lr'])

        if epoch == 0:
            iters_done = 0

        ### Training
        for i, (in_img, RGBout_img) in enumerate(train_loader):

            # To device
            # A is for input image, B is for target image
            in_img = in_img.cuda()
            RGBout_img = RGBout_img.cuda()

            ## Train Discriminator
            # Forward propagation
            out = generator(in_img)

            optimizer_D.zero_grad()
            # Fake samples
            fake_scalar_d = discriminator(in_img, out.detach())
            true_scalar_d = discriminator(in_img, RGBout_img)
            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
            loss_D.backward()
            #torch.nn.utils.clip_grad_norm(discriminator.parameters(), opt.grad_clip_norm)
            optimizer_D.step()

            ## Train Generator
            # Forward propagation
            out = generator(in_img)

            # GAN loss
            fake_scalar = discriminator(in_img, out)
            L_gan = -torch.mean(fake_scalar) * opt.lambda_gan

            # Perceptual loss features
            fake_B_fea = perceptualnet(utils.normalize_ImageNet_stats(out))
            true_B_fea = perceptualnet(
                utils.normalize_ImageNet_stats(RGBout_img))
            L_percep = opt.lambda_percep * criterion_L1(fake_B_fea, true_B_fea)

            # Pixel loss
            L_pixel = opt.lambda_pixel * criterion_L1(out, RGBout_img)

            # Color loss
            L_color = opt.lambda_color * yuv_loss(out, RGBout_img)

            # Sum up to total loss
            loss = L_pixel + L_percep + L_gan + L_color

            # Record losses
            writer.add_scalar('data/L_pixel', L_pixel.item(), iters_done)
            writer.add_scalar('data/L_percep', L_percep.item(), iters_done)
            writer.add_scalar('data/L_color', L_color.item(), iters_done)
            writer.add_scalar('data/L_gan', L_gan.item(), iters_done)
            writer.add_scalar('data/L_total', loss.item(), iters_done)
            writer.add_scalar('data/loss_D', loss_D.item(), iters_done)

            # Backpropagate gradients
            optimizer_G.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm(generator.parameters(), opt.grad_clip_norm)
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i + 1
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] [L_pixel: %.4f]"
                % ((epoch + 1), opt.epochs, i, len(train_loader), loss.item(),
                   L_pixel.item()))
            print(
                "\r[L_percep: %.4f] [L_color: %.4f] [L_gan: %.4f] [loss_D: %.4f] Time_left: %s"
                % (L_percep.item(), L_color.item(), L_gan.item(),
                   loss_D.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), iters_done, len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_G,
                                 opt.lr_g)
            adjust_learning_rate(opt, (epoch + 1), iters_done, optimizer_D,
                                 opt.lr_d)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [out, RGBout_img]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        ### Validation
        val_PSNR = 0
        num_of_val_image = 0

        for j, (in_img, RGBout_img) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            in_img = in_img.cuda()
            RGBout_img = RGBout_img.cuda()

            # Forward propagation
            with torch.no_grad():
                out = generator(in_img)

            # Accumulate num of image and val_PSNR
            num_of_val_image += in_img.shape[0]
            val_PSNR += utils.psnr(out, RGBout_img, 1) * in_img.shape[0]
        val_PSNR = val_PSNR / num_of_val_image

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [out, RGBout_img]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='val_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

        # Record average PSNR
        writer.add_scalar('data/val_PSNR', val_PSNR, epoch)
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))

    writer.close()
Пример #10
0
def Trainer_GAN(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num
    print("Batch size is changed to %d" % opt.batch_size)
    print("Number of workers is changed to %d" % opt.num_workers)

    # Build path folder
    utils.check_path(opt.save_path)
    utils.check_path(opt.sample_path)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(optimizer, epoch, opt, init_lr):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = init_lr * (opt.lr_decrease_factor
                        **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'GrayInpainting_GAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (grayscale, mask) in enumerate(dataloader):

            # Load and put to cuda
            grayscale = grayscale.cuda()  # out: [B, 1, 256, 256]
            mask = mask.cuda()  # out: [B, 1, 256, 256]

            # LSGAN vectors
            valid = Tensor(np.ones((grayscale.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((grayscale.shape[0], 1, 8, 8)))

            # ----------------------------------------
            #           Train Discriminator
            # ----------------------------------------
            optimizer_d.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(grayscale, mask)
            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()

            # ----------------------------------------
            #             Train Generator
            # ----------------------------------------
            optimizer_g.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(out_wholeimg, grayscale)

            # GAN Loss
            fake_scalar = discriminator(out_wholeimg, mask)
            MaskGAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            out_3c = torch.cat((out_wholeimg, out_wholeimg, out_wholeimg), 1)
            grayscale_3c = torch.cat((grayscale, grayscale, grayscale), 1)
            out_featuremaps = perceptualnet(out_3c)
            gt_featuremaps = perceptualnet(grayscale_3c)
            PerceptualLoss = L1Loss(out_featuremaps, gt_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * MaskL1Loss + opt.lambda_perceptual * PerceptualLoss + opt.lambda_gan * MaskGAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [Perceptual Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   MaskL1Loss.item(), PerceptualLoss.item(), loss_D.item(),
                   MaskGAN_Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g)
        adjust_learning_rate(optimizer_d, (epoch + 1), opt, opt.lr_d)

        # Save the model
        save_model(generator, (epoch + 1), opt)
        utils.sample(grayscale, mask, out_wholeimg, opt.sample_path,
                     (epoch + 1))
Пример #11
0
def Pre_train(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    utils.check_path(save_folder)
    utils.check_path(sample_folder)

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        target_epoch = opt.epochs - opt.lr_decrease_epoch
        remain_epoch = opt.epochs - epoch
        if epoch >= opt.lr_decrease_epoch:
            lr = opt.lr_g * remain_epoch / target_epoch
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = 'KPN_single_image_epoch%d_bs%d_mu%d_sigma%d.pth' % (
                epoch, opt.train_batch_size, opt.mu, opt.sigma)
        if opt.save_mode == 'iter':
            model_name = 'KPN_single_image_iter%d_bs%d_mu%d_sigma%d.pth' % (
                iteration, opt.train_batch_size, opt.mu, opt.sigma)
        save_model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    #opt.val_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Define the dataset
    trainset = dataset.DenoisingDataset(opt)
    print('The overall number of training images:', len(trainset))

    # Define the dataloader
    train_loader = DataLoader(trainset,
                              batch_size=opt.train_batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(train_loader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input, true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(train_loader),
                   Pixellevel_L1_Loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), optimizer_G)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [true_input, fake_target, true_target]
            name_list = ['in', 'pred', 'gt']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='train_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)
        '''
Пример #12
0
def Trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate_g(optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = opt.lr_g * (opt.lr_decrease_factor
                         **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_learning_rate_d(optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = opt.lr_d * (opt.lr_decrease_factor
                         **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net.module, 'ContextureEncoder_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net, 'ContextureEncoder_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            masked_img = img * (1 - mask)
            fake = generator(masked_img)

            # Fake samples
            fake_scalar = discriminator(fake.detach())
            # True samples
            true_scalar = discriminator(img)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()

            ### Train Generator
            optimizer_g.zero_grad()

            # forward propagation
            fusion_fake = img * (1 - mask) + fake * mask  # in range [-1, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(fusion_fake, img)

            # GAN Loss
            fake_scalar = discriminator(fusion_fake)
            GAN_Loss = -torch.mean(fake_scalar)

            # Compute losses
            loss = MaskL1Loss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s"
                %
                ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                 MaskL1Loss.item(), loss_D.item(), GAN_Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate_g(optimizer_g, (epoch + 1), opt)
        adjust_learning_rate_d(optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)
def Trainer_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator_a, generator_b = utils.create_generator(opt)
    discriminator_a, discriminator_b = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator_a = nn.DataParallel(generator_a)
        generator_a = generator_a.cuda()
        generator_b = nn.DataParallel(generator_b)
        generator_b = generator_b.cuda()
        discriminator_a = nn.DataParallel(discriminator_a)
        discriminator_a = discriminator_a.cuda()
        discriminator_b = nn.DataParallel(discriminator_b)
        discriminator_b = discriminator_b.cuda()
    else:
        generator_a = generator_a.cuda()
        generator_b = generator_b.cuda()
        discriminator_a = discriminator_a.cuda()
        discriminator_b = discriminator_b.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(generator_a.parameters(),
                                                   generator_b.parameters()),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D_a = torch.optim.Adam(discriminator_a.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_b = torch.optim.Adam(discriminator_b.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator_a,
                   generator_b):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator_a.module,
                            'WGAN_DRIT_epoch%d_bs%d_a.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            generator_b.module,
                            'WGAN_DRIT_epoch%d_bs%d_b.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator_a.module, 'WGAN_DRIT_iter%d_bs%d_a.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            generator_b.module, 'WGAN_DRIT_iter%d_bs%d_b.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator_a, 'WGAN_DRIT_epoch%d_bs%d_a.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            generator_b, 'WGAN_DRIT_epoch%d_bs%d_b.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator_a, 'WGAN_DRIT_iter%d_bs%d_a.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            generator_b, 'WGAN_DRIT_iter%d_bs%d_b.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    dataloader = utils.create_dataloader(opt)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (img_a, img_b) in enumerate(dataloader):

            # To device
            img_a = img_a.cuda()
            img_b = img_b.cuda()

            # Sampled style codes (prior)
            prior_s_a = torch.randn(img_a.shape[0], opt.style_dim).cuda()
            prior_s_b = torch.randn(img_a.shape[0], opt.style_dim).cuda()

            # ----------------------------------------
            #              Train Generator
            # ----------------------------------------
            # Note that:
            # input / output image dimension: [B, 3, 256, 256]
            # content_code dimension: [B, 256, 64, 64]
            # style_code dimension: [B, 8]
            # generator_a is related to domain a / style a
            # generator_b is related to domain b / style b

            optimizer_G.zero_grad()

            # Get shared latent representation
            c_a, s_a = generator_a.encode(img_a)
            c_b, s_b = generator_b.encode(img_b)

            # Reconstruct images
            img_aa_recon = generator_a.decode(c_a, s_a)
            img_bb_recon = generator_b.decode(c_b, s_b)

            # Translate images
            img_ba = generator_a.decode(c_b, prior_s_a)
            img_ab = generator_b.decode(c_a, prior_s_b)

            # Cycle code translation
            c_b_recon, s_a_recon = generator_a.encode(img_ba)
            c_a_recon, s_b_recon = generator_b.encode(img_ab)

            # Cycle image translation
            img_aa_recon_cycle = generator_a.decode(
                c_a_recon, s_a) if opt.lambda_cycle > 0 else 0
            img_bb_recon_cycle = generator_b.decode(
                c_b_recon, s_b) if opt.lambda_cycle > 0 else 0

            # Losses
            loss_id_1 = opt.lambda_id * criterion_L1(img_aa_recon, img_a)
            loss_id_2 = opt.lambda_id * criterion_L1(img_bb_recon, img_b)
            loss_s_1 = opt.lambda_style * criterion_L1(s_a_recon, prior_s_a)
            loss_s_2 = opt.lambda_style * criterion_L1(s_b_recon, prior_s_b)
            loss_c_1 = opt.lambda_content * criterion_L1(
                c_a_recon, c_a.detach())
            loss_c_2 = opt.lambda_content * criterion_L1(
                c_b_recon, c_b.detach())
            loss_cycle_1 = opt.lambda_cycle * criterion_L1(
                img_aa_recon_cycle, img_a) if opt.lambda_cycle > 0 else 0
            loss_cycle_2 = opt.lambda_cycle * criterion_L1(
                img_bb_recon_cycle, img_b) if opt.lambda_cycle > 0 else 0

            # GAN Loss
            fake_scalar_a = discriminator_a(img_ba)
            fake_scalar_b = discriminator_b(img_ab)
            loss_gan1 = -opt.lambda_gan * torch.mean(fake_scalar_a)
            loss_gan2 = -opt.lambda_gan * torch.mean(fake_scalar_b)

            # Overall Losses and optimization
            loss_G = loss_id_1 + loss_id_2 + loss_s_1 + loss_s_2 + loss_c_1 + loss_c_2 + loss_cycle_1 + loss_cycle_2 + loss_gan1 + loss_gan2
            loss_G.backward()
            optimizer_G.step()

            # ----------------------------------------
            #            Train Discriminator
            # ----------------------------------------

            optimizer_D_a.zero_grad()
            optimizer_D_b.zero_grad()

            # D_a
            fake_scalar_a = discriminator_a(img_ba.detach())
            true_scalar_a = discriminator_a(img_a)
            loss_D_a = torch.mean(fake_scalar_a) - torch.mean(true_scalar_a)
            loss_D_a.backward()
            optimizer_D_a.step()

            # D_b
            fake_scalar_b = discriminator_b(img_ab.detach())
            true_scalar_b = discriminator_b(img_b)
            loss_D_b = torch.mean(fake_scalar_b) - torch.mean(true_scalar_b)
            loss_D_b.backward()
            optimizer_D_b.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] [Style Loss: %.4f] [Content Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   (loss_id_1 + loss_id_2).item(),
                   (loss_s_1 + loss_s_2).item(), (loss_c_1 + loss_c_2).item(),
                   (loss_gan1 + loss_gan2).item(),
                   (loss_D_a + loss_D_b).item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator_a, generator_b)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_a)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_b)
                        help='the folder name of the b domain')
    parser.add_argument('--imgsize',
                        type=int,
                        default=128,
                        help='the image size')
    opt = parser.parse_args()

    utils.check_path(opt.save_path)

    # Define the dataset
    # a = 'cat'; b = 'human'
    testloader = utils.create_dataloader(opt)
    print('The overall number of images:', len(testloader))

    # Define networks
    generator_a, generator_b = utils.create_generator(opt)
    generator_a = generator_a.cuda()
    generator_b = generator_b.cuda()

    # Forward
    for i, (img_a, img_b) in enumerate(testloader):
        # To device
        img_a = img_a.cuda()
        img_b = img_b.cuda()
        # Forward
        with torch.no_grad():
            out = generator_b(img_a, img_a)
        out = out.squeeze(0).detach().permute(1, 2, 0).cpu().numpy()
        out = (out + 1) * 128
        out = out.astype(np.uint8)[:, :, [2, 1, 0]]
        # Save
Пример #15
0
def train(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    vgg_indices = opt.vgg_indices
    vggloss_weights = opt.lambda_vgg

    # Loss functions
    if 'l1' in opt.loss:
        criterion_L1 = torch.nn.L1Loss().cuda()

    if 'ssim' in opt.loss:
        ssimLoss = SSIM().cuda()

    if 'tv' in opt.loss:
        totalvar_loss = tv_loss().cuda()

    if 'color' in opt.loss:
        csColorLoss = color_loss().cuda()

    if 'grad' in opt.loss:
        gradLoss = GradientLoss().cuda()

    if 'vgg' in opt.loss:
        vggLoss = VGGLoss(opt).cuda().eval()

    # Initialize Generator
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.save_by_epoch == 0:
                torch.save(
                    net.module, './model/epoch%d_batchsize%d.pth' %
                    (epoch + opt.model_offset, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.save_by_epoch == 0:
                torch.save(
                    net, './model/epoch%d_batchsize%d.pth' %
                    (epoch + opt.model_offset, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # Valid the model
    def valid_model(testloader, net, epoch, opt):
        if epoch % opt.valid_by_epoch == 0:
            net.eval()
            with torch.no_grad():
                for i, (true_input, img_name) in enumerate(testloader):
                    if opt.num_channels == 1:
                        fake_target = net(true_input[:, 0:1, :, :].cuda())

                        h = fake_target.shape[2]
                        w = fake_target.shape[3]
                        y = fake_target[0].clamp(
                            0, 1).detach().cpu().numpy().reshape(1, h, w)[0]
                        y = y * (235 - 16) + 16

                        true_input_yuv = F.interpolate(true_input,
                                                       size=[h, w],
                                                       mode="nearest")
                        true_input_yuv = true_input_yuv[0].detach().cpu(
                        ).numpy().transpose((1, 2, 0))
                        true_input_yuv[:, :, 0] = y[:, :]
                        true_input_yuv[:, :,
                                       1] = true_input_yuv[:, :,
                                                           1] * (240 - 16) + 16
                        true_input_yuv[:, :,
                                       2] = true_input_yuv[:, :,
                                                           2] * (240 - 16) + 16

                        y_hr = true_input_yuv.copy()
                        y_hr[:, :, 2] = 1.164 * (y - 16) + 1.596 * (
                            true_input_yuv[:, :, 2] - 128)
                        y_hr[:, :, 1] = 1.164 * (y - 16) - 0.812 * (
                            true_input_yuv[:, :, 2] -
                            128) - 0.392 * (true_input_yuv[:, :, 1] - 128)
                        y_hr[:, :, 0] = 1.164 * (y - 16) + 2.017 * (
                            true_input_yuv[:, :, 1] - 128)

                        out_img_name = os.path.split(img_name[0])[1].split(
                            '.jpg')[0] + '_epoch_' + str(epoch) + '.png'
                        if not os.path.exists(opt.valid_folder):
                            os.mkdir(opt.valid_folder)

                        print(out_img_name)
                        output_img = os.path.join(opt.valid_folder,
                                                  out_img_name)
                        cv2.imwrite(output_img, y_hr)

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.SRDataset(opt)
    print(len(trainset))
    testset = dataset.SRValidDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)
    testloader = DataLoader(testset, batch_size=1, pin_memory=True)
    #test_dataloader = DataLoader(testset, batch_size = 1, pin_memory = True)
    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):

        valid_model(testloader, generator, (epoch + 1), opt)

        avg_l1_loss = 0
        avg_ssim_loss = 0
        avg_cs_ColorLoss = 0
        avg_grad_loss = 0
        avg_ssim_loss_lf = 0
        avg_vgg_loss = 0
        avg_tv_loss = 0

        generator.train()
        for i, (true_input, true_target) in enumerate(dataloader):

            if opt.num_channels == 1:
                true_input = true_input[:, 0:1, :, :]
                true_target = true_target[:, 0:1, :, :]

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # overall loss
            loss = 0

            disp_Pixellevel_L1_Loss = 0
            disp_tot_v_loss = 0
            disp_cs_ColorLoss = 0
            disp_grad_loss = 0
            disp_vgg_loss = 0
            disp_ssim_loss = 0

            # L1 Loss
            if 'l1' in opt.loss:
                Pixellevel_L1_Loss = opt.lambda_l1 * criterion_L1(
                    fake_target, true_target)
                loss += Pixellevel_L1_Loss
                avg_l1_loss += Pixellevel_L1_Loss.item()
                disp_Pixellevel_L1_Loss = Pixellevel_L1_Loss.item()
            # tv Loss
            if 'tv' in opt.loss:
                tot_v_loss = opt.lambda_tv * totalvar_loss(fake_target)
                #print(tot_v_loss.type)
                loss += tot_v_loss
                avg_tv_loss += tot_v_loss.item()
                disp_tot_v_loss = tot_v_loss.item()

            # color loss
            if 'color' in opt.loss:
                cs_ColorLoss = opt.lambda_color * csColorLoss(
                    fake_target, true_target)
                loss += cs_ColorLoss
                avg_cs_ColorLoss += cs_ColorLoss.item()
                disp_cs_ColorLoss = cs_ColorLoss.item()

            # gradient loss
            if 'grad' in opt.loss:
                grad_loss = opt.lambda_grad * gradLoss(fake_target,
                                                       true_target)
                loss += grad_loss
                avg_grad_loss += grad_loss.item()
                disp_grad_loss = grad_loss.item()

            # vgg loss
            if 'vgg' in opt.loss:

                if opt.num_channels == 1:
                    vgg_loss = vggLoss(
                        fake_target.clamp(0, 1).expand(-1, 3, -1, -1),
                        true_target.clamp(0, 1).expand(-1, 3, -1, -1))
                else:
                    vgg_loss = vggLoss(fake_target.clamp(0, 1),
                                       true_target.clamp(0, 1))
                loss += vgg_loss
                avg_vgg_loss += vgg_loss.item()
                disp_vgg_loss = vgg_loss.item()

            # ssim loss
            if 'ssim' in opt.loss:
                ssim_loss = 1 - ssimLoss(fake_target, true_target)
                loss += ssim_loss
                print(loss.data)
                avg_ssim_loss += ssim_loss.item()
                disp_ssim_loss = ssim_loss.item()

            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [color loss: %.4f] [ssim Loss: %.4f] [grad Loss: %.4f] [VGG Loss: %.4f] [TV Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   disp_Pixellevel_L1_Loss, disp_cs_ColorLoss, disp_ssim_loss,
                   disp_grad_loss, disp_vgg_loss, disp_tot_v_loss, time_left))

        # Save model at certain epochs or iterations
        save_model(generator, (epoch + 1), opt)

        # Learning rate decrease at certain epochs
        adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)

        avg_l1_loss = avg_l1_loss / (i + 1)
        avg_ssim_loss = avg_ssim_loss / (i + 1)
        avg_cs_ColorLoss = avg_cs_ColorLoss / (i + 1)
        avg_grad_loss = avg_grad_loss / (i + 1)
        avg_vgg_loss = avg_vgg_loss / (i + 1)
        avg_tv_loss = avg_tv_loss / (i + 1)

        f = open("log.txt", "a")
        f.write('epoch: ' + str(epoch) + ' avg l1 =' + str(avg_l1_loss) +
                ' avg color loss =' + str(avg_cs_ColorLoss) + ' avg ssim = ' +
                str(avg_ssim_loss) + ' avg grad loss = ' + str(avg_grad_loss) +
                ' avg vgg loss = ' + str(avg_vgg_loss) + 'avg tv loss = ' +
                str(avg_tv_loss) + '\n')
        f.close()
Пример #16
0
def Trainer(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    # criterion_L1 = torch.nn.L1Loss().cuda()
    # criterion = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3).cuda()
    criterion = nn.MSELoss().cuda()

    # Initialize model
    if opt.model == 'SGN':
        model = utils.create_generator(opt)
    elif opt.model == 'UNet':
        if opt.load:
            model = torch.load(opt.load).module
            print(f'Model loaded from {opt.load}')
        else:
            model = unet.UNet(opt.in_channels, opt.out_channels,
                              opt.start_channels)
    else:
        raise NotImplementedError(opt.model + 'is not implemented')

    dir_checkpoint = 'checkpoints/'
    try:
        os.mkdir(dir_checkpoint)
        print('Created checkpoint directory')
    except OSError:
        pass

    writer = SummaryWriter(
        comment=f'_{opt.model}_LR_{opt.lr}_BS_{opt.batch_size}')

    # To device
    if opt.multi_gpu:
        model = nn.DataParallel(model)
        model = model.cuda()
    else:
        model = model.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(model.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, iteration, optimizer):
        # Set the learning rate to the specific value
        if iteration >= opt.iter_decreased:
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt.lr_decreased

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, network):
        """Save the model at "checkpoint_interval" and its multiple"""
        if (epoch % opt.save_interval == 0) and (iteration % len_dataset == 0):
            torch.save(
                network, dir_checkpoint + '%s_epoch%d_bs%d_mu%d_sigma%d.pth' %
                (opt.model, epoch, opt.batch_size, opt.mu, opt.sigma))
            print('The trained model is successfully saved at epoch %d' %
                  (epoch))

        if (epoch % opt.validate_interval == 0) and (iteration % len_dataset
                                                     == 0):
            psnr = validation.validate(network, opt)
            print('validate PSNR:', psnr)
            writer.add_scalar('PSNR/validate', psnr, iteration)

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    opt.dataroot = opt.baseroot + 'DIV2K_train_HR'
    trainset = dataset.DenoisingDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (noisy_img, img) in enumerate(dataloader):
            # To device
            noisy_img = noisy_img.cuda()
            img = img.cuda()

            # Train model
            optimizer_G.zero_grad()

            # Forword propagation
            recon_img = model(noisy_img)
            loss = criterion(recon_img, img)

            # Overall Loss and optimize
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss.item(),
                   time_left))

            writer.add_scalar('Loss/train', loss.item(), iters_done)

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       model)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (iters_done + 1), optimizer_G)
def Trainer(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.train_batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Loss functions
    criterion_L2 = torch.nn.MSELoss().cuda()

    # Initialize SGN
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, optimizer):
        # Set the learning rate to the specific value
        if epoch >= opt.epoch_decreased:
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt.lr_decreased

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, network):
        """Save the model at "checkpoint_interval" and its multiple"""
        # Judge name
        if not os.path.exists(opt.save_root):
            os.makedirs(opt.save_root)
        # Save model dict
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                modelname = 'DnCNN_epoch%d_bs%d_mu%d_sigma%d.pth' % (
                    epoch, opt.train_batch_size, opt.mu, opt.sigma)
                modelpath = os.path.join(opt.save_root, modelname)
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(network.module.state_dict(), modelpath)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                modelname = 'DnCNN_iter%d_bs%d_mu%d_sigma%d.pth' % (
                    iteration, opt.train_batch_size, opt.mu, opt.sigma)
                modelpath = os.path.join(opt.save_root, modelname)
                if iteration % opt.save_by_iter == 0:
                    torch.save(network.module.state_dict(), modelpath)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                modelname = 'DnCNN_epoch%d_bs%d_mu%d_sigma%d.pth' % (
                    epoch, opt.train_batch_size, opt.mu, opt.sigma)
                modelpath = os.path.join(opt.save_root, modelname)
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(network.state_dict(), modelpath)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                modelname = 'DnCNN_iter%d_bs%d_mu%d_sigma%d.pth' % (
                    iteration, opt.train_batch_size, opt.mu, opt.sigma)
                modelpath = os.path.join(opt.save_root, modelname)
                if iteration % opt.save_by_iter == 0:
                    torch.save(network.state_dict(), modelpath)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.DenoisingDataset(opt, opt.train_root)
    valset = dataset.DenoisingDataset(opt, opt.val_root)
    print('The overall number of training images:', len(trainset))
    print('The overall number of validation images:', len(valset))

    # Define the dataloader
    train_loader = DataLoader(trainset,
                              batch_size=opt.train_batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers,
                              pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=opt.val_batch_size,
                            shuffle=False,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensorboard
    writer = SummaryWriter()

    # For loop training
    for epoch in range(opt.epochs):

        # Record learning rate
        for param_group in optimizer_G.param_groups:
            writer.add_scalar('data/lr', param_group['lr'], epoch)
            print('learning rate = ', param_group['lr'])

        if epoch == 0:
            iters_done = 0

        ### training
        for i, (noisy_img, img) in enumerate(train_loader):

            # To device
            noisy_img = noisy_img.cuda()
            img = img.cuda()

            # Train Generator
            optimizer_G.zero_grad()

            # Forword propagation
            recon_img = generator(noisy_img)
            loss = criterion_L2(recon_img, img)

            # Record losses
            writer.add_scalar('data/L2Loss', loss.item(), iters_done)

            # Overall Loss and optimize
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(train_loader) + i
            iters_left = opt.epochs * len(train_loader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(train_loader), loss.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(train_loader),
                       generator)

        # Learning rate decrease at certain epochs
        adjust_learning_rate(opt, (epoch + 1), optimizer_G)

        ### sampling
        utils.save_sample_png(opt,
                              epoch,
                              noisy_img,
                              recon_img,
                              img,
                              addition_str='training')

        ### Validation
        val_PSNR = 0
        num_of_val_image = 0

        for j, (val_noisy_img, val_img) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            val_noisy_img = val_noisy_img.cuda()
            val_img = val_img.cuda()

            # Forward propagation
            val_recon_img = generator(val_noisy_img)

            # Accumulate num of image and val_PSNR
            num_of_val_image += val_noisy_img.shape[0]
            val_PSNR += utils.psnr(val_recon_img, val_img,
                                   1) * val_noisy_img.shape[0]

        val_PSNR = val_PSNR / num_of_val_image

        # Record average PSNR
        writer.add_scalar('data/val_PSNR', val_PSNR, epoch)
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))

        ### sampling
        utils.save_sample_png(opt,
                              epoch,
                              val_noisy_img,
                              val_recon_img,
                              val_img,
                              addition_str='validation')

    writer.close()
Пример #18
0
def Trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num
    print("Batch size is changed to %d" % opt.batch_size)
    print("Number of workers is changed to %d" % opt.num_workers)

    # Build path folder
    utils.check_path(opt.save_path)
    utils.check_path(opt.sample_path)

    # Build networks
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(optimizer, epoch, opt, init_lr):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = init_lr * (opt.lr_decrease_factor
                        **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'GrayInpainting_epoch%d_batchsize%d.pth' % (
            epoch + 20, opt.batch_size)
        model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (grayscale, mask) in enumerate(dataloader):

            # Load and put to cuda
            grayscale = grayscale.cuda()  # out: [B, 1, 256, 256]
            mask = mask.cuda()  # out: [B, 1, 256, 256]

            # forward propagation
            optimizer_g.zero_grad()
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(out_wholeimg, grayscale)

            # Compute losses
            loss = MaskL1Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] time_left: %s"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   MaskL1Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g)

        # Save the model
        save_model(generator, (epoch + 1), opt)
        utils.sample(grayscale, mask, out_wholeimg, opt.sample_path,
                     (epoch + 1))
Пример #19
0
def Continue_train_LSGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'LSGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'LSGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'LSGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'LSGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.NormalRGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Adversarial ground truth
            valid = Tensor(np.ones((true_input.shape[0], 1, 30, 30)))
            fake = Tensor(np.zeros((true_input.shape[0], 1, 30, 30)))

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target = generator(true_input)

                # Fake samples
                fake_scalar_d = discriminator(true_input, fake_target.detach())
                loss_fake = criterion_MSE(fake_scalar_d, fake)
                # True samples
                true_scalar_d = discriminator(true_input, true_target)
                loss_true = criterion_MSE(true_scalar_d, valid)

                # Overall Loss and optimize
                loss_D = 0.5 * (loss_fake + loss_true)
                loss_D.backward()
                optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # GAN Loss
            fake_scalar = discriminator(true_input, fake_target)
            GAN_Loss = criterion_MSE(fake_scalar, valid)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
def Trainer(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs:" % (gpu_num))
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Create folders
    save_model_folder = opt.save_path
    utils.check_path(save_model_folder)

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize SGN
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr * (opt.lr_decrease_factor
                           **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr * (opt.lr_decrease_factor
                           **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(opt, epoch, iteration, len_dataset, generator):
        # Define the name of trained model
        if opt.save_mode == 'epoch':
            model_name = 'G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'G_iter%d_bs%d.pth' % (iteration, opt.batch_size)
        save_model_path = os.path.join(opt.save_path, model_name)
        # Save model
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_model_path)
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.HS_multiscale_DSet(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (img_A, img_B) in enumerate(dataloader):

            # To device
            img_A = img_A.cuda()
            img_B = img_B.cuda()

            # Train Generator
            optimizer_G.zero_grad()

            # Forword propagation
            recon_B = generator(img_A)

            # Losses
            loss = criterion_L1(recon_B, img_B)

            # Overall Loss and optimize
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
Пример #21
0
def CycleGAN_LSGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_BCE = torch.nn.BCEWithLogitsLoss().cuda()

    # Initialize networks
    G = utils.create_generator(opt)
    D = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        G = nn.DataParallel(G)
        G = G.cuda()
        D = nn.DataParallel(D)
        D = D.cuda()
    else:
        G = G.cuda()
        D = D.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(G.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(D.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, G, D):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(
                        G.module, 'AttnGAN_parent_G_epoch%d_bs%d.pth' %
                        (epoch, opt.batch_size))
                    torch.save(
                        D.module, 'AttnGAN_parent_D_epoch%d_bs%d.pth' %
                        (epoch, opt.batch_size))
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(
                        G.module, 'AttnGAN_parent_G_iter%d_bs%d.pth' %
                        (iteration, opt.batch_size))
                    torch.save(
                        D.module, 'AttnGAN_parent_D_iter%d_bs%d.pth' %
                        (iteration, opt.batch_size))
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(
                        G, 'AttnGAN_parent_G_epoch%d_bs%d.pth' %
                        (epoch, opt.batch_size))
                    torch.save(
                        D, 'AttnGAN_parent_D_epoch%d_bs%d.pth' %
                        (epoch, opt.batch_size))
                    print(
                        'The trained model is successfully saved at epoch %d' %
                        (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(
                        G, 'AttnGAN_parent_G_iter%d_bs%d.pth' %
                        (iteration, opt.batch_size))
                    torch.save(
                        D, 'AttnGAN_parent_D_iter%d_bs%d.pth' %
                        (iteration, opt.batch_size))
                    print(
                        'The trained model is successfully saved at iteration %d'
                        % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.CFP_dataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (img, imglabel) in enumerate(dataloader):

            # To device
            img = img.cuda()
            idx = torch.randperm(len(imglabel))
            imglabel_fake = imglabel[idx].contiguous()
            imglabel = imglabel.cuda()
            imglabel_fake = imglabel_fake.cuda()

            # ------------------------------- Train Generator -------------------------------
            optimizer_G.zero_grad()

            # Forward
            img_recon, img_fake = G(img, imglabel, imglabel_fake)
            out_adv, out_class = D(img_fake)

            # Recon Loss
            loss_recon = criterion_L1(img_recon, img)

            # WGAN loss
            loss_gan = -torch.mean(out_adv)

            # Classification Loss
            loss_class = criterion_BCE(out_class, imglabel_fake)

            # Overall Loss and optimize
            loss = opt.lambda_recon * loss_recon + opt.lambda_gan * loss_gan + opt.lambda_class * loss_class
            loss.backward()
            optimizer_G.step()

            # ------------------------------- Train Discriminator -------------------------------
            optimizer_D.zero_grad()

            # Forward
            img_recon, img_fake = G(img, imglabel, imglabel_fake)
            out_adv_fake, out_class_fake = D(img_fake.detach())
            out_adv_true, out_class_true = D(img.detach())

            # WGAN loss
            loss_gan = torch.mean(out_adv_fake) - torch.mean(out_adv_true)

            # Classification Loss
            loss_class = criterion_BCE(out_class_true, imglabel)

            # Overall Loss and optimize
            loss = loss_gan + loss_class
            loss.backward()
            optimizer_D.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] [GAN Loss: %.4f] [Class Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   loss_recon.item(), loss_gan.item(), loss_class.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), G,
                       D)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)
Пример #22
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark
    cv2.setNumThreads(0)
    cv2.ocl.setUseOpenCL(False)

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()#reduce=False, size_average=False)
    RELU = nn.ReLU()

    # Optimizers
    optimizer_g1 = torch.optim.Adam(generator.coarse.parameters(), lr=opt.lr_g)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Save the model if pre_train == True
    def save_model(net, epoch, opt, batch=0, is_D=False):
        """Save the model at "checkpoint_interval" and its multiple"""
        if is_D==True:
            model_name = 'discriminator_WGAN_epoch%d_batch%d.pth' % (epoch + 1, batch)
        else:
            model_name = 'deepfillv2_WGAN_epoch%d_batch%d.pth' % (epoch+1, batch)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
    
    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True)
    
    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        print("Start epoch ", epoch+1, "!")
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (1 - mask) + first_out * mask        # in range [0, 1]
            second_out_wholeimg = img * (1 - mask) + second_out * mask      # in range [0, 1]

            for wk in range(1):
                optimizer_d.zero_grad()
                fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
                true_scalar = discriminator(img, mask)
                #W_Loss = -torch.mean(true_scalar) + torch.mean(fake_scalar)#+ gradient_penalty(discriminator, img, second_out_wholeimg, mask)
                hinge_loss = torch.mean(RELU(1-true_scalar)) + torch.mean(RELU(fake_scalar+1))
                loss_D = hinge_loss
                loss_D.backward(retain_graph=True)
                optimizer_d.step()

            ### Train Generator
            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)
            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = - torch.mean(fake_scalar)

            optimizer_g1.zero_grad()
            first_MaskL1Loss.backward(retain_graph=True)
            optimizer_g1.step()

            optimizer_g.zero_grad()

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps)

            loss = 0.5*opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + GAN_Loss + second_PerceptualLoss * opt.lambda_perceptual
            loss.backward()

            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" %
                  ((epoch + 1), opt.epochs, (batch_idx+1), len(dataloader), first_MaskL1Loss.item(),
                   second_MaskL1Loss.item()))
            print("\r[D Loss: %.5f] [Perceptual Loss: %.5f] [G Loss: %.5f] time_left: %s" %
                  (loss_D.item(), second_PerceptualLoss.item(), GAN_Loss.item(), time_left))



            if (batch_idx + 1) % 100 ==0:

                # Generate Visualization image
                masked_img = img * (1 - mask) + mask
                img_save = torch.cat((img, masked_img, first_out, second_out, first_out_wholeimg, second_out_wholeimg),3)
                # Recover normalization: * 255 because last layer is sigmoid activated
                img_save = F.interpolate(img_save, scale_factor=0.5)
                img_save = img_save * 255
                # Process img_copy and do not destroy the data of img
                img_copy = img_save.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
                #img_copy = np.clip(img_copy, 0, 255)
                img_copy = img_copy.astype(np.uint8)
                save_img_name = 'sample_batch' + str(batch_idx+1) + '.png'
                save_img_path = os.path.join(sample_folder, save_img_name)
                img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
                cv2.imwrite(save_img_path, img_copy)
            if (batch_idx + 1) % 5000 == 0:
                save_model(generator, epoch, opt, batch_idx+1)
                save_model(discriminator, epoch, opt, batch_idx+1, is_D=True)


        #Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, epoch, opt)
        save_model(discriminator, epoch , opt, is_D=True)
Пример #23
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'deepfillv2_LSGAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [0, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = -torch.mean(fake_scalar)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + \
                opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)

        ### Sample data every epoch
        masked_img = img * (1 - mask) + mask
        mask = torch.cat((mask, mask, mask), 1)
        if (epoch + 1) % 1 == 0:
            img_list = [img, mask, masked_img, first_out, second_out]
            name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)
Пример #24
0
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.NormalRGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Sample noise and get data
            noise1 = utils.get_noise(true_input.shape[0], opt.z_dim,
                                     opt.random_type)
            noise1 = noise1.cuda()  # out: batch * z_dim
            noise2 = utils.get_noise(true_input.shape[0], opt.z_dim,
                                     opt.random_type)
            noise2 = noise2.cuda()  # out: batch * z_dim
            concat_noise = torch.cat((noise1, noise2),
                                     0)  # out: 2batch * z_dim
            concat_input = torch.cat((true_input, true_input),
                                     0)  # out: 2batch * 1 * 256 * 256
            concat_target = torch.cat((true_target, true_target),
                                      0)  # out: 2batch * 3 * 256 * 256

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(
                concat_input, concat_noise)  # out: 2batch * 3 * 256 * 256

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, concat_target)

            # MSGAN Loss
            fake_target1, fake_target2 = fake_target.split(
                true_input.shape[0], 0)
            ms_value = torch.mean(
                torch.abs(fake_target2 - fake_target1)) / torch.mean(
                    torch.abs(noise2 - noise1))
            eps = 1e-5
            ModeSeeking_Loss = 1 / (ms_value + eps)

            # GAN Loss
            fake_scalar = discriminator(concat_input, fake_target)
            GAN_Loss = -torch.mean(fake_scalar)

            # Overall Loss and optimize
            loss = opt.lambda_l1 * Pixellevel_L1_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_ms * ModeSeeking_Loss
            loss.backward()
            optimizer_G.step()

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target = generator(concat_input, concat_noise)
                fake_target1, fake_target2 = fake_target.split(
                    concat_noise.shape[0], 0)

                # Fake samples
                fake_scalar_d1 = discriminator(true_input,
                                               fake_target1.detach())
                fake_scalar_d2 = discriminator(true_input,
                                               fake_target2.detach())

                # True samples
                true_scalar_d = discriminator(true_input, true_target)

                # Overall Loss and optimize
                loss_D1 = -torch.mean(true_scalar_d) + torch.mean(
                    fake_scalar_d1)
                loss_D2 = -torch.mean(true_scalar_d) + torch.mean(
                    fake_scalar_d2)
                loss_D = loss_D1 + loss_D2
                loss_D.backward()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
Пример #25
0
    parser.add_argument('--mu',
                        type=int,
                        default=0,
                        help='Gaussian noise mean')
    parser.add_argument('--sigma',
                        type=int,
                        default=30,
                        help='Gaussian noise variance: 30 | 50 | 70')
    opt = parser.parse_args()
    print(opt)

    # ----------------------------------------
    #                   Test
    # ----------------------------------------
    # Initialize
    generator = utils.create_generator(opt).cuda()
    test_dataset = dataset.DenoisingValDataset(opt)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opt.test_batch_size,
                                              shuffle=False,
                                              num_workers=opt.num_workers,
                                              pin_memory=True)
    sample_folder = opt.save_name
    utils.check_path(sample_folder)

    # forward
    for i, (true_input, true_target) in enumerate(test_loader):

        # To device
        true_input = true_input.cuda()
        true_target = true_target.cuda()
Пример #26
0
    # ----------------------------------------
    #       Initialize testing dataset
    # ----------------------------------------

    # Define the dataset
    testset = dataset.FullResDenoisingDataset(opt)
    print('The overall number of images equals to %d' % len(testset))

    # Define the dataloader
    dataloader = DataLoader(testset, batch_size = opt.batch_size, pin_memory = True)

    # ----------------------------------------
    #                 Testing
    # ----------------------------------------

    model = utils.create_generator(opt)

    for batch_idx, (noisy_img, img) in enumerate(dataloader):

        # To Tensor
        noisy_img = noisy_img.cuda()
        img = img.cuda()

        # Generator output
        recon_img = model(noisy_img)

        # convert to visible image format
        h = img.shape[2]
        w = img.shape[3]
        img = img.cpu().numpy().reshape(3, h, w).transpose(1, 2, 0)
        img = (img + 1) * 128
Пример #27
0
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.RAW2RGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target, true_sal) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()
            true_sal = true_sal.cuda()
            true_sal3 = torch.cat((true_sal, true_sal, true_sal), 1).cuda()

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target, fake_sal = generator(true_input)

                # Fake samples
                fake_block1, fake_block2, fake_block3, fake_scalar = discriminator(
                    true_input, fake_target.detach())
                # True samples
                true_block1, true_block2, true_block3, true_scalar = discriminator(
                    true_input, true_target)
                '''
                # Feature Matching Loss
                FM_Loss = criterion_L1(fake_block1, true_block1) + criterion_L1(fake_block2, true_block2) + criterion_L1(fake_block3, true_block3)
                '''
                # Overall Loss and optimize
                loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
                loss_D.backward()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target, fake_sal = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # Attention Loss
            true_Attn_target = true_target.mul(true_sal3)
            fake_sal3 = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_Attn_target = fake_target.mul(fake_sal3)
            Attention_Loss = criterion_L1(fake_Attn_target, true_Attn_target)

            # GAN Loss
            fake_block1, fake_block2, fake_block3, fake_scalar = discriminator(
                true_input, fake_target)
            GAN_Loss = -torch.mean(fake_scalar)

            # Perceptual Loss
            fake_target = fake_target * 0.5 + 0.5
            true_target = true_target * 0.5 + 0.5
            fake_target = utils.normalize_ImageNet_stats(fake_target)
            true_target = utils.normalize_ImageNet_stats(true_target)
            fake_percep_feature = perceptualnet(fake_target)
            true_percep_feature = perceptualnet(true_target)
            Perceptual_Loss = criterion_L1(fake_percep_feature,
                                           true_percep_feature)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + opt.lambda_attn * Attention_Loss + opt.lambda_gan * GAN_Loss + opt.lambda_percep * Perceptual_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f]"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item()))
            print(
                "\r[Attention Loss: %.4f] [Perceptual Loss: %.4f] Time_left: %s"
                % (Attention_Loss.item(), Perceptual_Loss.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)
Пример #28
0
def trainer_LSGAN(opt):

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    if not os.path.exists(opt.save_path):
        os.makedirs(opt.save_path)

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(generator)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.save_mode == 'epoch':
            model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch,
                                                        opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration,
                                                       opt.batch_size)
        save_name = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.ColorizationDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_L, true_RGB, true_sal) in enumerate(dataloader):

            # To device
            true_L = true_L.cuda()
            true_RGB = true_RGB.cuda()
            true_sal = true_sal.cuda()
            true_sal = torch.cat((true_sal, true_sal, true_sal), 1)
            true_attn = true_RGB.mul(true_sal)

            # Adversarial ground truth
            valid = Tensor(np.ones((true_L.shape[0], 1, 30, 30)))
            fake = Tensor(np.zeros((true_L.shape[0], 1, 30, 30)))

            ### Train Discriminator
            optimizer_D.zero_grad()

            # Generator output
            fake_RGB, fake_sal = generator(true_L)

            # Fake colorizations
            fake_scalar_d = discriminator(true_L, fake_RGB.detach())
            loss_fake = criterion_MSE(fake_scalar_d, fake)

            # True colorizations
            true_scalar_d = discriminator(true_L, true_RGB)
            loss_true = criterion_MSE(true_scalar_d, valid)

            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_D.step()

            ### Train Generator
            optimizer_G.zero_grad()

            fake_RGB, fake_sal = generator(true_L)

            # Pixel-level L1 Loss
            loss_L1 = criterion_L1(fake_RGB, true_RGB)

            # Attention Loss
            fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_attn = fake_RGB.mul(fake_sal)
            loss_attn = criterion_L1(fake_attn, true_attn)

            # Perceptual Loss
            feature_fake_RGB = perceptualnet(fake_RGB)
            feature_true_RGB = perceptualnet(true_RGB)
            loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB)

            # GAN Loss
            fake_scalar = discriminator(true_L, fake_RGB)
            loss_GAN = criterion_MSE(fake_scalar, valid)

            # Overall Loss and optimize
            loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn
            loss_G.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(),
                   loss_attn.item(), loss_percep.item(), loss_D.item(),
                   loss_GAN.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_RGB, true_RGB]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=opt.sample_path,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list)
Пример #29
0
def LSGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()
    #FeatureMatchingLoss = FML1Loss(opt.fm_param)

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net.module, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # LSGAN vectors
            valid = Tensor(np.ones((img.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((img.shape[0], 1, 8, 8)))

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [-1, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [-1, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img = (img + 1) / 2  # in range [0, 1]
            img = utils.normalize_ImageNet_stats(img)  # in range of ImageNet
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg = (second_out_wholeimg +
                                   1) / 2  # in range [0, 1]
            second_out_wholeimg = utils.normalize_ImageNet_stats(
                second_out_wholeimg)
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)
Пример #30
0
def trainer_WGANGP(opt):

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    if not os.path.exists(opt.save_path):
        os.makedirs(opt.save_path)

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(generator)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.save_mode == 'epoch':
            model_name = 'SCGAN_%s_epoch%d_bs%d.pth' % (opt.gan_mode, epoch,
                                                        opt.batch_size)
        if opt.save_mode == 'iter':
            model_name = 'SCGAN_%s_iter%d_bs%d.pth' % (opt.gan_mode, iteration,
                                                       opt.batch_size)
        save_name = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.module.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    torch.save(generator.state_dict(), save_name)
                    print('The trained model is saved as %s' % (model_name))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.ColorizationDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Calculate the gradient penalty loss for WGAN-GP
    def compute_gradient_penalty(D, input_samples, real_samples, fake_samples):
        # Random weight term for interpolation between real and fake samples
        alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples +
                        ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = D(input_samples, interpolates)
        # For PatchGAN
        fake = Variable(Tensor(real_samples.shape[0], 1, 30, 30).fill_(1.0),
                        requires_grad=False)
        # Get gradient w.r.t. interpolates
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
        return gradient_penalty

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_L, true_RGB, true_sal) in enumerate(dataloader):

            # To device
            true_L = true_L.cuda()
            true_RGB = true_RGB.cuda()
            true_sal = true_sal.cuda()
            true_sal = torch.cat((true_sal, true_sal, true_sal), 1)
            true_attn = true_RGB.mul(true_sal)

            ### Train Discriminator
            optimizer_D.zero_grad()

            # Generator output
            fake_RGB, fake_sal = generator(true_L)

            # Fake colorizations
            fake_scalar_d = discriminator(true_L, fake_RGB.detach())

            # True colorizations
            true_scalar_d = discriminator(true_L, true_RGB)

            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(
                discriminator, true_L.data, true_RGB.data, fake_RGB.data)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar_d) + torch.mean(
                fake_scalar_d) + opt.lambda_gp * gradient_penalty
            loss_D.backward()
            optimizer_D.step()

            ### Train Generator
            optimizer_G.zero_grad()

            fake_RGB, fake_sal = generator(true_L)

            # Pixel-level L1 Loss
            loss_L1 = criterion_L1(fake_RGB, true_RGB)

            # Attention Loss
            fake_sal = torch.cat((fake_sal, fake_sal, fake_sal), 1)
            fake_attn = fake_RGB.mul(fake_sal)
            loss_attn = criterion_L1(fake_attn, true_attn)

            # Perceptual Loss
            feature_fake_RGB = perceptualnet(fake_RGB)
            feature_true_RGB = perceptualnet(true_RGB)
            loss_percep = criterion_L1(feature_fake_RGB, feature_true_RGB)

            # GAN Loss
            fake_scalar = discriminator(true_L, fake_RGB)
            loss_GAN = -torch.mean(fake_scalar)

            # Overall Loss and optimize
            loss_G = opt.lambda_l1 * loss_L1 + opt.lambda_gan * loss_GAN + opt.lambda_percep * loss_percep + opt.lambda_attn * loss_attn
            loss_G.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixel-level Loss: %.4f] [Attention Loss: %.4f] [Perceptual Loss: %.4f] [D Loss: %.4f] [G Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader), loss_L1.item(),
                   loss_attn.item(), loss_percep.item(), loss_D.item(),
                   loss_GAN.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D)

        ### Sample data every epoch
        if (epoch + 1) % 1 == 0:
            img_list = [fake_RGB, true_RGB]
            name_list = ['pred', 'gt']
            utils.save_sample_png(sample_folder=opt.sample_path,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list)