forked from sefibk/KernelGAN
/
train.py
64 lines (57 loc) · 2.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import tqdm
from configs import Config
from data import DataGenerator
from kernelGAN import KernelGAN
from learner import Learner
from torch.utils.data import DataLoader
batch_size = 16
def train(conf):
gan = KernelGAN(conf)
learner = Learner()
data = DataGenerator(conf, gan)
dataloader = DataLoader(data, batch_size=batch_size,
shuffle=False)
timer = 0
for i_batch, sample_batched in enumerate(tqdm.tqdm(dataloader)):
g_in,d_in = sample_batched
gan.train(g_in,d_in)
learner.update(i_batch*batch_size, gan)
if learner.flag:
timer += 1
if timer > 10:
break
gan.finish()
# for iteration in tqdm.tqdm(range(conf.max_iters), ncols=60):
# [g_in, d_in] = data.__getitem__(iteration)
# gan.train(g_in, d_in)
# learner.update(iteration, gan)
# gan.finish()
def main():
"""The main function - performs kernel estimation (+ ZSSR) for all images in the 'test_images' folder"""
import argparse
# Parse the command line arguments
prog = argparse.ArgumentParser()
prog.add_argument('--input-dir', '-i', type=str, default='test_images', help='path to image input directory.')
prog.add_argument('--output-dir', '-o', type=str, default='results', help='path to image output directory.')
prog.add_argument('--X4', action='store_true', help='The wanted SR scale factor')
prog.add_argument('--SR', action='store_true', help='when activated - ZSSR is not performed')
prog.add_argument('--real', action='store_true', help='ZSSRs configuration is for real images')
args = prog.parse_args()
# Run the KernelGAN sequentially on all images in the input directory
for filename in os.listdir(os.path.abspath(args.input_dir)):
conf = Config().parse(create_params(filename, args))
train(conf)
prog.exit(0)
def create_params(filename, args):
params = ['--input_image_path', os.path.join(args.input_dir, filename),
'--output_dir_path', os.path.abspath(args.output_dir)]
if args.X4:
params.append('--X4')
if args.SR:
params.append('--do_ZSSR')
if args.real:
params.append('--real_image')
return params
if __name__ == '__main__':
main()