def render_gen(content_file, style_file,
               content_region_file=None, style_region_file=None,
               random_init=False, load_saved_mapping=True, load_trained_image=False, blur_mapping=True,
               height=None, width=None,
               content_ratio=0, style3_ratio=3., style4_ratio=1., gram_ratio=0.001, diff_ratio=0.,
               gen_epochs=80, max_gen=3, pyramid=True, max_reduction_ratio=.8, final_epochs=200
               ):
    """
    Render the image by generation method.
    - Best used if the style has low similarity with the content.
    - max_reduction_ratio can be set to lower, e.g. 0.4, to improve synthesis effect, but less content will be
      preserved
    - content_ratio, gram_ratio will be set to 0 in final generation becuase of low effectiveness
    - blur_mapping will be switched off except the last generation to prevent severe content destruction

    :param content_file:            String file path of content image
    :param style_file:              String file path of style image
    :param content_region_file:     String file path of region mapping of content
    :param style_region_file:       String file path of region mapping of image
    :param random_init:             True to init the image with random
    :param load_saved_mapping:      True to use saved mapping file
    :param load_trained_image:      True to use saved training
    :param blur_mapping:            True to blur the mapping before calculate the max argument. Only applied
                                    to last generation
    :param height:                  int of height of result image
    :param width:                   int of width of result image. Leaving None with height will scaled
                                    according aspect ratio
    :param content_ratio:           float32 of weight of content cost, will be 0 for last generation
    :param style3_ratio:            float32 of weight of patch cost of conv3 layer
    :param style4_ratio:            float32 of weight of patch cost of conv4 layer
    :param gram_ratio:              float32 of weight of gram matrix cost, will be 0 for last generation
    :param diff_ratio:              float32 of weight of local different cost
    :param gen_epochs:              int of epochs of each generations, except the last generation
    :param max_gen:                 int of number of generations
    :param pyramid:                 True to pre-scale the image based on reduction ration
    :param max_reduction_ratio:     float32 of 0.0 to 1.0 of percentage of first reduction ratio in pyramid
    :param final_epochs:            int of epoch of training last generation
    """
    root_dir = os.path.dirname(__file__)
    os.makedirs(os.path.join(root_dir, 'train'), exist_ok=True)
    for gen in range(max_gen):
        if gen is 0:
            gen_content_file = content_file
            height = stylenet_core.load_image(content_file, height, width).shape[0]
        else:
            gen_content_file = os.path.join(root_dir, ("train/output-g" + str(gen - 1) + "-%d.jpg") % gen_epochs)

        output_file = os.path.join(root_dir, "train/output-g" + str(gen) + "-%d.jpg")
        output_file_final = output_file % gen_epochs
        if os.path.isfile(output_file_final):
            print(output_file_final, "exist. move to next generation")
            continue

        tf.reset_default_graph()
        ot = time.time()
        print("----------- %d generation started -----------" % gen)

        if pyramid and gen == max_gen - 1:
            h = height
            epochs = final_epochs
            cr = 0
            gr = 0
            bm = blur_mapping
        else:
            h = int(height * (gen * (1.0 - max_reduction_ratio) / max_gen + max_reduction_ratio))
            epochs = gen_epochs
            cr = content_ratio
            gr = gram_ratio
            bm = False

        render(
            content_file=gen_content_file,
            style_file=style_file,
            content_region_file=content_region_file,
            style_region_file=style_region_file,
            random_init=random_init,
            load_saved_mapping=load_saved_mapping,
            load_trained_image=load_trained_image,
            blur_mapping=bm,
            height=h,
            width=width,
            content_ratio=cr,
            style3_ratio=style3_ratio,
            style4_ratio=style4_ratio,
            gram_ratio=gr,
            diff_ratio=diff_ratio,
            epochs=epochs,
            output_file=output_file)
        print("----------- %d generation finished in %d sec -----------\n" % (gen, time.time() - ot))
def render(content_file, style_file,
           content_region_file=None, style_region_file=None,
           random_init=False, load_saved_mapping=True, load_trained_image=False, blur_mapping=True,
           height=None, width=None,
           content_ratio=0., style3_ratio=3., style4_ratio=1., gram_ratio=0.001, diff_ratio=0.,
           epochs=300, output_file="./train/output%d.jpg"):
    """
    Render the synthesis with single generation.
    - Best used if style has high similarity with the content
    - If any ratio is set to 0, the corresponding Tensor will not be generated
    - Pure Gram Matrix synthesis is best for painting abstract style. (gram_ratio = 1 and all others 0)

    :param content_file:            String file path of content image
    :param style_file:              String file path of style image
    :param content_region_file:     String file path of region mapping of content
    :param style_region_file:       String file path of region mapping of image
    :param random_init:             True to init the image with random
    :param load_saved_mapping:      True to use saved mapping file
    :param load_trained_image:      True to use saved training
    :param blur_mapping:            True to blur the mapping before calculate the max argument
    :param height:                  int of height of result image
    :param width:                   int of width of result image. Leaving None with height will scaled
                                    according aspect ratio
    :param content_ratio:           float32 of weight of content cost
    :param style3_ratio:            float32 of weight of patch cost of conv3 layer
    :param style4_ratio:            float32 of weight of patch cost of conv4 layer
    :param gram_ratio:              float32 of weight of gram matrix cost
    :param diff_ratio:              float32 of weight of local different cost
    :param epochs:                  int of number of epochs to train
    :param output_file:             String file name of output file. %d will be replaced running number
    """
    print("render started:")

    # print info:
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    for i in args:
        print("    %s = %s" % (i, values[i]))

    content_np = stylenet_core.load_image(content_file, height, width)
    style_np = stylenet_core.load_image(style_file, content_np.shape[0], content_np.shape[1])

    content_batch = np.expand_dims(content_np, 0)
    style_batch = np.expand_dims(style_np, 0)

    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
    # with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False)) as sess:

    tf_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        start_time = time.time()

        contents = tf.constant(content_batch, dtype=tf.float32, shape=content_batch.shape)
        styles = tf.constant(style_batch, dtype=tf.float32, shape=style_batch.shape)

        if random_init:
            var_image = tf.Variable(tf.truncated_normal(content_batch.shape, 0.5, 0.1))
        else:
            var_image = tf.Variable(contents)

        vgg_content = custom_vgg19.Vgg19()
        with tf.name_scope("content_vgg"):
            vgg_content.build(contents)

        vgg_style = custom_vgg19.Vgg19()
        with tf.name_scope("style_vgg"):
            vgg_style.build(styles)

        vgg_var = custom_vgg19.Vgg19()
        with tf.name_scope("variable_vgg"):
            vgg_var.build(var_image)

        with tf.name_scope("cost"):
            # style:
            # TODO change file name based on out file name
            style3file = "./train/%s-style_map_3" % (
                get_filename(content_file) + "-" + get_filename(style_file))
            style4file = "./train/%s-style_map_4" % (
                get_filename(content_file) + "-" + get_filename(style_file))

            if content_region_file is None or style_region_file is None:
                if style3_ratio is 0:
                    style_cost_3 = tf.constant(0.0)
                else:
                    style_cost_3 = stylenet_core.get_style_cost_patch2(sess, vgg_var.conv3_1,
                                                                       vgg_content.conv3_1,
                                                                       vgg_style.conv3_1,
                                                                       style3file,
                                                                       load_saved_mapping=load_saved_mapping)
                if style4_ratio is 0:
                    style_cost_4 = tf.constant(0.0)
                else:
                    style_cost_4 = stylenet_core.get_style_cost_patch2(sess, vgg_var.conv4_1,
                                                                       vgg_content.conv4_1,
                                                                       vgg_style.conv4_1,
                                                                       style4file,
                                                                       load_saved_mapping=load_saved_mapping)
            else:
                content_regions_np = stylenet_core.load_image(content_region_file, content_np.shape[0],
                                                              content_np.shape[1])
                style_regions_np = stylenet_core.load_image(style_region_file, content_np.shape[0],
                                                            content_np.shape[1])
                content_regions_batch = np.expand_dims(content_regions_np, 0)
                style_regions_batch = np.expand_dims(style_regions_np, 0)
                content_regions = tf.constant(content_regions_batch, dtype=tf.float32,
                                              shape=content_regions_batch.shape)
                style_regions = tf.constant(style_regions_batch, dtype=tf.float32,
                                            shape=style_regions_batch.shape)

                content_regions = vgg_var.avg_pool(content_regions, None)
                content_regions = vgg_var.avg_pool(content_regions, None)
                style_regions = vgg_var.avg_pool(style_regions, None)
                style_regions = vgg_var.avg_pool(style_regions, None)

                if style3_ratio is 0:
                    style_cost_3 = tf.constant(0.0)
                else:
                    style_cost_3 = stylenet_core.get_style_cost_patch2(sess,
                                                                       vgg_var.conv3_1,
                                                                       vgg_content.conv3_1,
                                                                       vgg_style.conv3_1,
                                                                       style3file,
                                                                       content_regions,
                                                                       style_regions,
                                                                       load_saved_mapping,
                                                                       blur_mapping=blur_mapping)

                content_regions = vgg_var.avg_pool(content_regions, None)
                style_regions = vgg_var.avg_pool(style_regions, None)

                if style4_ratio is 0:
                    style_cost_4 = tf.constant(0.0)
                else:
                    style_cost_4 = stylenet_core.get_style_cost_patch2(sess,
                                                                       vgg_var.conv4_1,
                                                                       vgg_content.conv4_1,
                                                                       vgg_style.conv4_1,
                                                                       style4file,
                                                                       content_regions,
                                                                       style_regions,
                                                                       load_saved_mapping,
                                                                       blur_mapping=blur_mapping)

            if gram_ratio is 0:
                style_cost_gram = tf.constant(0.0)
            else:
                style_cost_gram = stylenet_core.get_style_cost_gram(sess, vgg_style, vgg_var)

            # content:
            if content_ratio is 0:
                content_cost = tf.constant(.0)
            else:
                fixed_content = stylenet_core.get_constant(sess, vgg_content.conv4_2)
                content_cost = stylenet_core.l2_norm_cost(vgg_var.conv4_2 - fixed_content)

            # # smoothness:
            if diff_ratio is 0:
                diff_cost = tf.constant(.0)
            else:
                diff_filter_h = tf.constant([0, 0, 0, 0, -1, 1, 0, 0, 0], tf.float32, [3, 3, 1, 1])
                diff_filter_h = tf.concat([diff_filter_h, diff_filter_h, diff_filter_h], 2)
                diff_filter_v = tf.constant([0, 0, 0, 0, -1, 0, 0, 1, 0], tf.float32, [3, 3, 1, 1])
                diff_filter_v = tf.concat([diff_filter_v, diff_filter_v, diff_filter_v], 2)
                diff_filter = tf.concat([diff_filter_h, diff_filter_v], 3)
                filtered_input = tf.nn.conv2d(var_image, diff_filter, [1, 1, 1, 1], "VALID")
                diff_cost = stylenet_core.l2_norm_cost(filtered_input) * 1e7

            content_cost = content_cost * content_ratio
            style_cost_3 = style_cost_3 * style3_ratio
            style_cost_4 = style_cost_4 * style4_ratio
            style_cost_gram = style_cost_gram * gram_ratio
            diff_cost = diff_cost * diff_ratio
            cost = content_cost + style_cost_3 + style_cost_4 + style_cost_gram + diff_cost

        with tf.name_scope("train"):
            global_step = tf.Variable(0, name='global_step', trainable=False)

            optimizer = tf.train.AdamOptimizer(learning_rate=0.02)
            gvs = optimizer.compute_gradients(cost)

            training = optimizer.apply_gradients(gvs, global_step=global_step)

        print("Net generated:", (time.time() - start_time))
        start_time = time.time()

        with tf.name_scope("image_out"):
            image_out = tf.clip_by_value(tf.squeeze(var_image, [0]), 0, 1)

        # saver = tf.train.Saver(max_to_keep=1)

        # checkpoint = tf.train.get_checkpoint_state("./train")
        # if checkpoint and checkpoint.model_checkpoint_path and load_trained_image:
        #     saver.restore(sess, checkpoint.model_checkpoint_path)
        #     print("save restored:", checkpoint.model_checkpoint_path)
        # else:
        sess.run(tf.global_variables_initializer())
        print("all variables init")

        print("Var init: %d" % (time.time() - start_time))

        step_out = 0
        start_time = time.time()
        for i in range(epochs):
            if i % 5 == 0:
                img = sess.run(image_out)
                img_out_path = output_file % step_out
                skimage.io.imsave(img_out_path, img)
                print("img saved: ", img_out_path)

            step_out, content_out, style_patch3_out, style_patch4_out, style_gram_out, diff_cost_out, cost_out \
                , _ = sess.run(
                [global_step, content_cost, style_cost_3, style_cost_4, style_cost_gram, diff_cost, cost,
                 training])

            duration = time.time() - start_time
            print("Step %d: cost:%.10f\t(%.1f sec)" % (step_out, cost_out, duration), \
                "\t content:%.5f, style_3:%.5f, style_4:%.5f, gram:%.5f, diff_cost_out:%.5f" \
                % (content_out, style_patch3_out, style_patch4_out, style_gram_out, diff_cost_out))

            # if (i + 1) % 10 == 0:
            #     saved_path = saver.save(sess, "./train/saves-" + get_filename(content_file),
            #                             global_step=global_step)
            #     print("net saved: ", saved_path)

        img = sess.run(image_out)
        img_out_path = output_file % step_out
        skimage.io.imsave(img_out_path, img)
        print("img saved: ", img_out_path)