-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
63 lines (50 loc) · 2.36 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # disable tensorflow debugging logs
import itertools
import tensorflow as tf
import numpy as np
import PIL.Image
from model import ImageTransformNet
from utils import convert, tensor_to_image
from hparams import hparams
# Initialize DNN
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
def run_test(args):
it_network = ImageTransformNet(input_shape=hparams['test_size'],
residual_layers=hparams['residual_layers'],
residual_filters=hparams['residual_filters'])
ckpt_dir = os.path.join(args.name, 'pretrained')
ckpt = tf.train.Checkpoint(network=it_network, step=tf.Variable(0))
ckpt.restore(tf.train.latest_checkpoint(ckpt_dir)).expect_partial()
print('\n###################################################')
print('Perceptual Losses for Real-Time Style Transfer Test')
print('###################################################\n')
print('Restored {} step: {}\n'.format(args.name, str(ckpt.step.numpy())))
dir_size = 'step_{}_{}x{}'.format(str(ckpt.step.numpy()),
str(hparams['test_size'][0]),
str(hparams['test_size'][1]))
dir_model = 'output_img_{}'.format(args.name)
out_dir = os.path.join(args.output_path, dir_model, dir_size)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
content_img_list = os.listdir(args.test_content_img)
for c_file in content_img_list:
content = convert(os.path.join(args.test_content_img, c_file),
hparams['test_size'][:2])[tf.newaxis, :]
output = it_network(content, training=False)
tensor = tensor_to_image(output)
c_name = '{}_{}'.format(args.name, os.path.splitext(c_file)[0])
save_path = os.path.join(out_dir, c_name)
tensor.save(save_path + '.jpeg')
print ('Image: {}.jpeg saved'.format(save_path))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='model')
parser.add_argument('--test_content_img', default='./images/content_img/')
parser.add_argument('--output_path', default='./images/')
args = parser.parse_args()
run_test(args)
if __name__ == '__main__':
main()