Ejemplo n.º 1
0
def print_network():
    opt = setup()
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    printer('generator')
    summary(generator.cuda(), (3, 32, 32))
    printer('discriminator')
    summary(discriminator.cuda(), (3, 32, 32))
    printer('feature_extractor')
    summary(feature_extractor.cuda(), (3, 32, 32))
Ejemplo n.º 2
0
    D.load_state_dict(torch.load(opt.discriminatorWeights))
print(D)

# For the content loss
FE = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
print(FE)
content_criterion = nn.MSELoss()
adversarial_criterion = nn.BCELoss()

ones_const = Variable(torch.ones(opt.batchSize, 1))

# if gpu is to be used
if opt.cuda:
    G.cuda()
    D.cuda()
    FE.cuda()
    content_criterion.cuda()
    adversarial_criterion.cuda()
    ones_const = ones_const.cuda()

optim_generator = optim.Adam(G.parameters(), lr=opt.generatorLR)
optim_discriminator = optim.Adam(D.parameters(), lr=opt.discriminatorLR)

configure('../logs/' + timestamp, flush_secs=5)
# visualizer = Visualizer(image_size=opt.imageSize*opt.upSampling)

low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

# Pre-train generator using raw MSE loss
if opt.generatorWeights == '':
    print('Generator pre-training')
Ejemplo n.º 3
0
    discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
print discriminator

# For the content loss
feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
print feature_extractor
content_criterion = nn.MSELoss()
adversarial_criterion = nn.BCELoss()

ones_const = Variable(torch.ones(opt.batchSize, 1))

# if gpu is to be used
if opt.cuda:
    generator.cuda()
    discriminator.cuda()
    feature_extractor.cuda()
    content_criterion.cuda()
    adversarial_criterion.cuda()
    ones_const = ones_const.cuda()

optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR)
optim_discriminator = optim.Adam(discriminator.parameters(),
                                 lr=opt.discriminatorLR)

configure('logs/' + opt.dataset + '-' + str(opt.batchSize) + '-' +
          str(opt.generatorLR) + '-' + str(opt.discriminatorLR),
          flush_secs=5)
visualizer = Visualizer(save_dir=opt.samples_dir,
                        show_step=opt.sample_freq,
                        image_size=opt.imageSize * opt.upSampling)
def upsampling(path, picture_name, upsampling):
    opt = setup()
    # image = Image.open(os.getcwd() + r'\images\\' + path)
    image = Image.open(path)
    opt.imageSize = (image.size[1], image.size[0])

    log = '>>> process image : {} size : ({}, {}) sr_reconstruct size : ({}, {})'.format(
        picture_name, image.size[0], image.size[1], image.size[0] * upsampling,
        image.size[1] * upsampling)
    try:
        os.makedirs(os.getcwd() + r'\output\result')
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print(
            '[WARNING] : You have a CUDA device, so you should probably run with --cuda'
        )

    transform = transforms.Compose([
        transforms.RandomCrop(opt.imageSize),
        transforms.Pad(padding=0),
        transforms.ToTensor()
    ])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                       std=[4.367, 4.464, 4.444])
    scale = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(opt.imageSize),
        transforms.Pad(padding=0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot,
                                   download=True,
                                   train=False,
                                   transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot,
                                    download=True,
                                    train=False,
                                    transform=transform)
    assert dataset

    dataloader = transforms.Compose([transforms.ToTensor()])
    image = dataloader(image)

    # loading paras from networks
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # For the content loss
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0],
                                opt.imageSize[1])

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    # Generate data
    high_res_real = image

    # Downsample images to low resolution
    low_res = scale(high_res_real)
    low_res = torch.tensor([np.array(low_res)])

    high_res_real = normalize(high_res_real)
    high_res_real = torch.tensor([np.array(high_res_real)])

    # Generate real and fake inputs
    if opt.cuda:
        high_res_real = Variable(high_res_real.cuda())
        high_res_fake = generator(Variable(low_res).cuda())
    else:
        high_res_real = Variable(high_res_real)
        high_res_fake = generator(Variable(low_res))

    save_image(unnormalize(high_res_fake[0]),
               './output/result/' + picture_name)
    return log
def init(opt):
    # [folder] create folder for checkpoints
    try: os.makedirs(opt.out)
    except OSError: pass

    # [cuda] check cuda, if cuda is available, then display warning
    if torch.cuda.is_available() and not opt.cuda:
        sys.stdout.write('[WARNING] : You have a CUDA device, so you should probably run with --cuda')

    # [normalization] __return__ normalize images, set up mean and std
    normalize = transforms.Normalize(
                                        mean = [0.485, 0.456, 0.406],
                                        std = [0.229, 0.224, 0.225])
    # [scale] __return__
    scale = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(opt.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                                            mean = [0.485, 0.456, 0.406],
                                                            std = [0.229, 0.224, 0.225])])

    # [transform] up sampling transforms
    transform = transforms.Compose([transforms.RandomCrop((opt.imageSize[0] * opt.upSampling,
                                                           opt.imageSize[1] * opt.upSampling)),
                                    transforms.ToTensor()])
    # [dataset] training dataset
    if opt.dataset == 'folder':
        dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root = opt.dataroot, train = True, download = True, transform = transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root = opt.dataroot, train = True, download = False, transform = transform)
    assert dataset
    
    # [dataloader] __return__ loading dataset
    dataloader = torch.utils.data.DataLoader(
                                                 dataset,
                                                 batch_size = opt.batchSize,
                                                 shuffle = True,
                                                 num_workers = int(opt.workers))
    # [generator] __return__ generator of GAN
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '' and os.path.exists(opt.generatorWeights):
        generator.load_state_dict(torch.load(opt.generatorWeights))

    # [discriminator] __return__ discriminator of GAN
    discriminator = Discriminator()
    if opt.discriminatorWeights != '' and os.path.exists(opt.discriminatorWeights):
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # [extractor] __return__ feature extractor of GAN
    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True))

    # [loss] __return__ loss function
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()
    ones_const = Variable(torch.ones(opt.batchSize, 1))

    # [cuda] if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    # [optimizer] __return__ Optimizer for GAN 
    optim_generator = optim.Adam(generator.parameters(), lr = opt.generatorLR)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr = opt.discriminatorLR)

    # record configure
    configure('logs/{}-{}-{} -{}'.format(opt.dataset, str(opt.batchSize), str(opt.generatorLR), str(opt.discriminatorLR)), flush_secs = 5)
    # visualizer = Visualizer(image_size = (opt.imageSize[0] * opt.upSampling, opt.imageSize[1] * opt.upSampling))

    # __return__ low resolution images
    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1])

    return normalize,\
           scale,\
           dataloader,\
           generator,\
           discriminator,\
           feature_extractor,\
           content_criterion,\
           adversarial_criterion,\
           ones_const,\
           optim_generator,\
           optim_discriminator,\
           low_res
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        type=str,
                        default='folder',
                        help='cifar10 | cifar100 | folder')
    parser.add_argument('--dataroot',
                        type=str,
                        default='./data',
                        help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        default=1,
                        help='number of data loading workers')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='input batch size')
    parser.add_argument('--imageSize',
                        type=int,
                        default=32,
                        help='the low resolution image size')
    parser.add_argument('--upSampling',
                        type=int,
                        default=4,
                        help='low to high resolution scaling factor')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--nGPU',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument(
        '--generatorWeights',
        type=str,
        default='checkpoints/generator_final.pth',
        help="path to generator weights (to continue training)")
    parser.add_argument(
        '--discriminatorWeights',
        type=str,
        default='checkpoints/discriminator_final.pth',
        help="path to discriminator weights (to continue training)")

    opt = parser.parse_args()
    print(opt)

    if not os.path.exists('output/high_res_fake'):
        os.makedirs('output/high_res_fake')
    if not os.path.exists('output/high_res_real'):
        os.makedirs('output/high_res_real')
    if not os.path.exists('output/low_res'):
        os.makedirs('output/low_res')

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    transform = transforms.Compose([
        transforms.RandomCrop(opt.imageSize * opt.upSampling),
        transforms.ToTensor()
    ])

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Scale(opt.imageSize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                       std=[4.367, 4.464, 4.444])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot,
                                   download=True,
                                   train=False,
                                   transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot,
                                    download=True,
                                    train=False,
                                    transform=transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=False,
                                             num_workers=int(opt.workers))

    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))
    print(generator)

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    print(discriminator)

    # For the content loss
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))
    print(feature_extractor)
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    print('Test started...')
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    for i, data in enumerate(dataloader):
        # Generate data
        high_res_real, _ = data

        # Downsample images to low resolution
        if len(
                high_res_real
        ) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
            continue
        for j in range(opt.batchSize):
            low_res[j] = scale(high_res_real[j])
            high_res_real[j] = normalize(high_res_real[j])

        # Generate real and fake inputs
        if opt.cuda:
            high_res_real = Variable(high_res_real.cuda())
            high_res_fake = generator(Variable(low_res).cuda())
        else:
            high_res_real = Variable(high_res_real)
            high_res_fake = generator(Variable(low_res))

        ######### Test discriminator #########

        discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \
                                adversarial_criterion(discriminator(high_res_fake), target_fake)
        mean_discriminator_loss += discriminator_loss.data

        ######### Test generator #########

        real_features = feature_extractor(high_res_real)
        fake_features = feature_extractor(high_res_fake)

        generator_content_loss = content_criterion(
            high_res_fake, high_res_real) + 0.006 * content_criterion(
                fake_features, real_features)
        mean_generator_content_loss += generator_content_loss.data
        generator_adversarial_loss = adversarial_criterion(
            discriminator(high_res_fake), target_real)
        mean_generator_adversarial_loss += generator_adversarial_loss.data

        generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
        mean_generator_total_loss += generator_total_loss.data

        ######### Status and display #########
        sys.stdout.write(
            '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f'
            % (i, len(dataloader), discriminator_loss.data,
               generator_content_loss.data, generator_adversarial_loss.data,
               generator_total_loss.data))

        if len(
                high_res_real
        ) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
            continue
        for j in range(opt.batchSize):
            save_image(
                unnormalize(high_res_real[j].cpu()),
                'output/high_res_real/' + str(i * opt.batchSize + j) + '.png')
            save_image(
                unnormalize(high_res_fake[j].cpu()),
                'output/high_res_fake/' + str(i * opt.batchSize + j) + '.png')
            #save_image(high_res_real[j], 'output/high_res_real/' + str(i*opt.batchSize + j) + '.png') # without normlize, will mis-color real
            #save_image(high_res_fake[j], 'output/high_res_fake/' + str(i*opt.batchSize + j) + '.png')
            save_image(unnormalize(low_res[j]),
                       'output/low_res/' + str(i * opt.batchSize + j) + '.png')

    sys.stdout.write(
        '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n'
        % (i, len(dataloader), mean_discriminator_loss / len(dataloader),
           mean_generator_content_loss / len(dataloader),
           mean_generator_adversarial_loss / len(dataloader),
           mean_generator_total_loss / len(dataloader)))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar100', help='cifar10 | cifar100 | folder')
    parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset')
    parser.add_argument('--workers', type=int, default=2, help='number of data loading workers')
    parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
    parser.add_argument('--imageSize', type=int, default=15, help='the low resolution image size')
    parser.add_argument('--upSampling', type=int, default=2, help='low to high resolution scaling factor')
    parser.add_argument('--nEpochs', type=int, default=100, help='number of epochs to train for')
    parser.add_argument('--nPreEpochs', type=int, default=2, help='number of epochs to pre-train Generator')
    parser.add_argument('--generatorLR', type=float, default=0.0001, help='learning rate for generator')
    parser.add_argument('--discriminatorLR', type=float, default=0.0001, help='learning rate for discriminator')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use')
    parser.add_argument('--generatorWeights', type=str, default='', help="path to generator weights (to continue training)")
    parser.add_argument('--discriminatorWeights', type=str, default='', help="path to discriminator weights (to continue training)")
    parser.add_argument('--out', type=str, default='checkpoints', help='folder to output model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.out)
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    transform = transforms.Compose([transforms.RandomCrop(opt.imageSize*opt.upSampling),
                                    transforms.ToTensor()])

    normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                    std = [0.229, 0.224, 0.225])

    scale = transforms.Compose([transforms.ToPILImage(),
                                transforms.Scale(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                    std = [0.229, 0.224, 0.225])
                                ])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot, train=True, download=True, transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot, train=True, download=True, transform=transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                             shuffle=True, num_workers=int(opt.workers))

    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))
    print(generator)

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    print(discriminator)

    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
    print(feature_extractor)
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR)

    configure('logs/' + opt.dataset + '-' + str(opt.batchSize) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR), flush_secs=5)
    visualizer = Visualizer(image_size=opt.imageSize*opt.upSampling)

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    # Pre-train generator using raw MSE loss
    print('Generator pre-training')
    for epoch in range(opt.nPreEpochs):
        mean_generator_content_loss = 0.0

        for i, data in enumerate(dataloader):
            # Generate data
            high_res_real, _ = data

            # Downsample images to low resolution
            if len(high_res_real) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
                continue
            for j in range(opt.batchSize):  
                low_res[j] = scale(high_res_real[j])
                high_res_real[j] = normalize(high_res_real[j])

            # Generate real and fake inputs
            if opt.cuda:
                high_res_real = Variable(high_res_real.cuda())
                high_res_fake = generator(Variable(low_res).cuda())
            else:
                high_res_real = Variable(high_res_real)
                high_res_fake = generator(Variable(low_res))

            ######### Train generator #########
            generator.zero_grad()

            generator_content_loss = content_criterion(high_res_fake, high_res_real)

            mean_generator_content_loss += generator_content_loss.data

            generator_content_loss.backward()
            optim_generator.step()

            ######### Status and display #########
            sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (epoch, opt.nPreEpochs, i, len(dataloader), generator_content_loss.data))
            visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data)

        sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f\n' % (epoch, 2, i, len(dataloader), mean_generator_content_loss/len(dataloader)))
        log_value('generator_mse_loss', mean_generator_content_loss/len(dataloader), epoch)
        
        # Do checkpointing every epoch
        # torch.save(generator.state_dict(), '%s/generator_pretrain_%s.pth' %(opt.out,str(epoch)))

    # Do checkpointing
    torch.save(generator.state_dict(), '%s/generator_pretrain.pth' % opt.out)

    # SRGAN training
    optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR*0.1)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR*0.1)

    print('SRGAN training')
    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_adversarial_loss = 0.0
        mean_generator_total_loss = 0.0
        mean_discriminator_loss = 0.0

        for i, data in enumerate(dataloader):
            # Generate data
            high_res_real, _ = data

            # Downsample images to low resolution
            if len(high_res_real) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
                continue
            for j in range(opt.batchSize): 
                low_res[j] = scale(high_res_real[j])
                high_res_real[j] = normalize(high_res_real[j])

            # Generate real and fake inputs
            if opt.cuda:
                high_res_real = Variable(high_res_real.cuda())
                high_res_fake = generator(Variable(low_res).cuda())
                target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7).cuda()
                # size: opt.batchSize*1, and element is in 0.7~1.2
                target_fake = Variable(torch.rand(opt.batchSize,1)*0.3).cuda()
                # size: opt.batchSize*1, and element is in 0~0.3
            else:
                high_res_real = Variable(high_res_real)
                high_res_fake = generator(Variable(low_res))
                target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7)
                target_fake = Variable(torch.rand(opt.batchSize,1)*0.3)

            ######### Train discriminator #########
            discriminator.zero_grad()

            discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \
                                 adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake)
            mean_discriminator_loss += discriminator_loss.data

            discriminator_loss.backward()
            optim_discriminator.step()

            ######### Train generator #########
            generator.zero_grad()

            real_features = Variable(feature_extractor(high_res_real).data)
            fake_features = feature_extractor(high_res_fake)

            # for content loss, we use total images' pixel-wise MSE loss and 0.006* VggLoss, which VggLoss is actual
            # MSE loss of some layers result(feature) in VggNet
            generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
            mean_generator_content_loss += generator_content_loss.data
            generator_adversarial_loss = adversarial_criterion(discriminator(high_res_fake), ones_const)
            mean_generator_adversarial_loss += generator_adversarial_loss.data

            generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
            mean_generator_total_loss += generator_total_loss.data

            generator_total_loss.backward()
            optim_generator.step()

            ######### Status and display #########
            sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (epoch, opt.nEpochs, i, len(dataloader),
                discriminator_loss.data, generator_content_loss.data, generator_adversarial_loss.data, generator_total_loss.data))
            visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data)

        sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n' % (epoch, opt.nEpochs, i, len(dataloader),
        mean_discriminator_loss/len(dataloader), mean_generator_content_loss/len(dataloader),
        mean_generator_adversarial_loss/len(dataloader), mean_generator_total_loss/len(dataloader)))

        log_value('generator_content_loss', mean_generator_content_loss/len(dataloader), epoch)
        log_value('generator_adversarial_loss', mean_generator_adversarial_loss/len(dataloader), epoch)
        log_value('generator_total_loss', mean_generator_total_loss/len(dataloader), epoch)
        log_value('discriminator_loss', mean_discriminator_loss/len(dataloader), epoch)

        # Do checkpointing every epoch
        torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out)
        torch.save(discriminator.state_dict(), '%s/discriminator_final.pth' % opt.out)

    # Avoid closing
    print("train is over, and here can kill off threading after you watch the control log...")
    while True:
        pass
Ejemplo n.º 8
0
def down_and_up_sampling(image, save_name, upsampling):
    
    opt = setup()
    # create output folder
    try:
        os.makedirs('output/high_res_fake')
        os.makedirs('output/high_res_real')
        os.makedirs('output/low_res')
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print('[WARNING]: You have a CUDA device, so you should probably run with --cuda')

    transform = transforms.Compose([transforms.RandomCrop((
                                                                image.size[0],
                                                                image.size[1])),
                                    transforms.Pad(padding = 0),
                                    transforms.ToTensor()])
    normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])

    # [down sampling] down-sampling part
    scale = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize((int(image.size[1] / opt.upSampling), int(image.size[0] / opt.upSampling))),
                                transforms.Pad(padding=0),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                                        mean = [0.485, 0.456, 0.406],
                                                        std = [0.229, 0.224, 0.225])])
    
    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(
                                            mean = [-2.118, -2.036, -1.804],
                                            std = [4.367, 4.464, 4.444])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root = opt.dataroot, download = True, train = False, transform = transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root = opt.dataroot, download = True, train = False, transform = transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size = opt.batchSize,
                                             shuffle = False,
                                             num_workers = int(opt.workers))

    my_loader = transforms.Compose([transforms.ToTensor()])
    image = my_loader(image)

    # [paras] loading paras from .pth files
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1])

    # print('Test started...')
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    data = image
    for i in range(1):
        # Generate data
        high_res_real = data
        low_res = scale(high_res_real)
        low_res = torch.tensor([np.array(low_res)])
        high_res_real = normalize(high_res_real)
        high_res_real = torch.tensor([np.array(high_res_real)])
            
        # Generate real and fake inputs
        if opt.cuda:
            high_res_real = Variable(high_res_real.cuda())
            high_res_fake = generator(Variable(low_res).cuda())
        else:
            high_res_real = Variable(high_res_real)
            high_res_fake = generator(Variable(low_res)) # >>> create hr images

        save_image(unnormalize(high_res_real[0]), 'output/high_res_real/' + save_name)
        save_image(unnormalize(high_res_fake[0]), 'output/high_res_fake/' + save_name)
        save_image(unnormalize(low_res[0]), 'output/low_res/' + save_name)