Exemplo n.º 1
0
def main(args):
	lr_train_patch_size = args.patch_size
	layers_to_extract = [5, 9]
	hr_train_patch_size = lr_train_patch_size * args.scale
	if args.model == 'rdn':
		model = RDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':args.scale}, patch_size=lr_train_patch_size)
	else: 
		model = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':args.scale}, patch_size=lr_train_patch_size)
	f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
	discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
	loss_weights = {
   		'generator': 0.0,
    	'feature_extractor': 0.0833,
    	'discriminator': 0.01,
	}

	trainer = Trainer(
	    generator=model,
	    discriminator=discr,
	    feature_extractor=f_ext,
	    lr_train_dir='low_res/training/images',
	    hr_train_dir='high_res/training/images',
	    lr_valid_dir='low_res/validation/images',
	    hr_valid_dir='high_res/validation/images',
	    loss_weights=loss_weights,
	    dataname=args.name,
	    logs_dir='./logs',
	    weights_dir='./weights',
	    weights_generator=None,
	    weights_discriminator=None,
	    n_validation=40,
	    lr_decay_frequency=30,
	    lr_decay_factor=0.5,
	)

	trainer.train(epochs=args.num_epochs,
				  steps_per_epoch=args.epoch_steps,
				  batch_size,args.batch_size)
Exemplo n.º 2
0
layers_to_extract = [5, 9]
scale = 8
hr_train_patch_size = lr_train_patch_size * scale

rrdn = RRDN(arch_params={
    'C': 4,
    'D': 3,
    'G': 64,
    'G0': 64,
    'T': 10,
    'x': scale
},
            patch_size=lr_train_patch_size)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size,
                  layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

###################################################################################################################

#run = wandb.init(project='superres')
run = wandb.init(project='respick')
config = run.config

config.num_epochs = 50
config.batch_size = 32
config.input_height = 32
config.input_width = 32
config.output_height = 256
config.output_width = 256

val_dir = 'data/test'