示例#1
0
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size,embedding_dim=args.embedding_dim,input_width=args.image_size,output_width=args.image_size,embedding_num=args.embedding_num)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        embedding_ids = [int(i) for i in args.embedding_ids.split(",")]
        if not args.interpolate:
            if len(embedding_ids) == 1:
                embedding_ids = embedding_ids[0]
            model.infer(model_dir=args.model_dir, source_obj=args.source_obj, embedding_ids=embedding_ids,
                        save_dir=args.save_dir)
        else:
            if len(embedding_ids) < 2:
                raise Exception("no need to interpolate yourself unless you are a narcissist")
            chains = embedding_ids[:]
            if args.uroboros:
                chains.append(chains[0])
            pairs = list()
            for i in range(len(chains) - 1):
                pairs.append((chains[i], chains[i + 1]))
            for s, e in pairs:
                model.interpolate(model_dir=args.model_dir, source_obj=args.source_obj, between=[s, e],
                                  save_dir=args.save_dir, steps=args.steps)
            if args.output_gif:
                gif_path = os.path.join(args.save_dir, args.output_gif)
                compile_frames_to_gif(args.save_dir, gif_path)
                print("gif saved at %s" % gif_path)
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim)
        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True,
                              inst_norm=args.inst_norm,
                              no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])
        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    freeze_encoder=args.freeze_encoder,
                    fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps,
                    checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels)
示例#3
0
def main():

    model = UNet(args.experiment_dir,
                 batch_size=args.batch_size,
                 experiment_id=args.experiment_id,
                 input_width=args.image_size,
                 output_width=args.image_size,
                 embedding_num=args.embedding_num,
                 embedding_dim=args.embedding_dim,
                 L1_penalty=args.L1_penalty,
                 Lconst_penalty=args.Lconst_penalty,
                 Ltv_penalty=args.Ltv_penalty,
                 Lcategory_penalty=args.Lcategory_penalty)
    # model.register_session(sess)
    if args.flip_labels:
        model.build_model(is_training=True,
                          inst_norm=args.inst_norm,
                          no_target_source=True)
    else:
        model.build_model(is_training=True, inst_norm=args.inst_norm)
    fine_tune_list = None
    if args.fine_tune:
        ids = args.fine_tune.split(",")
        fine_tune_list = set([int(i) for i in ids])
    model.train(lr=args.lr,
                epoch=args.epoch,
                resume=args.resume,
                schedule=args.schedule,
                freeze_encoder=args.freeze_encoder,
                fine_tune=fine_tune_list,
                sample_steps=args.sample_steps,
                checkpoint_steps=args.checkpoint_steps,
                flip_labels=args.flip_labels,
                no_val=args.no_val)
示例#4
0
def main(_):
    # 탄력적으로 GPU 메모리를 사용하기 위해 allow_growth를 true로 설정
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        # zi2zi 전체 모델 생성(GAN 모델 전체)
        model = UNet(args.experiment_dir, batch_size=args.batch_size, experiment_id=args.experiment_id,
                     input_width=args.image_size, output_width=args.image_size, embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim, L1_penalty=args.L1_penalty, Lconst_penalty=args.Lconst_penalty,
                     Ltv_penalty=args.Ltv_penalty, Lcategory_penalty=args.Lcategory_penalty)
        model.register_session(sess)

        if args.flip_labels:
            model.build_model(is_training=True, inst_norm=args.inst_norm, no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)

        fine_tune_list = None
        # 구체적인 미세조정 글자가 옵션으로 지정되었다면,
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])

        # zi2zi 모델 학습 시작
        model.train(lr=args.lr, epoch=args.epoch, resume=args.resume,
                    schedule=args.schedule, freeze_encoder=args.freeze_encoder, fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps, checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels)
示例#5
0
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        model.export_generator(save_dir=args.save_dir, model_dir=args.model_dir)
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    src_font = ImageFont.truetype(args.src_font, size=args.char_size)

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size,
                     input_width=args.canvas_size,
                     output_width=args.canvas_size,
                     experiment_id=args.experiment_id,
                     embedding_dim=args.embedding_dim,
                     embedding_num=args.embedding_num)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        model.load_model(args.model_dir)

        count = 0
        batch_buffer = list()
        examples = []
        for ch in list(args.text):
            src_img = draw_single_char_by_font(ch, src_font, args.canvas_size,
                                               args.char_size)

            paired_img = draw_paired_image(src_img, src_img, args.canvas_size)

            p = os.path.join(args.save_dir, "inferred_%04d.png" % 100)
            misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((args.embedding_id, buffered.getvalue()))
        batch_iter = get_batch_iter(examples, args.batch_size, augment=False)

        for _, images in batch_iter:
            # inject specific embedding style here
            labels = [args.embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            print("getshape", type(merged_fake_images),
                  merged_fake_images.shape)
            if len(batch_buffer
                   ) > 0 and merged_fake_images.shape != batch_buffer[0].shape:

                continue
            batch_buffer.append(merged_fake_images)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, args.save_dir)
            #     batch_buffer = list()
            count += 1

        if batch_buffer:
            # last batch
            save_imgs(batch_buffer, count, args.save_dir)
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim,
                     L1_penalty=args.L1_penalty,
                     Lconst_penalty=args.Lconst_penalty,
                     Ltv_penalty=args.Ltv_penalty,
                     Lcategory_penalty=args.Lcategory_penalty)

        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True,
                              inst_norm=args.inst_norm,
                              no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])

        print("***************** number of parameters *******************")

        def get_num_params():
            num_params = 0
            for variable in tf.trainable_variables():
                shape = variable.get_shape()
                p = reduce(mul, [dim.value for dim in shape], 1)
                print(variable.name, p)
                num_params += p
            return num_params

        print(get_num_params())
        print("***************** number of parameters *******************")

        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    freeze_encoder=args.freeze_encoder,
                    fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps,
                    checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels)
示例#8
0
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    ##tensorflow auto-select available GPu
    #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim,
                     L1_penalty=args.L1_penalty,
                     Lconst_penalty=args.Lconst_penalty,
                     Ltv_penalty=args.Ltv_penalty,
                     Lcategory_penalty=args.Lcategory_penalty)
        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True,
                              inst_norm=args.inst_norm,
                              no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])
        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    freeze_encoder=args.freeze_encoder,
                    fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps,
                    checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels)
def infer_by_text_api2(str, str2, embedding_id, path):
    # CUDA_VISIBLE_DEVICES=0
    # --model_dir=experiments/checkpoint/experiment_0
    # --batch_size=32
    # --embedding_id=67
    # --save_dir=save_dir

    rootpath = os.path.dirname(os.path.abspath(__file__))
    print(rootpath)
    # path = os.path.join(rootpath, 'zi2zi')

    # default
    experiment_id = 0
    model_dir = os.path.join(rootpath, "experiments/checkpoint/experiment_0")
    batch_size = 16
    text = "库昊又双叒叕进三分了"
    # embedding_id = 67
    embedding_dim = EMBEDDING_DIM
    # save_dir = os.path.join(rootpath, 'save_dir')
    inst_norm = 1
    char_size = CHAR_SIZE
    src_font = os.path.join(rootpath, 'data/raw_fonts/SimSun.ttf')
    canvas_size = CANVAS_SIZE
    embedding_num = 185

    # ours
    text = str

    batch_size = 32
    model_dir = os.path.join(rootpath, "experiments/checkpoint/experiment_0")
    # embedding_id = 67
    save_dir = os.path.join(rootpath, "save_dir")

    # print(str, path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    src_font = ImageFont.truetype(src_font, size=char_size)

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=batch_size,
                     input_width=canvas_size,
                     output_width=canvas_size,
                     experiment_id=experiment_id,
                     embedding_dim=embedding_dim,
                     embedding_num=embedding_num)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=inst_norm)
        model.load_model(model_dir)

        count = 0
        batch_buffer = list()
        examples = []

        for ch in list(text):
            src_img = draw_single_char_by_font(ch, src_font, canvas_size,
                                               char_size)

            paired_img = draw_paired_image(src_img, src_img, canvas_size)

            # p = os.path.join(save_dir, "inferred_%04d.png" % 100)
            # p = path
            # misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((embedding_id, buffered.getvalue()))
        batch_iter1 = get_batch_iter(examples, batch_size, augment=False)

        examples = []
        for ch in list(str2):
            src_img = draw_single_char_by_font(ch, src_font, canvas_size,
                                               char_size)

            paired_img = draw_paired_image(src_img, src_img, canvas_size)

            # p = os.path.join(save_dir, "inferred_%04d.png" % 100)
            # p = path
            # misc.imsave(p, paired_img)

            buffered = BytesIO()
            paired_img.save(buffered, format="JPEG")

            examples.append((embedding_id, buffered.getvalue()))
        batch_iter2 = get_batch_iter(examples, batch_size, augment=False)

        for _, images in batch_iter1:
            # inject specific embedding style here
            labels = [embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            # print("getshape", type(merged_fake_images), merged_fake_images.shape)
            # if len(batch_buffer) > 0 and merged_fake_images.shape != batch_buffer[0].shape:
            #     continue
            batch_buffer.append(merged_fake_images)
            # print("getshape",merged_fake_images.shape)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, save_dir, path)
            #     batch_buffer = list()
            # count += 1
        for _, images in batch_iter2:
            # inject specific embedding style here
            labels = [embedding_id] * len(images)

            fake_imgs = model.generate_fake_samples(images, labels)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [-1, 1])  # scale 0-1
            # print("getshape", type(merged_fake_images), merged_fake_images.shape)
            # if len(batch_buffer) > 0 and merged_fake_images.shape != batch_buffer[0].shape:
            #     continue
            batch_buffer.append(merged_fake_images)
            # print("getshape",merged_fake_images.shape)
            # if len(batch_buffer) == 10:
            #     save_imgs(batch_buffer, count, save_dir, path)
            #     batch_buffer = list()
            # count += 1

        if batch_buffer:
            # last batch
            # l = len(batch_buffer)
            # for i in range(l, 10):
            #     batch_buffer.append(np.ones(81))
            save_imgs2(batch_buffer, count, save_dir, path)

        model = None

    return path