def gen(): """ transfer images from a directory. """ content_images = reader.image(4, FLAGS.image_size, FLAGS.content, epochs=FLAGS.epoch, shuffle=False, crop=False) generated_images = model.net(content_images / 255.) output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8) with tf.Session() as sess: file_ = tf.train.latest_checkpoint(FLAGS.model) if not file_: print('Could not find trained model in {}'.format(FLAGS.model)) return print('Using model from {}'.format(file_)) saver = tf.train.Saver() saver.restore(sess, file_) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) i = 0 start_time = time.time() try: while not coord.should_stop(): print(i) images_t = sess.run(output_format) elapsed = time.time() - start_time start_time = time.time() print('Time for one batch: {}'.format(elapsed)) for raw_image in images_t: i += 1 print("Save result in: ", "output/" + FLAGS.output + '-{0:04d}.jpg'.format(i)) misc.imsave( "output/" + FLAGS.output + '-{0:04d}.jpg'.format(i), raw_image) except tf.errors.OutOfRangeError: print('Done generate -- epoch limit reached!') except KeyboardInterrupt: print("Terminated by Keyboard Interrupt") finally: coord.request_stop() coord.join(threads)
def main(argv=None): if not FLAGS.CONTENT_IMAGES_PATH: print "train a fast nerual style need to set the Content images path" return content_images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.CONTENT_IMAGES_PATH, epochs=1, shuffle=False, crop=False) generated_images = model.net(content_images / 255.) output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8) with tf.Session() as sess: file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH) if not file: print('Could not find trained model in {0}'.format( FLAGS.MODEL_PATH)) return print('Using model from {}'.format(file)) saver = tf.train.Saver() saver.restore(sess, file) sess.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) i = 0 start_time = time.time() try: while not coord.should_stop(): print(i) images_t = sess.run(output_format) elapsed = time.time() - start_time start_time = time.time() print('Time for one batch: {}'.format(elapsed)) for raw_image in images_t: i += 1 misc.imsave('out{0:04d}.png'.format(i), raw_image) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def get_res_features(TRAIN_IMAGES_PATH=TRAIN_IMAGES_PATH, VGG_PATH=VGG_PATH): images = reader.image(BATCH_SIZE, 224, IMAGE_SIZE, TRAIN_IMAGES_PATH) net, _ = vgg.net(VGG_PATH, images) # for f in net: # print (f) vgg_features = net['relu5_4'] # pdb.set_trace() return vgg_features # for layer in content_layers: # generated_images, content_images = tf.split(value=net[layer], num_or_size_splits=2, axis=0) # size = tf.size(generated_images) # shape = tf.shape(generated_images) # width = shape[1] # height = shape[2] # num_filters = shape[3] # content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size) # content_loss = content_loss
def main(argv=None): if FLAGS.CONTENT_IMAGES_PATH: content_images = reader.image( FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.CONTENT_IMAGES_PATH, epochs=1, shuffle=False, crop=False) generated_images = model.net(content_images / 255.) output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8) with tf.Session() as sess: file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH) if not file: print('Could not find trained model in {}'.format(FLAGS.MODEL_PATH)) return print('Using model from {}'.format(file)) saver = tf.train.Saver() saver.restore(sess, file) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) i = 0 start_time = time.time() try: while not coord.should_stop(): print(i) images_t = sess.run(output_format) elapsed = time.time() - start_time start_time = time.time() print('Time for one batch: {}'.format(elapsed)) for raw_image in images_t: i += 1 misc.imsave('out{0:04d}.png'.format(i), raw_image) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads) return if not os.path.exists(FLAGS.MODEL_PATH): os.makedirs(FLAGS.MODEL_PATH) style_paths = FLAGS.STYLE_IMAGES.split(',') style_layers = FLAGS.STYLE_LAYERS.split(',') content_layers = FLAGS.CONTENT_LAYERS.split(',') style_features_t = get_style_features(style_paths, style_layers) images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH) generated = model.net(images / 255.) net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images])) content_loss = 0 for layer in content_layers: generated_images, content_images = tf.split(0, 2, net[layer]) size = tf.size(generated_images) content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size) content_loss = content_loss / len(content_layers) style_loss = 0 for style_gram, layer in zip(style_features_t, style_layers): generated_images, _ = tf.split(0, 2, net[layer]) size = tf.size(generated_images) for style_image in style_gram: style_loss += tf.nn.l2_loss(tf.reduce_sum(gram(generated_images) - style_image, 0)) / tf.to_float(size) style_loss = style_loss / len(style_layers) loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * total_variation_loss(generated) global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step) output_format = tf.saturate_cast(tf.concat(0, [generated, images]) + reader.mean_pixel, tf.uint8) with tf.Session() as sess: saver = tf.train.Saver(tf.all_variables()) file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH) if file: print('Restoring model from {}'.format(file)) saver.restore(sess, file) else: print('New model initilized') sess.run(tf.initialize_all_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 100 == 0: print(step, loss_t, elapsed_time) output_t = sess.run(output_format) for i, raw_image in enumerate(output_t): misc.imsave('out{}.png'.format(i), raw_image) if step % 10000 == 0: saver.save(sess, FLAGS.MODEL_PATH + '/fast-style-model', global_step=step) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def training(FLAGS): # FLAGS from *.yml file, reading by utils.read_conf_file # ensure training path exists # in *yml: <model_path>/<naming> models/f.yml training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: # # create training network net_fn = nets_factory.get_network_fn(FLAGS.loss_model, num_classes=1, is_training=False) img_prep_fn, img_unpr_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) ipt_imgs = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, _TRAIN_IMG_PATH, img_prep_fn, epochs=FLAGS.epoch) # # image transfrom network imgs_net = model.img_trans_net(ipt_imgs, training=True) opt_imgs = [ img_prep_fn(img, FLAGS.image_size, FLAGS.image_size) for img in tf.unstack(imgs_net, axis=0, num=FLAGS.batch_size) ] _, endpoints_d = net_fn(tf.concat([opt_imgs, ipt_imgs], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info( 'Loss network layers(content_layers and style_layers):') for key in endpoints_d: tf.logging.info(key) # # loss network L_content = loss.content_loss(endpoints_d, FLAGS.content_layers) style_features = loss.get_style_feature(FLAGS) L_style, L_style_sum = loss.style_loss(endpoints_d, style_features, FLAGS.style_layers) L_tv = loss.total_loss(opt_imgs) # weight defined in conf/*.yml # in mosaic.yml, W(content, style, total variation loss) = (1, 100, 0) l = FLAGS.style_weight * L_style + FLAGS.content_weight * L_content + FLAGS.tv_weight * L_tv # Add Summary for visualization in tensorboard. tf.summary.scalar('loss/content_loss', L_content) tf.summary.scalar('loss/style_loss', L_style) tf.summary.scalar('loss/regularizer_loss', L_tv) tf.summary.scalar('weighted_loss/weighted_content_loss', L_content * FLAGS.content_weight) tf.summary.scalar('weighted_loss/weighted_style_loss', L_style * FLAGS.style_weight) tf.summary.scalar('weighted_loss/weighted_regularizer_loss', L_tv * FLAGS.tv_weight) tf.summary.scalar('total_loss', l) for layer in FLAGS.style_layers: tf.summary.scalar('style_loss/' + layer, L_style_sum[layer]) tf.summary.image('generated img transf net', imgs_net) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image( 'origin', tf.stack([ img_unpr_fn(img) for img in tf.unstack( ipt_imgs, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) # # prepare for training global_step = tf.Variable(0, name='global_step', trainable=False) var2train = [] for var in tf.trainable_variables(): if not var.name.startswith(FLAGS.loss_model): var2train.append(var) train_op = tf.train.AdamOptimizer(1e-3).minimize( l, global_step=global_step, var_list=var2train) var2restore = [] for v in tf.global_variables(): if not v.name.startswith(FLAGS.loss_model): var2restore.append(v) saver = tf.train.Saver(var2restore, write_version=tf.train.SaverDef.V2) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) # Restore variables for loss network. init_func = get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) # # start training coord = tf.train.Coordinator() # manage threads threads = tf.train.start_queue_runners(coord=coord) start = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, l, global_step]) elapsed = time.time() - start start = time.time() # logging if step % _LOG_EPS == 0: tf.logging.info( 'step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed)) # summary if step % _SUMMARY_EPS == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() # checkpoint if step % _SAVE_EPS == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save( sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(FLAGS): # 得到风格特征 style_features_t = losses.get_style_features(FLAGS) # Make sure the training path exists. training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """Build Network""" # 构造vgg网络,按照FLAGS.loss_model中的网络名字,可以在/nets/nets_factory.py 中的networks_map找到对应 network_fn = nets_factory.get_network_fn(FLAGS.loss_model, num_classes=1, is_training=False) # 根据不同网络做不同的预处理 image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) # 读取一个批次的数据,并且预处理 # 这里的数据你可以不用coco,可以直接给一个包含很多图片的文件夹即可 # 因为coco过于大 processed_images = reader.image(FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, 'F:/CASIA/train_frame/real/', image_preprocessing_fn, epochs=FLAGS.epoch) # 通过生成网络,生成图片,相当于y^ generated = model.net(processed_images, training=True) # 因为一会要把生成图片喂入到后面vgg进行计算两个损失,所以要先进行预处理 processed_generated = [ image_preprocessing_fn(image, FLAGS.image_height, FLAGS.image_width) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] # 因为上面是list格式,所以用tf.stack堆叠成tensor processed_generated = tf.stack(processed_generated) # 按照batch那一个维度,拼起来,比如原来两个是[batch_size,h,w,c],concat后变为[2*batch_size,h,w,c] # 这样一次前向传播把y^ 和y_c的特征都计算出来了 _, endpoints_dict = network_fn(tf.concat( [processed_generated, processed_images], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info( 'Loss network layers(You can define them in "content_layers" and "style_layers"):' ) for key in endpoints_dict: tf.logging.info(key) """Build Losses""" # 计算三个损失 content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss( endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss( generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" # 为了tensorboard,可以忽略 tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss', loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image( 'origin', tf.stack([ image_unprocessing_fn(image) for image in tf.unstack( processed_images, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" # 步数 global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): # 把非vgg网络里面的可训练变量加入variable_to_train if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) # 注意var_list train_op = tf.train.AdamOptimizer(1e-3).minimize( loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): # 把非vgg中的可存储变量加入variables_to_restore if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) # 注意variables_to_restore saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) # Restore variables for loss network. # slim的,可以根据FLAGS里面配置把网络参数加载到sess这个会话里面 init_func = utils._get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" print(step) if step % 10 == 0: tf.logging.info( 'step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save( sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def optimize(): MODEL_DIR_NAME = os.path.dirname(FLAGS.MODEL_PATH) if not os.path.exists(MODEL_DIR_NAME): os.mkdir(MODEL_DIR_NAME) style_paths = FLAGS.STYLE_IMAGES.split(',') style_layers = FLAGS.STYLE_LAYERS.split(',') content_layers = FLAGS.CONTENT_LAYERS.split(',') # style gram matrix style_features_t = loss.get_style_features(style_paths, style_layers, FLAGS.IMAGE_SIZE, FLAGS.STYLE_SCALE, FLAGS.VGG_PATH) with tf.Graph().as_default(), tf.Session() as sess: # train_images images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_FOLDER, FLAGS.EPOCHS) generated = transform.net(images - vgg.MEAN_PIXEL, training=True) net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat([generated, images], 0) - vgg.MEAN_PIXEL) # 损失函数 content_loss = loss.content_loss(net, content_layers) style_loss = loss.style_loss( net, style_features_t, style_layers) / len(style_paths) total_loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + \ FLAGS.TV_WEIGHT * loss.total_variation_loss(generated) # 准备训练 global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not variable.name.startswith('vgg19'): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize( total_loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not v.name.startswith('vgg19'): variables_to_restore.append(v) # 开始训练 saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) # 加载检查点 ckpt = tf.train.latest_checkpoint(MODEL_DIR_NAME) if ckpt: tf.logging.info('Restoring model from {}'.format(ckpt)) saver.restore(sess, ckpt) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, total_loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 10 == 0: tf.logging.info( 'step: %d, total loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) if step % 10000 == 0: saver.save(sess, FLAGS.MODEL_PATH, global_step=step) tf.logging.info('Save model') except tf.errors.OutOfRangeError: saver.save(sess, FLAGS.MODEL_PATH + '-done') tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(argv=None): content_layers = FLAGS.content_layers.split(',') style_layers = FLAGS.style_layers.split(',') style_layers_weights = [ float(i) for i in FLAGS.style_layers_weights.split(",") ] #num_steps_decay = 82786 / FLAGS.batch_size num_steps_decay = 10000 style_features_t = losses.get_style_features(FLAGS) training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Session() as sess: """Build Network""" network_fn = nets_factory.get_network_fn(FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.net(processed_images, FLAGS.alpha) processed_generated = [ image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat( [processed_generated, processed_images], 0), spatial_squeeze=False) """Build Losses""" content_loss = losses.content_loss(endpoints_dict, content_layers) style_loss, style_losses = losses.style_loss(endpoints_dict, style_features_t, style_layers, style_layers_weights) tv_loss = losses.total_variation_loss( generated) # use the unprocessed image content_loss = FLAGS.content_weight * content_loss style_loss = FLAGS.style_weight * style_loss tv_loss = FLAGS.tv_weight * tv_loss loss = style_loss + content_loss + tv_loss """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) lr = tf.train.exponential_decay(learning_rate=1e-1, global_step=global_step, decay_steps=num_steps_decay, decay_rate=1e-1, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-8) train_op = optimizer.minimize(loss, global_step=global_step, var_list=variable_to_train) #train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) init_func = utils._get_init_fn(FLAGS) init_func(sess) last_file = tf.train.latest_checkpoint(training_path) if last_file: saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): _, c_loss, s_losses, t_loss, total_loss, step = sess.run([ train_op, content_loss, style_losses, tv_loss, loss, global_step ]) """logging""" if step % 10 == 0: print(step, c_loss, s_losses, t_loss, total_loss) """checkpoint""" if step % 10000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model'), global_step=step) if step == FLAGS.max_iter: saver.save( sess, os.path.join(training_path, 'fast-style-model-done')) break except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def evaluate(): with tf.Graph().as_default(): with tf.Session() as sess: #loss_input_style = tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) #loss_input_target =tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) # For online optimization problem, use testing preprocess for both train and test preprocess_func, unprocess_func = preprocessing.preprocessing_factory.get_preprocessing( args.loss_model, is_training=False) images = reader.image( 1 , args.size , args.size, args.test_dir , preprocess_func, \ 1 , shuffle = False) model = transform(sess, args) transformed_images = model.generator(images, reuse=False) unprocess_transform = [ unprocess_func(img) for img in tf.unstack( transformed_images, axis=0, num=args.batch) ] all_vars = tf.global_variables() to_restore = [ var for var in all_vars if not args.loss_model in var.name ] sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) saver = tf.train.Saver(to_restore) style_name = (args.style_dir.split('/')[-1]).split('.')[0] ckpt = tf.train.latest_checkpoint( os.path.join(args.ckpt_dir, style_name)) if ckpt: tf.logging.info('Restoring model from {}'.format(ckpt)) saver.restore(sess, ckpt) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time() i = 0 try: while True: images = sess.run(unprocess_transform) for img in images: path = os.path.join(args.save_dir, str(i) + '.jpg') scipy.misc.imsave(path, img) i = i + 1 except tf.errors.OutOfRangeError: print('eval finished') finally: coord.request_stop() coord.join(threads)
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) #style target的Gram #make sure the training path exists training_path = os.path.join(FLAGS.model_path,FLAGS.naming) #model/wave/ ;用于存放训练好的模型 if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): #默认计算图 with tf.Session() as sess:#没有as_default(),因此,走出with 语句,sess停止执行,不能在被用 """build loss network""" network_fn =nets_factory.get_network_fn(FLAGS.loss_model,num_classes=1,is_training=False) #取出loss model,且该model不用训练 #对要进入loss_model的content_image,和generated_image进行preprocessing image_preprocessing_fn,image_unpreprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.loss_model,is_training=False) #取出用于loss_model的,对image进行preprocessing和unpreprocessing的function processed_image = reader.image(FLAGS.batch_size,FLAGS.image_size,FLAGS.image_size,'train2014/',image_preprocessing_fn,epochs=FLAGS.epoch) #这里要preprocessing的image是一个batch,为training_data generated = model.net(processed_images,training=True) #输入“图像生成网络”的image为经过preprocessing_image,“图像生成网络”为要训练的网络 processed_generated = [image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size) for image in tf.unstack(generated,axis=0,num=FLAGS.batch_size)] processed_generated = tf.stack(processed_generated) #计算generated_image和content_image进入loss_model后,更layer的output _,endpoints_dict= network_fn(tf.concat([processed_generated,processed_images],0),spatial_squeeze=False)#endpoints_dict中存储的是2类image各个layer的值 #log the structure of loss network tf.logging.info('loss network layers(you can define them in "content layer" and "style layer"):') for key in endpoints_dict: tf.logging.info(key) #屏幕输出loss_model的各个layer name """build losses""" content_loss = losses.content_loss(endpoints_dict,FLAGS.content_layers) style_loss,style_loss_summary = losses.style_loss(endpoints_dict,style_features_t,FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard """Add Summary""" tf.summary.scalar('losses/content_loss',content_loss) tf.summary.scalar('losses/style_loss',style_loss) tf.summary.scalar('losses/regularizer_loss',tv_loss) tf.summary.scalar('weighted_losses/weighted content_loss',content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted style_loss',style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss',tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss',loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer,style_loss_summary[layer]) tf.summary.image('genearted',generated) tf.summary.image('origin',tf.stack([image_unprocessing_fn(image) for image in tf.unstack(processed_images,axis=0,num=FLAGS.batch_size)])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """prepare to train""" global_step = tf.Variable(0,name='global_step',trainable=False)#iteration step variable_to_train = []#需要训练的变量 for variable in tf.trainable_variables():#在图像风格迁移网络(图像生成网络+损失网络)各参数中,找需要训练的参数 if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss,global_step = global_step,var_list = variable_to_train) #需要放入sess.run() variable_to_restore = []#在所有的全局变量中,找需要恢复默认设置的变量; 注意:local_variable指的是一些临时变量和中间变量,用于线程中,线程结束则消失 for v tf.global_variables(): if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore,write_version=tf.train.SaverDef.V1)#利用saver.restore()恢复默认设置;这里的variable_to_restore,是需要save and restore的var_list sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])#对全局变量和局部变量进行初始化操作:即恢复默认设置 #restore variables for loss model 恢复loss model中的参数 init_func = utils._get_init_fn(FLAGS) init_func(sess) #restore variables for training model if the checkpoint file exists. 如果training_model已有训练好的参数,将其载入 last_file = tf.train.latest_checkpoint(training_path)#将train_path中的model参数数据取出 if last_file: tf.logging.info('restoringmodel from {}'.format(last_file)) saver.restore(sess,last_file) #那如果last_file不存在,就不执行restore操作吗?需要restore的参数只是图像生成网络吗? """start training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop():#查看线程是否停止(即:是否所有数据均运行完毕) _,loss_t,step = sess.run([train_op,loss,global_step]) elapsed_time = time.time() """logging""" #print(step) if step % 10 == 0: tf.logging.info('step:%d, total loss %f, secs/step: %f' % (step,loss_t,elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str,step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt'),global_step=step)#保存variable_to_restore中的参数值 except tf.errors.OutOfRangeError: saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop()#要求停止所有线程 coord.join(threads)#将线程并入主线程,删除
def train(): style_feature, style_grams = get_style_feature() with tf.Graph().as_default(): with tf.Session() as sess: #loss_input_style = tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) #loss_input_target =tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) # For online optimization problem, use testing preprocess for both train and test preprocess_func, unprocess_func = preprocessing.preprocessing_factory.get_preprocessing( args.loss_model, is_training=False) images = reader.image(args.batch, args.size , args.size, args.target_dir , preprocess_func, \ args.epoch , shuffle = True) model = transform(sess, args) transformed_images = model.generator(images, reuse=False) #print('qqq') #print( tf.shape(transformed_images).eval()) unprocess_transform = [(img) for img in tf.unstack( transformed_images, axis=0, num=args.batch)] processed_generated = [ preprocess_func(img, args.size, args.size) for img in unprocess_transform ] processed_generated = tf.stack(processed_generated) loss_model = nets.nets_factory.get_network_fn(args.loss_model, num_classes=1, is_training=False) pair = tf.concat([processed_generated, images], axis=0) _, end_dicts = loss_model(pair, spatial_squeeze=False) init_loss_model = load_pretrained_weight(args.loss_model) c_loss = losses.content_loss( end_dicts, loss_config.content_loss_dict[args.loss_model]) s_loss, s_loss_sum = losses.style_loss( end_dicts, loss_config.style_loss_dict[args.loss_model], style_grams) tv_loss = losses.total_variation_loss(transformed_images) loss = args.c_weight * c_loss + args.s_weight * s_loss + args.tv_weight * tv_loss print('shapes') print(pair.get_shape()) #tf.summary.scalar('average', tf.reduce_mean(images)) #tf.summary.scalar('gram average', tf.reduce_mean(tf.stack(style_feature))) tf.summary.scalar('losses/content_loss', c_loss) tf.summary.scalar('losses/style_loss', s_loss) tf.summary.scalar('losses/tv_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', c_loss * args.c_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', s_loss * args.s_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * args.tv_weight) tf.summary.scalar('total_loss', loss) for layer in loss_config.style_loss_dict[args.loss_model]: tf.summary.scalar('style_losses/' + layer, s_loss_sum[layer]) tf.summary.image('transformed', tf.stack(unprocess_transform, axis=0)) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image( 'ori', tf.stack([ unprocess_func(image) for image in tf.unstack(images, axis=0, num=args.batch) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(args.log_dir) step = tf.Variable(0, name='global_step', trainable=False) all_trainables = tf.trainable_variables() all_vars = tf.global_variables() to_train = [ var for var in all_trainables if not args.loss_model in var.name ] to_restore = [ var for var in all_vars if not args.loss_model in var.name ] optim = tf.train.AdamOptimizer( 1e-3 ).minimize(\ loss = loss , var_list = to_train , global_step = step) saver = tf.train.Saver(to_restore) style_name = (args.style_dir.split('/')[-1]).split('.')[0] ckpt = tf.train.latest_checkpoint( os.path.join(args.ckpt_dir, style_name)) if ckpt: tf.logging.info('Restoring model from {}'.format(ckpt)) saver.restore(sess, ckpt) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) #sess.run(init_loss_model) init_loss_model(sess) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time() #i = 0 try: while True: _, gs, sum_info, c_info, s_info, tv_info, loss_info = sess.run( [optim, step, summary, c_loss, s_loss, tv_loss, loss]) writer.add_summary(sum_info, gs) elapsed = time() - start_time print(gs) if gs % 10 == 0: tf.logging.info( 'step: %d, c_loss %f s_loss %f tv_loss %f total Loss %f, secs/step: %f' % (gs, c_info, s_info, tv_info, loss_info, elapsed)) if gs % args.save_freq == 0: saver.save( sess, os.path.join(args.ckpt_dir, style_name, style_name + '.ckpt')) except tf.errors.OutOfRangeError: print('run out of images! save final model: ' + os.path.join(args.ckpt_dir, style_name + '.ckpt-done')) saver.save( sess, os.path.join(args.ckpt_dir, style_name, style_name + '.ckpt-done')) tf.logging.info('Done -- file ran out of range') finally: coord.request_stop() coord.join(threads) print('end training') '''
def main(argv=None): run_id = FLAGS.NAME if FLAGS.NAME else str(uuid.uuid4()) model_path = '%s/%s' % (FLAGS.MODEL_PATH, run_id) if not os.path.exists(model_path): os.makedirs(model_path) summary_path = '%s/%s' % (FLAGS.SUMMARY_PATH, run_id) if not os.path.exists(summary_path): os.makedirs(summary_path) style_paths = FLAGS.STYLE_IMAGES.split(',') style_layers = FLAGS.STYLE_LAYERS.split(',') content_layers = FLAGS.CONTENT_LAYERS.split(',') style_features_t = get_style_features(style_paths, style_layers) images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH) generated = model.net(images - reader.mean_pixel, training=True) # Put both generated and training images in same batch through VGG net for efficiency net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images]) - reader.mean_pixel) content_loss = 0 for layer in content_layers: generated_images, content_images = tf.split(0, 2, net[layer]) size = tf.size(generated_images) shape = tf.shape(generated_images) width = shape[1] height = shape[2] num_filters = shape[3] content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size) content_loss = content_loss style_loss = 0 for style_grams, layer in zip(style_features_t, style_layers): generated_images, _ = tf.split(0, 2, net[layer]) size = tf.size(generated_images) for style_gram in style_grams: style_loss += tf.nn.l2_loss(gram(generated_images) - style_gram) / tf.to_float(size) style_loss = style_loss / len(style_paths) tv_loss = total_variation_loss(generated) loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * tv_loss global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step) # Statistics with tf.name_scope('losses'): tf.scalar_summary('content loss', content_loss) tf.scalar_summary('style loss', style_loss) tf.scalar_summary('regularizer loss', tv_loss) with tf.name_scope('weighted_losses'): tf.scalar_summary('weighted content loss', content_loss * FLAGS.CONTENT_WEIGHT) tf.scalar_summary('weighted style loss', style_loss * FLAGS.STYLE_WEIGHT) tf.scalar_summary('weighted regularizer loss', tv_loss * FLAGS.TV_WEIGHT) tf.scalar_summary('total loss', loss) tf.image_summary('original', images) tf.image_summary('generated', generated) summary = tf.merge_all_summaries() with tf.Session() as sess: writer = tf.train.SummaryWriter(summary_path, sess.graph) saver = tf.train.Saver(tf.all_variables()) file = tf.train.latest_checkpoint(model_path) sess.run([tf.initialize_all_variables(), tf.initialize_local_variables()]) if file: print('Restoring model from {}'.format(file)) saver.restore(sess, file) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 100 == 0: print(step, loss_t, elapsed_time) summary_str = sess.run(summary) writer.add_summary(summary_str, step) if step % 10000 == 0: saver.save(sess, model_path + '/fast-style-model', global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, model_path + '/fast-style-model-done') print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) # Make sure the training path exists. training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not(os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """Build Network""" network_fn = nets_factory.get_network_fn( FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.net(processed_images, training=True) processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') for key in endpoints_dict: tf.logging.info(key) """Build Losses""" content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss', loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image('origin', tf.stack([ image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not(variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not(v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) # Restore variables for loss network. init_func = utils._get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" # print(step) if step % 10 == 0: tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(argv=None): if FLAGS.CONTENT_IMAGES_PATH: content_images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.CONTENT_IMAGES_PATH, epochs=1, shuffle=False, crop=False) generated_images = model.net(content_images / 255.) output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8) with tf.Session() as sess: file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH) if not file: print('Could not find trained model in {}'.format( FLAGS.MODEL_PATH)) return print('Using model from {}'.format(file)) saver = tf.train.Saver() saver.restore(sess, file) sess.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) i = 0 start_time = time.time() try: while not coord.should_stop(): print(i) images_t = sess.run(output_format) elapsed = time.time() - start_time start_time = time.time() print('Time for one batch: {}'.format(elapsed)) for raw_image in images_t: i += 1 misc.imsave('out{0:04d}.png'.format(i), raw_image) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads) return if not os.path.exists(FLAGS.MODEL_PATH): os.makedirs(FLAGS.MODEL_PATH) style_paths = FLAGS.STYLE_IMAGES.split(',') style_layers = FLAGS.STYLE_LAYERS.split(',') content_layers = FLAGS.CONTENT_LAYERS.split(',') style_features_t = get_style_features(style_paths, style_layers) images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH) generated = model.net(images / 255.) net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images])) content_loss = 0 for layer in content_layers: generated_images, content_images = tf.split(0, 2, net[layer]) size = tf.size(generated_images) content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size) content_loss = content_loss / len(content_layers) style_loss = 0 for style_gram, layer in zip(style_features_t, style_layers): generated_images, _ = tf.split(0, 2, net[layer]) size = tf.size(generated_images) for style_image in style_gram: style_loss += tf.nn.l2_loss( tf.reduce_sum(gram(generated_images) - style_image, 0)) / tf.to_float(size) style_loss = style_loss / len(style_layers) loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * total_variation_loss( generated) global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step) output_format = tf.saturate_cast( tf.concat(0, [generated, images]) + reader.mean_pixel, tf.uint8) with tf.Session() as sess: saver = tf.train.Saver(tf.all_variables()) file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH) if file: print('Restoring model from {}'.format(file)) saver.restore(sess, file) sess.run(tf.initialize_local_variables()) else: print('New model initilized') sess.run(tf.initialize_all_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() print(step, loss_t, elapsed_time) if step % 100 == 0: output_t = sess.run(output_format) for i, raw_image in enumerate(output_t): misc.imsave('out{}.png'.format(i), raw_image) print('Save image.') if step % 1000 == 0: saver.save(sess, FLAGS.MODEL_PATH + '/fast-style-model', global_step=step) print('Save model.') except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(argv=None): run_id = FLAGS.NAME if FLAGS.NAME else str(uuid.uuid4()) model_path = '%s/%s' % (FLAGS.MODEL_PATH, run_id) if not os.path.exists(model_path): os.makedirs(model_path) summary_path = '%s/%s' % (FLAGS.SUMMARY_PATH, run_id) if not os.path.exists(summary_path): os.makedirs(summary_path) style_paths = FLAGS.STYLE_IMAGES.split(',') style_layers = FLAGS.STYLE_LAYERS.split(',') content_layers = FLAGS.CONTENT_LAYERS.split(',') style_features_t = get_style_features(style_paths, style_layers) images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH) generated = model.net(images - reader.mean_pixel, training=True) # Put both generated and training images in same batch through VGG net for efficiency net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images]) - reader.mean_pixel) content_loss = 0 for layer in content_layers: generated_images, content_images = tf.split(0, 2, net[layer]) size = tf.size(generated_images) shape = tf.shape(generated_images) width = shape[1] height = shape[2] num_filters = shape[3] content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size) content_loss = content_loss style_loss = 0 for style_grams, layer in zip(style_features_t, style_layers): generated_images, _ = tf.split(0, 2, net[layer]) size = tf.size(generated_images) for style_gram in style_grams: style_loss += tf.nn.l2_loss(gram(generated_images) - style_gram) / tf.to_float(size) style_loss = style_loss / len(style_paths) tv_loss = total_variation_loss(generated) loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * tv_loss global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step) # Statistics with tf.name_scope('losses'): tf.scalar_summary('content loss', content_loss) tf.scalar_summary('style loss', style_loss) tf.scalar_summary('regularizer loss', tv_loss) with tf.name_scope('weighted_losses'): tf.scalar_summary('weighted content loss', content_loss * FLAGS.CONTENT_WEIGHT) tf.scalar_summary('weighted style loss', style_loss * FLAGS.STYLE_WEIGHT) tf.scalar_summary('weighted regularizer loss', tv_loss * FLAGS.TV_WEIGHT) tf.scalar_summary('total loss', loss) tf.image_summary('original', images) tf.image_summary('generated', generated) summary = tf.merge_all_summaries() step=0 with tf.Session() as sess: writer = tf.train.SummaryWriter(summary_path, sess.graph) saver = tf.train.Saver(tf.all_variables()) file = tf.train.latest_checkpoint(model_path) sess.run([tf.initialize_all_variables(), tf.initialize_local_variables()]) if file: print('Restoring model from {}'.format(file)) saver.restore(sess, file) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while step<=1000: _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 10== 0: print(step, loss_t, elapsed_time) summary_str = sess.run(summary) writer.add_summary(summary_str, step) if step % 10 == 0: saver.save(sess, model_path + '/fast-style-model', global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, model_path + '/fast-style-model-done') print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) # Make sure the training path exists. training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not(os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """Build Network""" network_fn = nets_factory.get_network_fn( FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.net(processed_images, training=True) processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') for key in endpoints_dict: tf.logging.info(key) """Build Losses""" content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss', loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image('origin', tf.stack([ image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not(variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not(v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) # Restore variables for loss network. init_func = utils._get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" # print(step) if step % 10 == 0: tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def perceptual_loss(net_type): """Compute perceptual loss of content and style Return: generated 前向生成网络 images 输入图片(batch based) loss 各种loss. """ # Set style image style_paths = FLAGS.style_images.split(',') # Set style layers and content layers in vgg net style_layers = FLAGS.style_layers.split(',') content_layers = FLAGS.content_layers.split(',') # Get style feature, pre calculated and save it in memory style_features_t = get_style_features(style_paths, style_layers, net_type) # Read images from dataset images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.train_images_path, epochs=FLAGS.epoch) # Transfer images # 为什么要换成0-1编码? # 这里和里面的处理对应起来, 虽然这么写很丑, 也容易忘 generated = model.net(images / 255) # generated = model.net(tf.truncated_normal(images.get_shape(), stddev=0.3)) # Process generated and original images with vgg net, _ = vgg.net(FLAGS.vgg_path, tf.concat([generated, images], 0), net_type) # Get content loss content_loss = 0 for layer in content_layers: # 平均分为两组,每组都是batch长度的图片组 gen_features, images_features = tf.split(net[layer], num_or_size_splits=2, axis=0) size = tf.size(gen_features) content_loss += tf.nn.l2_loss(gen_features - images_features) / tf.to_float(size) content_loss /= len(content_layers) # Get Style loss style_loss = 0 for style_gram, layer in zip(style_features_t, style_layers): gen_features, _ = tf.split(net[layer], num_or_size_splits=2, axis=0) size = tf.size(gen_features) # Calculate style loss for each style image for style_image in style_gram: style_loss += tf.nn.l2_loss( model.gram(gen_features, FLAGS.batch_size) - style_image) / tf.to_float(size) style_loss /= len(style_layers) # Total loss total_v_loss = total_variation_loss(generated) loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * total_v_loss return generated, images, content_loss, style_loss, total_v_loss, loss