示例#1
0
def g_init_model(features, labels, mode, params):
    del params

    if mode == tf.estimator.ModeKeys.PREDICT:
        net_g_test = SRGAN_g(features, is_train=False)

        predictions = {'generated_images': net_g_test.outputs}

        return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)

    net_g = SRGAN_g(features, is_train=True)
    _ = SRGAN_d(labels, is_train=True)

    mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False)

    g_optimizer = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1)
    g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer)
    init_ops = g_optimizer.minimize(mse_loss,
                                    var_list=g_vars,
                                    global_step=tf.train.get_global_step())

    return tf.contrib.tpu.TPUEstimatorSpec(mode,
                                           loss=mse_loss,
                                           train_op=init_ops)
示例#2
0
def srgan_model(features, labels, mode, params):
    del params
    global load_flag

    if mode == tf.estimator.ModeKeys.PREDICT:
        net_g_test = SRGAN_g(features, is_train=False)

        predictions = {'generated_images': net_g_test.outputs}

        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    net_g = SRGAN_g(features, is_train=True)
    net_d, logits_real = SRGAN_d(labels, is_train=True)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True)

    t_target_image_224 = tf.image.resize_images(labels,
                                                size=[224, 224],
                                                method=0,
                                                align_corners=False)
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2)

    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False)

    # SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \
        .minimize(g_loss, var_list=g_vars, global_step=tf.train.get_global_step())
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \
        .minimize(d_loss, var_list=d_vars, global_step=tf.train.get_global_step())

    joint_op = tf.group([g_optim, d_optim])

    load_vgg(net_vgg)

    return tf.estimator.EstimatorSpec(mode, loss=g_loss, train_op=joint_op)
示例#3
0
    def __init__(self, args, global_net, vgg_net, train_dataloader,
                 test_dataloader):
        BasicTask.__init__(self)
        self.args = args
        # Datasets
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader

        # Network
        self.global_net = global_net
        self.net_g = SRGAN_g()
        self.vgg_net = vgg_net
        self.net_g.load_state_dict(self.global_net.state_dict())

        if args.cuda:
            self.global_net = self.global_net.cuda()
            self.net_g = self.net_g.cuda()
            self.vgg_net = self.vgg_net.cuda()

        # Optimizer
        self.opt = torch.optim.Adam(self.net_g.parameters(),
                                    lr=args.lr,
                                    weight_decay=1e-6)
        self.global_opt = torch.optim.Adam(self.global_net.parameters(),
                                           lr=args.lr,
                                           weight_decay=1e-6)

        # loss functions
        self.loss_items = {
            'vgg': {
                'func': utils.ContentLoss(self.vgg_net),
                'factor': 2e-6
            },
            'mse': {
                'func': F.mse_loss,
                'factor': 1.0
            }
        }

        # Summary
        self.writer = SummaryWriter(args.o_dir)

        # if args.resume:
        # 	self.net_g.load_state_dict(torch.load('../weights/SR_epoch09.pth'))

        # RL param
        self.max_episode_steps = 100
        self.action_dim = 1
        self.state_dim = 1
        self.env_step = 0
示例#4
0
def evaluate(args):
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)

    valid_lr_img = scipy.misc.imread(args.input, mode='RGB')
    #valid_lr_img = tl.vis.read_image(os.path.basename(args.input), os.path.dirname(args.input))


    ###========================== DEFINE MODEL ============================###
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]

    size = valid_lr_img.shape
    print("Inpu image size: " + str(size) )
    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name='checkpoint/g_srgan.npz', network=net_g)

    ###======================= EVALUATION =============================###
    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
    print("took: %4.4fs" % (time.time() - start_time))

    print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    tl.vis.save_image(out[0], args.output)
示例#5
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"
    ###========================== DEFINE MODEL ============================###
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    sess = tf.Session(config=tf.ConfigProto(device_count={'GPU':0},allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)
    for i in [5,29,35,62,78,83,150,192,258,289,310,]:
        valid_lr_img = get_imgs_fn(str(i)+'.png', '../..//model3/srgan/data2017/LR/')
        valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
        size = valid_lr_img.shape

    ###======================= EVALUATION =============================###
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
        print("took: %4.4fs" % (time.time() - start_time))

        print("LR size: %s /  generated HR size: %s" % (size, out.shape))
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir + '/valid_gen_8k_'+str(i)+'.png')
        out_4 = scipy.misc.imresize(out[0], [2160, 3840], mode=None)
        tl.vis.save_image(out_4, save_dir + '/valid_gen_4k_'+str(i)+'.png')
        tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_'+str(i)+'.png')

        out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 2, size[1] * 2], interp='bicubic', mode=None)
        tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic_'+str(i)+'.png')
示例#6
0
文件: main.py 项目: GinZhu/srgan
def evaluate():
    ## create folders to save result images
    save_dir = "/local/scratch/jz426/superResolution/results/samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "/local/scratch/jz426/superResolution/results/checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###

    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)

    ###======================= EVALUATION =============================###
    for imid in range(len(valid_lr_imgs)):

    # imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
        valid_lr_img = valid_lr_imgs[imid]
        valid_hr_img = valid_hr_imgs[imid]
        # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
        valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
        # print(valid_lr_img.min(), valid_lr_img.max())

        size = valid_lr_img.shape
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
        print("took: %4.4fs" % (time.time() - start_time))

        print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir + '/' + str(imid) + 'valid_gen.png')
        tl.vis.save_image(valid_lr_img, save_dir + '/' + str(imid) + 'valid_lr.png')
        tl.vis.save_image(valid_hr_img, save_dir + '/' + str(imid) + 'valid_hr.png')

        out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
        tl.vis.save_image(out_bicu, save_dir + '/' + str(imid) + 'valid_bicubic.png')
示例#7
0
def upscale_function(image, model_checkpoint, reuse=False):
    ## create folders to save result images

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    # valid_lr_img = valid_lr_imgs[imid]
    # valid_hr_img = valid_hr_imgs[imid]

    # image_name = '.'.join(os.path.basename(args.image_path).split('.')[:-1])

    # print('reuse = ', reuse)
    valid_lr_img = (image / 127.5) - 1  # rescale to [-1, 1]
    # print(valid_lr_img.min(), valid_lr_img.max())

    size = valid_lr_img.shape
    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=reuse)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=model_checkpoint,
                                 network=net_g)

    ###======================= EVALUATION =============================###
    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
    # print("took: %4.4fs" % (time.time() - start_time))
    #
    # print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    # print("[*] save images")

    output = out[0]
    sess.close()
    tf.reset_default_graph()
    return output
示例#8
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/valid"
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA and SAVE SAMPLEs ===========================###

    valid_hr_imgs = read_csv_data(config.VALID.hr_img_path,
                                  width=48,
                                  height=48,
                                  channel=1)
    sample_hr_imgs = valid_hr_imgs[10:19]
    sample_hr_imgs = tl.prepro.threading_data(sample_hr_imgs,
                                              fn=crop_sub_imgs_fn,
                                              is_random=False)
    tl.vis.save_images(sample_hr_imgs, [ni, ni], save_dir + '/_hr_sample.png')

    sample_lr_imgs = tl.prepro.threading_data(sample_hr_imgs,
                                              fn=downsample_fn,
                                              down_rate=3)
    sample_bicubuc_imgs = tl.prepro.threading_data(sample_lr_imgs,
                                                   fn=upsample_fn,
                                                   up_rate=3)
    tl.vis.save_images(sample_bicubuc_imgs, [ni, ni],
                       save_dir + '/_bicubic_sample.png')

    single_hr_img = crop_sub_imgs_fn(valid_hr_imgs[1], is_random=False)
    tl.vis.save_image(single_hr_img, save_dir + '/_hr.png')
    single_lr_img = downsample_fn(single_hr_img, down_rate=3)
    single_bicubic_img = upsample_fn(single_lr_img, up_rate=3)
    tl.vis.save_image(single_bicubic_img, save_dir + '/_bicubic.png')

    ###========================== DEFINE MODEL ============================###
    t_image = tf.placeholder('float32', [None, 16, 16, 1], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=config.MODEL_path,
                                 network=net_g)

    ###======================= EVALUATION =============================###
    start_time = time.time()
    single_lr_img = np.expand_dims(single_lr_img, axis=0)
    out = sess.run(net_g.outputs, {t_image: single_lr_img})
    out = np.squeeze(out, axis=0)
    print("took: %4.4fs" % (time.time() - start_time))
    tl.vis.save_image(out, save_dir + '/_srgan.png')

    out1 = sess.run(net_g.outputs, {t_image: sample_lr_imgs})
    tl.vis.save_images(out1, [ni, ni], save_dir + '/_srgan_samples.png')
示例#9
0
def predict(test_lr_path, checkpoint_path, save_path):
    '''
    Parameters:
    data:
        test_lr_path: path of test data
        checkpoint_path: where to fetch weights
        save_path: where to save output
    '''
    ## create folders to save result images
    save_dir = os.path.join(save_path, 'test_gen')
    tl.files.exists_or_mkdir(save_dir)

    ###======PRE-LOAD DATA======###
    test_lr_img_list = sorted(
        tl.files.load_file_list(path=test_lr_path,
                                regx='.*.jpg',
                                printable=False))

    test_lr_imgs = tl.vis.read_images(test_lr_img_list,
                                      path=test_lr_path,
                                      n_threads=32)

    ###======DEFINE MODEL======###

    test_lr_imgs = [(img / 127.5) - 1
                    for img in test_lr_imgs]  # rescale to [-1, 1]

    test_image = tf.placeholder('float32', [1, None, None, 3],
                                name='input_image')

    net_g = SRGAN_g(test_image, is_train=False, reuse=False)

    ###======RESTORE G======###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    tl.files.load_and_assign_npz(sess=sess,
                                 name=os.path.join(checkpoint_path,
                                                   'g_srgan.npz'),
                                 network=net_g)

    ###======EVALUATION======###
    start_time = time.time()
    for i in range(len(test_lr_img_list)):
        img = test_lr_imgs[i]
        out = sess.run(net_g.outputs, {test_image: [img]})
        out = (out[0] + 1) * 127.5
        tl.vis.save_image(
            out.astype(np.uint8),
            os.path.join(save_dir, '{}'.format(test_lr_img_list[i])))
        if (i != 0) and (i % 10 == 0):
            print('saving %d images, ok' % i)

    print('take: %4.2fs' % (time.time() - start_time))
示例#10
0
def upsample_images(pth_src, pth_dst, pth_checkpoint):
    import numpy as np
    import scipy

    import tensorlayer as tl
    import tensorflow as tf
    from model import SRGAN_g

    print("==== UPSAMPLING IMAGES")
    ## create folders to save result images
    tl.files.exists_or_mkdir(pth_dst)

    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    # Restore Generator
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name=pth_checkpoint, network=net_g)
    print("loaded srgan model from {}".format(pth_checkpoint))

    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=pth_src,
                                regx='.*.(jpg|png)',
                                printable=False))
    #valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=pth_src, n_threads=32)
    print("found {} valid images to upscale...".format(len(valid_lr_img_list)))
    for n, fname in enumerate(valid_lr_img_list):
        bname = os.path.splitext(os.path.join(pth_src, fname))[0]
        img_src = scipy.misc.imread(os.path.join(pth_src, fname), mode='RGB')
        img_src = (img_src / 127.5) - 1  # rescale to [-1, 1]
        size = img_src.shape
        if size[0] > MAX_SIZE or size[1] > MAX_SIZE:
            print("Image is too big ({}x{}). Skipping.".format(
                size[0], size[1]))
            continue

        # Evaluate
        start_time = time.time()
        img_dst = sess.run(net_g.outputs, {t_image: [img_src]})
        img_dst = ((img_dst + 1) / 2.0) * 255  # rescale to [0,255]
        img_dst = img_dst.astype(
            np.uint8)  # convert to unsigned int for saving to image
        print(
            "{} of {}\tUpsampling {} from {}x{} to {}x{} took {:.2f}s".format(
                n, len(valid_lr_img_list), fname, size[0], size[1],
                img_dst.shape[1], img_dst.shape[2],
                time.time() - start_time))
        tl.vis.save_image(img_dst[0], os.path.join(pth_dst, fname))
示例#11
0
def evaluate():
    ## create folders to save result images
    save_dir = samples_path + "evaluate"
    tl.files.exists_or_mkdir(save_dir)

    ###========================== DEFINE MODEL ============================###
    eval_img_name_list = load_deep_file_list(path=eval_img_path,
                                             regx=eval_img_name_regx,
                                             recursive=False,
                                             printable=False)
    print(eval_img_name_list)
    valid_lr_img = get_imgs_fn(
        eval_img_name_list[0],
        eval_img_path)  # if you want to test your own image
    valid_lr_img = rescale_m1p1(valid_lr_img)

    size = valid_lr_img.shape
    t_image = tf.compat.v1.placeholder('float32', [1, None, None, 3],
                                       name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.variables_initializer(tf.global_variables()))
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_path + 'g.npz',
                                 network=net_g)

    ###======================= EVALUATION =============================###
    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
    print("took: %4.4fs" % (time.time() - start_time))

    print("LR size: %s /  generated HR size: %s" % (size, out.shape))
    print("[*] save images")
    out = (out + 1) * 127.5  # rescale to [0, 255]
    out_uint8 = out.astype('uint8')
    save_img_fn(out_uint8[0], save_file_format, save_dir + '/valid_gen')

    out_bicu = (valid_lr_img + 1) * 127.5  # rescale to [0, 255]
    out_bicu = np.array(
        Image.fromarray(np.uint8(out_bicu)).resize((size[1] * 4, size[0] * 4),
                                                   Image.BICUBIC))
    out_bicu_uint8 = out_bicu.astype('uint8')
    save_img_fn(out_bicu_uint8, save_file_format, save_dir + '/valid_bicubic')
示例#12
0
def super_resolution_image(image_path):

    filename = ntpath.basename(image_path)
    filepath = image_path.replace(filename, "")
    scriptpath = os.path.dirname(os.path.realpath(__file__))
    temppath = '{0}\\temp'.format(scriptpath)

    tempname = next(tempfile._get_candidate_names()) + ".jpg"
    output_filename = '{0}\\output_images\\{1}'.format(scriptpath, tempname)

    ###========================== DEFINE MODEL ============================###

    valid_lr_img = get_imgs_fn(filename,
                               filepath)  # if you want to test your own image
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]

    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name='{0}\\g_srgan.npz'.format(scriptpath),
                                 network=net_g)

    ###======================= EVALUATION =============================###

    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})

    print("took: %4.4fs" % (time.time() - start_time))
    print("[*] save images")

    tl.vis.save_image(out[0], output_filename)
    send_to_ps(output_filename)

    # clear temp folder
    ([os.remove(os.path.join(temppath, f)) for f in os.listdir(temppath)])
示例#13
0
def evaluate(args):
    ## create folders to save result images
    save_dir = args.output_dir
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    valid_lr_img = get_imgs_fn(
        "frame_0001.ppm", args.input_dir)  # if you want to test your own image
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]

    size = valid_lr_img.shape
    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    #tf_config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=tf_config)
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)

    ###======================= EVALUATION =============================###

    for i in range(1, 73):
        valid_lr_img = get_imgs_fn("frame_%04d" % i + ".ppm", args.input_dir)
        valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
        print("took: %4.4fs" % (time.time() - start_time))
        print("LR size: %s /  generated HR size: %s" % (size, out.shape)
              )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images %d" % i)
        tl.vis.save_image(out[0], save_dir + "frame_%04d" % i + ".png")
def visualize(epoch):

    #checkpoint_dir = "checkpoint"

    checkpoint_dir = "/Users/btopiwala/Downloads/CS231N/2018/Project/gcloud-run-all-data/checkpoint_epoch_20_epoch_48_with_intermediate_checkpoint/checkpoint"

    ###========================== DEFINE MODEL ============================###

    #t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    t_image = tf.placeholder('float32', [1, None, None, 1],
                             name='input_image')  # 1 for 1 channel

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/g_srgan_{}.npz'.format(epoch),
                                 network=net_g)
示例#15
0
def evaluate():
    # create folders to save result images
    checkpoint_dir = "/home/fan/su/remove_face/checkpoint/facenet_pgd_joint_loss/"
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # 定义模型
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)

    for epoch_n in range(200, 250, 150):
        save_dir = "samples/evaluate/" + str(epoch_n)
        tl.files.exists_or_mkdir(save_dir)
        tl.files.load_and_assign_npz(sess=sess,
                                     name=checkpoint_dir +
                                     '/g_srgan_softmax%d.npz' % epoch_n,
                                     network=net_g)
        # 设置输入样本级数量
        for imid in range(1135):
            valid_lr_img = valid_lr_imgs[imid]
            print(valid_lr_img)
            valid_lr_img = (valid_lr_img / 127.5) - 1  # 归一化到[-1, 1]
            # 开始评估 123
            start_time = time.time()
            out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
            print("took: %4.4fs" % (time.time() - start_time))
            print("[*] save images")
            print(out)
            tl.vis.save_image(out[0],
                              save_dir + '/%s' % (valid_lr_img_list[imid]))
def export_model():
    """Load the model in TensorLayer's way and save
    the frozen graph

    Args:
        None

    Returns:
        None
    """

    # create folders to save result images
    checkpoint_dir = "checkpoint"

    ###========================== DEFINE MODEL ============================###
    t_image = tf.placeholder('float32', [None, None, None, 3],
                             name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)

    # Load model from .npz file
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)

    # export to meta file
    saver = tf.train.Saver()
    saver.save(sess, './meta/srgan')
    tf.train.write_graph(sess.graph.as_graph_def(),
                         '.',
                         './meta/srgan.pbtxt',
                         as_text=True)
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    valid_lr_imgs = []

    for img__ in valid_lr_img_list:
        image_loaded = scipy.misc.imread(os.path.join(config.VALID.lr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))

        valid_lr_imgs.append(image_loaded)

    print(type(valid_lr_imgs), len(valid_lr_img_list))

    valid_hr_imgs = []

    for img__ in valid_hr_img_list:
        image_loaded = scipy.misc.imread(os.path.join(config.VALID.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))

        valid_hr_imgs.append(image_loaded)

    print(type(valid_hr_imgs), len(valid_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    imid = 1
    valid_lr_img = valid_lr_imgs[imid]
    valid_hr_img = valid_hr_imgs[imid]
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]

    size = valid_lr_img.shape
    t_image = tf.placeholder('float32', [1, None, None, 1],
                             name='input_image')  # 1 for 1 channel

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)

    ###======================= EVALUATION =============================###
    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
    print("took: %4.4fs" % (time.time() - start_time))

    print("LR size: %s /  generated HR size: %s" % (size, out.shape))
    print("[*] save images")
    tl.vis.save_image(out[0], save_dir + '/valid_gen.png')
    tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')
    tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')

    valid_lr_img = valid_lr_img.reshape(valid_lr_img.shape[0],
                                        valid_lr_img.shape[1])
    out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4],
                                   interp='bicubic')
    tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')

    hr_img_path = save_dir + '/valid_hr.png'
    bi_cubic_img_path = save_dir + '/valid_bicubic.png'
    gen_img_path = save_dir + '/valid_gen.png'

    bicubic_psnr = computePSNR(hr_img_path, bi_cubic_img_path)
    gen_psnr = computePSNR(hr_img_path, gen_img_path)

    gnd_truth_hr_img = scipy.misc.imread(hr_img_path, mode='L')
    generated_hr_img = scipy.misc.imread(gen_img_path, mode='L')

    gen_ssim = skimage.measure.compare_ssim(gnd_truth_hr_img, generated_hr_img)

    print("Bicubic image PSNR:", bicubic_psnr)
    print("Generated image PSNR:", gen_psnr)
    print('Generated image SSIML', gen_ssim)
def train():
    ## create folders to save result images and trained model
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])  #srresnet
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine has enough memory, please pre-load the whole train set.
    print("reading images")

    train_hr_imgs = []

    for img__ in train_hr_img_list:
        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))
        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    t_image = tf.placeholder('float32', [batch_size, 56, 56, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1],
                                    name='t_target_image')

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True,
                    reuse=False)  #SRGAN_g is the SRResNet portion of the GAN

    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_g.print_params(False)
    net_g.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss

    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## SRResNet
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/g_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_g)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_224 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(),
          sample_imgs_224.max())
    sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(),
          sample_imgs_56.max())
    tl.vis.save_images(sample_imgs_56, [ni, ni],
                       save_dir_gan + '/_train_sample_56.png')
    tl.vis.save_images(sample_imgs_224, [ni, ni],
                       save_dir_gan + '/_train_sample_224.png')
    #tl.vis.save_image(sample_imgs_96[0],  save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png')

    ###========================= train SRResNet  =========================###

    merged_summary_generator = tf.summary.merge(
        [mse_loss_summary, vgg_loss_summary,
         g_loss_summary])  #g_gan_loss_summary
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_g_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn)

            summary_pb = tf.summary.Summary()

            ## update G
            errG, errM, errV, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_56,
                    t_target_image: b_imgs_224
                })  #g_ga_loss

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM,
                   errV))
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56})
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 3 == 0):

            tl.files.save_npz(
                net_g.all_params,
                name=checkpoint_dir +
                '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch),
                sess=sess)
示例#19
0
def train():
    ## create folders to save result images and trained model
    save_dir_gan = samples_path + "gan"
    tl.files.exists_or_mkdir(save_dir_gan)
    tl.files.exists_or_mkdir(checkpoint_path)

    ###====================== PRE-LOAD DATA ===========================###
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=valid_hr_img_path,
                                regx='.*\.(bmp|png|webp|jpg)',
                                printable=False))

    ###========================== DEFINE MODEL ============================###
    ## train inference
    sample_t_image = tf.compat.v1.placeholder(
        'float32', [sample_batch_size, 96, 96, 3],
        name='sample_t_image_input_to_SRGAN_generator')
    t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3],
                                       name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.compat.v1.placeholder('float32',
                                              [batch_size, 384, 384, 3],
                                              name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## test inference
    net_g_test = SRGAN_g(sample_t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###

    # MAE Loss
    mae_loss = tf.reduce_mean(tf.map_fn(tf.abs,
                                        t_target_image - net_g.outputs))

    # GAN Loss
    d_loss = 0.5 * (
        tf.reduce_mean(
            tf.square(logits_real - tf.reduce_mean(logits_fake) - 1)) +
        tf.reduce_mean(
            tf.square(logits_fake - tf.reduce_mean(logits_real) + 1)))
    g_gan_loss = 0.5 * (
        tf.reduce_mean(
            tf.square(logits_real - tf.reduce_mean(logits_fake) + 1)) +
        tf.reduce_mean(
            tf.square(logits_fake - tf.reduce_mean(logits_real) - 1)))

    g_loss = 1e-1 * g_gan_loss + mae_loss

    d_real = tf.reduce_mean(logits_real)
    d_fake = tf.reduce_mean(logits_fake)

    with tf.variable_scope('learning_rate'):
        learning_rate_var = tf.Variable(learning_rate, trainable=False)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    ## SRGAN
    g_optim = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate_var).minimize(g_loss, var_list=g_vars)
    d_optim = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate_var).minimize(d_loss, var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.variables_initializer(tf.global_variables()))
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_path + 'g.npz',
                                 network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_path + 'd.npz',
                                 network=net_d)

    ###============================= TRAINING ===============================###
    sample_imgs = tl.prepro.threading_data(
        valid_hr_img_list[0:sample_batch_size],
        fn=get_imgs_fn,
        path=valid_hr_img_path)
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    save_images(sample_imgs_96, [ni, ni], save_file_format,
                save_dir_gan + '/_train_sample_96')
    save_images(sample_imgs_384, [ni, ni], save_file_format,
                save_dir_gan + '/_train_sample_384')

    ###========================= train GAN =========================###
    sess.run(tf.assign(learning_rate_var, learning_rate))
    for epoch in range(0, n_epoch_gan + 1):
        epoch_time = time.time()
        total_d_loss, total_g_loss_mae, total_g_loss_gan, n_iter = 0, 0, 0, 0

        train_hr_img_list = load_deep_file_list(path=train_hr_img_path,
                                                regx='.*\.(bmp|png|webp|jpg)',
                                                recursive=True,
                                                printable=False)
        random.shuffle(train_hr_img_list)

        list_length = len(train_hr_img_list)
        print("Number of images: %d" % (list_length))

        if list_length % batch_size != 0:
            train_hr_img_list += train_hr_img_list[0:batch_size -
                                                   list_length % batch_size:1]

        list_length = len(train_hr_img_list)
        print("Length of list: %d" % (list_length))

        for idx in range(0, list_length, batch_size):
            step_time = time.time()
            b_imgs_list = train_hr_img_list[idx:idx + batch_size]
            b_imgs = tl.prepro.threading_data(b_imgs_list,
                                              fn=get_imgs_fn,
                                              path=train_hr_img_path)
            b_imgs_384 = tl.prepro.threading_data(b_imgs,
                                                  fn=crop_data_augment_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            b_imgs_384 = tl.prepro.threading_data(b_imgs_384, fn=rescale_m1p1)

            ## update D
            errD, d_r, d_f, _ = sess.run([d_loss, d_real, d_fake, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            errM, errA, _, _ = sess.run(
                [mae_loss, g_gan_loss, g_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            print(
                "Epoch[%2d/%2d] %4d time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f d_r: %.8f d_f: %.8f"
                % (epoch, n_epoch_gan, n_iter, time.time() - step_time, errD,
                   errM, errA, d_r, d_f))
            total_d_loss += errD
            total_g_loss_mae += errM
            total_g_loss_gan += errA
            n_iter += 1

        log = (
            "[*] Epoch[%2d/%2d] time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f"
            % (epoch, n_epoch_gan, time.time() - epoch_time, total_d_loss /
               n_iter, total_g_loss_mae / n_iter, total_g_loss_gan / n_iter))
        print(log)

        ## quick evaluation on train set
        out = sess.run(net_g_test.outputs, {sample_t_image: sample_imgs_96})
        print("[*] save images")
        save_images(out, [ni, ni], save_file_format,
                    save_dir_gan + '/train_%d' % epoch)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_path + 'g.npz',
                          sess=sess)
        tl.files.save_npz(net_d.all_params,
                          name=checkpoint_path + 'd.npz',
                          sess=sess)
示例#20
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list2 = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path2,
                                regx='.*.png',
                                printable=False))
    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    valid_lr_imgs2 = tl.vis.read_images(valid_lr_img_list2,
                                        path=config.VALID.lr_img_path2,
                                        n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    tf_gen_output = tf.placeholder('float32', [240, 320, 3])
    tf_hr_output = tf.placeholder('float32', [240, 320, 3])
    t_image = tf.placeholder('float32', [1, None, None, 6], name='input_image')
    avg_image = tf.placeholder('float32', [1, None, None, 3],
                               name='average_image')
    net_g = SRGAN_g(t_image, avg_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)

    for i in range(0, 30):

        imid = i  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡

        valid_lr_img_t1 = tl.prepro.threading_data(valid_lr_imgs,
                                                   fn=downsample_fn)
        valid_lr_img_t2 = tl.prepro.threading_data(valid_lr_imgs2,
                                                   fn=downsample_fn)
        valid_lr_img_t1_d = valid_lr_img_t1[imid]
        valid_lr_img_t2_d = valid_lr_img_t2[imid]
        print(valid_lr_img_t1_d.shape)
        print(valid_lr_img_t2_d.shape)
        valid_hr_img = valid_hr_imgs[imid]
        valid_lr_img = np.concatenate((valid_lr_img_t1_d, valid_lr_img_t2_d),
                                      axis=2)
        valid_avg_img = (np.add(valid_lr_img_t1_d, valid_lr_img_t2_d)) / 2.

        # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
        # valid_lr_img =  tl.prepro.threading_data(valid_lr_img, fn=downsample_fn)  # rescale to [-1, 1]
        # print(valid_lr_img.min(), valid_lr_img.max())

        size = valid_lr_img.shape
        print(size)
        # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size

        ###======================= EVALUATION =============================###
        start_time = time.time()
        out = sess.run(net_g.outputs, {
            t_image: [valid_lr_img],
            avg_image: [valid_avg_img]
        })
        print("took: %4.4fs" % (time.time() - start_time))

        print("LR size: %s /  generated HR size: %s" % (size, out.shape)
              )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir + '/{}/valid_gen.png'.format(imid))
        tl.vis.save_image(valid_lr_img_t1_d,
                          save_dir + '/{}/valid_lr_1.png'.format(imid))
        tl.vis.save_image(valid_lr_img_t2_d,
                          save_dir + '/{}/valid_lr_3.png'.format(imid))
        tl.vis.save_image(valid_hr_img,
                          save_dir + '/{}/valid_hr.png'.format(imid))

        # ssim_gen = tf.image.decode_png(save_dir + '/{}/valid_gen.png'.format(imid))
        # ssim_hr = tf.image.decode_png(save_dir + '/{}/valid_hr.png'.format(imid))

        # ssim_gen_output = out[0]
        # ssim_hr_output = valid_hr_img
        # print(ssim_gen_output)
        # print(ssim_hr_output)

        ssim_gen_output = tl.vis.read_image('valid_gen.png',
                                            save_dir + '/{}/'.format(imid))
        ssim_hr_output = tl.vis.read_image('valid_hr.png',
                                           save_dir + '/{}/'.format(imid))
        print(ssim_gen_output)
        print(ssim_hr_output)

        ssim1 = tf.image.ssim(tf_gen_output, tf_hr_output, max_val=1.0)

        tf_ssim = sess.run(ssim1,
                           feed_dict={
                               tf_gen_output: ssim_gen_output,
                               tf_hr_output: ssim_hr_output
                           })

        # pre_gray = prediction[:, :, 0]
        # frame2_gray = frame2[:, :, 0]
        # ssim = SSIM(pre_gray, frame2_gray).mean()

        print(tf_ssim)
示例#21
0
def evaluate(data, n_patients_train, eval_model, save_imgs=False):
    ## create folders for checkpoint and results
    checkpoint_dir = "models_checkpoints"
    results_dir = None
    if eval_model == '/g_srgan.npz':
        results_dir = "srgan_results"
    else:
        results_dir = "srresnet_results"
    tl.files.exists_or_mkdir(results_dir)

    ###========================== RESTORE G =============================###
    t_image = tf.placeholder('float32', [1, 512, 512, 1], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + eval_model,
                                 network=net_g)

    ###======================= EVALUATION =============================###
    counter, imgs_evald, total_mse = 0, 0, 0

    for patient, values in data.items():
        if counter >= n_patients_train:
            print("[] Evaluating patient " + patient + " files")
            tl.files.exists_or_mkdir(results_dir + "/" + patient)
            valid_lr_imgs = values[0]
            valid_hr_imgs = values[1]
            patient_mse = 0
            for i in range(len(valid_lr_imgs)):
                valid_lr_img = valid_lr_imgs[i]
                valid_hr_img = valid_hr_imgs[i]
                valid_lr_img = np.asarray(valid_lr_img).reshape((512, 512, 1))
                valid_hr_img = np.asarray(valid_hr_img).reshape((512, 512, 1))
                out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})

                curr_mse = mse(out, valid_hr_img)
                imgs_evald += 1
                patient_mse += curr_mse
                total_mse += curr_mse

                if save_imgs:
                    tl.vis.save_image(
                        out[0], results_dir + "/" + patient + "/" +
                        str(patient) + "_" + str(i) + '_valid_gen.png')
                    tl.vis.save_image(
                        valid_lr_img, results_dir + "/" + patient + "/" +
                        str(patient) + "_" + str(i) + '_valid_lr.png')
                    tl.vis.save_image(
                        valid_hr_img, results_dir + "/" + patient + "/" +
                        str(patient) + "_" + str(i) + '_valid_hr.png')

                if i % 100 == 0:
                    print("Batch " + str((i / float(100))) + "/" +
                          str(math.ceil(len(valid_lr_imgs) / float(100))))

            patient_mse /= len(valid_lr_imgs)
            print("Average MSE: " + str(patient_mse))

        counter += 1

    total_mse /= imgs_evald
    print("[*] Evaluation -- total MSE: " + str(total_mse))
示例#22
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)

    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    for im in valid_lr_imgs:
        print(im.shape)

    # # for im in train_hr_imgs:
    # #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # # for im in valid_lr_imgs:
    # #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # # for im in valid_hr_imgs:
    # #     print(im.shape)
    # # exit()
    # t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    t_image = tf.placeholder('float32', [config.TRAIN.batch_size, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)

    print("valid lr img list:" + str(valid_lr_img_list))
    ###========================== DEFINE MODEL ============================###
    # imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    # for n in range(len(valid_lr_imgs)):
    for n in range(len(valid_lr_imgs)/config.TRAIN.batch_size):
    # imid = 0  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
        imid = n * config.TRAIN.batch_size
        # valid_lr_img = valid_lr_imgs[imid]
        # valid_lr_img = valid_lr_imgs[imid:imid+config.TRAIN.batch_size]
        valid_lr_img = valid_lr_imgs[imid]


        # Form the first 3 channel image
        if len(valid_lr_img.shape) == 2:
            valid_lr_img = np.expand_dims(valid_lr_img, axis=2)
            print("resized: " + str(valid_lr_img.shape))
            valid_lr_img = np.concatenate((valid_lr_img, valid_lr_img, valid_lr_img), axis=2)
            print("resized: " + str(valid_lr_img.shape))


            for i in range(1, config.TRAIN.batch_size):
                curr_valid_lr_img = np.expand_dims(valid_lr_imgs[imid+i], axis=2)
                print("resized: " + str(curr_valid_lr_img.shape))
                curr_valid_lr_img = np.concatenate((curr_valid_lr_img, curr_valid_lr_img, curr_valid_lr_img), axis=2)
                print("resized: " + str(valid_lr_img.shape))
                valid_lr_img = np.concatenate(valid_lr_img, curr_valid_lr_img, axis=0)
        else:
            curr_valid_lr_img = valid_lr_imgs[imid]
            curr_valid_lr_img = (curr_valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
            curr_valid_lr_img = np.expand_dims(curr_valid_lr_img, axis=0)
            res_img = curr_valid_lr_img
            for i in range(1, len(valid_lr_img)):
                curr_valid_lr_img = valid_lr_imgs[imid+i]
                curr_valid_lr_img = (curr_valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
                print("curr valid img shape before expand: " + str(curr_valid_lr_img.shape))
                curr_valid_lr_img = np.expand_dims(curr_valid_lr_img, axis=0)
                print("curr valid img shape: " + str(curr_valid_lr_img.shape))
                res_img = np.concatenate((res_img, curr_valid_lr_img), axis=0)
                
            # valid_lr_img = valid_lr_imgs[imid:imid+config.TRAIN.batch_size]
            valid_lr_img = res_img
            print("bahbahbah res img shape: " + str(res_img.shape))


        # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
        # Resclae [-1, 1] for each img
        #for i in range(len(valid_lr_img)):
            #valid_lr_img[i] = (valid_lr_img[i] / 127.5) - 1  # rescale to [-1, 1]
        # valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
        # print(valid_lr_img.min(), valid_lr_img.max())

        size = valid_lr_img.shape[1:]
        print("size shape: " + str(size))
        # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size

        ###======================= EVALUATION =============================###
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
        print("took: %4.4fs" % (time.time() - start_time))

        print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images")

        # Save images
        for i in range(config.TRAIN.batch_size):
            tl.vis.save_image(out[0], save_dir + '/{}_valid_gen.png'.format(valid_lr_img_list[imid+i][:-4]))
            tl.vis.save_image(valid_lr_img[i], save_dir + '/{}_valid_lr.png'.format(valid_lr_img_list[imid+i][:-4]))
            # tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')

            out_bicu = scipy.misc.imresize(valid_lr_img[i], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
            tl.vis.save_image(out_bicu, save_dir + '/{}_valid_bicubic.png'.format(valid_lr_img_list[n][:-4]))
示例#23
0
def conceval(epoch, net_g_test, t_image, sess):
    print("Intermediate Evaluating epoch...", epoch)
    interlog = open("inter_eval.txt", 'a')
    tot_psnr = 0
    tot_mse = 0
    tot_ssim = 0
    tot_res_acc = 0
    tot_hr_acc = 0
    tot_lr_acc = 0
    tot_bic_acc = 0
    res_beats_hr = 0
    res_beats_bic = 0
    res_beats_lr = 0
    res_fails = 0
    global do_ocr
    do_ocr = True
    test_set_size = 16
    test_outputs = []
    ## create folders to save result images
    save_dir = "samples/intermediate"
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"
    if (do_ocr):
        print("Evaluating with OCR")
    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))[:test_set_size]
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))[:test_set_size]
    #for i in valid_hr_img_list:
    #    print (i)
    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    num_lr_imgs = len(valid_lr_imgs)
    num_hr_imgs = len(valid_hr_imgs)
    '''print("loaded", num_lr_imgs, "LR images")
    if(mode=='multi' and num_lr_imgs != num_hr_imgs):
        print('Unequal images in LR and HR')
        return
    if(mode=='single' and (num_lr_imgs==0 or num_hr_imgs==0)):
        print('No images found')
        return
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
    '''
    ###========================== RESTORE G =============================###
    in_image = tf.placeholder('float32', [1, None, None, 3],
                              name='input_image')
    net_g_oth = SRGAN_g(in_image, is_train=False, reuse=True)
    ses2 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tf.global_variables_initializer()
    tl.files.load_and_assign_npz(sess=ses2,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g_oth)
    print("Loaded model\nProcessing images...")
    ###======================= EVALUATION =============================###
    for imid in range(num_lr_imgs):  #64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
        valid_lr_img = valid_lr_imgs[imid]
        valid_hr_img = valid_hr_imgs[imid]
        img_name = valid_lr_img_list[imid]
        #print("Processing image :\t", imid, "/\t", num_lr_imgs, "\t", img_name)
        # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
        # rescale to [-1, 1]
        # print(valid_lr_img.min(), valid_lr_img.max())
        size = valid_lr_img.shape
        if len(size) == 2:
            valid_lr_img = np.stack((valid_lr_img, ) * 3, axis=-1)
            valid_hr_img = np.stack((valid_hr_img, ) * 3, axis=-1)
            size = valid_lr_img.shape
        # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
        #print("size", size)
        valid_lr_img_res = (valid_lr_img / 127.5) - 1
        start_time = time.time()
        out = ses2.run(net_g_oth.outputs, {in_image: [valid_lr_img_res]})
        #print("took: %4.4fs" % (time.time() - start_time))
        out_uint8 = convert(out[0])
        #print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        #print("[*] save images\n")
        test_outputs.append(out_uint8)
        tl.vis.save_image(
            out_uint8, save_dir + '/' + img_name[:-4] + '_gen_' +
            format(epoch, '03d') + '.png')
        #tl.vis.save_image(valid_lr_img, save_dir + '/'+img_name[:-4]+'_lr_'+format(epoch, '03d')+'.png')
        #tl.vis.save_image(valid_hr_img, save_dir + '/'+img_name[:-4]+'_hr_'+format(epoch, '03d')+'.png')
        out_bicu = scipy.misc.imresize(valid_lr_img,
                                       [size[0] * 4, size[1] * 4],
                                       interp='bicubic',
                                       mode=None)
        #tl.vis.save_image(out_bicu, save_dir + '/'+img_name[:-4]+'_bicubic.png')
        #print(type(out_uint8), out_uint8.shape)
        #print(type(valid_hr_img), valid_hr_img.shape)
        #print(type(valid_lr_img), valid_lr_img.shape)
        img_mse, img_psnr, img_ssim = quant(out_uint8, valid_hr_img)
        #print('===')
        tot_psnr += img_psnr
        tot_mse += img_mse
        tot_ssim += img_ssim
        if (do_ocr):
            res_acc, hr_acc, lr_acc, bic_acc = ocr.getAccuracy(
                out_uint8, valid_hr_img, valid_lr_img, out_bicu, imid)
            tot_res_acc += res_acc
            tot_hr_acc += hr_acc
            tot_lr_acc += lr_acc
            tot_bic_acc += bic_acc
            if (res_acc > hr_acc):
                res_beats_hr += 1
            elif (res_acc > bic_acc):
                res_beats_bic += 1
            elif (res_acc > lr_acc):
                res_beats_lr += 1
            else:
                res_fails += 1
    if (do_ocr):
        ocrres = "Average GEN accuracy: " + str(
            tot_res_acc / num_lr_imgs)[:8] + "\nAverage HRI accuracy: " + str(
                tot_hr_acc /
                num_lr_imgs)[:8] + "\nAverage LRI accuracy: " + str(
                    tot_lr_acc /
                    num_lr_imgs)[:8] + "\nAverage BIC accuracy: " + str(
                        tot_bic_acc / num_lr_imgs)[:8] + "\n\nRES>HRI: " + str(
                            res_beats_hr) + "\nRES>BIC: " + str(
                                res_beats_bic) + "\nRES>LRI: " + str(
                                    res_beats_lr) + "\nRESFAIL: " + str(
                                        res_fails) + '\n' + '=' * 50 + '\n'
        #ocrlog.write(ocrres)
        try:
            hist = "Average PSNR: " + str(
                tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                    tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                        tot_ssim / num_lr_imgs
                    )[:8] + "\tAverage Improvement over bicubic: " + str(
                        tot_res_acc / tot_bic_acc)[:8] + '\n'
        except:
            hist = "Average PSNR: " + str(
                tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                    tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                        tot_ssim / num_lr_imgs)[:8] + '\n'
    else:
        hist = "Average PSNR: " + str(
            tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                    tot_ssim / num_lr_imgs)[:8] + '\n'
    interlog.write(hist)
    interlog.close()
    ses2.close()
    del in_image
    del net_g_oth
    del test_outputs
    print("\nAll images done\n" + hist)
    return
示例#24
0
def evaluate(mode):
    print("Evaluating...")
    history = open("eval_history.txt", "a")
    latest = open("latest_eval.txt", "w")
    ocrlog = open("ocrlog.txt", "a")
    tot_psnr = 0
    tot_mse = 0
    tot_ssim = 0
    tot_res_acc = 0
    tot_hr_acc = 0
    tot_lr_acc = 0
    tot_bic_acc = 0
    res_beats_hr = 0
    res_beats_bic = 0
    res_beats_lr = 0
    res_fails = 0
    ## create folders to save result images
    save_dir = "samples/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"
    if (do_ocr):
        print("Evaluating with OCR")
    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    valid_hr_imgs = tl.vis.read_images(valid_hr_img_list,
                                       path=config.VALID.hr_img_path,
                                       n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    num_lr_imgs = len(valid_lr_imgs)
    num_hr_imgs = len(valid_hr_imgs)
    print("loaded", num_lr_imgs, "LR images")
    if (mode == 'multi' and num_lr_imgs != num_hr_imgs):
        print('Unequal images in LR and HR')
        return
    if (mode == 'single' and (num_lr_imgs == 0 or num_hr_imgs == 0)):
        print('No images found')
        return

    ###========================== RESTORE G =============================###
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    #tf.global_variables_initializer()
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)
    print("Loaded model\nProcessing images...")
    ###======================= EVALUATION =============================###
    for imid in range(num_lr_imgs):  #64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
        valid_lr_img = valid_lr_imgs[imid]
        valid_hr_img = valid_hr_imgs[imid]
        img_name = valid_lr_img_list[imid]
        #print("Processing image :\t", imid, "/\t", num_lr_imgs, "\t", img_name)
        # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
        # rescale to [-1, 1]
        # print(valid_lr_img.min(), valid_lr_img.max())
        size = valid_lr_img.shape
        if len(size) == 2:
            valid_lr_img = np.stack((valid_lr_img, ) * 3, axis=-1)
            valid_hr_img = np.stack((valid_hr_img, ) * 3, axis=-1)
            size = valid_lr_img.shape
        # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
        #print("size", size)
        valid_lr_img_res = (valid_lr_img / 127.5) - 1
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img_res]})
        #print("took: %4.4fs" % (time.time() - start_time))
        out_uint8 = convert(out[0])
        #print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        #print("[*] save images\n")
        tl.vis.save_image(out_uint8,
                          save_dir + '/' + img_name[:-4] + '_gen.png')
        tl.vis.save_image(valid_lr_img,
                          save_dir + '/' + img_name[:-4] + '_lr.png')
        tl.vis.save_image(valid_hr_img,
                          save_dir + '/' + img_name[:-4] + '_hr.png')
        out_bicu = scipy.misc.imresize(valid_lr_img,
                                       [size[0] * 4, size[1] * 4],
                                       interp='bicubic',
                                       mode=None)
        tl.vis.save_image(out_bicu,
                          save_dir + '/' + img_name[:-4] + '_bicubic.png')
        #print(type(out[0]), out[0].shape)
        #print(type(valid_hr_img), valid_hr_img.shape)
        #print(type(valid_lr_img), valid_lr_img.shape)
        img_mse, img_psnr, img_ssim = quant(out_uint8, valid_hr_img)

        tot_psnr += img_psnr
        tot_mse += img_mse
        tot_ssim += img_ssim
        if (do_ocr):
            res_acc, hr_acc, lr_acc, bic_acc = ocr.getAccuracy(
                out_uint8, valid_hr_img, valid_lr_img, out_bicu, imid)
            tot_res_acc += res_acc
            tot_hr_acc += hr_acc
            tot_lr_acc += lr_acc
            tot_bic_acc += bic_acc
            if (res_acc > hr_acc):
                res_beats_hr += 1
            elif (res_acc > bic_acc):
                res_beats_bic += 1
            elif (res_acc > lr_acc):
                res_beats_lr += 1
            else:
                res_fails += 1
        eval_log = "Image: " + str(imid + 1) + "\tPSNR: " + str(
            img_psnr)[:8] + "\tMSE: " + str(img_mse)[:8] + "\tSSIM: " + str(
                img_ssim)[:8] + '\n'
        latest.write(eval_log)
        #print(type(valid_lr_img), type(out_bicu))
        if (mode == 'single'):
            num_lr_imgs = 1
            latest.close()
            history.close()
            print("\n1 image done\n" + eval_log)
            return
        incre = int(50.0 / num_lr_imgs * imid)
        sys.stdout.write('\r' + '|%s%s| %d/%d images done' %
                         ('\033[7m' + ' ' * incre + ' \033[27m', ' ' *
                          (49 - incre), imid + 1, num_lr_imgs))
        sys.stdout.flush()
    if (do_ocr):
        ocrres = "Average GEN accuracy: " + str(
            tot_res_acc / num_lr_imgs)[:8] + "\nAverage HRI accuracy: " + str(
                tot_hr_acc /
                num_lr_imgs)[:8] + "\nAverage LRI accuracy: " + str(
                    tot_lr_acc /
                    num_lr_imgs)[:8] + "\nAverage BIC accuracy: " + str(
                        tot_bic_acc / num_lr_imgs)[:8] + "\n\nRES>HRI: " + str(
                            res_beats_hr) + "\nRES>BIC: " + str(
                                res_beats_bic) + "\nRES>LRI: " + str(
                                    res_beats_lr) + "\nRESFAIL: " + str(
                                        res_fails) + '\n' + '=' * 50 + '\n'
        ocrlog.write(ocrres)
        try:
            hist = "Average PSNR: " + str(
                tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                    tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                        tot_ssim / num_lr_imgs
                    )[:8] + "\tAverage Improvement over bicubic: " + str(
                        tot_res_acc / tot_bic_acc)[:8] + '\n'
        except:
            hist = "Average PSNR: " + str(
                tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                    tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                        tot_ssim / num_lr_imgs)[:8] + '\n'
    else:
        hist = "Average PSNR: " + str(
            tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str(
                tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str(
                    tot_ssim / num_lr_imgs)[:8] + '\n'
    history.write(hist)
    latest.close()
    history.close()
    ocrlog.close()
    print("\nAll images done\n" + hist)
    return
示例#25
0
def evaluate(ID, save_path, lr_path):
    ## create folders to save result images
    #save_dir = "samples/{}".format(tl.global_flag['mode'])
    save_dir = save_path
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    #valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=lr_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)

    #valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)

    # for im in valid_lr_imgs:
    #     print(im.shape)

    #valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)

    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    valid_lr_imgs = []

    #sess = tf.Session()
    for img__ in valid_lr_img_list:

        image_loaded = scipy.misc.imread(os.path.join(lr_path + '/', img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))
        #sh=image_loaded.shape
        #image_loaded = imresize(image_loaded, [int(sh[0]/2), int(sh[1]/2)], interp='bicubic', mode=None)
        valid_lr_imgs.append(image_loaded)

    print(type(valid_lr_imgs), len(valid_lr_img_list))

    valid_hr_imgs = []

    #sess = tf.Session()
    #for img__ in valid_hr_img_list:

    #   location='../SRGAN8x/DATA/valid_HR_256/'+img__
    #  image_loaded = scipy.misc.imread(os.path.join(config.VALID.hr_img_path,img__), mode='L')
    # image_loaded = image_loaded.reshape((image_loaded.shape[0], image_loaded.shape[1], 1))
    #   lr_img = imresize(image_loaded, [32,256], interp='bicubic', mode=None)
    #  valid_hr_imgs.append(image_loaded)
    # valid_lr_imgs.append(lr_img)
    # print(type(valid_hr_imgs), len(valid_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    imid = ID  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    valid_lr_img = valid_lr_imgs[imid]
    #valid_hr_img = valid_hr_imgs[imid]
    # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
    # print(valid_lr_img.min(), valid_lr_img.max())

    size = valid_lr_img.shape
    #print(size)
    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size

    #t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
    t_image = tf.placeholder('float32', [1, size[0], size[1], 1],
                             name='input_image')  # 1 for 1 channel

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_48.npz', network=net_g)

    saver = tf.train.Saver()
    saver.restore(sess, 'checkpoint/main_50.ckpt')
    ###======================= EVALUATION =============================###
    start_time = time.time()
    out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
    print("took: %4.4fs" % (time.time() - start_time))

    print("LR size: %s /  generated HR size: %s" % (size, out.shape)
          )  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    #for i in range(len(out)):
    tl.vis.save_image(
        out[0], save_dir + '/' +
        valid_lr_img_list[imid])  #'/valid_gen_'+str(imid)+'.png')
    # tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_'+str(imid)+'.png')
    #tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr_'+str(imid)+'.png')
    '''
示例#26
0
def train():
    n_epoch_init = 12
    ## create folders to save result images and trained model
    # save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    # save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    # tl.files.exists_or_mkdir(save_dir_ginit)
    # tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    log_dir = "logs"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###

    train_hr_img_list = sorted(
        get_synthia_imgs_list(config.VALID.hr_img_path,
                              is_train=True,
                              synthia_dataset=config.TRAIN.hr_img_path))
    valid_hr_img_list = sorted(
        get_synthia_imgs_list(config.VALID.hr_img_path,
                              is_train=False,
                              synthia_dataset=config.TRAIN.hr_img_path))
    print(len(train_hr_img_list))
    print(len(valid_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_input = tf.placeholder(tf.float32,
                             shape=(None, None, None, 1),
                             name='t_input')
    # try with log?
    t_input = tf.log(t_input)

    d_flg = tf.placeholder(tf.bool, name='is_train')

    t_image, t_target_image, t_interpolated = preprocess(t_input)

    net_g_outputs = SRGAN_g(t_image,
                            t_interpolated,
                            is_train=d_flg,
                            reuse=False)

    net_d, logits_real = SRGAN_d(t_target_image, is_train=d_flg, reuse=False)
    _, logits_fake = SRGAN_d(net_g_outputs, is_train=d_flg, reuse=True)

    vgg_model_true = VGG16(vgg16_npy_path)
    vgg_model_gen = VGG16(vgg16_npy_path)

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    # to 3 channels
    y_true_normalized = (t_target_image - tf.reduce_min(t_target_image)) / (
        tf.reduce_max(t_target_image) - tf.reduce_min(t_target_image))
    gen_normalized = (net_g_outputs - tf.reduce_min(net_g_outputs)) / (
        tf.reduce_max(net_g_outputs) - tf.reduce_min(net_g_outputs))

    t_target_image_3ch = tf.concat([y_true_normalized] * 3, 3)
    t_predict_image_3ch = tf.concat([gen_normalized] * 3, 3)

    vgg_model_true.build(t_target_image_3ch)
    true_features = vgg_model_true.conv3_1
    vgg_model_gen.build(t_predict_image_3ch)
    gen_features = vgg_model_gen.conv3_1

    ## test inference
    net_g_test = SRGAN_g(t_image, t_interpolated, is_train=d_flg, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    # d_vgg_loss =  2e-6*tl.cost.mean_squared_error(true_features, gen_features, is_mean=True)

    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-2 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')  # 1e-3 *
    mse_loss = tl.cost.mean_squared_error(net_g_outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        true_features, gen_features, is_mean=True)  # 2e-6 *
    tv_loss = 2e-6 * tf.reduce_mean(tf.square(net_g_outputs[:, :-1, :, :] - net_g_outputs[:, 1:, :, :])) + \
              tf.reduce_mean(tf.square(net_g_outputs[:, :, :-1, :] - net_g_outputs[:, :, 1:, :]))  # 2e-6*

    g_init_loss = mse_loss + vgg_loss  # mse_loss # + vgg_loss + tv_loss
    g_loss = g_gan_loss + mse_loss + vgg_loss  # + mse_loss

    g_vars = tl.layers.get_variables_with_name('G_Depth_SR', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    glob_step_t = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
        ## Pretrain
        g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
            g_init_loss, var_list=g_vars, global_step=glob_step_t)
        ## SRGAN
        g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
            g_loss, var_list=g_vars, global_step=glob_step_t)
        d_optim = tf.train.AdamOptimizer(lr_v,
                                         beta1=beta1).minimize(d_loss,
                                                               var_list=d_vars)

    ###========================== RESTORE MODEL =============================###

    saver = tf.train.Saver(max_to_keep=5)
    saver_d = tf.train.Saver(d_vars, max_to_keep=5)
    saver_g = tf.train.Saver(g_vars, max_to_keep=5)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)

    with tf.variable_scope('summaries'):
        tf.summary.scalar('d_loss', d_loss)
        tf.summary.scalar('g_loss', g_loss)
        tf.summary.scalar('mse_loss', mse_loss)
        tf.summary.scalar('vgg_loss', vgg_loss)
        tf.summary.scalar('tv_loss', tv_loss)
        tf.summary.scalar('g_gan_loss', g_gan_loss)
        mae = tf.reduce_mean(
            tf.abs(net_g_outputs - t_target_image) /
            (t_target_image + tf.constant(1e-8)))
        rmse = tf.sqrt(
            tf.reduce_mean(tf.square(net_g_outputs - t_target_image)))
        tf.summary.scalar('MAE', mae)
        tf.summary.scalar('RMSE', rmse)
        tf.summary.scalar('learning_rate', lr_v)
        # tf.summary.image('input', t_input , max_outputs=1)
        tf.summary.image('GT', t_target_image, max_outputs=1)
        tf.summary.image('input_small_size', t_image, max_outputs=1)
        tf.summary.image('interpolated', t_interpolated, max_outputs=1)
        tf.summary.image('result', net_g_outputs, max_outputs=1)
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(log_dir + '/test')

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    # sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set

    # sample_imgs = tl.prepro.threading_data(train_hr_img_list[0:batch_size], fn=get_imgs_fn)  # if no pre-load train set
    # print('sample images:', sample_imgs.shape, sample_imgs.min(), sample_imgs.max())

    n_batches = int(len(train_hr_img_list) / batch_size)
    n_batches_valid = int(len(valid_hr_img_list) / batch_size)

    ###========================= initialize G ====================###

    if not do_init_g:
        n_epoch_init = -1
        try:
            saver_g.restore(
                sess, tf.train.latest_checkpoint(checkpoint_dir + '/g_init'))
        except Exception as e:
            print(
                ' ** You need to initialize generator: put do_init_g to True or provide a valid restore path'
            )
            raise e

    else:
        try:
            #saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/gan')) # 2 round
            saver.restore(
                sess, tf.train.latest_checkpoint(checkpoint_dir + '/g_init'))
        except:
            print(' ** Creating new g_init model')
            pass

    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)

    train_iter, test_iter = 0, 0
    for epoch in range(0, n_epoch_init + 1):
        try:
            epoch_time = time.time()

            val_mae, val_mse, val_g_loss = 0, 0, 0
            batch_it = tqdm(SynthiaIterator(valid_hr_img_list,
                                            batchsize=batch_size,
                                            shuffle=True,
                                            buffer_size=70),
                            total=n_batches_valid,
                            leave=False)
            for b in batch_it:
                xb = b[0]
                errM, errG, mae_score = sess.run([mse_loss, g_loss, mae],
                                                 feed_dict={
                                                     t_input: xb,
                                                     d_flg: False
                                                 })
                val_mae += mae_score
                val_mse += errM
                val_g_loss += errG

            print("Validation: Epoch {0} val mae {1} val mse {2}".format(
                epoch - 1, val_mae / n_batches_valid,
                val_mse / n_batches_valid))

            total_mse_loss, total_g_loss = 0, 0
            batch_it = tqdm(SynthiaIterator(train_hr_img_list,
                                            batchsize=batch_size,
                                            shuffle=True,
                                            buffer_size=70),
                            total=n_batches,
                            leave=False)
            for b in batch_it:
                xb = b[0]
                xb = augment_imgs(xb)
                glob_step, errM, errG, _ = sess.run(
                    [glob_step_t, mse_loss, g_loss, g_optim_init],
                    feed_dict={
                        t_input: xb,
                        d_flg: True
                    })

                total_mse_loss += errM
                total_g_loss += errG
                if (train_iter + 1) % 200 == 0:
                    summary = sess.run(summary_op,
                                       feed_dict={
                                           t_input: xb,
                                           d_flg: False
                                       })
                    train_writer.add_summary(summary, train_iter + 1)

                train_iter += 1

            log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
                epoch, n_epoch_init, time.time() - epoch_time,
                total_mse_loss / n_batches)

            val_mse_summary = tf.Summary.Value(tag='g_init/val_mse_loss',
                                               simple_value=val_mse /
                                               n_batches_valid)
            val_g_loss_summary = tf.Summary.Value(tag='g_init/val_loss',
                                                  simple_value=val_g_loss /
                                                  n_batches_valid)

            train_mse_loss_summary = tf.Summary.Value(
                tag='g_init/train_mse_loss',
                simple_value=total_mse_loss / n_batches)
            train_g_loss_summary = tf.Summary.Value(tag='g_init/train_loss',
                                                    simple_value=total_g_loss /
                                                    n_batches)

            epoch_summary = tf.Summary(value=[
                val_mse_summary, val_g_loss_summary, train_mse_loss_summary,
                train_g_loss_summary
            ])

            train_writer.add_summary(epoch_summary, glob_step)

            print(log)
            saver.save(
                sess,
                os.path.join(checkpoint_dir + '/g_init',
                             'model' + str(epoch) + '.ckpt'))

        except Exception as e:
            batch_it.iterable.stop()
            raise e

    ###========================= train GAN (SRGAN) =========================###
    try:
        # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/g_init'))
        # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/gan'))
        pass
    except:
        print(' ** Creating new GAN model')
        pass

    train_iter, test_iter = 0, 0
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        try:
            epoch_time = time.time()

            val_mae, val_mse, val_g_loss, val_d_loss = 0, 0, 0, 0
            batch_it = tqdm(SynthiaIterator(valid_hr_img_list,
                                            batchsize=batch_size,
                                            shuffle=True,
                                            buffer_size=70),
                            total=n_batches_valid,
                            leave=False)
            for b in batch_it:
                xb = b[0]
                errM, mae_score, errG, errD = sess.run(
                    [mse_loss, mae, g_loss, d_loss],
                    feed_dict={
                        t_input: xb,
                        d_flg: False
                    })
                val_mae += mae_score
                val_mse += errM
                val_g_loss += errG
                val_d_loss += errD

            print("Validation (GAN): Epoch {0} val mae {1} val mse {2}".format(
                epoch - 1, val_mae / n_batches_valid,
                val_mse / n_batches_valid))

            total_d_loss, total_g_loss, total_mse_loss = 0, 0, 0
            batch_it = tqdm(SynthiaIterator(train_hr_img_list,
                                            batchsize=batch_size,
                                            shuffle=True,
                                            buffer_size=70),
                            total=n_batches,
                            leave=False)
            for b in batch_it:
                xb = b[0]
                xb = augment_imgs(xb)
                ## update D
                errD, _ = sess.run([d_loss, d_optim], {
                    t_input: xb,
                    d_flg: True
                })
                ## update G
                glob_step, errG, errM, _, summary = sess.run(
                    [glob_step_t, g_loss, mse_loss, g_optim, summary_op], {
                        t_input: xb,
                        d_flg: True
                    })
                total_mse_loss += errM
                total_d_loss += errD
                total_g_loss += errG
                if (train_iter + 1) % 10 == 0:
                    train_writer.add_summary(summary, train_iter + 1)

                train_iter += 1

        except Exception as e:
            batch_it.iterable.stop()
            raise e
            break

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f mse_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_batches,
            total_g_loss / n_batches, total_mse_loss / n_batches)

        val_mse_summary = tf.Summary.Value(tag='gan/val_mse_loss',
                                           simple_value=val_mse /
                                           n_batches_valid)
        val_g_loss_summary = tf.Summary.Value(tag='gan/val_g_loss',
                                              simple_value=val_g_loss /
                                              n_batches_valid)
        val_d_loss_summary = tf.Summary.Value(tag='gan/val_d_loss',
                                              simple_value=val_d_loss /
                                              n_batches_valid)

        train_mse_loss_summary = tf.Summary.Value(tag='gan/train_mse_loss',
                                                  simple_value=total_mse_loss /
                                                  n_batches)
        train_g_loss_summary = tf.Summary.Value(tag='gan/train_g_loss',
                                                simple_value=total_g_loss /
                                                n_batches)
        train_d_loss_summary = tf.Summary.Value(tag='gan/train_d_loss',
                                                simple_value=total_d_loss /
                                                n_batches)

        epoch_summary = tf.Summary(value=[
            val_mse_summary, val_g_loss_summary, val_d_loss_summary,
            train_mse_loss_summary, train_g_loss_summary, train_d_loss_summary
        ])

        train_writer.add_summary(epoch_summary, glob_step)

        print(log)
        saver.save(
            sess,
            os.path.join(checkpoint_dir + '/gan',
                         'model' + str(n_epoch_init + epoch) + '.ckpt'))
示例#27
0
文件: main.py 项目: TowardSun/srgan
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0,
        align_corners=False)  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            ## update G
            errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %
                  (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
                                                                                total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess)
            tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
示例#28
0
def train(train_lr_imgs, train_hr_imgs):
    ## create folders to save result images and trained model
    checkpoint_dir = "models_checkpoints"
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder(dtype='float32',
                             shape=(batch_size, 512, 512, 1),
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder(dtype='float32',
                                    shape=(batch_size, 512, 512, 1),
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg
    net_vgg, vgg_target_emb = Vgg19_simple_api(input=(t_target_image_224 + 1) /
                                               2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api(input=(t_predict_image_224 + 1) / 2,
                                          reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is False:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/d_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        if val[0] == 'conv1_1':
            W = np.mean(W, axis=2)
            W = W.reshape((3, 3, 1, 64))
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])

    tl.files.assign_params(sess, params, net_vgg)

    ###============================= TRAINING ===============================###

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    start_time = time.time()
    for epoch in range(0, n_epoch_init):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        step_time = None
        for idx in range(0, len(train_hr_imgs), batch_size):
            if idx % 1000 == 0: step_time = time.time()
            b_imgs_hr = train_hr_imgs[idx:idx + batch_size]
            b_imgs_lr = train_lr_imgs[idx:idx + batch_size]
            b_imgs_hr = np.asarray(b_imgs_hr).reshape(
                (batch_size, 512, 512, 1))
            b_imgs_lr = np.asarray(b_imgs_lr).reshape(
                (batch_size, 512, 512, 1))

            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_lr,
                t_target_image: b_imgs_hr
            })

            if idx % 1000 == 0:
                print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                      (epoch, n_epoch_init, n_iter, time.time() - step_time,
                       errM))
                tl.files.save_npz(
                    net_g.all_params,
                    name=checkpoint_dir +
                    '/g_{}_init.npz'.format(tl.global_flag['mode']),
                    sess=sess)

            total_mse_loss += errM
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_dir +
                          '/g_{}_init.npz'.format(tl.global_flag['mode']),
                          sess=sess)
    print("G init took: %4.4fs" % (time.time() - start_time))

    ###========================= train GAN (SRGAN) =========================###
    start_time = time.time()
    epoch_losses = defaultdict(list)
    iter_losses = defaultdict(list)

    for epoch in range(0, n_epoch):
        ## update learning rate
        if epoch != 0 and decay_every != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        step_time = None
        for idx in range(0, len(train_hr_imgs), batch_size):
            if idx % 1000 == 0: step_time = time.time()
            b_imgs_hr = train_hr_imgs[idx:idx + batch_size]
            b_imgs_lr = train_lr_imgs[idx:idx + batch_size]
            b_imgs_hr = np.asarray(b_imgs_hr).reshape(
                (batch_size, 512, 512, 1))
            b_imgs_lr = np.asarray(b_imgs_lr).reshape(
                (batch_size, 512, 512, 1))

            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_lr,
                t_target_image: b_imgs_hr
            })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_lr,
                    t_target_image: b_imgs_hr
                })

            if idx % 1000 == 0:
                print(
                    "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                    % (epoch, n_epoch, n_iter, time.time() - step_time, errD,
                       errG, errM, errV, errA))
                tl.files.save_npz(net_g.all_params,
                                  name=checkpoint_dir +
                                  '/g_{}.npz'.format(tl.global_flag['mode']),
                                  sess=sess)
                tl.files.save_npz(net_d.all_params,
                                  name=checkpoint_dir +
                                  '/d_{}.npz'.format(tl.global_flag['mode']),
                                  sess=sess)

            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

            iter_losses['d_loss'].append(errD)
            iter_losses['g_loss'].append(errG)
            iter_losses['mse_loss'].append(errM)
            iter_losses['vgg_loss'].append(errV)
            iter_losses['adv_loss'].append(errA)

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)
        epoch_losses['d_loss'].append(total_d_loss)
        epoch_losses['g_loss'].append(total_g_loss)

        ## save model
        tl.files.save_npz(net_g.all_params,
                          name=checkpoint_dir +
                          '/g_{}.npz'.format(tl.global_flag['mode']),
                          sess=sess)
        tl.files.save_npz(net_d.all_params,
                          name=checkpoint_dir +
                          '/d_{}.npz'.format(tl.global_flag['mode']),
                          sess=sess)
    print("G train took: %4.4fs" % (time.time() - start_time))

    ## create visualizations for losses from training
    plot_total_losses(epoch_losses)
    plot_iterative_losses(iter_losses)
    for loss, values in epoch_losses.items():
        np.save(checkpoint_dir + "/epoch_" + loss + '.npy', np.asarray(values))
    for loss, values in iter_losses.items():
        np.save(checkpoint_dir + "/iter_" + loss + '.npy', np.asarray(values))
    print("[*] saved losses")
示例#29
0
def evaluate():
    ## create folders to save result images
    save_dir = "samples_btcv2/{}".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = "checkpoint"

    ###====================== PRE-LOAD DATA ===========================###
    # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    valid_lr_imgs = tl.vis.read_images(valid_lr_img_list,
                                       path=config.VALID.lr_img_path,
                                       n_threads=32)
    for ii in range(0, 10):
        valid_lr_imgs[ii] = cv2.cvtColor(valid_lr_imgs[ii], cv2.COLOR_GRAY2RGB)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    #imid = 0  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
    # valid_lr_img = valid_lr_imgs[imid]
    # valid_hr_img = valid_hr_imgs[imid]
    # valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
    # valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
    # print(valid_lr_img.min(), valid_lr_img.max())

    # size = valid_lr_img.shape
    # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
    t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    # if imid == 0:
    #     net_g = SRGAN_g(t_image, is_train=False, reuse=False)
    # else:
    #     net_g = SRGAN_g(t_image, is_train=False, reuse=True)

    ###========================== RESTORE G =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir + '/g_srgan.npz',
                                 network=net_g)

    test_writer = tf.summary.FileWriter('logs/test/', sess.graph)
    merged = tf.summary.merge_all()

    ###======================= EVALUATION =============================###
    for ii in range(0, 10):
        valid_lr_img = valid_lr_imgs[ii]
        valid_lr_img = (valid_lr_img / 127.5) - 1
        start_time = time.time()
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
        print("took: %4.4fs" % (time.time() - start_time))

        # print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir + '/valid_gen_%d.png' % ii)
        tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_%d.png' % ii)
示例#30
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list,
                                       path=config.TRAIN.hr_img_path,
                                       n_threads=32)
    # for im in train_hr_imgs:
    #     print(im.shape)
    # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
    # for im in valid_lr_imgs:
    #     print(im.shape)
    # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
    # for im in valid_hr_imgs:
    #     print(im.shape)
    # exit()

    ###========================== DEFINE MODEL ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3],
                                    name='t_target_image')

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(
        mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']),
            network=net_g) is None:
        tl.files.load_and_assign_npz(
            sess=sess,
            name=checkpoint_dir +
            '/g_{}_init.npz'.format(tl.global_flag['mode']),
            network=net_g)
    tl.files.load_and_assign_npz(sess=sess,
                                 name=checkpoint_dir +
                                 '/d_{}.npz'.format(tl.global_flag['mode']),
                                 network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_ginit + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_ginit + '/_train_sample_384.png')
    tl.vis.save_images(sample_imgs_96, [ni, ni],
                       save_dir_gan + '/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni],
                       save_dir_gan + '/_train_sample_384.png')

    ###========================= initialize G ====================###
    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update G
            errM, _ = sess.run([mse_loss, g_optim_init], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " %
                  (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (
            epoch, n_epoch_init, time.time() - epoch_time,
            total_mse_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_ginit + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}_init.npz'.format(tl.global_flag['mode']),
                              sess=sess)

    ###========================= train GAN (SRGAN) =========================###
    for epoch in range(0, n_epoch + 1):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
            ## update D
            errD, _ = sess.run([d_loss, d_optim], {
                t_image: b_imgs_96,
                t_target_image: b_imgs_384
            })
            ## update G
            errG, errM, errV, errA, _ = sess.run(
                [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG,
                   errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        ## quick evaluation on train set
        if (epoch != 0) and (epoch % 10 == 0):
            out = sess.run(net_g_test.outputs, {
                t_image: sample_imgs_96
            })  #; print('gen sub-image:', out.shape, out.min(), out.max())
            print("[*] save images")
            tl.vis.save_images(out, [ni, ni],
                               save_dir_gan + '/train_%d.png' % epoch)

        ## save model
        if (epoch != 0) and (epoch % 10 == 0):
            tl.files.save_npz(net_g.all_params,
                              name=checkpoint_dir +
                              '/g_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
            tl.files.save_npz(net_d.all_params,
                              name=checkpoint_dir +
                              '/d_{}.npz'.format(tl.global_flag['mode']),
                              sess=sess)
示例#31
0
def train():
    ## create folders to save result images and trained model
    save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    ###====================== PRE-LOAD DATA ===========================###
    train_hr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.hr_img_path,
                                regx='.*.png',
                                printable=False))
    train_lr_img_list = sorted(
        tl.files.load_file_list(path=config.TRAIN.lr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_hr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.hr_img_path,
                                regx='.*.png',
                                printable=False))
    valid_lr_img_list = sorted(
        tl.files.load_file_list(path=config.VALID.lr_img_path,
                                regx='.*.png',
                                printable=False))

    ## If your machine have enough memory, please pre-load the whole train set.

    print("reading images")
    train_hr_imgs = []  #[None] * len(train_hr_img_list)

    #sess = tf.Session()
    for img__ in train_hr_img_list:

        image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path,
                                                      img__),
                                         mode='L')
        image_loaded = image_loaded.reshape(
            (image_loaded.shape[0], image_loaded.shape[1], 1))

        train_hr_imgs.append(image_loaded)

    print(type(train_hr_imgs), len(train_hr_img_list))

    ###========================== DEFINE MODEL ============================###
    ## train inference

    #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

    t_image = tf.placeholder('float32', [batch_size, 28, 224, 1],
                             name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder(
        'float32', [batch_size, 224, 224, 1], name='t_target_image'
    )  # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ?

    print("t_image:", tf.shape(t_image))
    print("t_target_image:", tf.shape(t_target_image))

    net_g = SRGAN_g(t_image, is_train=True, reuse=False)
    print("net_g.outputs:", tf.shape(net_g.outputs))

    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

    net_g.print_params(False)
    net_g.print_layers()
    net_d.print_params(False)
    net_d.print_layers()

    ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, align_corners=False
    )  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
    t_predict_image_224 = tf.image.resize_images(
        net_g.outputs, size=[224, 224], method=0,
        align_corners=False)  # resize_generate_image_for_vgg

    ## Added as VGG works for RGB and expects 3 channels.
    t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224)
    t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224)

    print("net_g.outputs:", tf.shape(net_g.outputs))
    print("t_predict_image_224:", tf.shape(t_predict_image_224))

    net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2,
                                               reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2,
                                          reuse=True)

    ## test inference
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###
    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real,
                                            tf.ones_like(logits_real),
                                            name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake,
                                            tf.zeros_like(logits_fake),
                                            name='d2')
    d_loss = d_loss1 + d_loss2

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(
        logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs,
                                          t_target_image,
                                          is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(
        vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

    d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss',
                                        d_loss1)
    d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss',
                                        d_loss2)
    d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss)

    g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss)
    mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss)
    vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss)
    g_loss_summary = tf.summary.scalar('Generator total loss', g_loss)

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    #	UNCOMMENT THE LINE BELOW!!!
    #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
    ## SRGAN
    g_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(g_loss,
                                                           var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v,
                                     beta1=beta1).minimize(d_loss,
                                                           var_list=d_vars)

    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
    #   tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
    #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

    ###============================= LOAD VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print(
            "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg"
        )
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()

    params = []
    for val in sorted(npz.items()):
        W = np.asarray(val[1][0])
        b = np.asarray(val[1][1])
        print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
        params.extend([W, b])
    tl.files.assign_params(sess, params, net_vgg)
    # net_vgg.print_params(False)
    # net_vgg.print_layers()

    ###============================= TRAINING ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_imgs = train_hr_imgs[0:batch_size]
    # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set

    print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape)

    sample_imgs_384 = tl.prepro.threading_data(sample_imgs,
                                               fn=crop_sub_imgs_fn,
                                               is_random=False)
    print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(),
          sample_imgs_384.max())
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,
                                              fn=downsample_fn_mod)
    print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(),
          sample_imgs_96.max())
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
    #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
    #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')
    '''
    ###========================= initialize G ====================###
    
    merged_summary_initial_G = tf.summary.merge([mse_loss_summary])
    summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G")
    
    

    ## fixed learning rate
    sess.run(tf.assign(lr_v, lr_init))
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    count = 0
    for epoch in range(0, n_epoch_init + 1):
        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.
        
        
        intial_MSE_G_summary_per_epoch = []
        
        
        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod)
            ## update G
            errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))

            
            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(mse_summary_initial_G)
            
            intial_G_summaries = {}
            for val in summary_pb.value:
            # Assuming all summaries are scalars.
                intial_G_summaries[val.tag] = val.simple_value
            #print("intial_G_summaries:", intial_G_summaries)
            
            
            intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss'])
            
            
            #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1))
            #count += 1


            total_mse_loss += errM
            n_iter += 1
        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
        print(log)

        
        summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch))


        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        for im in range(len(out)):
            if(im%4==0 or im==1197):
                tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im))

        ## save model
        saver=tf.train.Saver()
        if (epoch%10==0 and epoch!=0):
            saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt')      

   #if (epoch != 0) and (epoch % 10 == 0):
        #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
    '''
    ###========================= train GAN (SRGAN) =========================###
    saver = tf.train.Saver()
    saver.restore(sess, 'checkpoint/main_10.ckpt')
    print('Restored main_10, begin 11/50')
    merged_summary_discriminator = tf.summary.merge(
        [d_loss1_summary, d_loss2_summary, d_loss_summary])
    summary_discriminator_writer = tf.summary.FileWriter(
        "./log/train/discriminator")

    merged_summary_generator = tf.summary.merge([
        g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary
    ])
    summary_generator_writer = tf.summary.FileWriter("./log/train/generator")

    learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate")

    count = 0
    for epoch in range(11, n_epoch + 11):
        ## update learning rate
        if epoch != 0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init *
                                                           new_lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=(lr_init * new_lr_decay)),
                ]), (epoch))

        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (
                lr_init, decay_every, lr_decay)
            print(log)

            learning_rate_writer.add_summary(
                tf.Summary(value=[
                    tf.Summary.Value(tag="Learning_rate per epoch",
                                     simple_value=lr_init),
                ]), (epoch))

        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

        ## If your machine cannot load all images into memory, you should use
        ## this one to load batch of images while training.
        # random.shuffle(train_hr_img_list)
        # for idx in range(0, len(train_hr_img_list), batch_size):
        #     step_time = time.time()
        #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

        ## If your machine have enough memory, please pre-load the whole train set.

        loss_per_batch = []

        d_loss1_summary_per_epoch = []
        d_loss2_summary_per_epoch = []
        d_loss_summary_per_epoch = []

        g_gan_loss_summary_per_epoch = []
        mse_loss_summary_per_epoch = []
        vgg_loss_summary_per_epoch = []
        g_loss_summary_per_epoch = []

        for idx in range(0, len(train_hr_imgs), batch_size):
            step_time = time.time()
            b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx +
                                                                batch_size],
                                                  fn=crop_sub_imgs_fn,
                                                  is_random=True)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_384,
                                                 fn=downsample_fn_mod)
            ## update D
            errD, _, discriminator_summary = sess.run(
                [d_loss, d_optim, merged_summary_discriminator], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(discriminator_summary)
            #print("discriminator_summary", summary_pb, type(summary_pb))

            discriminator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                discriminator_summaries[val.tag] = val.simple_value

            d_loss1_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_real_loss'])
            d_loss2_summary_per_epoch.append(
                discriminator_summaries['Disciminator_logits_fake_loss'])
            d_loss_summary_per_epoch.append(
                discriminator_summaries['Disciminator_total_loss'])

            ## update G
            errG, errM, errV, errA, _, generator_summary = sess.run(
                [
                    g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim,
                    merged_summary_generator
                ], {
                    t_image: b_imgs_96,
                    t_target_image: b_imgs_384
                })

            summary_pb = tf.summary.Summary()
            summary_pb.ParseFromString(generator_summary)
            #print("generator_summary", summary_pb, type(summary_pb))

            generator_summaries = {}
            for val in summary_pb.value:
                # Assuming all summaries are scalars.
                generator_summaries[val.tag] = val.simple_value

            #print("generator_summaries:", generator_summaries)

            g_gan_loss_summary_per_epoch.append(
                generator_summaries['Generator_GAN_loss'])
            mse_loss_summary_per_epoch.append(
                generator_summaries['Generator_MSE_loss'])
            vgg_loss_summary_per_epoch.append(
                generator_summaries['Generator_VGG_loss'])
            g_loss_summary_per_epoch.append(
                generator_summaries['Generator_total_loss'])

            #summary_generator_writer.add_summary(generator_summary, (count + 1))

            #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384})
            #summary_total_merged_writer.add_summary(summary_total, (count + 1))

            #count += 1

            tot_epoch = n_epoch + 10
            print(
                "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)"
                % (epoch, tot_epoch, n_iter, time.time() - step_time, errD,
                   errG, errM, errV, errA))
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1
            #remove this for normal running:

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (
            epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter,
            total_g_loss / n_iter)
        print(log)

        #####
        #
        # logging discriminator summary
        #
        ######

        # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss1_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss2_summary_per_epoch)),
            ]), (epoch))

        # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch.
        summary_discriminator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Disciminator_total_loss per epoch",
                                 simple_value=np.mean(
                                     d_loss_summary_per_epoch)),
            ]), (epoch))

        #####
        #
        # logging generator summary
        #
        ######

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_GAN_loss per epoch",
                                 simple_value=np.mean(
                                     g_gan_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_MSE_loss per epoch",
                                 simple_value=np.mean(
                                     mse_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_VGG_loss per epoch",
                                 simple_value=np.mean(
                                     vgg_loss_summary_per_epoch)),
            ]), (epoch))

        summary_generator_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="Generator_total_loss per epoch",
                                 simple_value=np.mean(
                                     g_loss_summary_per_epoch)),
            ]), (epoch))

        ## quick evaluation on train set
        #if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(
            net_g_test.outputs,
            {t_image: sample_imgs_96
             })  #; print('gen sub-image:', out.shape, out.min(), out.max())
        ## save model
        if (epoch % 10 == 0 and epoch != 0):
            saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt')

            print("[*] save images")
            for im in range(len(out)):
                tl.vis.save_image(
                    out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))