-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_test.py
158 lines (118 loc) · 3.73 KB
/
demo_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
demo.py - Optimized style transfer pipeline for interactive demo.
"""
# system imports
import argparse
import logging
from threading import Lock
import time
import os
from datetime import datetime
# library imports
import caffe
import cv2
import numpy as np
from skimage.transform import rescale
# local imports
from style import StyleTransfer
# argparse
parser = argparse.ArgumentParser(description="Run the optimized style transfer pipeline.",
usage="demo.py -s <style_image> -c <content_image>")
parser.add_argument("-s", "--style-img", type=str, required=True, help="input style (art) image")
parser.add_argument("-c", "--content-img", type=str, required=True, help="input content image")
# style transfer
# style workers, each should be backed by a lock
workers = {}
def gpu_count():
"""
Counts the number of CUDA-capable GPUs (Linux only).
"""
# use nvidia-smi to count number of GPUs
try:
output = subprocess.check_output("nvidia-smi -L")
return len(output.strip().split("\n"))
except:
return 0
def init(n_workers=1):
"""
Initialize the style transfer backend.
"""
global workers
if n_workers == 0:
n_workers = 1
# assign a lock to each worker
for i in range(n_workers):
worker = StyleTransfer("vgg16", use_pbar=False)
workers.update({Lock(): worker})
def st_api(img_style, img_content, callback=None):
"""
Style transfer API.
"""
global workers
# style transfer arguments
all_args = [{"length": 360, "ratio": 5e3, "n_iter": 32, "callback": callback},
{"length": 512, "ratio": 5e4, "n_iter": 16, "callback": callback}]
# acquire a worker (non-blocking)
st_lock = None
st_worker = None
while st_lock is None:
for lock, worker in workers.items():
# unblocking acquire
if lock.acquire(False):
st_lock = lock
st_worker = worker
break
else:
time.sleep(0.1)
# start style transfer
img_out = "content"
for args in all_args:
args["init"] = img_out
st_worker.transfer_style(img_style, img_content, **args)
img_out = st_worker.get_generated()
st_lock.release()
return img_out
def main(args):
"""
Entry point.
"""
# spin up a worker
init()
# perform style transfer
img_style = caffe.io.load_image(args.style_img)
img_content = caffe.io.load_image(args.content_img)
result = st_api(img_style, img_content)
result_trans = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
#save the image
filename = datetime.now().strftime('%Y-%m-%d %H.%M.%S')
filename += 'pic'
cv2.imwrite(filename + '.jpg', result_trans*255)
# show the image
cv2.imshow("Art", result_trans)
cv2.waitKey()
cv2.destroyWindow("Art")
def transfer(image_style, image_content , choice):
"""
Entry point.
"""
# spin up a worker
init()
caffe.set_mode_gpu()
# perform style transfer
img_style = caffe.io.load_image(image_style)
img_content = caffe.io.load_image(image_content)
result = st_api(img_style, img_content)
result_trans = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
result_com = result_trans*255
# if(choice=='camera'):
# result_com = np.rot90(result_com)
#result_com = np.rot90(result_com)
#save the image
filename = datetime.now().strftime('%Y-%m-%d %H.%M.%S')
filename += 'pic.jpg'
img_path = os.path.join('C:/Users/ml2020/Desktop/Demo/static', filename)
cv2.imwrite(img_path, result_com)
return img_path
if __name__ == "__main__":
args = parser.parse_args()
main(args)