コード例 #1
0
ファイル: train_srgan.py プロジェクト: Tubbz-alt/SrGAN-2
def train():
    div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
    div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

    train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
    valid_ds = div2k_valid.dataset(batch_size=16,
                                   random_transform=True,
                                   repeat_count=1)

    pre_trainer = SrganGeneratorTrainer(model=generator(),
                                        checkpoint_dir='.ckpt/pre_generator')
    pre_trainer.train(train_ds,
                      valid_ds.take(10),
                      steps=1000000,
                      evaluate_every=1000,
                      save_best_only=False)
    pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

    gan_generator = generator()
    gan_generator.load_weights(weights_file('pre_generator.h5'))

    gan_trainer = SrganTrainer(generator=gan_generator,
                               discriminator=discriminator())
    gan_trainer.train(train_ds, steps=200000)

    gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
    gan_trainer.discriminator.save_weights(
        weights_file('gan_discriminator.h5'))
コード例 #2
0
def train(depth,scale,downgrade):
    weights_dir = f'weights/edsr-{depth}-x{scale}'
    weights_file = os.path.join(weights_dir, 'weights.h5')
    os.makedirs(weights_dir, exist_ok=True)

    div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)# 1-800 images
    div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)# 801-900 images

    train_batch_size = 16
    train_ds = div2k_train.dataset(batch_size=train_batch_size, random_transform=True)
    valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)

    trainer = EdsrTrainer(model=edsr(scale=scale, num_res_blocks=depth), 
                          checkpoint_dir=f'.ckpt/edsr-{depth}-x{scale}')

    steps_epoch = int(800/train_batch_size) # 50 steps/epoch
    # Train EDSR model for 300,000 steps and evaluate model
    # every 1000 steps on the first 10 images of the DIV2K
    # validation set.
    trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=6000*steps_epoch, 
                  evaluate_every=500*steps_epoch, 
                  save_best_only=True)

    # Restore from checkpoint with highest PSNR
    trainer.restore()

    # Evaluate model on full validation set
    psnrv = trainer.evaluate(valid_ds)
    print(f'PSNR = {psnrv.numpy():3f}')

    # Save weights
    trainer.model.save_weights(weights_file)
コード例 #3
0
def get_dataset(args):
    data_train = DIV2K(args)
    dataloader = torch.utils.data.DataLoader(data_train,
                                             batch_size=args.batchSize,
                                             drop_last=True,
                                             shuffle=True,
                                             num_workers=int(args.nThreads),
                                             pin_memory=False)
    return dataloader
コード例 #4
0
ファイル: main.py プロジェクト: tonyykam/SR-macroscope
def main():
    print(torch.cuda.device_count(), "gpus available")

    n_epochs = config["num_epochs"]
    print("Number of epochs: ", n_epochs)
    model = MDSR().cuda()

    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    train_HR_dataset = './test/x1/'
    train_LR_dataset = './test/x2/'

    #from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
    transformLR = transforms.Compose([transforms.CenterCrop(512), transforms.ToTensor(),
        transforms.Normalize((255/2, 255/2, 255/2), (255/3, 255/3, 255/3))])
    transformHR = transforms.Compose([transforms.CenterCrop(1024), transforms.ToTensor(),
        transforms.Normalize((255/2, 255/2, 255/2), (255/3, 255/3, 255/3))])

    dataset = DIV2K(train_HR_dataset, train_LR_dataset, transformHR, transformLR)
    train_dataset, valid_dataset, _ = torch.utils.data.random_split(
        dataset, [int(len(dataset) * .5), int(len(dataset) * .05), int(len(dataset) * .45)])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1)
    val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size = 1)
    #total_loss = train(train_loader, model, criterion, optimizer)
    #print("Total loss", total_loss)

    # todo: val loading
    # val_dataset = './validation/' + '*' #how will you get your dataset
    # val_loader = CIFAR(val_dataset) # how will you use pytorch's function to build a dataloader

    current_best_validation_loss = float('inf')
    for epoch in range(n_epochs):
        total_loss = train(train_loader, model, criterion, optimizer)
        print("Epoch {0}: {1}".format(epoch, total_loss))
        validation_loss = validate(val_loader, model, criterion)
        print("Test Loss {0}".format(validation_loss))
        if validation_loss < current_best_validation_loss:
            save_checkpoint(model.state_dict(), True)
            current_best_validation_loss = validation_loss
コード例 #5
0
        sr = preprocess_input(sr)
        hr = preprocess_input(hr)
        sr_features = self.vgg(sr) / 12.75
        hr_features = self.vgg(hr) / 12.75
        return self.mean_squared_error(hr_features, sr_features)

    def _generator_loss(self, sr_out):
        return self.binary_cross_entropy(tf.ones_like(sr_out), sr_out)

    def _discriminator_loss(self, hr_out, sr_out):
        hr_loss = self.binary_cross_entropy(tf.ones_like(hr_out), hr_out)
        sr_loss = self.binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
        return hr_loss + sr_loss


div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)

#To pretrain gen
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,valid_ds.take(10),steps=50000,evaluate_every=1000,save_best_only=False)

CWD_PATH = os.getcwd()

#To train gan
    gan_generator = generator()
    gan_generator.load_weights(os.path.join(CWD_PATH,'weights','pre_generator.h5'))
コード例 #6
0
# Downgrade operator
downgrade = 'bicubic'

SAVED_MODEL_DIR = 'saved_model_dscn_bw'

############################################### Load Data Set ####################################################

# Location of model weights (needed for demo)
weights_dir = f'weights/wdsr-b-{depth}-x{scale}'
weights_file = os.path.join(weights_dir, 'weights.h5')

os.makedirs(weights_dir, exist_ok=True)

div2k_train = DIV2K(scale=scale,
                    subset='train',
                    downgrade=downgrade,
                    make_input_img_bw=True)
div2k_valid = DIV2K(scale=scale,
                    subset='valid',
                    downgrade=downgrade,
                    make_input_img_bw=True)

train_ds = div2k_train.dataset(batch_size=256, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1,
                               random_transform=False,
                               repeat_count=1)

############################################### Setup Training stuff ####################################################

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
コード例 #7
0
# Implementation of Naive Pruning

from model.srgan import generator
from train import SrganGeneratorTrainer

from data import DIV2K

train_loader = DIV2K(scale=4, downgrade='bicubic', subset='train')

train_ds = train_loader.dataset(batch_size=16,
                                random_transform=True,
                                repeat_count=None)
valid_loader = DIV2K(scale=4, downgrade='bicubic', subset='valid')

valid_ds = valid_loader.dataset(batch_size=1,
                                random_transform=False,
                                repeat_count=1)

pre_trainer = SrganGeneratorTrainer(model=generator(num_res_blocks=6),
                                    checkpoint_dir=f'.ckpt/pre_generator')

pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=1000000,
                  evaluate_every=1000)

pre_trainer.model.save_weights('weights/srgan/pre_generator_6.h5')

from model.srgan import generator, discriminator
from train import SrganTrainer
コード例 #8
0
from data import DIV2K
#from model.wdsr import wdsr_b
from model.edsr import edsr
import os

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay

train = DIV2K(scale=4, downgrade='bicubic', subset='train')
train_ds = train.dataset(batch_size=16, random_transform=True)

# Create directory for saving model weights
weights_dir = 'weights/article'
os.makedirs(weights_dir, exist_ok=True)

# EDSR baseline as described in the EDSR paper (1.52M parameters)
model_edsr = edsr(scale=4, num_res_blocks=16)

# Adam optimizer with a scheduler that halfs learning rate after 200,000 steps
optim_edsr = Adam(learning_rate=PiecewiseConstantDecay(boundaries=[200000],
                                                       values=[1e-4, 5e-5]))

# Compile and train model for 300,000 steps with L1 pixel loss
model_edsr.compile(optimizer=optim_edsr, loss='mean_absolute_error')
model_edsr.fit(train_ds, epochs=300, steps_per_epoch=1000)

# Save model weights
model_edsr.save_weights(os.path.join(weights_dir, 'weights-edsr-16-x4.h5'))
"""
# Custom WDSR B model (0.62M parameters)
model_wdsr = wdsr_b(scale=4, num_res_blocks=32)
コード例 #9
0
    args = parser.parse_args()
    args, lg = parse(args)

    # Tensorboard save directory
    resume = args['solver']['resume']
    tensorboard_path = 'Tensorboard/{}'.format(args['name'])

    if resume == False:
        if osp.exists(tensorboard_path):
            shutil.rmtree(tensorboard_path, True)
            lg.info('Remove dir: [{}]'.format(tensorboard_path))
    writer = SummaryWriter(tensorboard_path)

    # create dataset
    train_data = DIV2K(args['datasets']['train'])
    lg.info('Create train dataset successfully!')
    lg.info('Training: [{}] iterations for each epoch'.format(len(train_data)))

    val_data = DIV2K(args['datasets']['val'])
    lg.info('Create val dataset successfully!')
    lg.info('Validating: [{}] iterations for each epoch'.format(len(val_data)))

    # create solver
    lg.info('Preparing for experiment: [{}]'.format(args['name']))
    solver = Solver(args, train_data, val_data, writer)

    # train
    lg.info('Start training...')
    solver.train()
コード例 #10
0
ファイル: sisr_wdsr.py プロジェクト: jacobdineen/SISR
from data import DIV2K
# from model.edsr import edsr
from train import WdsrTrainer
import tensorflow as tf
from model.wdsr_weight_norm import wdsr_a, wdsr_b

# Number of residual blocks
depth = [1, 3, 5, 8]

# Super-resolution factor
scale = 4

# Downgrade operator
downgrade = 'bicubic'

div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)

train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1,
                               random_transform=False,
                               repeat_count=1)

wdsra_psnr = {}

for i in depth:
    '''Store Model Weights. Will have |depth| different folders = 4'''
    weights_dir = f'weights/wdsr-a-{i}-x{scale}'
    weights_file = os.path.join(weights_dir, 'weights.h5')
    os.makedirs(weights_dir, exist_ok=True)
    '''instantiate training mechanism'''
コード例 #11
0
    if args.test is not None:
        weights_path = args.test
    else:
        #weights_path = '../Examples/PreTrained/'
        pass

    if os.path.exists(Zip_path):
        print('unzipping: ', Zip_path)
        util.unzip(Zip_path, args.train)
    else:
        print('nothing to unzip')

    # prepare training data by cropping and
    div2k_train = DIV2K(crop_size=args.crop_size,
                        subset='train',
                        images_dir=Image_dir,
                        caches_dir=Cache_dir)
    div2k_valid = DIV2K(crop_size=args.crop_size,
                        subset='valid',
                        images_dir=Image_dir,
                        caches_dir=Cache_dir)

    train_ds = div2k_train.dataset_hr(batch_size=args.batch_size,
                                      random_transform=True,
                                      normalize_dataset=False)
    valid_ds = div2k_valid.dataset_hr(batch_size=args.batch_size,
                                      random_transform=True,
                                      normalize_dataset=False)

    valid_lr, valid_hr = div2k_valid.get_single(818)