def g_init_model(features, labels, mode, params): del params if mode == tf.estimator.ModeKeys.PREDICT: net_g_test = SRGAN_g(features, is_train=False) predictions = {'generated_images': net_g_test.outputs} return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions) net_g = SRGAN_g(features, is_train=True) _ = SRGAN_d(labels, is_train=True) mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False) g_optimizer = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer) init_ops = g_optimizer.minimize(mse_loss, var_list=g_vars, global_step=tf.train.get_global_step()) return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=mse_loss, train_op=init_ops)
def srgan_model(features, labels, mode, params): del params global load_flag if mode == tf.estimator.ModeKeys.PREDICT: net_g_test = SRGAN_g(features, is_train=False) predictions = {'generated_images': net_g_test.outputs} return tf.estimator.EstimatorSpec(mode, predictions=predictions) net_g = SRGAN_g(features, is_train=True) net_d, logits_real = SRGAN_d(labels, is_train=True) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True) t_target_image_224 = tf.image.resize_images(labels, size=[224, 224], method=0, align_corners=False) t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2) d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False) # SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \ .minimize(g_loss, var_list=g_vars, global_step=tf.train.get_global_step()) d_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \ .minimize(d_loss, var_list=d_vars, global_step=tf.train.get_global_step()) joint_op = tf.group([g_optim, d_optim]) load_vgg(net_vgg) return tf.estimator.EstimatorSpec(mode, loss=g_loss, train_op=joint_op)
def __init__(self, args, global_net, vgg_net, train_dataloader, test_dataloader): BasicTask.__init__(self) self.args = args # Datasets self.train_dataloader = train_dataloader self.test_dataloader = test_dataloader # Network self.global_net = global_net self.net_g = SRGAN_g() self.vgg_net = vgg_net self.net_g.load_state_dict(self.global_net.state_dict()) if args.cuda: self.global_net = self.global_net.cuda() self.net_g = self.net_g.cuda() self.vgg_net = self.vgg_net.cuda() # Optimizer self.opt = torch.optim.Adam(self.net_g.parameters(), lr=args.lr, weight_decay=1e-6) self.global_opt = torch.optim.Adam(self.global_net.parameters(), lr=args.lr, weight_decay=1e-6) # loss functions self.loss_items = { 'vgg': { 'func': utils.ContentLoss(self.vgg_net), 'factor': 2e-6 }, 'mse': { 'func': F.mse_loss, 'factor': 1.0 } } # Summary self.writer = SummaryWriter(args.o_dir) # if args.resume: # self.net_g.load_state_dict(torch.load('../weights/SR_epoch09.pth')) # RL param self.max_episode_steps = 100 self.action_dim = 1 self.state_dim = 1 self.env_step = 0
def evaluate(args): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) valid_lr_img = scipy.misc.imread(args.input, mode='RGB') #valid_lr_img = tl.vis.read_image(os.path.basename(args.input), os.path.dirname(args.input)) ###========================== DEFINE MODEL ============================### valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] size = valid_lr_img.shape print("Inpu image size: " + str(size) ) # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name='checkpoint/g_srgan.npz', network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) tl.vis.save_image(out[0], args.output)
def evaluate(): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###========================== DEFINE MODEL ============================### t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) sess = tf.Session(config=tf.ConfigProto(device_count={'GPU':0},allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) for i in [5,29,35,62,78,83,150,192,258,289,310,]: valid_lr_img = get_imgs_fn(str(i)+'.png', '../..//model3/srgan/data2017/LR/') valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] size = valid_lr_img.shape ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/valid_gen_8k_'+str(i)+'.png') out_4 = scipy.misc.imresize(out[0], [2160, 3840], mode=None) tl.vis.save_image(out_4, save_dir + '/valid_gen_4k_'+str(i)+'.png') tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_'+str(i)+'.png') out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 2, size[1] * 2], interp='bicubic', mode=None) tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic_'+str(i)+'.png')
def evaluate(): ## create folders to save result images save_dir = "/local/scratch/jz426/superResolution/results/samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "/local/scratch/jz426/superResolution/results/checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) ###======================= EVALUATION =============================### for imid in range(len(valid_lr_imgs)): # imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img = valid_lr_imgs[imid] valid_hr_img = valid_hr_imgs[imid] # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/' + str(imid) + 'valid_gen.png') tl.vis.save_image(valid_lr_img, save_dir + '/' + str(imid) + 'valid_lr.png') tl.vis.save_image(valid_hr_img, save_dir + '/' + str(imid) + 'valid_hr.png') out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None) tl.vis.save_image(out_bicu, save_dir + '/' + str(imid) + 'valid_bicubic.png')
def upscale_function(image, model_checkpoint, reuse=False): ## create folders to save result images ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 # valid_lr_img = valid_lr_imgs[imid] # valid_hr_img = valid_hr_imgs[imid] # image_name = '.'.join(os.path.basename(args.image_path).split('.')[:-1]) # print('reuse = ', reuse) valid_lr_img = (image / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=reuse) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=model_checkpoint, network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) # print("took: %4.4fs" % (time.time() - start_time)) # # print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) # print("[*] save images") output = out[0] sess.close() tf.reset_default_graph() return output
def evaluate(): ## create folders to save result images save_dir = "samples/valid" tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA and SAVE SAMPLEs ===========================### valid_hr_imgs = read_csv_data(config.VALID.hr_img_path, width=48, height=48, channel=1) sample_hr_imgs = valid_hr_imgs[10:19] sample_hr_imgs = tl.prepro.threading_data(sample_hr_imgs, fn=crop_sub_imgs_fn, is_random=False) tl.vis.save_images(sample_hr_imgs, [ni, ni], save_dir + '/_hr_sample.png') sample_lr_imgs = tl.prepro.threading_data(sample_hr_imgs, fn=downsample_fn, down_rate=3) sample_bicubuc_imgs = tl.prepro.threading_data(sample_lr_imgs, fn=upsample_fn, up_rate=3) tl.vis.save_images(sample_bicubuc_imgs, [ni, ni], save_dir + '/_bicubic_sample.png') single_hr_img = crop_sub_imgs_fn(valid_hr_imgs[1], is_random=False) tl.vis.save_image(single_hr_img, save_dir + '/_hr.png') single_lr_img = downsample_fn(single_hr_img, down_rate=3) single_bicubic_img = upsample_fn(single_lr_img, up_rate=3) tl.vis.save_image(single_bicubic_img, save_dir + '/_bicubic.png') ###========================== DEFINE MODEL ============================### t_image = tf.placeholder('float32', [None, 16, 16, 1], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=config.MODEL_path, network=net_g) ###======================= EVALUATION =============================### start_time = time.time() single_lr_img = np.expand_dims(single_lr_img, axis=0) out = sess.run(net_g.outputs, {t_image: single_lr_img}) out = np.squeeze(out, axis=0) print("took: %4.4fs" % (time.time() - start_time)) tl.vis.save_image(out, save_dir + '/_srgan.png') out1 = sess.run(net_g.outputs, {t_image: sample_lr_imgs}) tl.vis.save_images(out1, [ni, ni], save_dir + '/_srgan_samples.png')
def predict(test_lr_path, checkpoint_path, save_path): ''' Parameters: data: test_lr_path: path of test data checkpoint_path: where to fetch weights save_path: where to save output ''' ## create folders to save result images save_dir = os.path.join(save_path, 'test_gen') tl.files.exists_or_mkdir(save_dir) ###======PRE-LOAD DATA======### test_lr_img_list = sorted( tl.files.load_file_list(path=test_lr_path, regx='.*.jpg', printable=False)) test_lr_imgs = tl.vis.read_images(test_lr_img_list, path=test_lr_path, n_threads=32) ###======DEFINE MODEL======### test_lr_imgs = [(img / 127.5) - 1 for img in test_lr_imgs] # rescale to [-1, 1] test_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(test_image, is_train=False, reuse=False) ###======RESTORE G======### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.global_variables_initializer()) tl.files.load_and_assign_npz(sess=sess, name=os.path.join(checkpoint_path, 'g_srgan.npz'), network=net_g) ###======EVALUATION======### start_time = time.time() for i in range(len(test_lr_img_list)): img = test_lr_imgs[i] out = sess.run(net_g.outputs, {test_image: [img]}) out = (out[0] + 1) * 127.5 tl.vis.save_image( out.astype(np.uint8), os.path.join(save_dir, '{}'.format(test_lr_img_list[i]))) if (i != 0) and (i % 10 == 0): print('saving %d images, ok' % i) print('take: %4.2fs' % (time.time() - start_time))
def upsample_images(pth_src, pth_dst, pth_checkpoint): import numpy as np import scipy import tensorlayer as tl import tensorflow as tf from model import SRGAN_g print("==== UPSAMPLING IMAGES") ## create folders to save result images tl.files.exists_or_mkdir(pth_dst) t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) # Restore Generator sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=pth_checkpoint, network=net_g) print("loaded srgan model from {}".format(pth_checkpoint)) valid_lr_img_list = sorted( tl.files.load_file_list(path=pth_src, regx='.*.(jpg|png)', printable=False)) #valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=pth_src, n_threads=32) print("found {} valid images to upscale...".format(len(valid_lr_img_list))) for n, fname in enumerate(valid_lr_img_list): bname = os.path.splitext(os.path.join(pth_src, fname))[0] img_src = scipy.misc.imread(os.path.join(pth_src, fname), mode='RGB') img_src = (img_src / 127.5) - 1 # rescale to [-1, 1] size = img_src.shape if size[0] > MAX_SIZE or size[1] > MAX_SIZE: print("Image is too big ({}x{}). Skipping.".format( size[0], size[1])) continue # Evaluate start_time = time.time() img_dst = sess.run(net_g.outputs, {t_image: [img_src]}) img_dst = ((img_dst + 1) / 2.0) * 255 # rescale to [0,255] img_dst = img_dst.astype( np.uint8) # convert to unsigned int for saving to image print( "{} of {}\tUpsampling {} from {}x{} to {}x{} took {:.2f}s".format( n, len(valid_lr_img_list), fname, size[0], size[1], img_dst.shape[1], img_dst.shape[2], time.time() - start_time)) tl.vis.save_image(img_dst[0], os.path.join(pth_dst, fname))
def evaluate(): ## create folders to save result images save_dir = samples_path + "evaluate" tl.files.exists_or_mkdir(save_dir) ###========================== DEFINE MODEL ============================### eval_img_name_list = load_deep_file_list(path=eval_img_path, regx=eval_img_name_regx, recursive=False, printable=False) print(eval_img_name_list) valid_lr_img = get_imgs_fn( eval_img_name_list[0], eval_img_path) # if you want to test your own image valid_lr_img = rescale_m1p1(valid_lr_img) size = valid_lr_img.shape t_image = tf.compat.v1.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.variables_initializer(tf.global_variables())) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_path + 'g.npz', network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) print("[*] save images") out = (out + 1) * 127.5 # rescale to [0, 255] out_uint8 = out.astype('uint8') save_img_fn(out_uint8[0], save_file_format, save_dir + '/valid_gen') out_bicu = (valid_lr_img + 1) * 127.5 # rescale to [0, 255] out_bicu = np.array( Image.fromarray(np.uint8(out_bicu)).resize((size[1] * 4, size[0] * 4), Image.BICUBIC)) out_bicu_uint8 = out_bicu.astype('uint8') save_img_fn(out_bicu_uint8, save_file_format, save_dir + '/valid_bicubic')
def super_resolution_image(image_path): filename = ntpath.basename(image_path) filepath = image_path.replace(filename, "") scriptpath = os.path.dirname(os.path.realpath(__file__)) temppath = '{0}\\temp'.format(scriptpath) tempname = next(tempfile._get_candidate_names()) + ".jpg" output_filename = '{0}\\output_images\\{1}'.format(scriptpath, tempname) ###========================== DEFINE MODEL ============================### valid_lr_img = get_imgs_fn(filename, filepath) # if you want to test your own image valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name='{0}\\g_srgan.npz'.format(scriptpath), network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("[*] save images") tl.vis.save_image(out[0], output_filename) send_to_ps(output_filename) # clear temp folder ([os.remove(os.path.join(temppath, f)) for f in os.listdir(temppath)])
def evaluate(args): ## create folders to save result images save_dir = args.output_dir tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" valid_lr_img = get_imgs_fn( "frame_0001.ppm", args.input_dir) # if you want to test your own image valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True #tf_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess = tf.Session(config=tf_config) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) ###======================= EVALUATION =============================### for i in range(1, 73): valid_lr_img = get_imgs_fn("frame_%04d" % i + ".ppm", args.input_dir) valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape) ) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images %d" % i) tl.vis.save_image(out[0], save_dir + "frame_%04d" % i + ".png")
def visualize(epoch): #checkpoint_dir = "checkpoint" checkpoint_dir = "/Users/btopiwala/Downloads/CS231N/2018/Project/gcloud-run-all-data/checkpoint_epoch_20_epoch_48_with_intermediate_checkpoint/checkpoint" ###========================== DEFINE MODEL ============================### #t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') t_image = tf.placeholder('float32', [1, None, None, 1], name='input_image') # 1 for 1 channel net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_{}.npz'.format(epoch), network=net_g)
def evaluate(): # create folders to save result images checkpoint_dir = "/home/fan/su/remove_face/checkpoint/facenet_pgd_joint_loss/" valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # 定义模型 t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) for epoch_n in range(200, 250, 150): save_dir = "samples/evaluate/" + str(epoch_n) tl.files.exists_or_mkdir(save_dir) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_softmax%d.npz' % epoch_n, network=net_g) # 设置输入样本级数量 for imid in range(1135): valid_lr_img = valid_lr_imgs[imid] print(valid_lr_img) valid_lr_img = (valid_lr_img / 127.5) - 1 # 归一化到[-1, 1] # 开始评估 123 start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("[*] save images") print(out) tl.vis.save_image(out[0], save_dir + '/%s' % (valid_lr_img_list[imid]))
def export_model(): """Load the model in TensorLayer's way and save the frozen graph Args: None Returns: None """ # create folders to save result images checkpoint_dir = "checkpoint" ###========================== DEFINE MODEL ============================### t_image = tf.placeholder('float32', [None, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) # Load model from .npz file tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) # export to meta file saver = tf.train.Saver() saver.save(sess, './meta/srgan') tf.train.write_graph(sess.graph.as_graph_def(), '.', './meta/srgan.pbtxt', as_text=True)
def evaluate(): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) valid_lr_imgs = [] for img__ in valid_lr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.VALID.lr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) valid_lr_imgs.append(image_loaded) print(type(valid_lr_imgs), len(valid_lr_img_list)) valid_hr_imgs = [] for img__ in valid_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.VALID.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) valid_hr_imgs.append(image_loaded) print(type(valid_hr_imgs), len(valid_hr_img_list)) ###========================== DEFINE MODEL ============================### imid = 1 valid_lr_img = valid_lr_imgs[imid] valid_hr_img = valid_hr_imgs[imid] valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] size = valid_lr_img.shape t_image = tf.placeholder('float32', [1, None, None, 1], name='input_image') # 1 for 1 channel net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/valid_gen.png') tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png') tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png') valid_lr_img = valid_lr_img.reshape(valid_lr_img.shape[0], valid_lr_img.shape[1]) out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic') tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png') hr_img_path = save_dir + '/valid_hr.png' bi_cubic_img_path = save_dir + '/valid_bicubic.png' gen_img_path = save_dir + '/valid_gen.png' bicubic_psnr = computePSNR(hr_img_path, bi_cubic_img_path) gen_psnr = computePSNR(hr_img_path, gen_img_path) gnd_truth_hr_img = scipy.misc.imread(hr_img_path, mode='L') generated_hr_img = scipy.misc.imread(gen_img_path, mode='L') gen_ssim = skimage.measure.compare_ssim(gnd_truth_hr_img, generated_hr_img) print("Bicubic image PSNR:", bicubic_psnr) print("Generated image PSNR:", gen_psnr) print('Generated image SSIML', gen_ssim)
def train(): ## create folders to save result images and trained model save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) #srresnet tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine has enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 56, 56, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1], name='t_target_image') print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) #SRGAN_g is the SRResNet portion of the GAN print("net_g.outputs:", tf.shape(net_g.outputs)) net_g.print_params(False) net_g.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## SRResNet g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_224 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(), sample_imgs_224.max()) sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(), sample_imgs_56.max()) tl.vis.save_images(sample_imgs_56, [ni, ni], save_dir_gan + '/_train_sample_56.png') tl.vis.save_images(sample_imgs_224, [ni, ni], save_dir_gan + '/_train_sample_224.png') #tl.vis.save_image(sample_imgs_96[0], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png') ###========================= train SRResNet =========================### merged_summary_generator = tf.summary.merge( [mse_loss_summary, vgg_loss_summary, g_loss_summary]) #g_gan_loss_summary summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_g_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn) summary_pb = tf.summary.Summary() ## update G errG, errM, errV, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_56, t_target_image: b_imgs_224 }) #g_ga_loss summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) print( "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM, errV)) total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time / n_iter, total_g_loss / n_iter) print(log) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56}) print("[*] save images") tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 3 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
def train(): ## create folders to save result images and trained model save_dir_gan = samples_path + "gan" tl.files.exists_or_mkdir(save_dir_gan) tl.files.exists_or_mkdir(checkpoint_path) ###====================== PRE-LOAD DATA ===========================### valid_hr_img_list = sorted( tl.files.load_file_list(path=valid_hr_img_path, regx='.*\.(bmp|png|webp|jpg)', printable=False)) ###========================== DEFINE MODEL ============================### ## train inference sample_t_image = tf.compat.v1.placeholder( 'float32', [sample_batch_size, 96, 96, 3], name='sample_t_image_input_to_SRGAN_generator') t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## test inference net_g_test = SRGAN_g(sample_t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### # MAE Loss mae_loss = tf.reduce_mean(tf.map_fn(tf.abs, t_target_image - net_g.outputs)) # GAN Loss d_loss = 0.5 * ( tf.reduce_mean( tf.square(logits_real - tf.reduce_mean(logits_fake) - 1)) + tf.reduce_mean( tf.square(logits_fake - tf.reduce_mean(logits_real) + 1))) g_gan_loss = 0.5 * ( tf.reduce_mean( tf.square(logits_real - tf.reduce_mean(logits_fake) + 1)) + tf.reduce_mean( tf.square(logits_fake - tf.reduce_mean(logits_real) - 1))) g_loss = 1e-1 * g_gan_loss + mae_loss d_real = tf.reduce_mean(logits_real) d_fake = tf.reduce_mean(logits_fake) with tf.variable_scope('learning_rate'): learning_rate_var = tf.Variable(learning_rate, trainable=False) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) ## SRGAN g_optim = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate_var).minimize(g_loss, var_list=g_vars) d_optim = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate_var).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.variables_initializer(tf.global_variables())) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_path + 'g.npz', network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_path + 'd.npz', network=net_d) ###============================= TRAINING ===============================### sample_imgs = tl.prepro.threading_data( valid_hr_img_list[0:sample_batch_size], fn=get_imgs_fn, path=valid_hr_img_path) sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) save_images(sample_imgs_96, [ni, ni], save_file_format, save_dir_gan + '/_train_sample_96') save_images(sample_imgs_384, [ni, ni], save_file_format, save_dir_gan + '/_train_sample_384') ###========================= train GAN =========================### sess.run(tf.assign(learning_rate_var, learning_rate)) for epoch in range(0, n_epoch_gan + 1): epoch_time = time.time() total_d_loss, total_g_loss_mae, total_g_loss_gan, n_iter = 0, 0, 0, 0 train_hr_img_list = load_deep_file_list(path=train_hr_img_path, regx='.*\.(bmp|png|webp|jpg)', recursive=True, printable=False) random.shuffle(train_hr_img_list) list_length = len(train_hr_img_list) print("Number of images: %d" % (list_length)) if list_length % batch_size != 0: train_hr_img_list += train_hr_img_list[0:batch_size - list_length % batch_size:1] list_length = len(train_hr_img_list) print("Length of list: %d" % (list_length)) for idx in range(0, list_length, batch_size): step_time = time.time() b_imgs_list = train_hr_img_list[idx:idx + batch_size] b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=train_hr_img_path) b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_data_augment_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) b_imgs_384 = tl.prepro.threading_data(b_imgs_384, fn=rescale_m1p1) ## update D errD, d_r, d_f, _ = sess.run([d_loss, d_real, d_fake, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errM, errA, _, _ = sess.run( [mae_loss, g_gan_loss, g_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print( "Epoch[%2d/%2d] %4d time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f d_r: %.8f d_f: %.8f" % (epoch, n_epoch_gan, n_iter, time.time() - step_time, errD, errM, errA, d_r, d_f)) total_d_loss += errD total_g_loss_mae += errM total_g_loss_gan += errA n_iter += 1 log = ( "[*] Epoch[%2d/%2d] time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f" % (epoch, n_epoch_gan, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss_mae / n_iter, total_g_loss_gan / n_iter)) print(log) ## quick evaluation on train set out = sess.run(net_g_test.outputs, {sample_t_image: sample_imgs_96}) print("[*] save images") save_images(out, [ni, ni], save_file_format, save_dir_gan + '/train_%d' % epoch) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_path + 'g.npz', sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_path + 'd.npz', sess=sess)
def evaluate(): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) valid_lr_img_list2 = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path2, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) valid_lr_imgs2 = tl.vis.read_images(valid_lr_img_list2, path=config.VALID.lr_img_path2, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### tf_gen_output = tf.placeholder('float32', [240, 320, 3]) tf_hr_output = tf.placeholder('float32', [240, 320, 3]) t_image = tf.placeholder('float32', [1, None, None, 6], name='input_image') avg_image = tf.placeholder('float32', [1, None, None, 3], name='average_image') net_g = SRGAN_g(t_image, avg_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) for i in range(0, 30): imid = i # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img_t1 = tl.prepro.threading_data(valid_lr_imgs, fn=downsample_fn) valid_lr_img_t2 = tl.prepro.threading_data(valid_lr_imgs2, fn=downsample_fn) valid_lr_img_t1_d = valid_lr_img_t1[imid] valid_lr_img_t2_d = valid_lr_img_t2[imid] print(valid_lr_img_t1_d.shape) print(valid_lr_img_t2_d.shape) valid_hr_img = valid_hr_imgs[imid] valid_lr_img = np.concatenate((valid_lr_img_t1_d, valid_lr_img_t2_d), axis=2) valid_avg_img = (np.add(valid_lr_img_t1_d, valid_lr_img_t2_d)) / 2. # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image # valid_lr_img = tl.prepro.threading_data(valid_lr_img, fn=downsample_fn) # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape print(size) # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, { t_image: [valid_lr_img], avg_image: [valid_avg_img] }) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape) ) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/{}/valid_gen.png'.format(imid)) tl.vis.save_image(valid_lr_img_t1_d, save_dir + '/{}/valid_lr_1.png'.format(imid)) tl.vis.save_image(valid_lr_img_t2_d, save_dir + '/{}/valid_lr_3.png'.format(imid)) tl.vis.save_image(valid_hr_img, save_dir + '/{}/valid_hr.png'.format(imid)) # ssim_gen = tf.image.decode_png(save_dir + '/{}/valid_gen.png'.format(imid)) # ssim_hr = tf.image.decode_png(save_dir + '/{}/valid_hr.png'.format(imid)) # ssim_gen_output = out[0] # ssim_hr_output = valid_hr_img # print(ssim_gen_output) # print(ssim_hr_output) ssim_gen_output = tl.vis.read_image('valid_gen.png', save_dir + '/{}/'.format(imid)) ssim_hr_output = tl.vis.read_image('valid_hr.png', save_dir + '/{}/'.format(imid)) print(ssim_gen_output) print(ssim_hr_output) ssim1 = tf.image.ssim(tf_gen_output, tf_hr_output, max_val=1.0) tf_ssim = sess.run(ssim1, feed_dict={ tf_gen_output: ssim_gen_output, tf_hr_output: ssim_hr_output }) # pre_gray = prediction[:, :, 0] # frame2_gray = frame2[:, :, 0] # ssim = SSIM(pre_gray, frame2_gray).mean() print(tf_ssim)
def evaluate(data, n_patients_train, eval_model, save_imgs=False): ## create folders for checkpoint and results checkpoint_dir = "models_checkpoints" results_dir = None if eval_model == '/g_srgan.npz': results_dir = "srgan_results" else: results_dir = "srresnet_results" tl.files.exists_or_mkdir(results_dir) ###========================== RESTORE G =============================### t_image = tf.placeholder('float32', [1, 512, 512, 1], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + eval_model, network=net_g) ###======================= EVALUATION =============================### counter, imgs_evald, total_mse = 0, 0, 0 for patient, values in data.items(): if counter >= n_patients_train: print("[] Evaluating patient " + patient + " files") tl.files.exists_or_mkdir(results_dir + "/" + patient) valid_lr_imgs = values[0] valid_hr_imgs = values[1] patient_mse = 0 for i in range(len(valid_lr_imgs)): valid_lr_img = valid_lr_imgs[i] valid_hr_img = valid_hr_imgs[i] valid_lr_img = np.asarray(valid_lr_img).reshape((512, 512, 1)) valid_hr_img = np.asarray(valid_hr_img).reshape((512, 512, 1)) out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) curr_mse = mse(out, valid_hr_img) imgs_evald += 1 patient_mse += curr_mse total_mse += curr_mse if save_imgs: tl.vis.save_image( out[0], results_dir + "/" + patient + "/" + str(patient) + "_" + str(i) + '_valid_gen.png') tl.vis.save_image( valid_lr_img, results_dir + "/" + patient + "/" + str(patient) + "_" + str(i) + '_valid_lr.png') tl.vis.save_image( valid_hr_img, results_dir + "/" + patient + "/" + str(patient) + "_" + str(i) + '_valid_hr.png') if i % 100 == 0: print("Batch " + str((i / float(100))) + "/" + str(math.ceil(len(valid_lr_imgs) / float(100)))) patient_mse /= len(valid_lr_imgs) print("Average MSE: " + str(patient_mse)) counter += 1 total_mse /= imgs_evald print("[*] Evaluation -- total MSE: " + str(total_mse))
def evaluate(): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) for im in valid_lr_imgs: print(im.shape) # # for im in train_hr_imgs: # # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # # for im in valid_lr_imgs: # # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # # for im in valid_hr_imgs: # # print(im.shape) # # exit() # t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') t_image = tf.placeholder('float32', [config.TRAIN.batch_size, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) print("valid lr img list:" + str(valid_lr_img_list)) ###========================== DEFINE MODEL ============================### # imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 # for n in range(len(valid_lr_imgs)): for n in range(len(valid_lr_imgs)/config.TRAIN.batch_size): # imid = 0 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 imid = n * config.TRAIN.batch_size # valid_lr_img = valid_lr_imgs[imid] # valid_lr_img = valid_lr_imgs[imid:imid+config.TRAIN.batch_size] valid_lr_img = valid_lr_imgs[imid] # Form the first 3 channel image if len(valid_lr_img.shape) == 2: valid_lr_img = np.expand_dims(valid_lr_img, axis=2) print("resized: " + str(valid_lr_img.shape)) valid_lr_img = np.concatenate((valid_lr_img, valid_lr_img, valid_lr_img), axis=2) print("resized: " + str(valid_lr_img.shape)) for i in range(1, config.TRAIN.batch_size): curr_valid_lr_img = np.expand_dims(valid_lr_imgs[imid+i], axis=2) print("resized: " + str(curr_valid_lr_img.shape)) curr_valid_lr_img = np.concatenate((curr_valid_lr_img, curr_valid_lr_img, curr_valid_lr_img), axis=2) print("resized: " + str(valid_lr_img.shape)) valid_lr_img = np.concatenate(valid_lr_img, curr_valid_lr_img, axis=0) else: curr_valid_lr_img = valid_lr_imgs[imid] curr_valid_lr_img = (curr_valid_lr_img / 127.5) - 1 # rescale to [-1, 1] curr_valid_lr_img = np.expand_dims(curr_valid_lr_img, axis=0) res_img = curr_valid_lr_img for i in range(1, len(valid_lr_img)): curr_valid_lr_img = valid_lr_imgs[imid+i] curr_valid_lr_img = (curr_valid_lr_img / 127.5) - 1 # rescale to [-1, 1] print("curr valid img shape before expand: " + str(curr_valid_lr_img.shape)) curr_valid_lr_img = np.expand_dims(curr_valid_lr_img, axis=0) print("curr valid img shape: " + str(curr_valid_lr_img.shape)) res_img = np.concatenate((res_img, curr_valid_lr_img), axis=0) # valid_lr_img = valid_lr_imgs[imid:imid+config.TRAIN.batch_size] valid_lr_img = res_img print("bahbahbah res img shape: " + str(res_img.shape)) # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image # Resclae [-1, 1] for each img #for i in range(len(valid_lr_img)): #valid_lr_img[i] = (valid_lr_img[i] / 127.5) - 1 # rescale to [-1, 1] # valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape[1:] print("size shape: " + str(size)) # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") # Save images for i in range(config.TRAIN.batch_size): tl.vis.save_image(out[0], save_dir + '/{}_valid_gen.png'.format(valid_lr_img_list[imid+i][:-4])) tl.vis.save_image(valid_lr_img[i], save_dir + '/{}_valid_lr.png'.format(valid_lr_img_list[imid+i][:-4])) # tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png') out_bicu = scipy.misc.imresize(valid_lr_img[i], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None) tl.vis.save_image(out_bicu, save_dir + '/{}_valid_bicubic.png'.format(valid_lr_img_list[n][:-4]))
def conceval(epoch, net_g_test, t_image, sess): print("Intermediate Evaluating epoch...", epoch) interlog = open("inter_eval.txt", 'a') tot_psnr = 0 tot_mse = 0 tot_ssim = 0 tot_res_acc = 0 tot_hr_acc = 0 tot_lr_acc = 0 tot_bic_acc = 0 res_beats_hr = 0 res_beats_bic = 0 res_beats_lr = 0 res_fails = 0 global do_ocr do_ocr = True test_set_size = 16 test_outputs = [] ## create folders to save result images save_dir = "samples/intermediate" tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" if (do_ocr): print("Evaluating with OCR") ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:test_set_size] valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:test_set_size] #for i in valid_hr_img_list: # print (i) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### num_lr_imgs = len(valid_lr_imgs) num_hr_imgs = len(valid_hr_imgs) '''print("loaded", num_lr_imgs, "LR images") if(mode=='multi' and num_lr_imgs != num_hr_imgs): print('Unequal images in LR and HR') return if(mode=='single' and (num_lr_imgs==0 or num_hr_imgs==0)): print('No images found') return sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) ''' ###========================== RESTORE G =============================### in_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g_oth = SRGAN_g(in_image, is_train=False, reuse=True) ses2 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tf.global_variables_initializer() tl.files.load_and_assign_npz(sess=ses2, name=checkpoint_dir + '/g_srgan.npz', network=net_g_oth) print("Loaded model\nProcessing images...") ###======================= EVALUATION =============================### for imid in range(num_lr_imgs): #64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img = valid_lr_imgs[imid] valid_hr_img = valid_hr_imgs[imid] img_name = valid_lr_img_list[imid] #print("Processing image :\t", imid, "/\t", num_lr_imgs, "\t", img_name) # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape if len(size) == 2: valid_lr_img = np.stack((valid_lr_img, ) * 3, axis=-1) valid_hr_img = np.stack((valid_hr_img, ) * 3, axis=-1) size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size #print("size", size) valid_lr_img_res = (valid_lr_img / 127.5) - 1 start_time = time.time() out = ses2.run(net_g_oth.outputs, {in_image: [valid_lr_img_res]}) #print("took: %4.4fs" % (time.time() - start_time)) out_uint8 = convert(out[0]) #print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) #print("[*] save images\n") test_outputs.append(out_uint8) tl.vis.save_image( out_uint8, save_dir + '/' + img_name[:-4] + '_gen_' + format(epoch, '03d') + '.png') #tl.vis.save_image(valid_lr_img, save_dir + '/'+img_name[:-4]+'_lr_'+format(epoch, '03d')+'.png') #tl.vis.save_image(valid_hr_img, save_dir + '/'+img_name[:-4]+'_hr_'+format(epoch, '03d')+'.png') out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None) #tl.vis.save_image(out_bicu, save_dir + '/'+img_name[:-4]+'_bicubic.png') #print(type(out_uint8), out_uint8.shape) #print(type(valid_hr_img), valid_hr_img.shape) #print(type(valid_lr_img), valid_lr_img.shape) img_mse, img_psnr, img_ssim = quant(out_uint8, valid_hr_img) #print('===') tot_psnr += img_psnr tot_mse += img_mse tot_ssim += img_ssim if (do_ocr): res_acc, hr_acc, lr_acc, bic_acc = ocr.getAccuracy( out_uint8, valid_hr_img, valid_lr_img, out_bicu, imid) tot_res_acc += res_acc tot_hr_acc += hr_acc tot_lr_acc += lr_acc tot_bic_acc += bic_acc if (res_acc > hr_acc): res_beats_hr += 1 elif (res_acc > bic_acc): res_beats_bic += 1 elif (res_acc > lr_acc): res_beats_lr += 1 else: res_fails += 1 if (do_ocr): ocrres = "Average GEN accuracy: " + str( tot_res_acc / num_lr_imgs)[:8] + "\nAverage HRI accuracy: " + str( tot_hr_acc / num_lr_imgs)[:8] + "\nAverage LRI accuracy: " + str( tot_lr_acc / num_lr_imgs)[:8] + "\nAverage BIC accuracy: " + str( tot_bic_acc / num_lr_imgs)[:8] + "\n\nRES>HRI: " + str( res_beats_hr) + "\nRES>BIC: " + str( res_beats_bic) + "\nRES>LRI: " + str( res_beats_lr) + "\nRESFAIL: " + str( res_fails) + '\n' + '=' * 50 + '\n' #ocrlog.write(ocrres) try: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs )[:8] + "\tAverage Improvement over bicubic: " + str( tot_res_acc / tot_bic_acc)[:8] + '\n' except: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs)[:8] + '\n' else: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs)[:8] + '\n' interlog.write(hist) interlog.close() ses2.close() del in_image del net_g_oth del test_outputs print("\nAll images done\n" + hist) return
def evaluate(mode): print("Evaluating...") history = open("eval_history.txt", "a") latest = open("latest_eval.txt", "w") ocrlog = open("ocrlog.txt", "a") tot_psnr = 0 tot_mse = 0 tot_ssim = 0 tot_res_acc = 0 tot_hr_acc = 0 tot_lr_acc = 0 tot_bic_acc = 0 res_beats_hr = 0 res_beats_bic = 0 res_beats_lr = 0 res_fails = 0 ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" if (do_ocr): print("Evaluating with OCR") ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### num_lr_imgs = len(valid_lr_imgs) num_hr_imgs = len(valid_hr_imgs) print("loaded", num_lr_imgs, "LR images") if (mode == 'multi' and num_lr_imgs != num_hr_imgs): print('Unequal images in LR and HR') return if (mode == 'single' and (num_lr_imgs == 0 or num_hr_imgs == 0)): print('No images found') return ###========================== RESTORE G =============================### t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) #tf.global_variables_initializer() tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) print("Loaded model\nProcessing images...") ###======================= EVALUATION =============================### for imid in range(num_lr_imgs): #64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img = valid_lr_imgs[imid] valid_hr_img = valid_hr_imgs[imid] img_name = valid_lr_img_list[imid] #print("Processing image :\t", imid, "/\t", num_lr_imgs, "\t", img_name) # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape if len(size) == 2: valid_lr_img = np.stack((valid_lr_img, ) * 3, axis=-1) valid_hr_img = np.stack((valid_hr_img, ) * 3, axis=-1) size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size #print("size", size) valid_lr_img_res = (valid_lr_img / 127.5) - 1 start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img_res]}) #print("took: %4.4fs" % (time.time() - start_time)) out_uint8 = convert(out[0]) #print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) #print("[*] save images\n") tl.vis.save_image(out_uint8, save_dir + '/' + img_name[:-4] + '_gen.png') tl.vis.save_image(valid_lr_img, save_dir + '/' + img_name[:-4] + '_lr.png') tl.vis.save_image(valid_hr_img, save_dir + '/' + img_name[:-4] + '_hr.png') out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None) tl.vis.save_image(out_bicu, save_dir + '/' + img_name[:-4] + '_bicubic.png') #print(type(out[0]), out[0].shape) #print(type(valid_hr_img), valid_hr_img.shape) #print(type(valid_lr_img), valid_lr_img.shape) img_mse, img_psnr, img_ssim = quant(out_uint8, valid_hr_img) tot_psnr += img_psnr tot_mse += img_mse tot_ssim += img_ssim if (do_ocr): res_acc, hr_acc, lr_acc, bic_acc = ocr.getAccuracy( out_uint8, valid_hr_img, valid_lr_img, out_bicu, imid) tot_res_acc += res_acc tot_hr_acc += hr_acc tot_lr_acc += lr_acc tot_bic_acc += bic_acc if (res_acc > hr_acc): res_beats_hr += 1 elif (res_acc > bic_acc): res_beats_bic += 1 elif (res_acc > lr_acc): res_beats_lr += 1 else: res_fails += 1 eval_log = "Image: " + str(imid + 1) + "\tPSNR: " + str( img_psnr)[:8] + "\tMSE: " + str(img_mse)[:8] + "\tSSIM: " + str( img_ssim)[:8] + '\n' latest.write(eval_log) #print(type(valid_lr_img), type(out_bicu)) if (mode == 'single'): num_lr_imgs = 1 latest.close() history.close() print("\n1 image done\n" + eval_log) return incre = int(50.0 / num_lr_imgs * imid) sys.stdout.write('\r' + '|%s%s| %d/%d images done' % ('\033[7m' + ' ' * incre + ' \033[27m', ' ' * (49 - incre), imid + 1, num_lr_imgs)) sys.stdout.flush() if (do_ocr): ocrres = "Average GEN accuracy: " + str( tot_res_acc / num_lr_imgs)[:8] + "\nAverage HRI accuracy: " + str( tot_hr_acc / num_lr_imgs)[:8] + "\nAverage LRI accuracy: " + str( tot_lr_acc / num_lr_imgs)[:8] + "\nAverage BIC accuracy: " + str( tot_bic_acc / num_lr_imgs)[:8] + "\n\nRES>HRI: " + str( res_beats_hr) + "\nRES>BIC: " + str( res_beats_bic) + "\nRES>LRI: " + str( res_beats_lr) + "\nRESFAIL: " + str( res_fails) + '\n' + '=' * 50 + '\n' ocrlog.write(ocrres) try: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs )[:8] + "\tAverage Improvement over bicubic: " + str( tot_res_acc / tot_bic_acc)[:8] + '\n' except: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs)[:8] + '\n' else: hist = "Average PSNR: " + str( tot_psnr / num_lr_imgs)[:8] + "\tAverage MSE: " + str( tot_mse / num_lr_imgs)[:8] + "\tAverage SSIM: " + str( tot_ssim / num_lr_imgs)[:8] + '\n' history.write(hist) latest.close() history.close() ocrlog.close() print("\nAll images done\n" + hist) return
def evaluate(ID, save_path, lr_path): ## create folders to save result images #save_dir = "samples/{}".format(tl.global_flag['mode']) save_dir = save_path tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) #valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=lr_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) #valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) #valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() valid_lr_imgs = [] #sess = tf.Session() for img__ in valid_lr_img_list: image_loaded = scipy.misc.imread(os.path.join(lr_path + '/', img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) #sh=image_loaded.shape #image_loaded = imresize(image_loaded, [int(sh[0]/2), int(sh[1]/2)], interp='bicubic', mode=None) valid_lr_imgs.append(image_loaded) print(type(valid_lr_imgs), len(valid_lr_img_list)) valid_hr_imgs = [] #sess = tf.Session() #for img__ in valid_hr_img_list: # location='../SRGAN8x/DATA/valid_HR_256/'+img__ # image_loaded = scipy.misc.imread(os.path.join(config.VALID.hr_img_path,img__), mode='L') # image_loaded = image_loaded.reshape((image_loaded.shape[0], image_loaded.shape[1], 1)) # lr_img = imresize(image_loaded, [32,256], interp='bicubic', mode=None) # valid_hr_imgs.append(image_loaded) # valid_lr_imgs.append(lr_img) # print(type(valid_hr_imgs), len(valid_hr_img_list)) ###========================== DEFINE MODEL ============================### imid = ID # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img = valid_lr_imgs[imid] #valid_hr_img = valid_hr_imgs[imid] # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape #print(size) # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size #t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') t_image = tf.placeholder('float32', [1, size[0], size[1], 1], name='input_image') # 1 for 1 channel net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_48.npz', network=net_g) saver = tf.train.Saver() saver.restore(sess, 'checkpoint/main_50.ckpt') ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape) ) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") #for i in range(len(out)): tl.vis.save_image( out[0], save_dir + '/' + valid_lr_img_list[imid]) #'/valid_gen_'+str(imid)+'.png') # tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_'+str(imid)+'.png') #tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr_'+str(imid)+'.png') '''
def train(): n_epoch_init = 12 ## create folders to save result images and trained model # save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) # save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) # tl.files.exists_or_mkdir(save_dir_ginit) # tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv log_dir = "logs" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( get_synthia_imgs_list(config.VALID.hr_img_path, is_train=True, synthia_dataset=config.TRAIN.hr_img_path)) valid_hr_img_list = sorted( get_synthia_imgs_list(config.VALID.hr_img_path, is_train=False, synthia_dataset=config.TRAIN.hr_img_path)) print(len(train_hr_img_list)) print(len(valid_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference t_input = tf.placeholder(tf.float32, shape=(None, None, None, 1), name='t_input') # try with log? t_input = tf.log(t_input) d_flg = tf.placeholder(tf.bool, name='is_train') t_image, t_target_image, t_interpolated = preprocess(t_input) net_g_outputs = SRGAN_g(t_image, t_interpolated, is_train=d_flg, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=d_flg, reuse=False) _, logits_fake = SRGAN_d(net_g_outputs, is_train=d_flg, reuse=True) vgg_model_true = VGG16(vgg16_npy_path) vgg_model_gen = VGG16(vgg16_npy_path) ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA # to 3 channels y_true_normalized = (t_target_image - tf.reduce_min(t_target_image)) / ( tf.reduce_max(t_target_image) - tf.reduce_min(t_target_image)) gen_normalized = (net_g_outputs - tf.reduce_min(net_g_outputs)) / ( tf.reduce_max(net_g_outputs) - tf.reduce_min(net_g_outputs)) t_target_image_3ch = tf.concat([y_true_normalized] * 3, 3) t_predict_image_3ch = tf.concat([gen_normalized] * 3, 3) vgg_model_true.build(t_target_image_3ch) true_features = vgg_model_true.conv3_1 vgg_model_gen.build(t_predict_image_3ch) gen_features = vgg_model_gen.conv3_1 ## test inference net_g_test = SRGAN_g(t_image, t_interpolated, is_train=d_flg, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') # d_vgg_loss = 2e-6*tl.cost.mean_squared_error(true_features, gen_features, is_mean=True) d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-2 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') # 1e-3 * mse_loss = tl.cost.mean_squared_error(net_g_outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( true_features, gen_features, is_mean=True) # 2e-6 * tv_loss = 2e-6 * tf.reduce_mean(tf.square(net_g_outputs[:, :-1, :, :] - net_g_outputs[:, 1:, :, :])) + \ tf.reduce_mean(tf.square(net_g_outputs[:, :, :-1, :] - net_g_outputs[:, :, 1:, :])) # 2e-6* g_init_loss = mse_loss + vgg_loss # mse_loss # + vgg_loss + tv_loss g_loss = g_gan_loss + mse_loss + vgg_loss # + mse_loss g_vars = tl.layers.get_variables_with_name('G_Depth_SR', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) glob_step_t = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( g_init_loss, var_list=g_vars, global_step=glob_step_t) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( g_loss, var_list=g_vars, global_step=glob_step_t) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### saver = tf.train.Saver(max_to_keep=5) saver_d = tf.train.Saver(d_vars, max_to_keep=5) saver_g = tf.train.Saver(g_vars, max_to_keep=5) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) with tf.variable_scope('summaries'): tf.summary.scalar('d_loss', d_loss) tf.summary.scalar('g_loss', g_loss) tf.summary.scalar('mse_loss', mse_loss) tf.summary.scalar('vgg_loss', vgg_loss) tf.summary.scalar('tv_loss', tv_loss) tf.summary.scalar('g_gan_loss', g_gan_loss) mae = tf.reduce_mean( tf.abs(net_g_outputs - t_target_image) / (t_target_image + tf.constant(1e-8))) rmse = tf.sqrt( tf.reduce_mean(tf.square(net_g_outputs - t_target_image))) tf.summary.scalar('MAE', mae) tf.summary.scalar('RMSE', rmse) tf.summary.scalar('learning_rate', lr_v) # tf.summary.image('input', t_input , max_outputs=1) tf.summary.image('GT', t_target_image, max_outputs=1) tf.summary.image('input_small_size', t_image, max_outputs=1) tf.summary.image('interpolated', t_interpolated, max_outputs=1) tf.summary.image('result', net_g_outputs, max_outputs=1) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(log_dir + '/test') ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training # sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set # sample_imgs = tl.prepro.threading_data(train_hr_img_list[0:batch_size], fn=get_imgs_fn) # if no pre-load train set # print('sample images:', sample_imgs.shape, sample_imgs.min(), sample_imgs.max()) n_batches = int(len(train_hr_img_list) / batch_size) n_batches_valid = int(len(valid_hr_img_list) / batch_size) ###========================= initialize G ====================### if not do_init_g: n_epoch_init = -1 try: saver_g.restore( sess, tf.train.latest_checkpoint(checkpoint_dir + '/g_init')) except Exception as e: print( ' ** You need to initialize generator: put do_init_g to True or provide a valid restore path' ) raise e else: try: #saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/gan')) # 2 round saver.restore( sess, tf.train.latest_checkpoint(checkpoint_dir + '/g_init')) except: print(' ** Creating new g_init model') pass ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) train_iter, test_iter = 0, 0 for epoch in range(0, n_epoch_init + 1): try: epoch_time = time.time() val_mae, val_mse, val_g_loss = 0, 0, 0 batch_it = tqdm(SynthiaIterator(valid_hr_img_list, batchsize=batch_size, shuffle=True, buffer_size=70), total=n_batches_valid, leave=False) for b in batch_it: xb = b[0] errM, errG, mae_score = sess.run([mse_loss, g_loss, mae], feed_dict={ t_input: xb, d_flg: False }) val_mae += mae_score val_mse += errM val_g_loss += errG print("Validation: Epoch {0} val mae {1} val mse {2}".format( epoch - 1, val_mae / n_batches_valid, val_mse / n_batches_valid)) total_mse_loss, total_g_loss = 0, 0 batch_it = tqdm(SynthiaIterator(train_hr_img_list, batchsize=batch_size, shuffle=True, buffer_size=70), total=n_batches, leave=False) for b in batch_it: xb = b[0] xb = augment_imgs(xb) glob_step, errM, errG, _ = sess.run( [glob_step_t, mse_loss, g_loss, g_optim_init], feed_dict={ t_input: xb, d_flg: True }) total_mse_loss += errM total_g_loss += errG if (train_iter + 1) % 200 == 0: summary = sess.run(summary_op, feed_dict={ t_input: xb, d_flg: False }) train_writer.add_summary(summary, train_iter + 1) train_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_batches) val_mse_summary = tf.Summary.Value(tag='g_init/val_mse_loss', simple_value=val_mse / n_batches_valid) val_g_loss_summary = tf.Summary.Value(tag='g_init/val_loss', simple_value=val_g_loss / n_batches_valid) train_mse_loss_summary = tf.Summary.Value( tag='g_init/train_mse_loss', simple_value=total_mse_loss / n_batches) train_g_loss_summary = tf.Summary.Value(tag='g_init/train_loss', simple_value=total_g_loss / n_batches) epoch_summary = tf.Summary(value=[ val_mse_summary, val_g_loss_summary, train_mse_loss_summary, train_g_loss_summary ]) train_writer.add_summary(epoch_summary, glob_step) print(log) saver.save( sess, os.path.join(checkpoint_dir + '/g_init', 'model' + str(epoch) + '.ckpt')) except Exception as e: batch_it.iterable.stop() raise e ###========================= train GAN (SRGAN) =========================### try: # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/g_init')) # saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir+'/gan')) pass except: print(' ** Creating new GAN model') pass train_iter, test_iter = 0, 0 for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) try: epoch_time = time.time() val_mae, val_mse, val_g_loss, val_d_loss = 0, 0, 0, 0 batch_it = tqdm(SynthiaIterator(valid_hr_img_list, batchsize=batch_size, shuffle=True, buffer_size=70), total=n_batches_valid, leave=False) for b in batch_it: xb = b[0] errM, mae_score, errG, errD = sess.run( [mse_loss, mae, g_loss, d_loss], feed_dict={ t_input: xb, d_flg: False }) val_mae += mae_score val_mse += errM val_g_loss += errG val_d_loss += errD print("Validation (GAN): Epoch {0} val mae {1} val mse {2}".format( epoch - 1, val_mae / n_batches_valid, val_mse / n_batches_valid)) total_d_loss, total_g_loss, total_mse_loss = 0, 0, 0 batch_it = tqdm(SynthiaIterator(train_hr_img_list, batchsize=batch_size, shuffle=True, buffer_size=70), total=n_batches, leave=False) for b in batch_it: xb = b[0] xb = augment_imgs(xb) ## update D errD, _ = sess.run([d_loss, d_optim], { t_input: xb, d_flg: True }) ## update G glob_step, errG, errM, _, summary = sess.run( [glob_step_t, g_loss, mse_loss, g_optim, summary_op], { t_input: xb, d_flg: True }) total_mse_loss += errM total_d_loss += errD total_g_loss += errG if (train_iter + 1) % 10 == 0: train_writer.add_summary(summary, train_iter + 1) train_iter += 1 except Exception as e: batch_it.iterable.stop() raise e break log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f mse_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_batches, total_g_loss / n_batches, total_mse_loss / n_batches) val_mse_summary = tf.Summary.Value(tag='gan/val_mse_loss', simple_value=val_mse / n_batches_valid) val_g_loss_summary = tf.Summary.Value(tag='gan/val_g_loss', simple_value=val_g_loss / n_batches_valid) val_d_loss_summary = tf.Summary.Value(tag='gan/val_d_loss', simple_value=val_d_loss / n_batches_valid) train_mse_loss_summary = tf.Summary.Value(tag='gan/train_mse_loss', simple_value=total_mse_loss / n_batches) train_g_loss_summary = tf.Summary.Value(tag='gan/train_g_loss', simple_value=total_g_loss / n_batches) train_d_loss_summary = tf.Summary.Value(tag='gan/train_d_loss', simple_value=total_d_loss / n_batches) epoch_summary = tf.Summary(value=[ val_mse_summary, val_g_loss_summary, val_d_loss_summary, train_mse_loss_summary, train_g_loss_summary, train_d_loss_summary ]) train_writer.add_summary(epoch_summary, glob_step) print(log) saver.save( sess, os.path.join(checkpoint_dir + '/gan', 'model' + str(n_epoch_init + epoch) + '.ckpt'))
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) ## update G errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def train(train_lr_imgs, train_hr_imgs): ## create folders to save result images and trained model checkpoint_dir = "models_checkpoints" tl.files.exists_or_mkdir(checkpoint_dir) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder(dtype='float32', shape=(batch_size, 512, 512, 1), name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder(dtype='float32', shape=(batch_size, 512, 512, 1), name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api(input=(t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api(input=(t_predict_image_224 + 1) / 2, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) if val[0] == 'conv1_1': W = np.mean(W, axis=2) W = W.reshape((3, 3, 1, 64)) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) ###============================= TRAINING ===============================### ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) start_time = time.time() for epoch in range(0, n_epoch_init): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 step_time = None for idx in range(0, len(train_hr_imgs), batch_size): if idx % 1000 == 0: step_time = time.time() b_imgs_hr = train_hr_imgs[idx:idx + batch_size] b_imgs_lr = train_lr_imgs[idx:idx + batch_size] b_imgs_hr = np.asarray(b_imgs_hr).reshape( (batch_size, 512, 512, 1)) b_imgs_lr = np.asarray(b_imgs_lr).reshape( (batch_size, 512, 512, 1)) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) if idx % 1000 == 0: print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) print("G init took: %4.4fs" % (time.time() - start_time)) ###========================= train GAN (SRGAN) =========================### start_time = time.time() epoch_losses = defaultdict(list) iter_losses = defaultdict(list) for epoch in range(0, n_epoch): ## update learning rate if epoch != 0 and decay_every != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 step_time = None for idx in range(0, len(train_hr_imgs), batch_size): if idx % 1000 == 0: step_time = time.time() b_imgs_hr = train_hr_imgs[idx:idx + batch_size] b_imgs_lr = train_lr_imgs[idx:idx + batch_size] b_imgs_hr = np.asarray(b_imgs_hr).reshape( (batch_size, 512, 512, 1)) b_imgs_lr = np.asarray(b_imgs_lr).reshape( (batch_size, 512, 512, 1)) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) if idx % 1000 == 0: print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) total_d_loss += errD total_g_loss += errG n_iter += 1 iter_losses['d_loss'].append(errD) iter_losses['g_loss'].append(errG) iter_losses['mse_loss'].append(errM) iter_losses['vgg_loss'].append(errV) iter_losses['adv_loss'].append(errA) log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) epoch_losses['d_loss'].append(total_d_loss) epoch_losses['g_loss'].append(total_g_loss) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) print("G train took: %4.4fs" % (time.time() - start_time)) ## create visualizations for losses from training plot_total_losses(epoch_losses) plot_iterative_losses(iter_losses) for loss, values in epoch_losses.items(): np.save(checkpoint_dir + "/epoch_" + loss + '.npy', np.asarray(values)) for loss, values in iter_losses.items(): np.save(checkpoint_dir + "/iter_" + loss + '.npy', np.asarray(values)) print("[*] saved losses")
def evaluate(): ## create folders to save result images save_dir = "samples_btcv2/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) for ii in range(0, 10): valid_lr_imgs[ii] = cv2.cvtColor(valid_lr_imgs[ii], cv2.COLOR_GRAY2RGB) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### #imid = 0 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 # valid_lr_img = valid_lr_imgs[imid] # valid_hr_img = valid_hr_imgs[imid] # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image # valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) # size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) # if imid == 0: # net_g = SRGAN_g(t_image, is_train=False, reuse=False) # else: # net_g = SRGAN_g(t_image, is_train=False, reuse=True) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) test_writer = tf.summary.FileWriter('logs/test/', sess.graph) merged = tf.summary.merge_all() ###======================= EVALUATION =============================### for ii in range(0, 10): valid_lr_img = valid_lr_imgs[ii] valid_lr_img = (valid_lr_img / 127.5) - 1 start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) # print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/valid_gen_%d.png' % ii) tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr_%d.png' % ii)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is None: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] #[None] * len(train_hr_img_list) #sess = tf.Session() for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') t_image = tf.placeholder('float32', [batch_size, 28, 224, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder( 'float32', [batch_size, 224, 224, 1], name='t_target_image' ) # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ? print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) print("net_g.outputs:", tf.shape(net_g.outputs)) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss', d_loss1) d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss', d_loss2) d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss) g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss) mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain # UNCOMMENT THE LINE BELOW!!! #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: # tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn_mod) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ''' ###========================= initialize G ====================### merged_summary_initial_G = tf.summary.merge([mse_loss_summary]) summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G") ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) count = 0 for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. intial_MSE_G_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update G errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(mse_summary_initial_G) intial_G_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. intial_G_summaries[val.tag] = val.simple_value #print("intial_G_summaries:", intial_G_summaries) intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss']) #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1)) #count += 1 total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") for im in range(len(out)): if(im%4==0 or im==1197): tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im)) ## save model saver=tf.train.Saver() if (epoch%10==0 and epoch!=0): saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt') #if (epoch != 0) and (epoch % 10 == 0): #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess) ''' ###========================= train GAN (SRGAN) =========================### saver = tf.train.Saver() saver.restore(sess, 'checkpoint/main_10.ckpt') print('Restored main_10, begin 11/50') merged_summary_discriminator = tf.summary.merge( [d_loss1_summary, d_loss2_summary, d_loss_summary]) summary_discriminator_writer = tf.summary.FileWriter( "./log/train/discriminator") merged_summary_generator = tf.summary.merge([ g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary ]) summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(11, n_epoch + 11): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] d_loss1_summary_per_epoch = [] d_loss2_summary_per_epoch = [] d_loss_summary_per_epoch = [] g_gan_loss_summary_per_epoch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update D errD, _, discriminator_summary = sess.run( [d_loss, d_optim, merged_summary_discriminator], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(discriminator_summary) #print("discriminator_summary", summary_pb, type(summary_pb)) discriminator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. discriminator_summaries[val.tag] = val.simple_value d_loss1_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_real_loss']) d_loss2_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_fake_loss']) d_loss_summary_per_epoch.append( discriminator_summaries['Disciminator_total_loss']) ## update G errG, errM, errV, errA, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) #print("generator_summary", summary_pb, type(summary_pb)) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value #print("generator_summaries:", generator_summaries) g_gan_loss_summary_per_epoch.append( generator_summaries['Generator_GAN_loss']) mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) #summary_generator_writer.add_summary(generator_summary, (count + 1)) #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384}) #summary_total_merged_writer.add_summary(summary_total, (count + 1)) #count += 1 tot_epoch = n_epoch + 10 print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, tot_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 #remove this for normal running: log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ##### # # logging discriminator summary # ###### # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch", simple_value=np.mean( d_loss1_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch", simple_value=np.mean( d_loss2_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_total_loss per epoch", simple_value=np.mean( d_loss_summary_per_epoch)), ]), (epoch)) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_GAN_loss per epoch", simple_value=np.mean( g_gan_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run( net_g_test.outputs, {t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) ## save model if (epoch % 10 == 0 and epoch != 0): saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt') print("[*] save images") for im in range(len(out)): tl.vis.save_image( out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))