def _handler2(content_path, model_path, save_path=None, prefix=None, suffix=None): with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content_image = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content_image') output_image = itn.transform(content_image) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) output = [] for content in content_path: content_target = get_images(content) result = sess.run(output_image, feed_dict={content_image: content_target}) output.append(result[0]) if save_path is not None: save_images(content_path, output, save_path, prefix=prefix, suffix=suffix) return output
def _handler1(content_path, model_path, resize_height=None, resize_width=None, save_path=None, prefix=None, suffix=None): # get the actual image data, output shape: (num_images, height, width, color_channels) content_target = get_images(content_path, resize_height, resize_width) with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content_image = tf.placeholder(tf.float32, shape=content_target.shape, name='content_image') output_image = itn.transform(content_image) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) output = sess.run(output_image, feed_dict={content_image: content_target}) if save_path is not None: save_images(content_path, output, save_path, prefix=prefix, suffix=suffix) return output
def train(content_targets_path, style_target_path, content_weight, style_weight, tv_weight, vgg_path, save_path, debug=False, logging_period=100): if debug: from datetime import datetime start_time = datetime.now() # guarantee the size of content_targets is a multiple of BATCH_SIZE mod = len(content_targets_path) % BATCH_SIZE if mod > 0: print('Train set has been trimmed %d samples...' % mod) content_targets_path = content_targets_path[:-mod] height, width, channels = TRAINING_IMAGE_SHAPE input_shape = (BATCH_SIZE, height, width, channels) # create a pre-trained VGG network vgg = VGG(vgg_path) # retrive the style_target image style_target = get_images( style_target_path) # shape: (1, height, width, channels) style_shape = style_target.shape # compute the style features style_features = {} with tf.Graph().as_default(), tf.Session() as sess: style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image') # pass style_image through 'pretrained VGG-19 network' style_img_preprocess = preprocess(style_image) style_net = vgg.forward(style_img_preprocess) for style_layer in STYLE_LAYERS: features = style_net[style_layer].eval( feed_dict={style_image: style_target}) features = np.reshape(features, [-1, features.shape[3]]) gram = np.matmul(features.T, features) / features.size style_features[style_layer] = gram # compute the perceptual losses with tf.Graph().as_default(), tf.Session() as sess: content_images = tf.placeholder(tf.float32, shape=input_shape, name='content_images') # pass content_images through 'pretrained VGG-19 network' content_imgs_preprocess = preprocess(content_images) content_net = vgg.forward(content_imgs_preprocess) # compute the content features content_features = {} content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER] # pass content_images through 'Image Transform Net' output_images = itn.transform(content_images) # pass output_images through 'pretrained VGG-19 network' output_imgs_preprocess = preprocess(output_images) output_net = vgg.forward(output_imgs_preprocess) # ** compute the feature reconstruction loss ** content_size = tf.size(content_features[CONTENT_LAYER]) content_loss = 2 * tf.nn.l2_loss( output_net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / tf.to_float(content_size) # ** compute the style reconstruction loss ** style_losses = [] for style_layer in STYLE_LAYERS: features = output_net[style_layer] shape = tf.shape(features) num_images, height, width, num_filters = shape[0], shape[1], shape[ 2], shape[3] features = tf.reshape(features, [num_images, height * width, num_filters]) grams = tf.matmul(features, features, transpose_a=True) / tf.to_float( height * width * num_filters) style_gram = style_features[style_layer] layer_style_loss = 2 * tf.nn.l2_loss(grams - style_gram) / tf.to_float( tf.size(grams)) style_losses.append(layer_style_loss) style_loss = tf.reduce_sum(tf.stack(style_losses)) # ** compute the total variation loss ** shape = tf.shape(output_images) height, width = shape[1], shape[2] y = tf.slice(output_images, [0, 0, 0, 0], [-1, height - 1, -1, -1]) - tf.slice( output_images, [0, 1, 0, 0], [-1, -1, -1, -1]) x = tf.slice(output_images, [0, 0, 0, 0], [-1, -1, width - 1, -1]) - tf.slice( output_images, [0, 0, 1, 0], [-1, -1, -1, -1]) tv_loss = tf.nn.l2_loss(x) / tf.to_float( tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y)) # overall perceptual losses loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss # Training step train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) sess.run(tf.global_variables_initializer()) # saver = tf.train.Saver() saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) # ** Start Training ** step = 0 n_batches = len(content_targets_path) // BATCH_SIZE if debug: elapsed_time = datetime.now() - start_time tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info( 'Elapsed time for preprocessing before actually train the model: %s' % elapsed_time) tf.logging.info('Now begin to train the model...') start_time = datetime.now() for epoch in range(EPOCHS): np.random.shuffle(content_targets_path) for batch in range(n_batches): # retrive a batch of content_targets images content_batch_path = content_targets_path[batch * BATCH_SIZE:( batch * BATCH_SIZE + BATCH_SIZE)] content_batch = get_images(content_batch_path, input_shape[1], input_shape[2]) # run the training step sess.run(train_op, feed_dict={content_images: content_batch}) step += 1 if step % 1000 == 0: saver.save(sess, save_path, global_step=step) if debug: is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) if is_last_step or step % logging_period == 0: elapsed_time = datetime.now() - start_time _content_loss, _style_loss, _tv_loss, _loss = sess.run( [content_loss, style_loss, tv_loss, loss], feed_dict={content_images: content_batch}) tf.logging.info( 'step: %d, total loss: %f, elapsed time: %s' % (step, _loss, elapsed_time)) tf.logging.info( 'content loss: %f, weighted content loss: %f' % (_content_loss, content_weight * _content_loss)) tf.logging.info( 'style loss : %f, weighted style loss : %f' % (_style_loss, style_weight * _style_loss)) tf.logging.info( 'tv loss : %f, weighted tv loss : %f' % (_tv_loss, tv_weight * _tv_loss)) tf.logging.info('\n') # ** Done Training & Save the model ** saver.save(sess, save_path) if debug: elapsed_time = datetime.now() - start_time tf.logging.info('Done training! Elapsed time: %s' % elapsed_time) tf.logging.info('Model is saved to: %s' % save_path)
def train(content_path, style_path, content_weight, style_weight, tv_weight, vgg_path, save_path): height, width, channels = IMG_SHAPE input_shape = (batch_size, height, width, channels) start_time = datetime.now() vgg = VGG(vgg_path) style_target = get_images(style_path, height, width) style_shape = style_target.shape style_features = {} config = tf.ConfigProto() config.gpu_options.allow_growth = True # style net with tf.Session(config=config) as sess: style_img = tf.placeholder(tf.float32, shape=style_shape, name='style_image') style_net = vgg.forward(preprocess(style_img)) for layer in STYLE: features = style_net[layer].eval( feed_dict={style_img: style_target}) features = np.reshape(features, [-1, features.shape[3]]) gram = np.matmul(features.T, features) / features.size style_features[layer] = gram # content net with tf.Session(config=config) as sess: content_img = tf.placeholder(tf.float32, shape=input_shape, name='content_img') content_net = vgg.forward(preprocess(content_img)) content_features = content_net[CONTENT] trans_images = itn.transform(content_img) output_net = vgg.forward(preprocess(trans_images)) # reconstruction loss content_size = tf.size(content_features) content_loss = tf.nn.l2_loss(output_net[CONTENT] - content_features ) * 2 / tf.to_float(content_size) # style reconstruction loss style_losses = [] for layer in STYLE: features = output_net[layer] shape = tf.shape(features) num_images, height, width, num_filters = shape[0], shape[1], shape[ 2], shape[3] features = tf.reshape(features, [num_images, height * width, num_filters]) grams = tf.matmul(features, features, transpose_a=True) / tf.to_float( height * width * num_filters) style_gram = style_features[layer] layer_style_loss = tf.nn.l2_loss(grams - style_gram) * 2 / tf.to_float( tf.size(grams)) style_losses.append(layer_style_loss) style_loss = tf.reduce_sum(tf.stack(style_losses)) # total variation loss shape = tf.shape(trans_images) height, width = shape[1], shape[2] y = tf.slice(trans_images, [0, 0, 0, 0], [-1, height - 1, -1, -1]) - tf.slice( trans_images, [0, 1, 0, 0], [-1, -1, -1, -1]) x = tf.slice(trans_images, [0, 0, 0, 0], [-1, -1, width - 1, -1]) - tf.slice( trans_images, [0, 0, 1, 0], [-1, -1, -1, -1]) tv_loss = tf.nn.l2_loss(x) / tf.to_float( tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y)) # overall perceptual losses loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss # Training step train_op = tf.train.AdamOptimizer(lr).minimize(loss) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) step = 0 n_batches = len(content_path) // batch_size elapsed_time = datetime.now() - start_time tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info( 'Elapsed time for preprocessing before actually train the model: %s' % elapsed_time) tf.logging.info('Now begin to train the model...') start_time = datetime.now() c_loss = [] s_loss = [] tv = [] total_loss = [] for epoch in range(EPOCHS): np.random.shuffle(content_path) for batch in range(n_batches): # retrive a batch of content_targets images content_batch_path = content_path[batch * batch_size:( batch * batch_size + batch_size)] content_batch = get_images(content_batch_path, input_shape[1], input_shape[2]) # run the training step sess.run(train_op, feed_dict={content_img: content_batch}) step += 1 if step % 1000 == 0: saver.save(sess, save_path, global_step=step) is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) if is_last_step or step % 100 == 0: elapsed_time = datetime.now() - start_time _content_loss, _style_loss, _tv_loss, _loss = sess.run( [content_loss, style_loss, tv_loss, loss], feed_dict={content_img: content_batch}) tf.logging.info( 'step: %d, total loss: %f, elapsed time: %s' % (step, _loss, elapsed_time)) tf.logging.info( 'content loss: %f, weighted content loss: %f' % (_content_loss, content_weight * _content_loss)) tf.logging.info( 'style loss : %f, weighted style loss : %f' % (_style_loss, style_weight * _style_loss)) tf.logging.info( 'tv loss : %f, weighted tv loss : %f' % (_tv_loss, tv_weight * _tv_loss)) tf.logging.info('\n') c_loss.append(_content_loss) s_loss.append(_style_loss) tv.append(_tv_loss) total_loss.append(_loss) saver.save(sess, save_path) return c_loss, s_loss, tv, total_loss