def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    src_font = ImageFont.truetype(args.src_font, size=args.char_size)

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size,
                     input_width=args.canvas_size,
                     output_width=args.canvas_size,
                     experiment_id=args.experiment_id,
                     embedding_dim=args.embedding_dim,
                     embedding_num=args.embedding_num)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        model.load_model(args.model_dir)

        count = 0
        batch_buffer = list()
        examples = []
        for ch in list(args.text):
            src_img = draw_single_char_by_font(ch, src_font, args.canvas_size,
                                               args.char_size)

            paired_img = draw_paired_image(src_img, src_img, args.canvas_size)

            p = os.path.join(args.save_dir, "inferred_%04d.png" % 100)
            misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((args.embedding_id, buffered.getvalue()))
        batch_iter = get_batch_iter(examples, args.batch_size, augment=False)

        for _, images in batch_iter:
            # inject specific embedding style here
            labels = [args.embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            print("getshape", type(merged_fake_images),
                  merged_fake_images.shape)
            if len(batch_buffer
                   ) > 0 and merged_fake_images.shape != batch_buffer[0].shape:

                continue
            batch_buffer.append(merged_fake_images)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, args.save_dir)
            #     batch_buffer = list()
            count += 1

        if batch_buffer:
            # last batch
            save_imgs(batch_buffer, count, args.save_dir)
def infer_by_text_api2(str, str2, embedding_id, path):
    # CUDA_VISIBLE_DEVICES=0
    # --model_dir=experiments/checkpoint/experiment_0
    # --batch_size=32
    # --embedding_id=67
    # --save_dir=save_dir

    rootpath = os.path.dirname(os.path.abspath(__file__))
    print(rootpath)
    # path = os.path.join(rootpath, 'zi2zi')

    # default
    experiment_id = 0
    model_dir = os.path.join(rootpath, "experiments/checkpoint/experiment_0")
    batch_size = 16
    text = "库昊又双叒叕进三分了"
    # embedding_id = 67
    embedding_dim = EMBEDDING_DIM
    # save_dir = os.path.join(rootpath, 'save_dir')
    inst_norm = 1
    char_size = CHAR_SIZE
    src_font = os.path.join(rootpath, 'data/raw_fonts/SimSun.ttf')
    canvas_size = CANVAS_SIZE
    embedding_num = 185

    # ours
    text = str

    batch_size = 32
    model_dir = os.path.join(rootpath, "experiments/checkpoint/experiment_0")
    # embedding_id = 67
    save_dir = os.path.join(rootpath, "save_dir")

    # print(str, path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    src_font = ImageFont.truetype(src_font, size=char_size)

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=batch_size,
                     input_width=canvas_size,
                     output_width=canvas_size,
                     experiment_id=experiment_id,
                     embedding_dim=embedding_dim,
                     embedding_num=embedding_num)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=inst_norm)
        model.load_model(model_dir)

        count = 0
        batch_buffer = list()
        examples = []

        for ch in list(text):
            src_img = draw_single_char_by_font(ch, src_font, canvas_size,
                                               char_size)

            paired_img = draw_paired_image(src_img, src_img, canvas_size)

            # p = os.path.join(save_dir, "inferred_%04d.png" % 100)
            # p = path
            # misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((embedding_id, buffered.getvalue()))
        batch_iter1 = get_batch_iter(examples, batch_size, augment=False)

        examples = []
        for ch in list(str2):
            src_img = draw_single_char_by_font(ch, src_font, canvas_size,
                                               char_size)

            paired_img = draw_paired_image(src_img, src_img, canvas_size)

            # p = os.path.join(save_dir, "inferred_%04d.png" % 100)
            # p = path
            # misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((embedding_id, buffered.getvalue()))
        batch_iter2 = get_batch_iter(examples, batch_size, augment=False)

        for _, images in batch_iter1:
            # inject specific embedding style here
            labels = [embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            # print("getshape", type(merged_fake_images), merged_fake_images.shape)
            # if len(batch_buffer) > 0 and merged_fake_images.shape != batch_buffer[0].shape:
            #     continue
            batch_buffer.append(merged_fake_images)
            # print("getshape",merged_fake_images.shape)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, save_dir, path)
            #     batch_buffer = list()
            # count += 1
        for _, images in batch_iter2:
            # inject specific embedding style here
            labels = [embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            # print("getshape", type(merged_fake_images), merged_fake_images.shape)
            # if len(batch_buffer) > 0 and merged_fake_images.shape != batch_buffer[0].shape:
            #     continue
            batch_buffer.append(merged_fake_images)
            # print("getshape",merged_fake_images.shape)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, save_dir, path)
            #     batch_buffer = list()
            # count += 1

        if batch_buffer:
            # last batch
            # l = len(batch_buffer)
            # for i in range(l, 10):
            #     batch_buffer.append(np.ones(81))
            save_imgs2(batch_buffer, count, save_dir, path)

        model = None

    return path