/
train.py
183 lines (147 loc) · 11.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import print_function
import argparse
from random import shuffle
import random
import os
import sys
import math
import tensorflow.compat.v1 as tf
import glob
import cv2
from image_reader import *
from net import *
parser = argparse.ArgumentParser(description='')
parser.add_argument("--snapshot_dir", default='./snapshots/', help="path of snapshots") # 保存模型的路径
parser.add_argument("--log_dir", default='./log_dir/', help="path of log") # 保存模型的路径
parser.add_argument("--out_dir", default='./train_out/', help="path of train outputs") # 训练时保存可视化输出的路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") # 网络输入的尺度
parser.add_argument("--random_seed", type=int, default=1234, help="random seed") # 随机数种子
parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') # 学习率
parser.add_argument('--epoch', dest='epoch', type=int, default=150, help='# of epoch') # 训练的epoch数量
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') # adam优化器的beta1参数
parser.add_argument("--summary_pred_every", type=int, default=10,
help="times to summary.") # 训练中每过多少step保存训练日志(记录一下loss值)
parser.add_argument("--write_pred_every", type=int, default=40, help="times to write.") # 训练中每过多少step保存可视化结果
parser.add_argument("--save_pred_every", type=int, default=10499, help="times to save.") # 训练中每过多少step保存模型(可训练参数)
parser.add_argument("--lamda_l1_weight", type=float, default=100, help="L1 lamda") # 训练中L1_Loss前的乘数
parser.add_argument("--lamda_gan_weight", type=float, default=1, help="GAN lamda") # 训练中GAN_Loss前的乘数
parser.add_argument("--train_picture_format", default='.jpg',
help="format of training datas.") # 网络训练输入的图片的格式(图片在CGAN中被当做条件)
parser.add_argument("--train_label_format", default='.jpg',
help="format of training labels.") # 网络训练输入的标签的格式(标签在CGAN中被当做真样本)
parser.add_argument("--train_picture_path", default='./train_picture/',
help="path of training datas.") # 网络训练输入的图片路径
parser.add_argument("--train_label_path", default='./train_label/',
help="path of training labels.") # 网络训练输入的标签路径
args = parser.parse_args() # 用来解析命令行参数
EPS = 1e-12 # EPS用于保证log函数里面的参数大于零
def save(saver, sess, logdir, step): # 保存模型的save函数
model_name = 'model' # 保存的模型名前缀
checkpoint_path = os.path.join(logdir, model_name) # 模型的保存路径与名称
if not os.path.exists(logdir): # 如果路径不存在即创建
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step) # 保存模型
print('The checkpoint has been created.')
def cv_inv_proc(img): # cv_inv_proc函数将读取图片时归一化的图片还原成原图
img_rgb = (img + 1.) * 127.5
return img_rgb.astype(np.float32) # 返回bgr格式的图像,方便cv2写图像
def get_write_picture(picture, gen_label, label, height, width): # get_write_picture函数得到训练过程中的可视化结果
picture_image = cv_inv_proc(picture) # 还原输入的图像
gen_label_image = cv_inv_proc(gen_label[0]) # 还原生成的样本
label_image = cv_inv_proc(label) # 还原真实的样本(标签)
inv_picture_image = cv2.resize(picture_image, (width, height)) # 还原图像的尺寸
inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) # 还原生成的样本的尺寸
inv_label_image = cv2.resize(label_image, (width, height)) # 还原真实的样本的尺寸
output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) # 把他们拼起来
return output
def l1_loss(src, dst): # 定义l1_loss
return tf.reduce_mean(tf.abs(src - dst))
def main(): # 训练程序的主函数
if not os.path.exists(args.snapshot_dir): # 如果保存模型参数的文件夹不存在则创建
os.makedirs(args.snapshot_dir)
if not os.path.exists(args.out_dir): # 如果保存训练中可视化输出的文件夹不存在则创建
os.makedirs(args.out_dir)
train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) # 得到训练输入图像路径名称列表
tf.set_random_seed(args.random_seed) # 初始一下随机数
tf.disable_eager_execution()#tf2.0版本和1.x版本的区别
train_picture = tf.placeholder(tf.float32, shape=[1, args.image_size, args.image_size, 3],
name='train_picture') # 输入的训练图像
train_label = tf.placeholder(tf.float32, shape=[1, args.image_size, args.image_size, 3],
name='train_label') # 输入的与训练图像匹配的标签
gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') # 得到生成器的输出
dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False,
name="discriminator") # 判别器返回的对真实标签的判别结果
dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True,
name="discriminator") # 判别器返回的对生成(虚假的)标签判别结果
#原损失函数
gen_loss_GAN = tf.reduce_mean(-tf.log(1-dis_fake + EPS)) # 计算生成器损失中的GAN_loss部分
gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) # 计算生成器损失中的L1_loss部分
#测试用损失函数
#gen_loss_L2 = tf.reduce_mean(tf.square(tf.abs(gen_label - train_label + EPS))) # 计算生成器损失中的L2_loss部分
#gen_loss_ssim = tf.reduce_mean(1 - tf.image.ssim(gen_label, train_label, max_val=255)) # 计算生成器损失中的ssim_loss部分
# 原生成器损失函数
gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight # 计算生成器的loss
# 测试用生成器损失函数
#gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L2 * args.lamda_l1_weight
##最小二乘损失
#gen_loss = tf.reduce_mean(tf.square(dis_fake+ EPS)) * args.lamda_gan_weight + gen_loss_l1 * args.lamda_l1_weight
# 原判别器损失函数
dis_loss = tf.reduce_mean(-(tf.log(1-dis_real + EPS) + tf.log(dis_fake + EPS))) # 计算判别器的loss
# 测试用判别器损失函数
##最小二乘损失
# dis_loss = tf.reduce_mean(tf.square(dis_fake+ EPS))
gen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) # 记录生成器loss的日志
dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) # 记录判别器loss的日志
summary_writer = tf.summary.FileWriter(args.log_dir, graph=tf.get_default_graph()) # 日志记录器
g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] # 所有生成器的可训练参数
d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] # 所有判别器的可训练参数
d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) # 判别器训练器
g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) # 生成器训练器
d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) # 计算判别器参数梯度
d_train = d_optim.apply_gradients(d_grads_and_vars) # 更新判别器参数
g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) # 计算生成器参数梯度
g_train = g_optim.apply_gradients(g_grads_and_vars) # 更新生成器参数
train_op = tf.group(d_train, g_train) # train_op表示了参数更新操作
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 设定显存不超量使用
sess = tf.Session(config=config) # 新建会话层
#init = tf.compat.v1.global_variables_initializer() # 参数初始化器
init = tf.global_variables_initializer() # 参数初始化器
sess.run(init) # 初始化所有可训练参数
# saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables(), max_to_keep=50) # 模型保存器
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) # 模型保存器
counter = 0 # counter记录训练步数
for epoch in range(args.epoch): # 训练epoch数
shuffle(train_picture_list) # 每训练一个epoch,就打乱一下输入的顺序
for step in range(len(train_picture_list)): # 每个训练epoch中的训练step数
counter += 1
picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) # 获取不包含路径和格式的输入图片名称
# 读取一张训练图片,一张训练标签,以及相应的高和宽
picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,
picture_path=args.train_picture_path,
label_path=args.train_label_path,
picture_format=args.train_picture_format,
label_format=args.train_label_format,
size=args.image_size)
batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) # 填充维度
batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis=0) # 填充维度
feed_dict = {train_picture: batch_picture, train_label: batch_label} # 构造feed_dict
gen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op],
feed_dict=feed_dict) # 得到每个step中的生成器和判别器loss
if counter % args.save_pred_every == 0: # 每过save_pred_every次保存模型
save(saver, sess, args.snapshot_dir, counter)
if counter % args.summary_pred_every == 0: # 每过summary_pred_every次保存训练日志
gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum],
feed_dict=feed_dict)
summary_writer.add_summary(gen_loss_sum_value, counter)
summary_writer.add_summary(discriminator_sum_value, counter)
if counter % args.write_pred_every == 0: # 每过write_pred_every次写一下训练的可视化结果
gen_label_value = sess.run(gen_label, feed_dict=feed_dict) # run出生成器的输出
write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height,
picture_width) # 得到训练的可视化结果
write_image_name = args.out_dir +"out"+str(counter) + ".png" # 待保存的训练可视化结果路径与名称
cv2.imwrite(write_image_name, write_image) # 保存训练的可视化结果
print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value,
dis_loss_value))
if __name__ == '__main__':
main()