def total_loss(combination_image, base_image, style_image): input_tensor = tf.concat([base_image, style_image, combination_image], axis=0) features = feature_model(input_tensor) # Initialize the loss loss = tf.zeros(shape=()) layer_features = features[content_layer_name] content_image_features = layer_features[0, :, :, :] combination_features = layer_features[2, :, :, :] loss = loss + CONTENT_WEIGHT * content_loss(content_image_features, combination_features) for layer_name in style_layer_names: layer_features = features[layer_name] style_features = layer_features[1, :, :, :] combination_features = layer_features[2, :, :, :] sl = style_loss(style_features, combination_features, size=image_height * image_width) loss += (STYLE_WEIGHT / len(style_layer_names)) * sl loss += TV_WEIGHT * \ total_variation_loss(combination_image, image_height, image_width) return loss
def main(argv=None): network_fn = nets_factory.get_network_fn('vgg_16', num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( 'vgg_16', is_training=False) preprocess_content_image = reader.get_image(FLAGS.CONTENT_IMAGE, FLAGS.IMAGE_SIZE) # add bath for vgg net training preprocess_content_image = tf.expand_dims(preprocess_content_image, 0) _, endpoints_dict = network_fn(preprocess_content_image, spatial_squeeze=False) # Log the structure of loss network tf.logging.info( 'Loss network layers(You can define them in "content_layers" and "style_layers"):' ) for key in endpoints_dict: tf.logging.info(key) """Build Losses""" # style_features_t = losses.get_style_features(endpoints_dict, FLAGS.STYLE_LAYERS) content_loss, generaged_image = losses.content_loss( endpoints_dict, FLAGS.CONTENT_LAYERS, FLAGS.CONTENT_IMAGE) style_loss, style_loss_summary = losses.style_loss(endpoints_dict, FLAGS.style_layers, FLAGS.STYLE_IMAGE) tv_loss = losses.total_variation_loss( generaged_image) # use the unprocessed image loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * tv_loss train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize(loss) output_image = tf.image.encode_png( tf.saturate_cast( tf.squeeze(generaged_image) + reader.mean_pixel, tf.uint8)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) start_time = time.time() for step in range(FLAGS.NUM_ITERATIONS): _, loss_t, cl, sl = sess.run( [train_op, loss, content_loss, style_loss]) elapsed = time.time() - start_time start_time = time.time() print(step, elapsed, loss_t, cl, sl) image_t = sess.run(output_image) with open('out.png', 'wb') as f: f.write(image_t)
def _total_content_loss(sess, network, content, config: ArtistConfig): sess.run(network['input'].assign(content)) loss = 0. if config.verbose: print('Content Layer: ') for indx, w in zip(config.content_layers, config.content_layer_weights): if config.verbose: print(f'\t{list(network.keys())[indx]}') p = tf.convert_to_tensor( sess.run(network[list(network.keys())[indx]])) x = network[list(network.keys())[indx]] loss += content_loss(p, x) loss /= float(len(config.content_layer_weights)) return loss
def main(): # Make sure the training path exists. training_path = 'models/log/' if not (os.path.exists(training_path)): os.makedirs(training_path) style_features = get_style_feature() with tf.Graph().as_default(): train_image = reader.get_train_image(batch_size, image_size, image_size, dataset_path) generated = model.net(train_image) processed_generated = [ reader.prepose_image(image, image_size, image_size) for image in tf.unstack(generated, axis=0, num=batch_size) ] processed_generated = tf.stack(processed_generated) net = model.load_model( tf.concat([processed_generated, train_image], 0), vgg16_ckpt_path) with tf.Session() as sess: """Build Losses""" content_loss = losses.content_loss(net, content_layers) style_loss, style_loss_summary = losses.style_loss( net, style_features, style_layers) tv_loss = losses.total_variation_loss( generated) # use the unprocessed image loss = style_weight * style_loss + content_weight * content_loss + tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * tv_weight) tf.summary.scalar('total_loss', loss) for layer in style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) tf.summary.image( 'origin', tf.stack([ reader.mean_add(image) for image in tf.unstack( train_image, axis=0, num=batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize( loss, global_step=global_step, var_list=tf.trainable_variables()) saver = tf.train.Saver(tf.trainable_variables()) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) loss_c, loss_s = sess.run([content_loss, style_loss]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" if step % 10 == 0: print( 'step: %d, content Loss %f, style Loss %f, total Loss %f, secs/step: %f' % (step, loss_c, loss_s, loss_t, elapsed_time)) """summary""" if step % 25 == 0: print('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save( sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
include_top=False) print('Loaded Model') #creating a dictionary with the layer name as key and the layer output as value output_dict = dict([(layer.name, layer.output) for layer in model.layers]) #initializing the loss as a tensorflow variable loss = K.variable(0.) layer_features = output_dict['block2_conv2'] #content_image_features output at block2_conv2 content_image_features = layer_features[0, :, :, :] #combination_image_features output at block2_conv2 combination_image_features = layer_features[2, :, :, :] #calculating the content loss loss += content_weight * content_loss(content_image_features, combination_image_features) #layers at which the style loss is to be calculated feature_layers = [ 'block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3', 'block5_conv3' ] #iterating over the feature layers for layer_name in feature_layers: layer_features = output_dict[layer_name] #style_image_features output at the particular feature layer style_image_features = layer_features[1, :, :, :] #combination_image_features output at the particular feature layer combination_image_features = layer_features[2, :, :, :] #calculating the style loss
def solve(Config): gc.enable() # get the style feature style_features = losses.get_style_feature(Config) # prepare some dirs for use # tf.reset_default_graph() model_dir = Config.model_dir if not osp.exists(model_dir): os.mkdir(model_dir) # construct the graph and model # prepare the dataset images = Dataset(Config).imagedata_pipelines() # the trainnet generated = model.inference_trainnet(images) # concat the content image and the generated together to save time and feed to the vgg net one time # preprocess the generated preprocess_generated = preprocess(generated, Config) layer_infos = Vgg(Config.feature_path).build( tf.concat([preprocess_generated, images], 0)) # get the loss content_loss = losses.content_loss(layer_infos, Config.content_layers) style_loss = losses.style_loss(layer_infos, Config.style_layers, style_features) tv_loss = losses.tv_loss(generated) loss = Config.style_weight * style_loss + Config.content_weight * content_loss + Config.tv_weight * tv_loss # train op global_step = tf.Variable(0, name='global_step', trainable=False) train_op = tf.train.AdamOptimizer(Config.lr).minimize( loss, global_step=global_step) # add summary with tf.name_scope('losses'): tf.summary.scalar('content_loss', content_loss) tf.summary.scalar('style_loss', style_loss) tf.summary.scalar('tv_loss', tv_loss) with tf.name_scope('weighted_losses'): tf.summary.scalar('weighted_content_loss', content_loss * Config.content_weight) tf.summary.scalar('weighted_style_loss', style_loss * Config.style_weight) tf.summary.scalar('weighted_tv_loss', tv_loss * Config.tv_weight) tf.summary.scalar('total_loss', loss) tf.summary.image('generated', generated) tf.summary.image('original', images) summary = tf.summary.merge_all() summary_path = osp.join(model_dir, 'summary') if not osp.exists(summary_path): os.mkdir(summary_path) writer = tf.summary.FileWriter(summary_path) # the saver loader saver = tf.train.Saver(tf.global_variables()) #for var in tf.global_variables(): # print var restore = tf.train.latest_checkpoint(model_dir) # begin training work config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: # restore the variables sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) # if we need finetune? if Config.finetune: if restore: print 'restoring model from {}'.format(restore) saver.restore(sess, restore) else: print 'no model exist, from scratch' # pop the data queue coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print 'begin training' start_time = time.time() local_time = time.time() for step in xrange(Config.max_iter + 1): _, loss_value = sess.run([train_op, loss]) #plt.imshow(np.uint8(gen[0,...])) if step % Config.display == 0 or step == Config.max_iter: print "{}[iterations], train loss {}, time consumes {}s".format( step, loss_value, time.time() - local_time) local_time = time.time() assert not np.isnan(loss_value), 'model with loss nan' if step != 0 and (step % Config.snapshot == 0 or step == Config.max_iter): # save the generated to see print 'adding summary and saving snapshot...' saver.save(sess, osp.join(model_dir, 'model.ckpt'), global_step=step) summary_str = sess.run(summary) writer.add_summary(summary_str, global_step=step) writer.flush() coord.request_stop() coord.join(threads) sess.close() print 'done, consumes time {}s'.format(time.time() - start_time)
def main(unused_agrv=None): """main :param args: argparse.Namespace object from argparse.parse_args(). """ # Unpack command-line arguments. train_dir = FLAGS.train_dir style_dataset = FLAGS.style_dataset model_name = FLAGS.model_name preprocess_size = [FLAGS.image_size, FLAGS.image_size] batch_size = FLAGS.batch_size n_epochs = FLAGS.n_epochs learn_rate = FLAGS.learning_rate content_weights = FLAGS.content_weights style_weights = FLAGS.style_weights num_pipe_buffer = FLAGS.num_pipe_buffer num_styles = FLAGS.num_styles train_steps = FLAGS.train_steps upsample_method = FLAGS.upsample_method # Setup input pipeline (delegate it to CPU to let GPU handle neural net) files = tf.train.match_filenames_once(train_dir + '/train-*') style_files = tf.train.match_filenames_once(style_dataset) print("style %s" % style_files) with tf.variable_scope('input_pipe'), tf.device('/cpu:0'): _, style_labels, style_grams = datapipe.style_batcher( style_files, batch_size, preprocess_size, n_epochs, num_pipe_buffer) batch_op = datapipe.batcher(files, batch_size, preprocess_size, n_epochs, num_pipe_buffer) """ Set up weight of style and content image """ content_weights = ast.literal_eval(content_weights) style_weights = ast.literal_eval(style_weights) target_grams = [] for name, val in style_weights.iteritems(): target_grams.append(style_grams[name]) # Alter the names to include a namescope that we'll use + output suffix. loss_style_layers = [] loss_style_weights = [] loss_content_layers = [] loss_content_weights = [] for key, val in style_weights.iteritems(): loss_style_layers.append(key + ':0') loss_style_weights.append(val) for key, val in content_weights.iteritems(): loss_content_layers.append(key + ':0') loss_content_weights.append(val) # Load in image transformation network into default graph. shape = [batch_size] + preprocess_size + [3] with tf.variable_scope('styleNet'): X = tf.placeholder(tf.float32, shape=shape, name='input') Y = transform(X, style_labels, num_styles, upsample_method) print(Y) # Connect vgg directly to the image transformation network. with tf.variable_scope('vgg'): vggnet = vgg16.vgg16(Y) # Get the gram matrices' tensors for the style loss features. input_img_grams = losses.get_grams(loss_style_layers) # Get the tensors for content loss features. content_layers = losses.get_layers(loss_content_layers) # Create loss function content_targets = tuple( tf.placeholder(tf.float32, shape=layer.get_shape(), name='content_input_{}'.format(i)) for i, layer in enumerate(content_layers)) cont_loss = losses.content_loss(content_layers, content_targets, loss_content_weights) style_loss = losses.style_loss(input_img_grams, target_grams, loss_style_weights) tv_loss = losses.tv_loss(Y) loss = cont_loss + style_loss + tv_loss # We do not want to train VGG, so we must grab the subset. train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='styleNet') # Setup step + optimizer global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learn_rate).minimize( loss, global_step, train_vars) if not os.path.exists('./models'): # Dir that save final models to os.makedirs('./models') final_saver = tf.train.Saver(train_vars) # We must include local variables because of batch pipeline. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # Begin training. print 'Starting training...' with tf.Session() as sess: # Initialization sess.run(init_op) vggnet.load_weights(vgg16.checkpoint_file(), sess) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): current_step = sess.run(global_step) batch = sess.run(batch_op) # Collect content targets content_data = sess.run(content_layers, feed_dict={Y: batch}) feed_dict = {X: batch, content_targets: content_data} _, loss_out = sess.run([optimizer, loss], feed_dict=feed_dict) if (current_step % 10 == 0): print current_step, loss_out # Throw error if we reach number of steps to break after. if current_step == train_steps: print('Done training.') break except tf.errors.OutOfRangeError: print('Done training.') finally: # Save the model (the image transformation network) for later usage # in predict.py final_saver.save(sess, 'models/' + model_name + '_final.ckpt', write_meta_graph=False) coord.request_stop() coord.join(threads)
def main(args): """main :param args: argparse.Namespace object from argparse.parse_args(). """ # Unpack command-line arguments. train_dir = args.train_dir style_img_path = args.style_img_path model_name = args.model_name preprocess_size = args.preprocess_size batch_size = args.batch_size n_epochs = args.n_epochs run_name = args.run_name learn_rate = args.learn_rate loss_content_layers = args.loss_content_layers loss_style_layers = args.loss_style_layers content_weights = args.content_weights style_weights = args.style_weights num_steps_ckpt = args.num_steps_ckpt num_pipe_buffer = args.num_pipe_buffer num_steps_break = args.num_steps_break beta_val = args.beta style_target_resize = args.style_target_resize upsample_method = args.upsample_method # Load in style image that will define the model. style_img = utils.imread(style_img_path) style_img = utils.imresize(style_img, style_target_resize) style_img = style_img[np.newaxis, :].astype(np.float32) # Alter the names to include a namescope that we'll use + output suffix. loss_style_layers = ['vgg/' + i + ':0' for i in loss_style_layers] loss_content_layers = ['vgg/' + i + ':0' for i in loss_content_layers] # Get target Gram matrices from the style image. with tf.variable_scope('vgg'): X_vgg = tf.placeholder(tf.float32, shape=style_img.shape, name='input') vggnet = vgg16.vgg16(X_vgg) with tf.Session() as sess: vggnet.load_weights('libs/vgg16_weights.npz', sess) print('Precomputing target style layers.') target_grams = sess.run(utils.get_grams(loss_style_layers), feed_dict={X_vgg: style_img}) # Clean up so we can re-create vgg connected to our image network. print('Resetting default graph.') tf.reset_default_graph() # Load in image transformation network into default graph. shape = [batch_size] + preprocess_size + [3] with tf.variable_scope('img_t_net'): X = tf.placeholder(tf.float32, shape=shape, name='input') Y = create_net(X, upsample_method) # Connect vgg directly to the image transformation network. with tf.variable_scope('vgg'): vggnet = vgg16.vgg16(Y) # Get the gram matrices' tensors for the style loss features. input_img_grams = utils.get_grams(loss_style_layers) # Get the tensors for content loss features. content_layers = utils.get_layers(loss_content_layers) # Create loss function content_targets = tuple( tf.placeholder(tf.float32, shape=layer.get_shape(), name='content_input_{}'.format(i)) for i, layer in enumerate(content_layers)) cont_loss = losses.content_loss(content_layers, content_targets, content_weights) style_loss = losses.style_loss(input_img_grams, target_grams, style_weights) tv_loss = losses.tv_loss(Y) beta = tf.placeholder(tf.float32, shape=[], name='tv_scale') loss = cont_loss + style_loss + beta * tv_loss with tf.name_scope('summaries'): tf.summary.scalar('loss', loss) tf.summary.scalar('style_loss', style_loss) tf.summary.scalar('content_loss', cont_loss) tf.summary.scalar('tv_loss', beta * tv_loss) # Setup input pipeline (delegate it to CPU to let GPU handle neural net) files = tf.train.match_filenames_once(train_dir + '/train-*') with tf.variable_scope('input_pipe'), tf.device('/cpu:0'): batch_op = datapipe.batcher(files, batch_size, preprocess_size, n_epochs, num_pipe_buffer) # We do not want to train VGG, so we must grab the subset. train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='img_t_net') # Setup step + optimizer global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learn_rate).minimize( loss, global_step, train_vars) # Setup subdirectory for this run's Tensoboard logs. if not os.path.exists('./summaries/train/'): os.makedirs('./summaries/train/') if run_name is None: current_dirs = [ name for name in os.listdir('./summaries/train/') if os.path.isdir('./summaries/train/' + name) ] name = model_name + '0' count = 0 while name in current_dirs: count += 1 name = model_name + '{}'.format(count) run_name = name # Savers and summary writers if not os.path.exists('./training'): # Dir that we'll later save .ckpts to os.makedirs('./training') if not os.path.exists('./models'): # Dir that save final models to os.makedirs('./models') saver = tf.train.Saver() final_saver = tf.train.Saver(train_vars) merged = tf.summary.merge_all() full_log_path = './summaries/train/' + run_name train_writer = tf.summary.FileWriter(full_log_path) # We must include local variables because of batch pipeline. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # Begin training. print('Starting training...') with tf.Session() as sess: # Initialization sess.run(init_op) vggnet.load_weights('libs/vgg16_weights.npz', sess) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): current_step = sess.run(global_step) batch = sess.run(batch_op) # Collect content targets content_data = sess.run(content_layers, feed_dict={Y: batch}) feed_dict = { X: batch, content_targets: content_data, beta: beta_val } if (current_step % num_steps_ckpt == 0): # Save a checkpoint save_path = 'training/' + model_name + '.ckpt' saver.save(sess, save_path, global_step=global_step) summary, _, loss_out = sess.run([merged, optimizer, loss], feed_dict=feed_dict) train_writer.add_summary(summary, current_step) print(current_step, loss_out) elif (current_step % 10 == 0): # Collect some diagnostic data for Tensorboard. summary, _, loss_out = sess.run([merged, optimizer, loss], feed_dict=feed_dict) train_writer.add_summary(summary, current_step) # Do some standard output. print(current_step, loss_out) else: _, loss_out = sess.run([optimizer, loss], feed_dict=feed_dict) # Throw error if we reach number of steps to break after. if current_step == num_steps_break: print('Done training.') break except tf.errors.OutOfRangeError: print('Done training.') finally: # Save the model (the image transformation network) for later usage # in predict.py final_saver.save(sess, 'models/' + model_name + '_final.ckpt') coord.request_stop() coord.join(threads)
'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'] total_variation_weight = 10 ** -4 style_weight = 1. content_weight = 0.025 loss = K.variable(0.) layer_features = outputs_dict[content_layer] target_image_features = layer_features[0, :, :, :] combination_features = layer_features[2, :, :, :] loss.assign_add(content_weight * content_loss(target_image_features,combination_features)) for layer_name in style_layers: layer_features = outputs_dict[layer_name] style_reference_features = layer_features[1, :, :, :] combination_features = layer_features[2, :, :, :] sl = style_loss(style_reference_features, combination_features, img_height, img_width) loss.assign_add((style_weight / len(style_layers)) * sl) loss.assign_add(total_variation_weight * total_variation_loss(combination_image, img_height, img_width)) grads = K.gradients(loss, combination_image)[0] fetch_loss_and_grads = K.function([combination_image], [loss, grads])
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) #style target的Gram #make sure the training path exists training_path = os.path.join(FLAGS.model_path,FLAGS.naming) #model/wave/ ;用于存放训练好的模型 if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): #默认计算图 with tf.Session() as sess:#没有as_default(),因此,走出with 语句,sess停止执行,不能在被用 """build loss network""" network_fn =nets_factory.get_network_fn(FLAGS.loss_model,num_classes=1,is_training=False) #取出loss model,且该model不用训练 #对要进入loss_model的content_image,和generated_image进行preprocessing image_preprocessing_fn,image_unpreprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.loss_model,is_training=False) #取出用于loss_model的,对image进行preprocessing和unpreprocessing的function processed_image = reader.image(FLAGS.batch_size,FLAGS.image_size,FLAGS.image_size,'train2014/',image_preprocessing_fn,epochs=FLAGS.epoch) #这里要preprocessing的image是一个batch,为training_data generated = model.net(processed_images,training=True) #输入“图像生成网络”的image为经过preprocessing_image,“图像生成网络”为要训练的网络 processed_generated = [image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size) for image in tf.unstack(generated,axis=0,num=FLAGS.batch_size)] processed_generated = tf.stack(processed_generated) #计算generated_image和content_image进入loss_model后,更layer的output _,endpoints_dict= network_fn(tf.concat([processed_generated,processed_images],0),spatial_squeeze=False)#endpoints_dict中存储的是2类image各个layer的值 #log the structure of loss network tf.logging.info('loss network layers(you can define them in "content layer" and "style layer"):') for key in endpoints_dict: tf.logging.info(key) #屏幕输出loss_model的各个layer name """build losses""" content_loss = losses.content_loss(endpoints_dict,FLAGS.content_layers) style_loss,style_loss_summary = losses.style_loss(endpoints_dict,style_features_t,FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard """Add Summary""" tf.summary.scalar('losses/content_loss',content_loss) tf.summary.scalar('losses/style_loss',style_loss) tf.summary.scalar('losses/regularizer_loss',tv_loss) tf.summary.scalar('weighted_losses/weighted content_loss',content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted style_loss',style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss',tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss',loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer,style_loss_summary[layer]) tf.summary.image('genearted',generated) tf.summary.image('origin',tf.stack([image_unprocessing_fn(image) for image in tf.unstack(processed_images,axis=0,num=FLAGS.batch_size)])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """prepare to train""" global_step = tf.Variable(0,name='global_step',trainable=False)#iteration step variable_to_train = []#需要训练的变量 for variable in tf.trainable_variables():#在图像风格迁移网络(图像生成网络+损失网络)各参数中,找需要训练的参数 if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss,global_step = global_step,var_list = variable_to_train) #需要放入sess.run() variable_to_restore = []#在所有的全局变量中,找需要恢复默认设置的变量; 注意:local_variable指的是一些临时变量和中间变量,用于线程中,线程结束则消失 for v tf.global_variables(): if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore,write_version=tf.train.SaverDef.V1)#利用saver.restore()恢复默认设置;这里的variable_to_restore,是需要save and restore的var_list sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])#对全局变量和局部变量进行初始化操作:即恢复默认设置 #restore variables for loss model 恢复loss model中的参数 init_func = utils._get_init_fn(FLAGS) init_func(sess) #restore variables for training model if the checkpoint file exists. 如果training_model已有训练好的参数,将其载入 last_file = tf.train.latest_checkpoint(training_path)#将train_path中的model参数数据取出 if last_file: tf.logging.info('restoringmodel from {}'.format(last_file)) saver.restore(sess,last_file) #那如果last_file不存在,就不执行restore操作吗?需要restore的参数只是图像生成网络吗? """start training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop():#查看线程是否停止(即:是否所有数据均运行完毕) _,loss_t,step = sess.run([train_op,loss,global_step]) elapsed_time = time.time() """logging""" #print(step) if step % 10 == 0: tf.logging.info('step:%d, total loss %f, secs/step: %f' % (step,loss_t,elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str,step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt'),global_step=step)#保存variable_to_restore中的参数值 except tf.errors.OutOfRangeError: saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop()#要求停止所有线程 coord.join(threads)#将线程并入主线程,删除
def train(): style_feature, style_grams = get_style_feature() with tf.Graph().as_default(): with tf.Session() as sess: #loss_input_style = tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) #loss_input_target =tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ]) # For online optimization problem, use testing preprocess for both train and test preprocess_func, unprocess_func = preprocessing.preprocessing_factory.get_preprocessing( args.loss_model, is_training=False) images = reader.image(args.batch, args.size , args.size, args.target_dir , preprocess_func, \ args.epoch , shuffle = True) model = transform(sess, args) transformed_images = model.generator(images, reuse=False) #print('qqq') #print( tf.shape(transformed_images).eval()) unprocess_transform = [(img) for img in tf.unstack( transformed_images, axis=0, num=args.batch)] processed_generated = [ preprocess_func(img, args.size, args.size) for img in unprocess_transform ] processed_generated = tf.stack(processed_generated) loss_model = nets.nets_factory.get_network_fn(args.loss_model, num_classes=1, is_training=False) pair = tf.concat([processed_generated, images], axis=0) _, end_dicts = loss_model(pair, spatial_squeeze=False) init_loss_model = load_pretrained_weight(args.loss_model) c_loss = losses.content_loss( end_dicts, loss_config.content_loss_dict[args.loss_model]) s_loss, s_loss_sum = losses.style_loss( end_dicts, loss_config.style_loss_dict[args.loss_model], style_grams) tv_loss = losses.total_variation_loss(transformed_images) loss = args.c_weight * c_loss + args.s_weight * s_loss + args.tv_weight * tv_loss print('shapes') print(pair.get_shape()) #tf.summary.scalar('average', tf.reduce_mean(images)) #tf.summary.scalar('gram average', tf.reduce_mean(tf.stack(style_feature))) tf.summary.scalar('losses/content_loss', c_loss) tf.summary.scalar('losses/style_loss', s_loss) tf.summary.scalar('losses/tv_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', c_loss * args.c_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', s_loss * args.s_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * args.tv_weight) tf.summary.scalar('total_loss', loss) for layer in loss_config.style_loss_dict[args.loss_model]: tf.summary.scalar('style_losses/' + layer, s_loss_sum[layer]) tf.summary.image('transformed', tf.stack(unprocess_transform, axis=0)) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image( 'ori', tf.stack([ unprocess_func(image) for image in tf.unstack(images, axis=0, num=args.batch) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(args.log_dir) step = tf.Variable(0, name='global_step', trainable=False) all_trainables = tf.trainable_variables() all_vars = tf.global_variables() to_train = [ var for var in all_trainables if not args.loss_model in var.name ] to_restore = [ var for var in all_vars if not args.loss_model in var.name ] optim = tf.train.AdamOptimizer( 1e-3 ).minimize(\ loss = loss , var_list = to_train , global_step = step) saver = tf.train.Saver(to_restore) style_name = (args.style_dir.split('/')[-1]).split('.')[0] ckpt = tf.train.latest_checkpoint( os.path.join(args.ckpt_dir, style_name)) if ckpt: tf.logging.info('Restoring model from {}'.format(ckpt)) saver.restore(sess, ckpt) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) #sess.run(init_loss_model) init_loss_model(sess) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time() #i = 0 try: while True: _, gs, sum_info, c_info, s_info, tv_info, loss_info = sess.run( [optim, step, summary, c_loss, s_loss, tv_loss, loss]) writer.add_summary(sum_info, gs) elapsed = time() - start_time print(gs) if gs % 10 == 0: tf.logging.info( 'step: %d, c_loss %f s_loss %f tv_loss %f total Loss %f, secs/step: %f' % (gs, c_info, s_info, tv_info, loss_info, elapsed)) if gs % args.save_freq == 0: saver.save( sess, os.path.join(args.ckpt_dir, style_name, style_name + '.ckpt')) except tf.errors.OutOfRangeError: print('run out of images! save final model: ' + os.path.join(args.ckpt_dir, style_name + '.ckpt-done')) saver.save( sess, os.path.join(args.ckpt_dir, style_name, style_name + '.ckpt-done')) tf.logging.info('Done -- file ran out of range') finally: coord.request_stop() coord.join(threads) print('end training') '''
def main(style_img_path: str, content_img_path: str, img_dim: int, num_iter: int, style_weight: int, content_weight: int, variation_weight: int, print_every: int, save_every: int): assert style_img_path is not None assert content_img_path is not None # define the device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # read the images style_img = Image.open(style_img_path) cont_img = Image.open(content_img_path) # define the transform transform = transforms.Compose([transforms.Resize((img_dim, img_dim)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # get the tensor of the image content_image = transform(cont_img).unsqueeze(0).to(device) style_image = transform(style_img).unsqueeze(0).to(device) # init the network vgg = VGG().to(device).eval() # replace the MaxPool with the AvgPool layers for name, child in vgg.vgg.named_children(): if isinstance(child, nn.MaxPool2d): vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2) # lock the gradients for param in vgg.parameters(): param.requires_grad = False # get the content activations of the content image and detach them from the graph content_activations = vgg.get_content_activations(content_image).detach() # unroll the content activations content_activations = content_activations.view(512, -1) # get the style activations of the style image style_activations = vgg.get_style_activations(style_image) # for every layer in the style activations for i in range(len(style_activations)): # unroll the activations and detach them from the graph style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach() # calculate the gram matrices of the style image style_grams = [gram(style_activations[i]) for i in range(len(style_activations))] # generate the Gaussian noise noise = torch.randn(1, 3, img_dim, img_dim, device=device, requires_grad=True) # define the adam optimizer # pass the feature map pixels to the optimizer as parameters adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999)) # run the iteration for iteration in range(num_iter): # zero the gradient adam.zero_grad() # get the content activations of the Gaussian noise noise_content_activations = vgg.get_content_activations(noise) # unroll the feature maps of the noise noise_content_activations = noise_content_activations.view(512, -1) # calculate the content loss content_loss_ = content_loss(noise_content_activations, content_activations) # get the style activations of the noise image noise_style_activations = vgg.get_style_activations(noise) # for every layer for i in range(len(noise_style_activations)): # unroll the the noise style activations noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1) # calculate the noise gram matrices noise_grams = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))] # calculate the total weighted style loss style_loss = 0 for i in range(len(style_activations)): N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1] style_loss += (gram_loss(noise_grams[i], style_grams[i], N, M) / 5.) # put the style loss on device style_loss = style_loss.to(device) # calculate the total variation loss variation_loss = total_variation_loss(noise).to(device) # weight the final losses and add them together total_loss = content_weight * content_loss_ + style_weight * style_loss + variation_weight * variation_loss if iteration % print_every == 0: print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, content_weight * content_loss_.item(), style_weight * style_loss.item(), variation_weight * variation_loss.item())) # create the folder for the generated images if not os.path.exists('./generated/'): os.mkdir('./generated/') # generate the image if iteration % save_every == 0: save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration)) # backprop total_loss.backward() # update parameters adam.step()
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """创建Network""" network_fn = nets_factory.get_network_fn( FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) """训练图片预处理""" processed_images = reader.batch_image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.transform_network(processed_images, training=True) processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') for key in endpoints_dict: tf.logging.info(key) """创建 Losses""" content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss """准备训练""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): # 只训练和保存生成网络中的变量 if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) """优化""" train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) init_func = utils._get_init_fn(FLAGS) init_func(sess) last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """开始训练""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 10 == 0: tf.logging.info( 'step: %d, total Loss %f, secs/step: %f,%s' % (step, loss_t, elapsed_time, time.asctime())) """checkpoint""" if step % 50 == 0: tf.logging.info('saving check point...') saver.save(sess, os.path.join(training_path, FLAGS.naming + '.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() tf.logging.info('coordinator stop') coord.join(threads)
def main(FLAGS): style_features_t = losses.get_style_features(FLAGS) # Make sure the training path exists. training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not(os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """Build Network""" network_fn = nets_factory.get_network_fn( FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.net(processed_images, training=True) processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') for key in endpoints_dict: tf.logging.info(key) """Build Losses""" content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss(generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss', loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image('origin', tf.stack([ image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not(variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not(v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) # Restore variables for loss network. init_func = utils._get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" # print(step) if step % 10 == 0: tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
def compute_losses(self): """ In this function we are defining the variables for loss calculations and training model. d_loss_A/d_loss_B -> loss for discriminator A/B g_loss_A/g_loss_B -> loss for generator A/B *_trainer -> Various trainer for above loss functions *_summ -> Summary variables for above loss functions """ #cycle loss cycle_consistency_loss_a = \ self._lambda_a * losses.cycle_consistency_loss( real_images=self.input_a, generated_images=self.cycle_images_a, ) cycle_consistency_loss_b = \ self._lambda_b * losses.cycle_consistency_loss( real_images=self.input_b, generated_images=self.cycle_images_b, ) #vgg_loss content_loss_a = self._delta_a * losses.content_loss(real_images=self.input_a, generated_images=self.cycle_images_a) content_loss_b = self._delta_b * losses.content_loss(real_images=self.input_b, generated_images=self.cycle_images_b) #adv_loss lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real) g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b + content_loss_a + content_loss_b g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a + content_loss_b + content_loss_a d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real, ) d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real, ) optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) self.model_vars = tf.trainable_variables() d_A_vars = [var for var in self.model_vars if 'd_A' in var.name] g_A_vars = [var for var in self.model_vars if 'g_A' in var.name] d_B_vars = [var for var in self.model_vars if 'd_B' in var.name] g_B_vars = [var for var in self.model_vars if 'g_B' in var.name] self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars) self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars) self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars) self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars) for var in self.model_vars: print(var.name) # Summary variables for tensorboard self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
def main(args): # Unpack command-line arguments. style_img_path = args.style_img_path cont_img_path = args.cont_img_path learn_rate = args.learn_rate loss_content_layers = args.loss_content_layers loss_style_layers = args.loss_style_layers content_weights = args.content_weights style_weights = args.style_weights num_steps_break = args.num_steps_break beta = args.beta style_target_resize = args.style_target_resize cont_target_resize = args.cont_target_resize output_img_path = args.output_img_path # Load in style image that will define the model. style_img = utils.imread(style_img_path) style_img = utils.imresize(style_img, style_target_resize) style_img = style_img[np.newaxis, :].astype(np.float32) # Alter the names to include a namescope that we'll use + output suffix. loss_style_layers = ['vgg/' + i + ':0' for i in loss_style_layers] loss_content_layers = ['vgg/' + i + ':0' for i in loss_content_layers] # Get target Gram matrices from the style image. with tf.variable_scope('vgg'): X_vgg = tf.placeholder(tf.float32, shape=style_img.shape, name='input') vggnet = vgg16.vgg16(X_vgg) with tf.Session() as sess: vggnet.load_weights('libs/vgg16_weights.npz', sess) print 'Precomputing target style layers.' target_grams = sess.run(utils.get_grams(loss_style_layers), feed_dict={'vgg/input:0': style_img}) # Clean up so we can re-create vgg at size of input content image for # training. print 'Resetting default graph.' tf.reset_default_graph() # Read in + resize the content image. cont_img = utils.imread(cont_img_path) cont_img = utils.imresize(cont_img, cont_target_resize) cont_img = cont_img[np.newaxis, :].astype(np.float32) # Setup VGG and initialize it with white noise image that we'll optimize. shape = cont_img.shape with tf.variable_scope('to_train'): white_noise = np.random.rand(shape[0], shape[1], shape[2], shape[3]) * 255.0 white_noise = tf.constant(white_noise.astype(np.float32)) X = tf.get_variable('input', dtype=tf.float32, initializer=white_noise) with tf.variable_scope('vgg'): vggnet = vgg16.vgg16(X) # Get the gram matrices' tensors for the style loss features. input_img_grams = utils.get_grams(loss_style_layers) # Get the tensors for content loss features. content_layers = utils.get_layers(loss_content_layers) # Get the target content features with tf.Session() as sess: vggnet.load_weights('libs/vgg16_weights.npz', sess) print 'Precomputing target content layers.' content_targets = sess.run(content_layers, feed_dict={'to_train/input:0': cont_img}) # Create loss function cont_loss = losses.content_loss(content_layers, content_targets, content_weights) style_loss = losses.style_loss(input_img_grams, target_grams, style_weights) tv_loss = losses.tv_loss(X) loss = cont_loss + style_loss + beta * tv_loss # We do not want to train VGG, so we must grab the subset. train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='to_train') # Setup step + optimizer global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learn_rate) \ .minimize(loss, global_step, train_vars) # Initializer init_op = tf.global_variables_initializer() # Begin training with tf.Session() as sess: sess.run(init_op) vggnet.load_weights('libs/vgg16_weights.npz', sess) current_step = 0 while current_step < num_steps_break: current_step = sess.run(global_step) if (current_step % 10 == 0): # Collect some diagnostic data for Tensorboard. _, loss_out = sess.run([optimizer, loss]) # Do some standard output. print current_step, loss_out else: # optimizer.minimize(sess) _, loss_out = sess.run([optimizer, loss]) # Upon finishing, get the X tensor (our image). img_out = sess.run(X) # Save it. img_out = np.squeeze(img_out) utils.imwrite(output_img_path, img_out)
feed_dict={image: style_pre}) # create image merging content and style g = tf.Graph() with g.as_default(), g.device('/gpu:0'), tf.Session() as sess: # init randomly # white noise target = tf.random_normal((1, ) + content.shape) target_pre_var = tf.Variable(target) # build model with empty layer activations for generated target image model = network_model.get_model(target_pre_var) # compute loss cont_cost = losses.content_loss(content_out, model, C_LAYER, options.content_weight) style_cost = losses.style_loss(style_out, model, S_LAYERS, style_weight_layer) tv_cost = losses.total_var_loss(target_pre_var, options.tv_weight) total_loss = cont_cost + tf.add_n(style_cost) + tv_cost # total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) sess.run(tf.global_variables_initializer()) min_loss = float("inf") best = None for i in range(options.iter): train_step.run() print('Iteration %d/%d' % (i + 1, options.iter))
def main(unused_agrv=None): """main :param args: argparse.Namespace object from argparse.parse_args(). """ # Unpack command-line arguments. train_dir = FLAGS.train_dir style_dataset = FLAGS.style_dataset model_name = FLAGS.model_name preprocess_size = [FLAGS.image_size, FLAGS.image_size] batch_size = FLAGS.batch_size n_epochs = FLAGS.n_epochs run_name = FLAGS.run_name checkpoint = FLAGS.checkpoint learn_rate = FLAGS.learning_rate content_weights = FLAGS.content_weights style_weights = FLAGS.style_weights num_pipe_buffer = FLAGS.num_pipe_buffer style_coefficients = FLAGS.style_coefficients num_styles = FLAGS.num_styles train_steps = FLAGS.train_steps upsample_method = FLAGS.upsample_method # Setup input pipeline (delegate it to CPU to let GPU handle neural net) files = tf.train.match_filenames_once(train_dir + '/train-*') style_files = tf.train.match_filenames_once(style_dataset) print("style %s" % style_files) with tf.variable_scope('input_pipe'), tf.device('/cpu:0'): _, style_labels, style_grams = datapipe.style_batcher( style_files, batch_size, preprocess_size, n_epochs, num_pipe_buffer) batch_op = datapipe.batcher(files, batch_size, preprocess_size, n_epochs, num_pipe_buffer) """ Set up the style coefficients """ if style_coefficients is None: style_coefficients = [1.0 for _ in range(num_styles)] else: style_coefficients = ast.literal_eval(style_coefficients) if len(style_coefficients) != num_styles: raise ValueError( 'number of style coeffients differs from number of styles') style_coefficient = tf.gather(tf.constant(style_coefficients), style_labels) """ Set up weight of style and content image """ content_weights = ast.literal_eval(content_weights) style_weights = ast.literal_eval(style_weights) style_weights = dict([(key, style_coefficient * val) for key, val in style_weights.iteritems()]) target_grams = [] for name, val in style_weights.iteritems(): target_grams.append(style_grams[name]) # Alter the names to include a name_scope that we'll use + output suffix. loss_style_layers = [] loss_style_weights = [] loss_content_layers = [] loss_content_weights = [] for key, val in style_weights.iteritems(): loss_style_layers.append(key + ':0') loss_style_weights.append(val) for key, val in content_weights.iteritems(): loss_content_layers.append(key + ':0') loss_content_weights.append(val) # Load in image transformation network into default graph. shape = [batch_size] + preprocess_size + [3] with tf.variable_scope('styleNet'): X = tf.placeholder(tf.float32, shape=shape, name='input') Y = transform(X, style_labels, num_styles, upsample_method) print(Y) # Connect vgg directly to the image transformation network. with tf.variable_scope('vgg'): vggnet = vgg16.vgg16(Y) # Get the gram matrices' tensors for the style loss features. input_img_grams = utils.get_grams(loss_style_layers) # Get the tensors for content loss features. content_layers = utils.get_layers(loss_content_layers) # Create loss function content_targets = tuple( tf.placeholder(tf.float32, shape=layer.get_shape(), name='content_input_{}'.format(i)) for i, layer in enumerate(content_layers)) cont_loss = losses.content_loss(content_layers, content_targets, loss_content_weights) style_loss = losses.style_loss(input_img_grams, target_grams, loss_style_weights) loss = cont_loss + style_loss with tf.name_scope('summaries'): tf.summary.scalar('loss', loss) tf.summary.scalar('style_loss', style_loss) tf.summary.scalar('content_loss', cont_loss) # We do not want to train VGG, so we must grab the subset. other_vars = [ var for var in tf.get_variable_scope('styleNet') if 'CondInstNorm' not in var.name ] train_vars = [ var for var in tf.get_variable_scope('styleNet') if 'CondInstNorm' in var.name ] # Setup step + optimizer global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learn_rate).minimize( loss, global_step, train_vars) # Setup subdirectory for this run's Tensoboard logs. if not os.path.exists('./summaries/train/'): os.makedirs('./summaries/train/') if run_name is None: current_dirs = [ name for name in os.listdir('./summaries/train/') if os.path.isdir('./summaries/train/' + name) ] name = model_name + '0' count = 0 while name in current_dirs: count += 1 name = model_name + '{}'.format(count) run_name = name # Savers and summary writers if not os.path.exists('./training'): # Dir that we'll later save .ckpts to os.makedirs('./training') if not os.path.exists('./models'): # Dir that save final models to os.makedirs('./models') saver = tf.train.Saver() saver_n_stylee = tf.train.Saver(other_vars) final_saver = tf.train.Saver(train_vars) merged = tf.summary.merge_all() full_log_path = './summaries/train/' + run_name train_writer = tf.summary.FileWriter(full_log_path, tf.Session().graph) # We must include local variables because of batch pipeline. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # Begin training. print 'Starting training...' with tf.Session() as sess: # Initialization sess.run(init_op) vggnet.load_weights(vgg16.checkpoint_file(), sess) saver_n_stylee.restore(sess, checkpoint) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): current_step = sess.run(global_step) batch = sess.run(batch_op) # Collect content targets content_data = sess.run(content_layers, feed_dict={Y: batch}) feed_dict = {X: batch, content_targets: content_data} if (current_step % 1000 == 0): # Save a checkpoint save_path = 'training/' + model_name + '.ckpt' saver.save(sess, save_path, global_step=global_step) summary, _, loss_out, c_loss, s_loss = sess.run( [merged, optimizer, loss, cont_loss, style_loss], feed_dict=feed_dict) train_writer.add_summary(summary, current_step) print current_step, loss_out, c_loss, s_loss elif (current_step % 10 == 0): # Collect some diagnostic data for Tensorboard. summary, _, loss_out, c_loss, s_loss = sess.run( [merged, optimizer, loss, cont_loss, style_loss], feed_dict=feed_dict) train_writer.add_summary(summary, current_step) # Do some standard output. # if (current_step % 1000 == 0): print current_step, loss_out, c_loss, s_loss else: _, loss_out = sess.run([optimizer, loss], feed_dict=feed_dict) # Throw error if we reach number of steps to break after. if current_step == train_steps: print('Done training.') break except tf.errors.OutOfRangeError: print('Done training.') finally: # Save the model (the image transformation network) for later usage # in predict.py final_saver.save(sess, 'models/' + model_name + '_final.ckpt', write_meta_graph=False) coord.request_stop() coord.join(threads)
loss += (analogy_weight / len(analogy_layers)) * al if mrf_weight != 0.0: for layer_name in mrf_layers: ap_image_features = K.variable(all_ap_image_features[layer_name][0]) layer_features = outputs_dict[layer_name] combination_features = layer_features[0, :, :, :] sl = losses.mrf_loss(ap_image_features, combination_features, patch_size=patch_size, patch_stride=patch_stride) loss += (mrf_weight / len(mrf_layers)) * sl if b_bp_content_weight != 0.0: for layer_name in b_content_layers: b_features = K.variable(all_b_features[layer_name][0]) bp_features = outputs_dict[layer_name] cl = losses.content_loss(bp_features, b_features) loss += b_bp_content_weight / len(b_content_layers) * cl loss += total_variation_weight * losses.total_variation_loss(vgg_input, img_width, img_height) # get the gradients of the generated image wrt the loss grads = K.gradients(loss, vgg_input) outputs = [loss] if type(grads) in {list, tuple}: outputs += grads else: outputs.append(grads) f_outputs = K.function([vgg_input], outputs) evaluator = Evaluator()
def main(argv=None): content_layers = FLAGS.content_layers.split(',') style_layers = FLAGS.style_layers.split(',') style_layers_weights = [ float(i) for i in FLAGS.style_layers_weights.split(",") ] #num_steps_decay = 82786 / FLAGS.batch_size num_steps_decay = 10000 style_features_t = losses.get_style_features(FLAGS) training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Session() as sess: """Build Network""" network_fn = nets_factory.get_network_fn(FLAGS.loss_model, num_classes=1, is_training=False) image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) generated = model.net(processed_images, FLAGS.alpha) processed_generated = [ image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] processed_generated = tf.stack(processed_generated) _, endpoints_dict = network_fn(tf.concat( [processed_generated, processed_images], 0), spatial_squeeze=False) """Build Losses""" content_loss = losses.content_loss(endpoints_dict, content_layers) style_loss, style_losses = losses.style_loss(endpoints_dict, style_features_t, style_layers, style_layers_weights) tv_loss = losses.total_variation_loss( generated) # use the unprocessed image content_loss = FLAGS.content_weight * content_loss style_loss = FLAGS.style_weight * style_loss tv_loss = FLAGS.tv_weight * tv_loss loss = style_loss + content_loss + tv_loss """Prepare to Train""" global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) lr = tf.train.exponential_decay(learning_rate=1e-1, global_step=global_step, decay_steps=num_steps_decay, decay_rate=1e-1, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-8) train_op = optimizer.minimize(loss, global_step=global_step, var_list=variable_to_train) #train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) saver = tf.train.Saver(variables_to_restore) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) init_func = utils._get_init_fn(FLAGS) init_func(sess) last_file = tf.train.latest_checkpoint(training_path) if last_file: saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): _, c_loss, s_losses, t_loss, total_loss, step = sess.run([ train_op, content_loss, style_losses, tv_loss, loss, global_step ]) """logging""" if step % 10 == 0: print(step, c_loss, s_losses, t_loss, total_loss) """checkpoint""" if step % 10000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model'), global_step=step) if step == FLAGS.max_iter: saver.save( sess, os.path.join(training_path, 'fast-style-model-done')) break except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(training_path, 'fast-style-model-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)
for layer_name in mrf_layers: ap_image_features = K.variable( all_ap_image_features[layer_name][0]) layer_features = outputs_dict[layer_name] combination_features = layer_features[0, :, :, :] sl = losses.mrf_loss(ap_image_features, combination_features, patch_size=patch_size, patch_stride=patch_stride) loss += (mrf_weight / len(mrf_layers)) * sl if b_bp_content_weight != 0.0: for layer_name in b_content_layers: b_features = K.variable(all_b_features[layer_name][0]) bp_features = outputs_dict[layer_name] cl = losses.content_loss(bp_features, b_features) loss += b_bp_content_weight / len(b_content_layers) * cl loss += total_variation_weight * losses.total_variation_loss( vgg_input, img_width, img_height) # get the gradients of the generated image wrt the loss grads = K.gradients(loss, vgg_input) outputs = [loss] if type(grads) in {list, tuple}: outputs += grads else: outputs.append(grads) f_outputs = K.function([vgg_input], outputs)
def train(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) '''数据加载''' # 读取文件名 filenames = [os.path.join(opt.data_root, f) for f in os.listdir(opt.data_root) if os.path.isfile(os.path.join(opt.data_root, f))] # 判断文件格式,png为True,jpeg为False png = filenames[0].lower().endswith('png') # If first file is a png, assume they all are # 维持文件名队列 filename_queue = tf.train.string_input_producer(filenames, shuffle=True, num_epochs=opt.epoches) # 初始化阅读器 reader = tf.WholeFileReader() # 返回tuple,是key-value对 _, img_bytes = reader.read(filename_queue) # 图片格式解码 image_row = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) # 预处理 image = utils.img_proprocess(image_row, opt.image_size) image_batch = tf.train.batch([image], opt.batch_size, dynamic_pad=True) '''生成式网络生成数据''' generated = net(image_batch, training=True) generated = tf.image.resize_bilinear(generated, [opt.image_size, opt.image_size], align_corners=False) generated.set_shape([opt.batch_size, opt.image_size, opt.image_size, 3]) # unstack将指定维度拆分为1后降维,split随意指定拆分后维度值且不会自动降维 # processed_generated = tf.stack([utils.img_proprocess(tf.squeeze(img, axis=0), opt.image_size) # for img in tf.split(generated, num_or_size_splits=opt.batch_size, axis=0)]) processed_generated = tf.stack([utils.img_proprocess(img, opt.image_size) for img in tf.unstack(generated, axis=0)]) '''数据流经损失网络_VGG''' # 一次送入数据量为2×batch_size:[原始batch经生成式网络生成的数据 + 原始batch] with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0)): # 调用 _, endpoint = vgg.vgg_16(tf.concat([processed_generated, image_batch], 0), num_classes=1, is_training=False, spatial_squeeze=False) tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') for key in endpoint: tf.logging.info(key) '''损失函数构建''' style_gram = utils.get_style_feature(opt.style_path, opt.image_size, opt.style_layers, opt.model_path, opt.exclude_scopes) content_loss, content_loss_summary = losses.content_loss(endpoint, opt.content_layers) style_loss, style_loss_summary = losses.style_loss(endpoint, style_gram, opt.style_layers) tv_loss = losses.total_variation_loss(generated) # use the unprocessed image, 我们想要的图像也是这个 loss = opt.style_weight * style_loss + opt.content_weight * content_loss + opt.tv_weight * tv_loss '''优化器构建''' # 优化器维护非vgg16的可训练变量 variables_to_train = [] for variable in tf.trainable_variables(): if not (variable.name.startswith("vgg_16")): # "vgg16" variables_to_train.append(variable) global_step = tf.Variable(0, name="global_step", trainable=False) train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variables_to_train) '''存储器构建''' # 存储器保存非vgg16的全局变量 # tf.global_variables():返回全局变量。 # 全局变量是分布式环境中跨计算机共享的变量。该Variable()构造函数或get_variable() # 自动将新变量添加到图形集合:GraphKeys.GLOBAL_VARIABLES。这个方便函数返回该集合的内容。 # 全局变量的替代方法是局部变量。参考:tf.local_variables variables_to_restore = [] # 比trainable多出的主要是用于bp的变量 for variable in tf.global_variables(): if not (variable.name.startswith("vgg_16")): # "vgg16" variables_to_restore.append(variable) saver = tf.train.Saver(var_list=variables_to_restore, write_version=tf.train.SaverDef.V2) """添加监测项""" # 添加总体loss监测 tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * opt.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * opt.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * opt.tv_weight) tf.summary.scalar('total_loss', loss) # 添加各层style loss监测 for layer in opt.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) # 监测生成图监测 tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? # 添加原图监测 tf.summary.image('origin', image_batch) summary_path = "./logs" model_path = "./logs/model" summary = tf.summary.merge_all() # with open('train_v.txt', 'w') as f: # for s in variable_to_train: # f.write(s.name + '\n') # with open('restore_v.txt', 'w') as f: # for s in variables_to_restore: # f.write(s.name + '\n') '''训练''' with tf.Session(config=config) as sess: writer = tf.summary.FileWriter(summary_path, sess.graph) sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) # vgg网络预训练参数载入 param_init_fn = utils.param_load_fn(opt.model_path, opt.exclude_scopes) param_init_fn(sess) # 由于使用saver,故载入变量不包含vgg16相关变量 if not os.path.exists(model_path): os.makedirs(model_path) ckpt = tf.train.get_checkpoint_state(model_path) if ckpt: tf.logging.info("Success to read {}".format(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.info("Failed to find a checkpoint") coord = tf.train.Coordinator() # 线程控制器 threads = tf.train.start_queue_runners(coord=coord) # 启动队列 start_time = time.time() # 计时开始 try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() if step % 50 == 0: tf.logging.info('step: {0:d}, total Loss {1:.2f}, secs/step: {2:.3f}'. format(step, loss_t, elapsed_time)) if step % 100 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() if step % 1000 == 0: saver.save(sess, os.path.join(model_path, 'fast_style_model'), global_step=step) except tf.errors.OutOfRangeError: saver.save(sess, os.path.join(model_path, 'fast_style_model')) tf.logging.info('Epoch limit reached') finally: coord.request_stop() coord.join(threads) writer.close() '''调试输出'''
def main(FLAGS): # 得到风格特征 style_features_t = losses.get_style_features(FLAGS) # Make sure the training path exists. training_path = os.path.join(FLAGS.model_path, FLAGS.naming) if not (os.path.exists(training_path)): os.makedirs(training_path) with tf.Graph().as_default(): with tf.Session() as sess: """Build Network""" # 构造vgg网络,按照FLAGS.loss_model中的网络名字,可以在/nets/nets_factory.py 中的networks_map找到对应 network_fn = nets_factory.get_network_fn(FLAGS.loss_model, num_classes=1, is_training=False) # 根据不同网络做不同的预处理 image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) # 读取一个批次的数据,并且预处理 # 这里的数据你可以不用coco,可以直接给一个包含很多图片的文件夹即可 # 因为coco过于大 processed_images = reader.image(FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, 'F:/CASIA/train_frame/real/', image_preprocessing_fn, epochs=FLAGS.epoch) # 通过生成网络,生成图片,相当于y^ generated = model.net(processed_images, training=True) # 因为一会要把生成图片喂入到后面vgg进行计算两个损失,所以要先进行预处理 processed_generated = [ image_preprocessing_fn(image, FLAGS.image_height, FLAGS.image_width) for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) ] # 因为上面是list格式,所以用tf.stack堆叠成tensor processed_generated = tf.stack(processed_generated) # 按照batch那一个维度,拼起来,比如原来两个是[batch_size,h,w,c],concat后变为[2*batch_size,h,w,c] # 这样一次前向传播把y^ 和y_c的特征都计算出来了 _, endpoints_dict = network_fn(tf.concat( [processed_generated, processed_images], 0), spatial_squeeze=False) # Log the structure of loss network tf.logging.info( 'Loss network layers(You can define them in "content_layers" and "style_layers"):' ) for key in endpoints_dict: tf.logging.info(key) """Build Losses""" # 计算三个损失 content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) style_loss, style_loss_summary = losses.style_loss( endpoints_dict, style_features_t, FLAGS.style_layers) tv_loss = losses.total_variation_loss( generated) # use the unprocessed image loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss # Add Summary for visualization in tensorboard. """Add Summary""" # 为了tensorboard,可以忽略 tf.summary.scalar('losses/content_loss', content_loss) tf.summary.scalar('losses/style_loss', style_loss) tf.summary.scalar('losses/regularizer_loss', tv_loss) tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight) tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight) tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight) tf.summary.scalar('total_loss', loss) for layer in FLAGS.style_layers: tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer]) tf.summary.image('generated', generated) # tf.image_summary('processed_generated', processed_generated) # May be better? tf.summary.image( 'origin', tf.stack([ image_unprocessing_fn(image) for image in tf.unstack( processed_images, axis=0, num=FLAGS.batch_size) ])) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(training_path) """Prepare to Train""" # 步数 global_step = tf.Variable(0, name="global_step", trainable=False) variable_to_train = [] for variable in tf.trainable_variables(): # 把非vgg网络里面的可训练变量加入variable_to_train if not (variable.name.startswith(FLAGS.loss_model)): variable_to_train.append(variable) # 注意var_list train_op = tf.train.AdamOptimizer(1e-3).minimize( loss, global_step=global_step, var_list=variable_to_train) variables_to_restore = [] for v in tf.global_variables(): # 把非vgg中的可存储变量加入variables_to_restore if not (v.name.startswith(FLAGS.loss_model)): variables_to_restore.append(v) # 注意variables_to_restore saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) # Restore variables for loss network. # slim的,可以根据FLAGS里面配置把网络参数加载到sess这个会话里面 init_func = utils._get_init_fn(FLAGS) init_func(sess) # Restore variables for training model if the checkpoint file exists. last_file = tf.train.latest_checkpoint(training_path) if last_file: tf.logging.info('Restoring model from {}'.format(last_file)) saver.restore(sess, last_file) """Start Training""" coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) start_time = time.time() try: while not coord.should_stop(): _, loss_t, step = sess.run([train_op, loss, global_step]) elapsed_time = time.time() - start_time start_time = time.time() """logging""" print(step) if step % 10 == 0: tf.logging.info( 'step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) """summary""" if step % 25 == 0: tf.logging.info('adding summary...') summary_str = sess.run(summary) writer.add_summary(summary_str, step) writer.flush() """checkpoint""" if step % 1000 == 0: saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step) except tf.errors.OutOfRangeError: saver.save( sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) tf.logging.info('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads)