예제 #1
0
def style_transfer(content_image, color_to_gram_dict, content_masks,
                   init_image, result_dir, timestamp, args):
    print("Style transfer started")
    style_conv_grams = []
    for i in range(5):
        style_gram = {}
        for color in color_to_gram_dict.keys():
            style_gram[color] = color_to_gram_dict[color][i]
        style_conv_grams.append(style_gram)

    content_image = vgg.preprocess(content_image)

    global weight_restorer

    image_placeholder = tf.compat.v1.placeholder(tf.float32,
                                                 shape=[1, None, None, 3])

    with tf.compat.v1.variable_scope("", reuse=True):
        vgg19 = vgg.VGG19ConvSub(image_placeholder)

    with tf.compat.v1.Session() as sess:
        transfer_image = tf.Variable(init_image)
        transfer_image_vgg = vgg.preprocess(transfer_image)
        transfer_image_nima = nima.preprocess(transfer_image)

        sess.run(tf.compat.v1.global_variables_initializer())
        weight_restorer.init(sess)
        content_conv4_2 = sess.run(
            fetches=vgg19.conv4_2,
            feed_dict={image_placeholder: content_image})

        with tf.compat.v1.variable_scope("", reuse=True):
            vgg19 = vgg.VGG19ConvSub(transfer_image_vgg)

        content_loss = calculate_layer_content_loss(content_conv4_2,
                                                    vgg19.conv4_2)
        style_conv1_1_gram, style_conv2_1_gram, style_conv3_1_gram, style_conv4_1_gram, style_conv5_1_gram = style_conv_grams

        style_loss = (1. / 5.) * calculate_layer_style_loss(
            style_conv1_1_gram, vgg19.conv1_1, content_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(
            style_conv2_1_gram, vgg19.conv2_1, content_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(
            style_conv3_1_gram, vgg19.conv3_1, content_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(
            style_conv4_1_gram, vgg19.conv4_1, content_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(
            style_conv5_1_gram, vgg19.conv5_1, content_masks)

        photorealism_regularization = calculate_photorealism_regularization(
            transfer_image_vgg, content_image)

        nima_loss = compute_nima_loss(transfer_image_nima)

        content_loss = args.content_weight * content_loss
        style_loss = args.style_weight * style_loss
        photorealism_regularization = args.regularization_weight * photorealism_regularization
        nima_loss = args.nima_weight * nima_loss

        total_loss = content_loss + style_loss + photorealism_regularization + nima_loss

        tf.compat.v1.summary.scalar('Content loss', content_loss)
        tf.compat.v1.summary.scalar('Style loss', style_loss)
        tf.compat.v1.summary.scalar('Photorealism Regularization',
                                    photorealism_regularization)
        tf.compat.v1.summary.scalar('NIMA loss', nima_loss)
        tf.compat.v1.summary.scalar('Total loss', total_loss)

        summary_op = tf.compat.v1.summary.merge_all()
        summary_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(os.path.dirname(__file__),
                         'logs/{}'.format(timestamp)), sess.graph)

        iterations_dir = os.path.join(result_dir, "iterations")
        os.mkdir(iterations_dir)

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=args.adam_learning_rate,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon)

        train_op = optimizer.minimize(total_loss, var_list=[transfer_image])
        sess.run(adam_variables_initializer(optimizer, [transfer_image]))

        min_loss, best_image = float("inf"), None
        for i in range(args.iterations + 1):
            _, result_image, loss, c_loss, s_loss, p_loss, n_loss, summary = sess.run(
                fetches=[
                    train_op, transfer_image, total_loss, content_loss,
                    style_loss, photorealism_regularization, nima_loss,
                    summary_op
                ])

            summary_writer.add_summary(summary, i)

            if i % args.print_loss_interval == 0:
                print("Iteration: {0:5} \t "
                      "Total loss: {1:15.2f} \t "
                      "Content loss: {2:15.2f} \t "
                      "Style loss: {3:15.2f} \t "
                      "Photorealism Regularization: {4:15.2f} \t "
                      "NIMA loss: {5:15.2f} \t".format(i, loss, c_loss, s_loss,
                                                       p_loss, n_loss))

            if loss < min_loss:
                min_loss, best_image = loss, result_image

            #if i % args.intermediate_result_interval == 0:
            # save_image(best_image, os.path.join(iterations_dir, "iter_{}.png".format(i)))

        return best_image
def style_transfer(content_image, style_image, content_masks, style_masks, init_image, args):
    r"""
        Syle transfer computation

        Parameters
        ----------
        content_image:
        style_image:
        content_masks:
        style_masks:
        init_image:
        args:

        Returns
        -------
        tf.Tensor
    """
    print("Style transfer started")

    content_image = vgg.preprocess(content_image)
    style_image = vgg.preprocess(style_image)

    weight_restorer = vgg.load_weights()

    image_placeholder = tf.compat.v1.placeholder(tf.float32, shape=[1, None, None, 3])
    vgg19 = vgg.VGG19ConvSub(image_placeholder)

    with tf.compat.v1.Session() as sess:
        transfer_image = tf.Variable(init_image)
        transfer_image_vgg = vgg.preprocess(transfer_image)
        transfer_image_nima = nima.preprocess(transfer_image)

        sess.run(tf.compat.v1.global_variables_initializer())
        weight_restorer.init(sess)
        content_conv4_2 = sess.run(fetches=vgg19.conv4_2, feed_dict={image_placeholder: content_image})
        style_conv1_1, style_conv2_1, style_conv3_1, style_conv4_1, style_conv5_1 = sess.run(
            fetches=[vgg19.conv1_1, vgg19.conv2_1, vgg19.conv3_1, vgg19.conv4_1, vgg19.conv5_1],
            feed_dict={image_placeholder: style_image})


        with tf.compat.v1.variable_scope("", reuse=True):
            vgg19 = vgg.VGG19ConvSub(transfer_image_vgg)

        print(content_conv4_2, vgg19.conv4_2)
        print(type(content_conv4_2), type(vgg19.conv4_2))
        content_loss = calculate_layer_content_loss(content_conv4_2, vgg19.conv4_2)

        style_loss = (1. / 5.) * calculate_layer_style_loss(style_conv1_1, vgg19.conv1_1, content_masks, style_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv2_1, vgg19.conv2_1, content_masks, style_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv3_1, vgg19.conv3_1, content_masks, style_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv4_1, vgg19.conv4_1, content_masks, style_masks)
        style_loss += (1. / 5.) * calculate_layer_style_loss(style_conv5_1, vgg19.conv5_1, content_masks, style_masks)

        # TODO: convert to tensor sooner and understand placeholders and stuff like that
        # photorealism_regularization = calculate_photorealism_regularization(transfer_image_vgg, tf.cast(content_image, dtype=tf.float32), args.matting)
        photorealism_regularization = calculate_photorealism_regularization(transfer_image_vgg, content_image, args.matting)

        nima_loss = compute_nima_loss(transfer_image_nima)

        content_loss = args.content_weight * content_loss
        style_loss = args.style_weight * style_loss
        photorealism_regularization = args.regularization_weight * photorealism_regularization
        nima_loss = args.nima_weight * nima_loss

        total_loss = content_loss + style_loss + photorealism_regularization + nima_loss

        tf.compat.v1.summary.scalar('Content loss', content_loss)
        tf.compat.v1.summary.scalar('Style loss', style_loss)
        tf.compat.v1.summary.scalar('Photorealism Regularization', photorealism_regularization)
        tf.compat.v1.summary.scalar('NIMA loss', nima_loss)
        tf.compat.v1.summary.scalar('Total loss', total_loss)

        summary_op = tf.compat.v1.summary.merge_all()
        summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(os.path.dirname(__file__), 'logs/{}'.format(args.results_dir)),
                                               sess.graph)

        iterations_dir = os.path.join(args.results_dir, "iterations")
        os.mkdir(iterations_dir)

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.adam_learning_rate, beta1=args.adam_beta1,
                                           beta2=args.adam_beta2, epsilon=args.adam_epsilon)

        train_op = optimizer.minimize(total_loss, var_list=[transfer_image])
        sess.run(adam_variables_initializer(optimizer, [transfer_image]))

        min_loss, best_image = float("inf"), None
        for i in range(1, args.iter + 1):
            _, result_image, loss, c_loss, s_loss, p_loss, n_loss, summary = sess.run(
                fetches=[train_op, transfer_image, total_loss, content_loss, style_loss, photorealism_regularization,
                         nima_loss, summary_op])

            summary_writer.add_summary(summary, i)

            if i % args.print_loss_interval == 0:
                print(
                    "Iteration: {0:5}\t"
                    "Total loss: {1:10.2f}\t"
                    "Content loss: {2:10.2f}\t"
                    "Style loss: {3:10.2f}\t "
                    "Photorealism Regularization: {4:10.2f}\t"
                    "NIMA loss: {5:10.2f}".format(i, loss, c_loss, s_loss, p_loss, n_loss)
                )

            if loss < min_loss:
                min_loss, best_image = loss, result_image

            if i % args.intermediate_result_interval == 0:
                save_image(best_image, os.path.join(iterations_dir, "iter_{}.png".format(i)))

        return best_image
예제 #3
0
        "--evaluation",
        type=bool,
        help="Script activation for evaluation, default: False",
        default=False)
    init_image_options = ["noise", "content", "style"]
    parser.add_argument("--init",
                        type=str,
                        help="Initialization image (%s).",
                        default="content")
    parser.add_argument("--gpu",
                        help="comma separated list of GPU(s) to use.",
                        default="0")

    args = parser.parse_args()
    assert (args.init in init_image_options)
    vgg19 = vgg.VGG19ConvSub(image_placeholder)
    # For more information on the similarity metrics: http://gsi-upm.github.io/sematch/similarity/#word-similarity

    if args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.evaluation == False:

        timestamp = datetime.now().strftime('%Y_%m_%d_%H_%M')
        style_text = load_text(args.style_text)

        result_dir = 'result_' + args.content_image.split("/")[-1].split(
            '.')[0] + '_' + style_text
        os.mkdir(result_dir)

        # check if manual segmentation masks are available