def style_transfer(content_image, style_image, content_masks, style_masks, init_image, args): r""" Syle transfer computation Parameters ---------- content_image: style_image: content_masks: style_masks: init_image: args: Returns ------- tf.Tensor """ print("Style transfer started") content_image = vgg.preprocess(content_image) style_image = vgg.preprocess(style_image) weight_restorer = vgg.load_weights() image_placeholder = tf.compat.v1.placeholder(tf.float32, shape=[1, None, None, 3]) vgg19 = vgg.VGG19ConvSub(image_placeholder) with tf.compat.v1.Session() as sess: transfer_image = tf.Variable(init_image) transfer_image_vgg = vgg.preprocess(transfer_image) transfer_image_nima = nima.preprocess(transfer_image) sess.run(tf.compat.v1.global_variables_initializer()) weight_restorer.init(sess) content_conv4_2 = sess.run(fetches=vgg19.conv4_2, feed_dict={image_placeholder: content_image}) style_conv1_1, style_conv2_1, style_conv3_1, style_conv4_1, style_conv5_1 = sess.run( fetches=[vgg19.conv1_1, vgg19.conv2_1, vgg19.conv3_1, vgg19.conv4_1, vgg19.conv5_1], feed_dict={image_placeholder: style_image}) with tf.compat.v1.variable_scope("", reuse=True): vgg19 = vgg.VGG19ConvSub(transfer_image_vgg) print(content_conv4_2, vgg19.conv4_2) print(type(content_conv4_2), type(vgg19.conv4_2)) content_loss = calculate_layer_content_loss(content_conv4_2, vgg19.conv4_2) style_loss = (1. / 5.) * calculate_layer_style_loss(style_conv1_1, vgg19.conv1_1, content_masks, style_masks) style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv2_1, vgg19.conv2_1, content_masks, style_masks) style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv3_1, vgg19.conv3_1, content_masks, style_masks) style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv4_1, vgg19.conv4_1, content_masks, style_masks) style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv5_1, vgg19.conv5_1, content_masks, style_masks) # TODO: convert to tensor sooner and understand placeholders and stuff like that # photorealism_regularization = calculate_photorealism_regularization(transfer_image_vgg, tf.cast(content_image, dtype=tf.float32), args.matting) photorealism_regularization = calculate_photorealism_regularization(transfer_image_vgg, content_image, args.matting) nima_loss = compute_nima_loss(transfer_image_nima) content_loss = args.content_weight * content_loss style_loss = args.style_weight * style_loss photorealism_regularization = args.regularization_weight * photorealism_regularization nima_loss = args.nima_weight * nima_loss total_loss = content_loss + style_loss + photorealism_regularization + nima_loss tf.compat.v1.summary.scalar('Content loss', content_loss) tf.compat.v1.summary.scalar('Style loss', style_loss) tf.compat.v1.summary.scalar('Photorealism Regularization', photorealism_regularization) tf.compat.v1.summary.scalar('NIMA loss', nima_loss) tf.compat.v1.summary.scalar('Total loss', total_loss) summary_op = tf.compat.v1.summary.merge_all() summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(os.path.dirname(__file__), 'logs/{}'.format(args.results_dir)), sess.graph) iterations_dir = os.path.join(args.results_dir, "iterations") os.mkdir(iterations_dir) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.adam_learning_rate, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon) train_op = optimizer.minimize(total_loss, var_list=[transfer_image]) sess.run(adam_variables_initializer(optimizer, [transfer_image])) min_loss, best_image = float("inf"), None for i in range(1, args.iter + 1): _, result_image, loss, c_loss, s_loss, p_loss, n_loss, summary = sess.run( fetches=[train_op, transfer_image, total_loss, content_loss, style_loss, photorealism_regularization, nima_loss, summary_op]) summary_writer.add_summary(summary, i) if i % args.print_loss_interval == 0: print( "Iteration: {0:5}\t" "Total loss: {1:10.2f}\t" "Content loss: {2:10.2f}\t" "Style loss: {3:10.2f}\t " "Photorealism Regularization: {4:10.2f}\t" "NIMA loss: {5:10.2f}".format(i, loss, c_loss, s_loss, p_loss, n_loss) ) if loss < min_loss: min_loss, best_image = loss, result_image if i % args.intermediate_result_interval == 0: save_image(best_image, os.path.join(iterations_dir, "iter_{}.png".format(i))) return best_image
def style_transfer(content_image, color_to_gram_dict, content_masks, init_image, result_dir, timestamp, args): print("Style transfer started") style_conv_grams = [] for i in range(5): style_gram = {} for color in color_to_gram_dict.keys(): style_gram[color] = color_to_gram_dict[color][i] style_conv_grams.append(style_gram) content_image = vgg.preprocess(content_image) global weight_restorer image_placeholder = tf.compat.v1.placeholder(tf.float32, shape=[1, None, None, 3]) with tf.compat.v1.variable_scope("", reuse=True): vgg19 = vgg.VGG19ConvSub(image_placeholder) with tf.compat.v1.Session() as sess: transfer_image = tf.Variable(init_image) transfer_image_vgg = vgg.preprocess(transfer_image) transfer_image_nima = nima.preprocess(transfer_image) sess.run(tf.compat.v1.global_variables_initializer()) weight_restorer.init(sess) content_conv4_2 = sess.run( fetches=vgg19.conv4_2, feed_dict={image_placeholder: content_image}) with tf.compat.v1.variable_scope("", reuse=True): vgg19 = vgg.VGG19ConvSub(transfer_image_vgg) content_loss = calculate_layer_content_loss(content_conv4_2, vgg19.conv4_2) style_conv1_1_gram, style_conv2_1_gram, style_conv3_1_gram, style_conv4_1_gram, style_conv5_1_gram = style_conv_grams style_loss = (1. / 5.) * calculate_layer_style_loss( style_conv1_1_gram, vgg19.conv1_1, content_masks) style_loss += (1. / 5.) * calculate_layer_style_loss( style_conv2_1_gram, vgg19.conv2_1, content_masks) style_loss += (1. / 5.) * calculate_layer_style_loss( style_conv3_1_gram, vgg19.conv3_1, content_masks) style_loss += (1. / 5.) * calculate_layer_style_loss( style_conv4_1_gram, vgg19.conv4_1, content_masks) style_loss += (1. / 5.) * calculate_layer_style_loss( style_conv5_1_gram, vgg19.conv5_1, content_masks) photorealism_regularization = calculate_photorealism_regularization( transfer_image_vgg, content_image) nima_loss = compute_nima_loss(transfer_image_nima) content_loss = args.content_weight * content_loss style_loss = args.style_weight * style_loss photorealism_regularization = args.regularization_weight * photorealism_regularization nima_loss = args.nima_weight * nima_loss total_loss = content_loss + style_loss + photorealism_regularization + nima_loss tf.compat.v1.summary.scalar('Content loss', content_loss) tf.compat.v1.summary.scalar('Style loss', style_loss) tf.compat.v1.summary.scalar('Photorealism Regularization', photorealism_regularization) tf.compat.v1.summary.scalar('NIMA loss', nima_loss) tf.compat.v1.summary.scalar('Total loss', total_loss) summary_op = tf.compat.v1.summary.merge_all() summary_writer = tf.compat.v1.summary.FileWriter( os.path.join(os.path.dirname(__file__), 'logs/{}'.format(timestamp)), sess.graph) iterations_dir = os.path.join(result_dir, "iterations") os.mkdir(iterations_dir) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=args.adam_learning_rate, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon) train_op = optimizer.minimize(total_loss, var_list=[transfer_image]) sess.run(adam_variables_initializer(optimizer, [transfer_image])) min_loss, best_image = float("inf"), None for i in range(args.iterations + 1): _, result_image, loss, c_loss, s_loss, p_loss, n_loss, summary = sess.run( fetches=[ train_op, transfer_image, total_loss, content_loss, style_loss, photorealism_regularization, nima_loss, summary_op ]) summary_writer.add_summary(summary, i) if i % args.print_loss_interval == 0: print("Iteration: {0:5} \t " "Total loss: {1:15.2f} \t " "Content loss: {2:15.2f} \t " "Style loss: {3:15.2f} \t " "Photorealism Regularization: {4:15.2f} \t " "NIMA loss: {5:15.2f} \t".format(i, loss, c_loss, s_loss, p_loss, n_loss)) if loss < min_loss: min_loss, best_image = loss, result_image #if i % args.intermediate_result_interval == 0: # save_image(best_image, os.path.join(iterations_dir, "iter_{}.png".format(i))) return best_image