forked from yusuketomoto/chainer-fast-neuralstyle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·129 lines (106 loc) · 4.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import os
import argparse
from PIL import Image
from chainer import cuda, Variable, optimizers, serializers
from net import *
def gram_matrix(y):
b, ch, w, h = y.data.shape
features = F.reshape(y, (ch, w*h))
gram = F.matmul(features, features, transb=True) / np.float32(ch*w*h)
# gram = F.batch_matmul(features, features, transb=True)/np.float32(ch*w*h)
return gram
def total_variation_regularization(x, beta=2):
xp = cuda.get_array_module(x.data)
wh = Variable(xp.array([[[[1],[-1]],[[1],[-1]],[[1],[-1]]]], dtype=xp.float32))
ww = Variable(xp.array([[[[1, -1]],[[1, -1]],[[1, -1]]]], dtype=xp.float32))
tvh = lambda x: F.convolution_2d(x, W=wh)
tvw = lambda x: F.convolution_2d(x, W=ww)
dh = tvh(x)
dw = tvw(x)
tv = (F.sum(dh**2) + F.sum(dw**2)) ** (beta / 2.)
return tv
parser = argparse.ArgumentParser(description='Real-time style transfer')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--dataset', '-d', default='dataset', type=str,
help='dataset directory path (according to the paper, use MSCOCO 80k images)')
parser.add_argument('--style_image', '-s', type=str, required=True,
help='style image path')
# parser.add_argument('--batchsize', '-b', type=int, default=4,
# help='batch size (default value is 4)')
parser.add_argument('--input', '-i', default=None, type=str,
help='input model file path without extension')
parser.add_argument('--output', '-o', default='out', type=str,
help='output model file path without extension')
parser.add_argument('--lambda_tv', default=10e-4, type=float,
help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.')
parser.add_argument('--lambda_feat', default=5e0, type=float)
parser.add_argument('--lambda_style', default=1e2, type=float)
parser.add_argument('--epoch', '-e', default=2, type=int)
parser.add_argument('--lr', '-l', default=1e-3, type=float)
parser.add_argument('--checkpoint', '-c', default=0, type=int)
args = parser.parse_args()
# batchsize = args.batchsize
batchsize = 1 # force batchsize 1 since it cannot train agains mini-batches now.
n_epoch = args.epoch
lambda_tv = args.lambda_tv
lambda_f = args.lambda_feat
lambda_s = args.lambda_style
fs = os.listdir(args.dataset)
imagepaths = []
for fn in fs:
base, ext = os.path.splitext(fn)
if ext == '.jpg' or ext == '.png':
imagepath = os.path.join(args.dataset,fn)
imagepaths.append(imagepath)
n_data = len(imagepaths)
print 'num traning images:', n_data
n_iter = n_data / batchsize
print n_iter, 'iterations,', n_epoch, 'epochs'
model = FastStyleNet()
vgg = VGG()
serializers.load_npz('vgg16.model', vgg)
if args.gpu >= 0:
cuda.get_device(args.gpu).use()
model.to_gpu()
vgg.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy
O = optimizers.Adam(alpha=args.lr)
O.setup(model)
style = vgg.preprocess(np.asarray(Image.open(args.style_image).convert('RGB').resize((256,256)), dtype=np.float32))
style = xp.asarray(style, dtype=xp.float32)
style_b = xp.zeros((batchsize,) + style.shape, dtype=xp.float32)
for i in range(batchsize):
style_b[i] = style
feature_s = vgg(Variable(style_b, volatile=True))
gram_s = [gram_matrix(y) for y in feature_s]
for epoch in range(n_epoch):
print 'epoch', epoch
for i in range(n_iter):
model.zerograds()
vgg.zerograds()
indices = range(i * batchsize, (i+1) * batchsize)
x = xp.zeros((batchsize, 3, 256, 256), dtype=xp.float32)
for j in range(batchsize):
x[j] = xp.asarray(Image.open(imagepaths[i*batchsize + j]).convert('RGB').resize((256,256)), dtype=np.float32).transpose(2, 0, 1)
x -= 120 # subtract mean
xc = Variable(x.copy(), volatile=True)
x = Variable(x)
y = model(x)
feature = vgg(xc)
feature_hat = vgg(y)
L_feat = lambda_f * F.mean_squared_error(Variable(feature[2].data), feature_hat[2]) # compute for only the output of layer conv3_3
L_style = Variable(xp.zeros((), dtype=np.float32))
for f, f_hat, g_s in zip(feature, feature_hat, gram_s):
L_style += lambda_s * F.mean_squared_error(gram_matrix(f_hat), Variable(g_s.data))
L_tv = lambda_tv * total_variation_regularization(y)
L = L_feat + L_style + L_tv
print '(epoch {}) batch {}/{}... training loss is...{}'.format(epoch, i, n_iter, L.data)
L.backward()
O.update()
if args.checkpoint > 0 and i % args.checkpoint == 0:
serializers.save_npz('models/style_{}_{}.model'.format(epoch, i), model)
print 'save "style.model"'
serializers.save_npz('models/style_{}.model'.format(epoch), model)
serializers.save_npz('models/style.model'.format(epoch), model)