def srgan_model(features, labels, mode, params): del params global load_flag if mode == tf.estimator.ModeKeys.PREDICT: net_g_test = SRGAN_g(features, is_train=False) predictions = {'generated_images': net_g_test.outputs} return tf.estimator.EstimatorSpec(mode, predictions=predictions) net_g = SRGAN_g(features, is_train=True) net_d, logits_real = SRGAN_d(labels, is_train=True) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True) t_target_image_224 = tf.image.resize_images(labels, size=[224, 224], method=0, align_corners=False) t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2) d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, labels, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(config.TRAIN.lr_init, trainable=False) # SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \ .minimize(g_loss, var_list=g_vars, global_step=tf.train.get_global_step()) d_optim = tf.train.AdamOptimizer(lr_v, beta1=config.TRAIN.beta1) \ .minimize(d_loss, var_list=d_vars, global_step=tf.train.get_global_step()) joint_op = tf.group([g_optim, d_optim]) load_vgg(net_vgg) return tf.estimator.EstimatorSpec(mode, loss=g_loss, train_op=joint_op)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is None: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def train(train_lr_imgs, train_hr_imgs): ## create folders to save result images and trained model checkpoint_dir = "models_checkpoints" tl.files.exists_or_mkdir(checkpoint_dir) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder(dtype='float32', shape=(batch_size, 512, 512, 1), name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder(dtype='float32', shape=(batch_size, 512, 512, 1), name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api(input=(t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api(input=(t_predict_image_224 + 1) / 2, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) if val[0] == 'conv1_1': W = np.mean(W, axis=2) W = W.reshape((3, 3, 1, 64)) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) ###============================= TRAINING ===============================### ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) start_time = time.time() for epoch in range(0, n_epoch_init): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 step_time = None for idx in range(0, len(train_hr_imgs), batch_size): if idx % 1000 == 0: step_time = time.time() b_imgs_hr = train_hr_imgs[idx:idx + batch_size] b_imgs_lr = train_lr_imgs[idx:idx + batch_size] b_imgs_hr = np.asarray(b_imgs_hr).reshape( (batch_size, 512, 512, 1)) b_imgs_lr = np.asarray(b_imgs_lr).reshape( (batch_size, 512, 512, 1)) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) if idx % 1000 == 0: print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) print("G init took: %4.4fs" % (time.time() - start_time)) ###========================= train GAN (SRGAN) =========================### start_time = time.time() epoch_losses = defaultdict(list) iter_losses = defaultdict(list) for epoch in range(0, n_epoch): ## update learning rate if epoch != 0 and decay_every != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 step_time = None for idx in range(0, len(train_hr_imgs), batch_size): if idx % 1000 == 0: step_time = time.time() b_imgs_hr = train_hr_imgs[idx:idx + batch_size] b_imgs_lr = train_lr_imgs[idx:idx + batch_size] b_imgs_hr = np.asarray(b_imgs_hr).reshape( (batch_size, 512, 512, 1)) b_imgs_lr = np.asarray(b_imgs_lr).reshape( (batch_size, 512, 512, 1)) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_lr, t_target_image: b_imgs_hr }) if idx % 1000 == 0: print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) total_d_loss += errD total_g_loss += errG n_iter += 1 iter_losses['d_loss'].append(errD) iter_losses['g_loss'].append(errG) iter_losses['mse_loss'].append(errM) iter_losses['vgg_loss'].append(errV) iter_losses['adv_loss'].append(errA) log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) epoch_losses['d_loss'].append(total_d_loss) epoch_losses['g_loss'].append(total_g_loss) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) print("G train took: %4.4fs" % (time.time() - start_time)) ## create visualizations for losses from training plot_total_losses(epoch_losses) plot_iterative_losses(iter_losses) for loss, values in epoch_losses.items(): np.save(checkpoint_dir + "/epoch_" + loss + '.npy', np.asarray(values)) for loss, values in iter_losses.items(): np.save(checkpoint_dir + "/iter_" + loss + '.npy', np.asarray(values)) print("[*] saved losses")
def network_new2( top_dir="sr_tanh/", svg_dir="dataset/Test/", #test_data pxl_dir="dataset/Train/", #train_data output_dir="pic_smooth/", test_output_dir='test_output/', checkpoint_dir="save_model", checkpoint_dir1="save_model", model_name="model4", big_loop=1, scale_num=2, epoch_init=5000, strides=20, batch_size=4, max_idx=92, data_size=92, lr_init=1e-3, learning_rate=1e-5, vgg_weight_list=[1, 1, 5e-1, 1e-1], use_vgg=False, use_L1_loss=False, wgan=False, init_g=True, init_d=True, init_b=False, method=0, lowest_resolution_log2=4, train_net=True, generate_pics=True, resume_network=False): logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) handler = logging.FileHandler(top_dir + "log.txt") handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) for idx, val in enumerate(network_new2.__defaults__): logger.info( str(network_new2.__code__.co_varnames[idx]) + ' == ' + str(val)) output_dir = top_dir + output_dir test_output_dir = top_dir + test_output_dir checkpoint_dir = top_dir + checkpoint_dir checkpoint_dir1 = top_dir + checkpoint_dir1 logger.info("start building the net") if use_vgg: print('use vgg') t_target_image_data, image_padding_nums = read_data(pxl_dir, data_size) resolution = t_target_image_data.shape[1] / scale_num target_resolution = t_target_image_data.shape[1] resolution_log2 = int(np.floor(np.log2(resolution))) target_resolution_log2 = int(np.floor(np.log2(target_resolution))) #image = tf.image.resize_images(t_image, size=[64, 64], method=2) t_image_target = tf.placeholder( 'float32', [None, target_resolution, target_resolution, 3], name='t_image_target') t_image_ = tf.image.resize_images( t_image_target, size=[target_resolution // scale_num, target_resolution // scale_num], method=method) t_image = tf.image.resize_images( t_image_, size=[target_resolution, target_resolution], method=method) t_image_target_list = [] t_image_list = [] #generate list of pics from 2 ** 2 resolution to t_image_size resolution net_Gs, mix_rates = my_GAN_G2(t_image, is_train=True, reuse=False) print("init Gs") net_Gs[-1].print_params(False) net_g_test, _ = my_GAN_G2(t_image, is_train=False, reuse=True) print("init g_test") if use_vgg: t_target_image_224 = tf.placeholder('float32', [None, 224, 224, 3], name='t_image_224') t_predict_image_224 = tf.placeholder('float32', [None, 224, 224, 3], name='t_target_224') net_vgg, vgg_target_emb = Vgg19_simple_api( (t_target_image_224 + 1) / 2, reuse=False) #initialize the list to store different level net net_ds = [] b_outputs = [] logits_reals = [] logits_fakes = [] logits_fakes2 = [] d_loss_list = [] b_loss_list = [] d_loss3_list = [] mse_loss_list = [] g_gan_loss_list = [] g_loss_list = [] g_init_optimizer_list = [] d_init_optimizer_list = [] g_optimizer_list = [] d_optimizer_list = [] b_optimizer_list = [] w_clip_list = [] with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) print("init Ds") for i in range(lowest_resolution_log2, target_resolution_log2 + 1): idx = i - lowest_resolution_log2 cur_resolution = 2**i size = [cur_resolution, cur_resolution] target_i = tf.image.resize_images(t_image_target, size=size, method=method) image_i = tf.image.resize_images(t_image, size=size, method=method) t_image_target_list += [target_i] t_image_list += [image_i] if use_vgg: t_target_image_224 = tf.image.resize_images( t_image_target, size=[224, 224], method=1, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer add_dimens = tf.zeros_like(t_target_image_224) print(add_dimens.dtype) print(t_target_image_224.dtype) t_predict_image_224 = tf.image.resize_images( net_Gs[idx].outputs, size=[224, 224], method=1, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api( (t_target_image_224 + 1) / 2, reuse=True) _, vgg_predict_emb = Vgg19_simple_api( (t_predict_image_224 + 1) / 2, reuse=True) #initialize the D_reals and D_fake net_d, logits_real = my_GAN_D1(target_i, is_train=True, reuse=False, use_sigmoid=not wgan) _, logits_fake = my_GAN_D1(net_Gs[idx].outputs, is_train=True, reuse=True, use_sigmoid=not wgan) _, logits_fake2 = my_GAN_D1(image_i, is_train=True, reuse=True, use_sigmoid=not wgan) blend_output = net_CT_blend(image_i, net_Gs[idx].outputs) b_outputs += [blend_output] net_ds += [net_d] logits_reals += [logits_real] logits_fakes += [logits_fake] logits_fakes2 += [logits_fake2] mix_factors = np.random.uniform(size=[1, 1, 1, int(target_i.shape[3])]) print(mix_factors.shape) mix_pic = net_Gs[idx].outputs * mix_factors + target_i * (1 - mix_factors) _, logits_mix = my_GAN_D1(mix_pic, is_train=True, reuse=True) d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1_%d' % cur_resolution) d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2_%d' % cur_resolution) d_loss3 = tl.cost.sigmoid_cross_entropy(logits_fake2, tf.zeros_like(logits_fake2), name='d3_%d' % cur_resolution) d_loss4 = (tf.reduce_mean(logits_fake2)) - ( tf.reduce_mean(logits_real)) d_loss4 = tf.nn.sigmoid(d_loss4) #make sure in [0, 1] d_loss = 1 * (d_loss1 + d_loss2) #+ d_loss3 + d_loss4 d_loss += 0. d_loss3 += d_loss1 use_vgg22 = True vgg_loss = 0 if use_vgg: for i, vgg_target in enumerate(vgg_target_emb): vgg_loss += vgg_weight_list[i] * tl.cost.mean_squared_error( vgg_predict_emb[i].outputs, vgg_target.outputs, is_mean=True) g_gan_loss1 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g_%d' % cur_resolution) g_gan_loss2 = (tf.reduce_mean(logits_fake2)) - ( tf.reduce_mean(logits_fake)) g_gan_loss2 = tf.nn.sigmoid(g_gan_loss2) #make sure in [0, 1] g_gan_loss = g_gan_loss1 # + g_gan_loss2 mse_loss = tl.cost.mean_squared_error(net_Gs[idx].outputs, target_i, is_mean=True) if use_L1_loss: mes_loss = tf.reduce_mean( tf.reduce_mean(tf.abs(net_Gs[idx].outputs - target_i))) g_gan_loss_list += [g_gan_loss] mse_loss_list += [mse_loss] g_loss = 1e-3 * g_gan_loss + mse_loss L1_norm = tf.reduce_mean(tf.reduce_mean(net_Gs[idx].outputs)) def TV_loss(x): loss1 = x[:, :, 1:, :] - x[:, :, :-1, :]**2 loss2 = x[:, 1:, :, :] - x[:, :-1, :, :]**2 return tf.reduce_sum(tf.reduce_sum(loss1)) + tf.reduce_sum( tf.reduce_sum(loss2)) tV_loss = TV_loss(net_Gs[idx].outputs) b_loss = tl.cost.mean_squared_error(blend_output.outputs, target_i, is_mean=True) if i >= 7: g_loss += vgg_loss #g_loss += vgg_loss g_vars = tl.layers.get_variables_with_name('my_GAN_G', True, True) d_vars = tl.layers.get_variables_with_name( 'my_GAN_D_%d' % cur_resolution, True, True) b_vars = tl.layers.get_variables_with_name( 'my_CT_blend_%d' % cur_resolution, True, True) g_optim_init = tf.train.AdamOptimizer(lr_v, 0.9).minimize(mse_loss, var_list=g_vars) g_init_optimizer_list += [g_optim_init] d_optim_init = tf.train.AdamOptimizer(lr_v, 0.9).minimize(d_loss3, var_list=d_vars) g_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(d_loss, var_list=d_vars) b_optim = tf.train.AdamOptimizer(lr_v, 0.9).minimize(b_loss, var_list=b_vars) #WGAN if wgan: print('mode is wgan') g_loss = -(tf.reduce_mean(logits_fake)) + vgg_loss d_loss = (tf.reduce_mean(logits_fake)) - ( tf.reduce_mean(logits_real)) d_loss3 = (tf.reduce_mean(logits_fake2)) - ( tf.reduce_mean(logits_real)) mix_grads = tf.gradients(tf.reduce_sum(logits_mix), mix_pic) mix_norms = tf.sqrt( tf.reduce_sum(tf.square(mix_grads), axis=[1, 2, 3])) addtion = tf.reduce_mean(tf.square(mix_norms - 1.)) * 5.0 #d_loss = d_loss + d_loss3 + addtion + tl.cost.mean_squared_error(logits_real, tf.zeros_like(logits_real)) * 1e-3 d_loss = d_loss + d_loss3 g_optim = tf.train.RMSPropOptimizer(learning_rate).minimize( g_loss, var_list=g_vars) d_optim = tf.train.RMSPropOptimizer(learning_rate).minimize( d_loss, var_list=d_vars) d_optim_init = tf.train.RMSPropOptimizer(learning_rate).minimize( d_loss3, var_list=d_vars) clip_ops = [] for var in d_vars: clip_bound = [-1.0, 1.0] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bound[0], clip_bound[1]))) clip_disc_weights = tf.group(*clip_ops) w_clip_list += [clip_disc_weights] d_loss_list += [d_loss] d_loss3_list += [d_loss3] g_loss_list += [g_loss] b_loss_list += [b_loss] g_optimizer_list += [g_optim] d_optimizer_list += [d_optim] b_optimizer_list += [b_optim] d_init_optimizer_list += [d_optim_init] print("init Res : %d D" % cur_resolution) #Restore Model config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.per_process_gpu_memory_fraction = 0.8 sess = tf.Session(config=config) tl.layers.initialize_global_variables(sess) #......code for restore model if use_vgg: vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) if (len(params) == len(net_vgg.all_params)): break tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() #Read Data #t_image_data, t_target_image_data = split_pic(read_data(svg_dir, data_size)) #initialize G for temp_i in range(big_loop): decay_every = epoch_init // 2 lr_decay = 0.1 logger.info("start training the net") for R in range(lowest_resolution_log2, target_resolution_log2 + 1): idx = R - lowest_resolution_log2 if resume_network or not train_net: tl.files.load_and_assign_npz_dict(sess=sess, name=checkpoint_dir1 + '/g_%d_level_my_gan.npz' % R, network=net_Gs[idx]) tl.files.load_and_assign_npz_dict(sess=sess, name=checkpoint_dir1 + '/b_%d_level_my_gan.npz' % R, network=b_outputs[idx]) #tl.files.load_and_assign_npz_dict(sess = sess, name = checkpoint_dir1 + '/d_%d_level_my_gan.npz' % R, network = net_ds[idx]) total_mse_loss = 0 mse_loss = mse_loss_list[idx] g_optim_init = g_init_optimizer_list[idx] total_d3_loss = 0 d_loss3 = d_loss3_list[idx] d_optim = d_optimizer_list[idx] d_loss = d_loss_list[idx] d_optim_init = d_init_optimizer_list[idx] ni = int(np.sqrt(batch_size)) out_svg = sess.run( t_image_list[idx], {t_image_target: t_target_image_data[0:batch_size]}) out_pxl = sess.run( t_image_target_list[idx], {t_image_target: t_target_image_data[0:batch_size]}) print(out_pxl[0]) print(out_pxl.dtype) tl.vis.save_images(out_svg, [ni, ni], output_dir + "R_%d_svg.png" % (R)) tl.vis.save_images(out_pxl, [ni, ni], output_dir + "R_%d_pxl.png" % (R)) f = open('log%d.txt' % R, 'w') pre_loss_list = [] now_loss_list = [] if init_g and train_net: #fix lr_v print('init g') sess.run(tf.assign(lr_v, lr_init)) for epoch in range(epoch_init + 1): iters, data, padding_nums = batch_data( t_target_image_data, image_padding_nums, max_idx, batch_size) total_mse_loss = 0 total_pre_loss = np.zeros([2]) total_now_loss = np.zeros([2]) for i in range(iters): errM, _ = sess.run([mse_loss, g_optim_init], {t_image_target: data[i]}) total_mse_loss += errM if R == target_resolution_log2: #final steps lowR_pics, output_pics, GT_pics = sess.run( [ t_image_list[idx], net_g_test[idx].outputs, t_image_target_list[idx] ], {t_image_target: data[i]}) pre_lowR_pics = clip_pics(lowR_pics, padding_nums[i]) pre_output_pics = clip_pics( output_pics, padding_nums[i]) pre_GT_pics = clip_pics(GT_pics, padding_nums[i]) for ii in range(data[i].shape[0]): pre_loss = cal_loss(pre_lowR_pics[ii], pre_GT_pics[ii]) now_loss = cal_loss(pre_output_pics[ii], pre_GT_pics[ii]) total_pre_loss += pre_loss total_now_loss += now_loss pre_loss_list += [total_pre_loss / max_idx] now_loss_list += [total_now_loss / max_idx] print("[%d/%d] total_mse_loss = %f errM = %f" % (epoch, epoch_init, total_mse_loss, errM)) ## save model if (epoch % strides == 0): print("save img %d" % R) out, logits_real, logits_fake, logits_fake2 = sess.run( [ net_g_test[idx].outputs, tf.nn.sigmoid(logits_reals[idx]), tf.nn.sigmoid(logits_fakes[idx]), tf.nn.sigmoid(logits_fakes2[idx]) ], { t_image_target: t_target_image_data[0:batch_size] }) print(out[0]) print(out.dtype) tl.vis.save_images( out, [ni, ni], output_dir + "R_%d_init_%d.png" % (R, epoch)) if epoch % 10 == 0: tl.files.save_npz_dict( net_Gs[idx].all_params, name=checkpoint_dir + ('/g_%d_level_{}_init.npz' % R).format( tl.global_flag['mode']), sess=sess) print("R %d total_mse_loss = %f" % (2**R, total_mse_loss)) save_list(top_dir + 'init_g_pre', pre_loss_list) save_list(top_dir + 'init_g_now', now_loss_list) pre_loss_list = [] now_loss_list = [] if init_d and train_net: #fix lr_v print('init d') sess.run(tf.assign(lr_v, lr_init)) for epoch in range(epoch_init + 1): iters, data, padding_nums = batch_data( t_target_image_data, image_padding_nums, max_idx, batch_size) for i in range(iters): errD3, errD, _ = sess.run( [d_loss3, d_loss, d_optim_init], {t_image_target: data[i]}) total_d3_loss += errD3 print("[%d/%d] d_loss = %f, errD3 = %f" % (epoch, epoch_init, errD, errD3)) ## save model if (epoch != 0) and (epoch % 5 == 0): tl.files.save_npz_dict( net_ds[idx].all_params, name=checkpoint_dir + '/d_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) if epoch % 10 == 0: out, logits_real, logits_fake, logits_fake2 = sess.run( [ net_g_test[idx].outputs, tf.nn.sigmoid(logits_reals[idx]), tf.nn.sigmoid(logits_fakes[idx]), tf.nn.sigmoid(logits_fakes2[idx]) ], { t_image_target: t_target_image_data[0:batch_size] }) print("logits_real", file=f) print(logits_real, file=f) print("logits_fake", file=f) print(logits_fake, file=f) print("logits_fake2", file=f) print(logits_fake2, file=f) print("R %d total_d3_loss = %f" % (2**R, total_d3_loss)) print("init g or d end", file=f) #train GAN g_optim = g_optimizer_list[idx] d_optim = d_optimizer_list[idx] d_loss = d_loss_list[idx] g_loss = g_loss_list[idx] mse_loss = mse_loss_list[idx] g_gan_loss = g_gan_loss_list[idx] mix_rate, pic_rate = mix_rates[idx] increas = 2. / epoch_init mix_rate_vals = np.arange(0., 1. + increas, increas) last_errD = 0. last_errG = 0. if train_net: for epoch in range(epoch_init + 1): if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) #mix_mat = np.zeros([t_image_list[idx].shape[i] for i in range(1, 4)], dtype = 'float32') sess.run(tf.assign(mix_rate, 0)) sess.run(tf.assign(pic_rate, 0)) total_d_loss = 0 total_g_loss = 0 total_mse_loss = 0 iters, data, padding_nums = batch_data( t_target_image_data, image_padding_nums, max_idx, batch_size) total_pre_loss = np.zeros([2]) total_now_loss = np.zeros([2]) for i in range(iters): #update G if wgan: errG, errM, errA, _ = sess.run( [g_loss, mse_loss, g_gan_loss, g_optim], {t_image_target: data[i]}) #update D if True: #last_errG * 1e3 <= last_errD * 10: # D learning too fast flag = 1 errD, _ = sess.run([d_loss, d_optim], {t_image_target: data[i]}) #print("[%d/%d] epoch %d times d_loss : %f" % (epoch, epoch_init, i, errD)) #update G if not wgan: #print("train G") errG, errM, errA, _ = sess.run( [g_loss, mse_loss, g_gan_loss, g_optim], {t_image_target: data[i]}) #print("[%d/%d] epoch %d times, g_loss : %f, mse_loss : %f, g_gan_loss : %f" # % (epoch, epoch_init, i, errG, errM, errA)) #clip var_val if wgan: _ = sess.run(w_clip_list[idx]) last_errD = errD last_errG = errA total_d_loss += errD total_g_loss += errG total_mse_loss += errM if R == target_resolution_log2: #final steps lowR_pics, output_pics, GT_pics = sess.run( [ t_image_list[idx], net_g_test[idx].outputs, t_image_target_list[idx] ], {t_image_target: data[i]}) pre_lowR_pics = clip_pics(lowR_pics, padding_nums[i]) pre_output_pics = clip_pics( output_pics, padding_nums[i]) pre_GT_pics = clip_pics(GT_pics, padding_nums[i]) for ii in range(data[i].shape[0]): pre_loss = cal_loss(pre_lowR_pics[ii], pre_GT_pics[ii]) now_loss = cal_loss(pre_output_pics[ii], pre_GT_pics[ii]) total_pre_loss += pre_loss total_now_loss += now_loss pre_loss_list += [total_pre_loss / max_idx] now_loss_list += [total_now_loss / max_idx] print("lastD = %f, lastG = %f" % (last_errD, last_errG)) print("[%d/%d] epoch %d times d_loss : %f" % (epoch, epoch_init, i, errD)) print( "[%d/%d] epoch %d times, errM = %f, mse_loss : %f, g_gan_loss : %f" % (epoch, epoch_init, i, errM, total_mse_loss, errA)) #save genate pic if (epoch % strides == 0): print("save img %d" % R) out, logits_real, logits_fake, logits_fake2 = sess.run( [ net_g_test[idx].outputs, tf.nn.sigmoid(logits_reals[idx]), tf.nn.sigmoid(logits_fakes[idx]), tf.nn.sigmoid(logits_fakes2[idx]) ], { t_image_target: t_target_image_data[0:batch_size] }) print(out[0]) out = out.clip(0, 255) print(out.dtype) tl.vis.save_images( out, [ni, ni], output_dir + "R_%d_train_%d.png" % (R, epoch)) #increase the mix_rate from 0 to 1 linearly mix_rate_val = tf.nn.sigmoid(mix_rate).eval( session=sess) mix_pic_val = tf.nn.sigmoid(pic_rate).eval( session=sess) print("logits_real") print(logits_real) print("logits_fake") print(logits_fake) print("logits_fake2") print(logits_fake2) print("logits_real", file=f) print(logits_real, file=f) print("logits_fake", file=f) print(logits_fake, file=f) print("logits_fake2", file=f) print(logits_fake2, file=f) if (logits_real == logits_fake).all(): print("optimize well") print("optimize well", file=f) print("mix_rate, pic_rate") print(mix_rate_val, mix_pic_val) print("mix_rate, pic_rate", file=f) print(mix_rate_val, mix_pic_val, file=f) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz_dict( net_Gs[idx].all_params, name=checkpoint_dir + ('/g_%d_level_{}.npz' % R).format( tl.global_flag['mode']), sess=sess) tl.files.save_npz_dict( net_d.all_params, name=checkpoint_dir + ('/d_%d_level_{}.npz' % R).format( tl.global_flag['mode']), sess=sess) save_list(top_dir + 'g_pre', pre_loss_list) save_list(top_dir + 'g_now', now_loss_list) pre_loss_list = [] now_loss_list = [] f.close() blend_output = b_outputs[idx] b_loss = b_loss_list[idx] b_optim = b_optimizer_list[idx] if not True: #fix lr_v sess.run(tf.assign(lr_v, lr_init)) for epoch in range(epoch_init * 3 + 1): iters, data, padding_nums = batch_data( t_target_image_data, image_padding_nums, max_idx, batch_size) for i in range(iters): errM, _ = sess.run([b_loss, b_optim], {t_image_target: data[i]}) total_mse_loss += errM print("[%d/%d] total_mse_loss = %f errM = %f" % (epoch, epoch_init, total_mse_loss, errM)) ## save model if (epoch % (strides * 3) == 0): print("save img %d" % R) out = sess.run(blend_output.outputs, { t_image_target: t_target_image_data[0:batch_size] }) out = out.clip(0, 255) #print(out[0]) print(out.dtype) tl.vis.save_images( out, [ni, ni], output_dir + "b_%d_output_%d.png" % (R, epoch)) if epoch % 100 == 0: tl.files.save_npz_dict( blend_output.all_params, name=checkpoint_dir + ('/b_%d_level_{}.npz' % R).format( tl.global_flag['mode']), sess=sess) logger.info("end training the net") if not train_net or generate_pics: if init_b: sess.run(tf.assign(lr_v, lr_init)) for epoch in range(epoch_init * 3 + 1): iters, data, padding_nums = batch_data( t_target_image_data, image_padding_nums, max_idx, batch_size) for i in range(iters): errM, _ = sess.run([b_loss, b_optim], {t_image_target: data[i]}) total_mse_loss += errM print("[%d/%d] total_mse_loss = %f errM = %f" % (epoch, epoch_init, total_mse_loss, errM)) ## save model if (epoch % (strides * 3) == 0): print("save img %d" % R) out = sess.run(blend_output.outputs, { t_image_target: t_target_image_data[0:batch_size] }) out = out.clip(0, 255) #print(out[0]) print(out.dtype) tl.vis.save_images( out, [ni, ni], output_dir + "b_%d_output_%d.png" % (R, epoch)) if epoch % 100 == 0: tl.files.save_npz_dict( blend_output.all_params, name=checkpoint_dir + ('/b_%d_level_{}.npz' % R).format( tl.global_flag['mode']), sess=sess) logger.info("load params") tl.files.load_and_assign_npz_dict(sess=sess, name=checkpoint_dir1 + '/g_%d_level_my_gan.npz' % R, network=net_Gs[-1]) tl.files.load_and_assign_npz_dict(sess=sess, name=checkpoint_dir1 + '/b_%d_level_my_gan.npz' % R, network=b_outputs[-1]) logger.info("read pics") test_set_dir = ["Set5/", "Set14/"] test_no = [5, 13] for j in range(2): data_pxl, pic_pad_nums = read_data(svg_dir + test_set_dir[j], num=test_no[j]) iters = data_pxl.shape[0] data_pxl = np.split(data_pxl, iters) #iters, data = batch_data((t_image_data, t_target_image_data), 100, batch_size) logger.info('start evaluating pics') for i in range(iters): print("save img %d" % R) out = sess.run(net_g_test[idx].outputs, {t_image_target: data_pxl[i]}) out = out.clip(0, 255) out = np.array([clip_pic(out[0], pic_pad_nums[i])]) tl.vis.save_images( out, [1, 1], test_output_dir + test_set_dir[j] + "g_%d_output_%d.png" % (R, i)) out = sess.run(b_outputs[idx].outputs, {t_image_target: data_pxl[i]}) out = out.clip(0, 255) out = np.array([clip_pic(out[0], pic_pad_nums[i])]) tl.vis.save_images( out, [1, 1], test_output_dir + test_set_dir[j] + "b_%d_output_%d.png" % (R, i)) out = sess.run(t_image, {t_image_target: data_pxl[i]}) out = np.array([clip_pic(out[0], pic_pad_nums[i])]) tl.vis.save_images( out, [1, 1], test_output_dir + test_set_dir[j] + "svg_%d_%d.png" % (R, i)) out = sess.run(t_image_target, {t_image_target: data_pxl[i]}) out = np.array([clip_pic(out[0], pic_pad_nums[i])]) tl.vis.save_images( out, [1, 1], test_output_dir + test_set_dir[j] + "pxl_%d_%d.png" % (R, i)) logger.info('end evaluating pics')
def train(): ## create folders to save result images and trained model save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) #srresnet tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine has enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 56, 56, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1], name='t_target_image') print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) #SRGAN_g is the SRResNet portion of the GAN print("net_g.outputs:", tf.shape(net_g.outputs)) net_g.print_params(False) net_g.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## SRResNet g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_224 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(), sample_imgs_224.max()) sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(), sample_imgs_56.max()) tl.vis.save_images(sample_imgs_56, [ni, ni], save_dir_gan + '/_train_sample_56.png') tl.vis.save_images(sample_imgs_224, [ni, ni], save_dir_gan + '/_train_sample_224.png') #tl.vis.save_image(sample_imgs_96[0], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png') ###========================= train SRResNet =========================### merged_summary_generator = tf.summary.merge( [mse_loss_summary, vgg_loss_summary, g_loss_summary]) #g_gan_loss_summary summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_g_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn) summary_pb = tf.summary.Summary() ## update G errG, errM, errV, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_56, t_target_image: b_imgs_224 }) #g_ga_loss summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) print( "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM, errV)) total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time / n_iter, total_g_loss / n_iter) print(log) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56}) print("[*] save images") tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 3 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
def train(): import os, time ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) num_sub_imgs = config.Size.num_sub_imgs input_img_path = config.TRAIN.input_img_path if config.TRAIN.input_type == "quan": input_img_path = config.TRAIN.input_img_path + "_quan" elif config.TRAIN.input_type == "clip": input_img_path = config.TRAIN.input_img_path + "_clip" else: print("input_type error") return label_img_path = config.TRAIN.label_img_path ###====================== PRE-LOAD DATA ===========================### train_input_img_list = sorted( tl.files.load_file_list(path=input_img_path, regx='.*.png', printable=False)) train_label_img_list = sorted( tl.files.load_file_list(path=label_img_path, regx='.*.png', printable=False)) print('train_input_img_list : ', train_input_img_list) print('train_label_img_list : ', train_label_img_list) ## If your machine have enough memory, please pre-load the whole train set. train_input_imgs = tl.vis.read_images(train_input_img_list, path=input_img_path, n_threads=32) train_label_imgs = tl.vis.read_images(train_label_img_list, path=label_img_path, n_threads=32) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder( 'float32', [batch_size * num_sub_imgs, sub_img_size, sub_img_size, 3], name='t_image_input_to_generator') t_target_image = tf.placeholder( 'float32', [batch_size * num_sub_imgs, sub_img_size, sub_img_size, 3], name='t_target_image') net_g = unet(t_image, reuse=False) net_g.print_params(False) net_g.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA if with_vgg: t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api( (t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = unet(t_image, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================## if with_mse: mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) if with_vgg: vgg_loss = 5e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) mse_loss = mse_loss + vgg_loss else: if with_vgg: mse_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/checkpoint.npz', network=net_g) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] if with_vgg: for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) net_vgg.print_params(False) net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training _sample_imgs_input = train_input_imgs[0:batch_size] _sample_imgs_label = train_label_imgs[0:batch_size] sample_imgs_input, sample_imgs_label = crop_sub_imgs( _sample_imgs_input, _sample_imgs_label) tl.vis.save_images(sample_imgs_input, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_label, [ni, ni], save_dir_ginit + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate lr_g = config.TRAIN.lr_g sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) with open("loss.txt", 'w') as f: f.write("") decay_g = config.TRAIN.decay_g init_time = time.time() for epoch in range(0, n_epoch_init + 1): if epoch != 0 and (epoch % decay_g == 0): new_lr_decay = lr_decay**(epoch // decay_g) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) epoch_time = time.time() total_mse_loss, n_iter = 0, 0 for idx in range(0, len(train_input_imgs), batch_size): step_time = time.time() b_imgs_input, b_imgs_label = list_sub_imgs( train_input_imgs[idx:idx + batch_size], train_label_imgs[idx:idx + batch_size]) errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_input, t_target_image: b_imgs_label }) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) with open("loss.txt", 'a') as f: f.write(str(total_mse_loss / n_iter) + "\n") ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_input }) # ; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) print("init complete %dsec" % (init_time - time.time()))
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] #[None] * len(train_hr_img_list) #sess = tf.Session() for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') t_image = tf.placeholder('float32', [batch_size, 28, 224, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder( 'float32', [batch_size, 224, 224, 1], name='t_target_image' ) # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ? print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) print("net_g.outputs:", tf.shape(net_g.outputs)) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss', d_loss1) d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss', d_loss2) d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss) g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss) mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain # UNCOMMENT THE LINE BELOW!!! #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: # tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn_mod) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ''' ###========================= initialize G ====================### merged_summary_initial_G = tf.summary.merge([mse_loss_summary]) summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G") ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) count = 0 for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. intial_MSE_G_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update G errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(mse_summary_initial_G) intial_G_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. intial_G_summaries[val.tag] = val.simple_value #print("intial_G_summaries:", intial_G_summaries) intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss']) #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1)) #count += 1 total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") for im in range(len(out)): if(im%4==0 or im==1197): tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im)) ## save model saver=tf.train.Saver() if (epoch%10==0 and epoch!=0): saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt') #if (epoch != 0) and (epoch % 10 == 0): #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess) ''' ###========================= train GAN (SRGAN) =========================### saver = tf.train.Saver() saver.restore(sess, 'checkpoint/main_10.ckpt') print('Restored main_10, begin 11/50') merged_summary_discriminator = tf.summary.merge( [d_loss1_summary, d_loss2_summary, d_loss_summary]) summary_discriminator_writer = tf.summary.FileWriter( "./log/train/discriminator") merged_summary_generator = tf.summary.merge([ g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary ]) summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(11, n_epoch + 11): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] d_loss1_summary_per_epoch = [] d_loss2_summary_per_epoch = [] d_loss_summary_per_epoch = [] g_gan_loss_summary_per_epoch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update D errD, _, discriminator_summary = sess.run( [d_loss, d_optim, merged_summary_discriminator], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(discriminator_summary) #print("discriminator_summary", summary_pb, type(summary_pb)) discriminator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. discriminator_summaries[val.tag] = val.simple_value d_loss1_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_real_loss']) d_loss2_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_fake_loss']) d_loss_summary_per_epoch.append( discriminator_summaries['Disciminator_total_loss']) ## update G errG, errM, errV, errA, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) #print("generator_summary", summary_pb, type(summary_pb)) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value #print("generator_summaries:", generator_summaries) g_gan_loss_summary_per_epoch.append( generator_summaries['Generator_GAN_loss']) mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) #summary_generator_writer.add_summary(generator_summary, (count + 1)) #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384}) #summary_total_merged_writer.add_summary(summary_total, (count + 1)) #count += 1 tot_epoch = n_epoch + 10 print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, tot_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 #remove this for normal running: log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ##### # # logging discriminator summary # ###### # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch", simple_value=np.mean( d_loss1_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch", simple_value=np.mean( d_loss2_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_total_loss per epoch", simple_value=np.mean( d_loss_summary_per_epoch)), ]), (epoch)) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_GAN_loss per epoch", simple_value=np.mean( g_gan_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run( net_g_test.outputs, {t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) ## save model if (epoch % 10 == 0 and epoch != 0): saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt') print("[*] save images") for im in range(len(out)): tl.vis.save_image( out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))
def train_distil(): ## create folders to save result images and trained model save_dir_ginit = "samples/student_ginit" save_dir_gan = "samples/student_gan" tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) d_losses, g_losses, m_losses, v_losses, a_losses = [], [], [], [], [] g0losses, d1losses, d2losses = [], [], [] ###====================== PRE-LOAD IMAGE DATA ===========================### print("loading images") train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # train_hr_img_list = train_hr_img_list[0:16] # train_lr_img_list = train_lr_img_list[0:16] train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) train_lr_imgs = tl.vis.read_images(train_lr_img_list, path=config.TRAIN.lr_img_path, n_threads=32) print("images loaded") ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') # t_distilled_d = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') # t_distilled_g = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') #nets net_g_student, net_g_student_distil = SRGAN_g_student(t_image, is_train=True, reuse=False) net_d_student, logits_real_student, net_d_student_distil = SRGAN_d_student( t_target_image, is_train=True, reuse=False) net_d_student_fake, logits_fake_student, net_d_student_distil_fake = SRGAN_d_student( net_g_student.outputs, is_train=True, reuse=True) if small_techer is True and train_all_nine is False: net_g_teacher_distil = SRGAN_g_teacher_small(t_image, is_train=False, reuse=False) net_d_teacher_distil = SRGAN_d_teacher_small(t_target_image, is_train=False, reuse=False) net_d_teacher_distil_fake = SRGAN_d_teacher_small( net_g_student.outputs, is_train=False, reuse=True) else: net_g_teacher, net_g_teacher_distil = SRGAN_g_teacher(t_image, is_train=False, reuse=False) net_d_teacher, _, net_d_teacher_distil = SRGAN_d_teacher( t_target_image, is_train=False, reuse=False) net_d_teacher_fake, _, net_d_teacher_distil_fake = SRGAN_d_teacher( net_g_student.outputs, is_train=False, reuse=True) if not train_all_nine is True: net_g0_predict, _ = SRGAN_g0_predict(net_g_student_distil.outputs, is_train=True, reuse=False) net_d1_predict, _ = SRGAN_d1_predict(net_d_student_distil.outputs, is_train=True, reuse=False) net_d2_predict, _ = SRGAN_d2_predict(net_d_student_distil_fake.outputs, is_train=True, reuse=False) else: net_g0_predict, _ = SRGAN_g0_predict(net_g_student_distil.outputs, is_train=True, reuse=False) net_d1_predict, _ = SRGAN_d1_predict(net_d_student_distil.outputs, is_train=True, reuse=False) net_d2_predict, _ = SRGAN_d2_predict(net_d_student_distil_fake.outputs, is_train=True, reuse=False) net_d1d2_predict, _ = SRGAN_d1d2_predict(net_d_student_distil.outputs, is_train=True, reuse=False) net_d2d1_predict, _ = SRGAN_d2d1_predict( net_d_student_distil_fake.outputs, is_train=True, reuse=False) net_g0d1_predict, _ = SRGAN_g0d1_predict(net_g_student_distil.outputs, is_train=True, reuse=False) net_g0d2_predict, _ = SRGAN_g0d2_predict(net_g_student_distil.outputs, is_train=True, reuse=False) net_d1g0_predict, _ = SRGAN_d1g0_predict(net_d_student_distil.outputs, is_train=True, reuse=False) net_d2g0_predict, _ = SRGAN_d2g0_predict(net_d_student_distil.outputs, is_train=True, reuse=False) net_g_student.print_params(False) net_g_student.print_layers() net_d_student.print_params(False) net_d_student.print_layers() # net_g_student_distil_fake.print_params(False) # net_g_student_distil_fake.print_layers() # net_d_student_distil_fake.print_params(False) # net_d_student_distil_fake.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g_student.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test, _ = SRGAN_g_student(t_image, is_train=True, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real_student, tf.ones_like(logits_real_student), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake_student, tf.zeros_like(logits_fake_student), name='d2') g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake_student, tf.ones_like(logits_fake_student), name='g') mse_loss = tl.cost.mean_squared_error(net_g_student.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) if not train_all_nine is True: g0_loss = 4e-3 * tl.cost.mean_squared_error( net_g0_predict.outputs, net_g_teacher_distil.outputs, is_mean=True) d1_loss = 1 / 5 * tl.cost.mean_squared_error( net_d1_predict.outputs, net_d_teacher_distil.outputs, is_mean=True) d2_loss = 1 / 5 * tl.cost.mean_squared_error( net_d2_predict.outputs, net_d_teacher_distil_fake.outputs, is_mean=True) else: g0_loss = 4e-3 * tl.cost.mean_squared_error( (net_g0_predict.outputs + net_d1g0_predict.outputs + net_d2g0_predict.outputs) / 3, net_g_teacher_distil.outputs, is_mean=True) d1_loss = 1 / 5 * tl.cost.mean_squared_error( (net_g0d1_predict.outputs + net_d1_predict.outputs + net_d2d1_predict.outputs) / 3, net_d_teacher_distil.outputs, is_mean=True) d2_loss = 1 / 5 * tl.cost.mean_squared_error( (net_g0d2_predict.outputs + net_d1d2_predict.outputs + net_d2_predict.outputs) / 3, net_d_teacher_distil_fake.outputs, is_mean=True) d_loss = d_loss1 + d_loss2 + d1_loss + d2_loss g_loss = mse_loss + vgg_loss + g_gan_loss + g0_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g_student', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d_student', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) # ## Pretrain # g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_srgan_student.npz', network=net_g_student) is False: pass # tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_init.npz', network=net_g_student) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_srgan_student.npz', network=net_d_student) if small_teacher is True: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_small_teacher_bicube.npz', network=net_g_teacher_distil) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_small_teacher_bicube.npz', network=net_d_teacher_distil) else: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan_teacher.npz', network=net_g_teacher) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_srgan_teacher.npz', network=net_d_teacher) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= PreSample ===============================### ## use first `batch_size` of train set to have a quick test during training # sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = train_lr_imgs[0:batch_size] # # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384, sample_imgs_96 = threading_data_2( (train_hr_imgs[0:batch_size], train_lr_imgs[0:batch_size]), fn=crop2, is_random=True) # sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=True) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) # sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= train GAN (SRGAN) =========================### print("starting") for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 total_errM, total_errV, total_errA = 0, 0, 0 total_erg0, total_erd1, total_erd2 = 0, 0, 0 ## Actual training for idx in range(0, len(train_hr_imgs), batch_size): #using 4x lowresolution images b_imgs_384, b_imgs_96 = threading_data_2( (train_hr_imgs[idx:idx + batch_size], train_lr_imgs[idx:idx + batch_size]), fn=crop2, is_random=True) #for 4x high resolution images # if n_iter==0 and epoch==0: tl.vis.save_images(b_imgs_384, [ni, ni], save_dir_gan + '/original_train_384_%d.png' % epoch) # if n_iter==0 and epoch==0: tl.vis.save_images(b_imgs_96, [ni, ni], save_dir_gan + '/original_train_96_%d.png' % epoch) step_time = time.time() ## update D errD, erg0, erd1, erd2, _ = sess.run( [d_loss, g0_loss, d1_loss, d2_loss, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) total_d_loss += errD total_g_loss += errG total_errM += errM total_errV += errV total_errA += errA total_erg0 += erg0 total_erd1 += erd1 total_erd2 += erd2 n_iter += 1 print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)(erg0: %.6f erd1: %.6f erd2: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA, erg0, erd1, erd2)) log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ##write losses d_losses += [total_d_loss / n_iter] g_losses += [total_g_loss / n_iter] m_losses += [total_errM / n_iter] v_losses += [total_errV / n_iter] a_losses += [total_errA / n_iter] g0losses += [total_erg0 / n_iter] d1losses += [total_erd1 / n_iter] d2losses += [total_erd2 / n_iter] if epoch % 20 == 0: tl.files.save_npz(net_g_student.all_params, name=checkpoint_dir + '/g_srgan_student_%d.npz' % epoch, sess=sess) tl.files.save_npz(net_d_student.all_params, name=checkpoint_dir + '/d_srgan_student_%d.npz' % epoch, sess=sess) write_losses("d_losses", d_losses) write_losses("g_losses", g_losses) write_losses("m_losses", m_losses) write_losses("v_losses", v_losses) write_losses("a_losses", a_losses) write_losses("g0losses", g0losses) write_losses("d1losses", d1losses) write_losses("d2losses", d2losses) if epoch % 10 == 0: # out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) tl.files.save_npz(net_g_student.all_params, name=checkpoint_dir + '/g_srgan_student.npz', sess=sess) tl.files.save_npz(net_d_student.all_params, name=checkpoint_dir + '/d_srgan_student.npz', sess=sess) if not small_teacher is True: tl.files.save_npz(net_g_teacher_distil.all_params, name=checkpoint_dir + '/g_small_teacher_bicube.npz', sess=sess) tl.files.save_npz(net_d_teacher_distil.all_params, name=checkpoint_dir + '/d_small_teacher_bicube.npz', sess=sess) evaluate() ## quick evaluation on train set if (epoch % 5 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 30, 30, 3], name='t_image_input_to_generator') t_target_image = tf.placeholder('float32', [batch_size, 120, 120, 3], name='t_target_image') net_g = UMSR_g(t_image, is_train=True, reuse=False) net_g.print_params(False) net_g.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb1, vgg_target_emb2, vgg_target_emb3, vgg_target_emb4, vgg_target_emb5 = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb1, vgg_predict_emb2, vgg_predict_emb3, vgg_predict_emb4, vgg_predict_emb5 = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = UMSR_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb5.outputs, vgg_target_emb5.outputs, is_mean=True) gram_loss1 = 1e-6 * gram_scale_loss1(vgg_target_emb1.outputs,vgg_predict_emb1.outputs) gram_loss2 = 1e-6 * gram_scale_loss2(vgg_target_emb3.outputs,vgg_predict_emb3.outputs) gram_loss = gram_loss1 + gram_loss2 #tf.summary.scalar('loss', mse_loss) g1_loss = mse_loss + vgg_loss + gram_loss g_vars = tl.layers.get_variables_with_name('UMSR_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1, beta2=beta2).minimize(g1_loss, var_list=g_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training #sample_imgs = train_hr_imgs[0:batch_size] sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_120 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) #print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_30 = tl.prepro.threading_data(sample_imgs_120, fn=downsample_fn) #print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_30, [ni, ni], save_dir_ginit + '/_train_sample_30.png') tl.vis.save_images(sample_imgs_120, [ni, ni], save_dir_ginit + '/_train_sample_120.png') ###========================= initialize G ====================### ## fixed learning rate #sess.run(tf.assign(lr_v, lr_init)) for epoch in range(0, n_epoch_init + 1): if epoch != 0 and (epoch % decay_every_init == 0): new_lr_decay_init = lr_decay_init**(epoch // decay_every_init) sess.run(tf.assign(lr_v, lr_init * new_lr_decay_init)) log = " ** new learning rate: %f (for Generator)" % (lr_init * new_lr_decay_init) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for Generator)" % (lr_init, decay_every_init, lr_decay_init) print(log) epoch_time = time.time() total_g1_loss, n_iter = 0, 0 ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_120 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) #in order to get the fix size of inputs to be suitable for the network. b_imgs_30 = tl.prepro.threading_data(b_imgs_120, fn=downsample_fn) ## update G errG1, _ = sess.run([g1_loss, g_optim_init], {t_image: b_imgs_30, t_target_image: b_imgs_120}) print("Epoch [%2d/%2d] %4d time: %4.4fs, g1: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errG1)) total_g1_loss += errG1 n_iter += 1 # tf.summary.scalar('loss', mse_loss) log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g1: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_g1_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 50 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_30}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) if (epoch != 0) and (epoch % 20 == 0): average_lossG1 = total_g1_loss / n_iter f = open('testG1.text', 'a') f.write(str(average_lossG1) + '\n') f.close() ## save model if (epoch != 0) and (epoch % 500 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_%{}_init.npz'.format(tl.global_flag['mode']) % epoch, sess=sess)
def train(train_lr_path, train_hr_path, save_path, save_every_epoch=2, validation=True, ratio=0.9, batch_size=16, lr_init=1e-4, beta1=0.9, n_epoch_init=10, n_epoch=20, lr_decay=0.1): ''' Parameters: data: train_lr_path/train_hr_path: path of data save_path: the parent folder to save model result validation: whether to split data into train set and validation set save_every_epoch: how frequent to save the checkpoints and sample images Adam: batch_size lr_init beta1 Generator Initialization n_epoch_init Adversarial Net n_epoch lr_decay ''' ## Folders to save results save_dir_ginit = os.path.join(save_path, 'srgan_ginit') save_dir_gan = os.path.join(save_path, 'srgan_gan') checkpoint_dir = os.path.join(save_path, 'checkpoint') tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) tl.files.exists_or_mkdir(checkpoint_dir) ###======LOAD DATA======### train_lr_img_list = sorted( tl.files.load_file_list(path=train_lr_path, regx='.*.jpg', printable=False)) train_hr_img_list = sorted( tl.files.load_file_list(path=train_hr_path, regx='.*.jpg', printable=False)) if validation: idx = np.random.choice(len(train_lr_img_list), int(len(train_lr_img_list) * ratio), replace=False) valid_lr_img_list = sorted( [x for i, x in enumerate(train_lr_img_list) if i not in idx]) valid_hr_img_list = sorted( [x for i, x in enumerate(train_hr_img_list) if i not in idx]) train_lr_img_list = sorted( [x for i, x in enumerate(train_lr_img_list) if i in idx]) train_hr_img_list = sorted( [x for i, x in enumerate(train_hr_img_list) if i in idx]) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=train_lr_path, n_threads=32) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=train_hr_path, n_threads=32) train_lr_imgs = tl.vis.read_images(train_lr_img_list, path=train_lr_path, n_threads=32) train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=train_hr_path, n_threads=32) ###======DEFINE MODEL======### ## train inference lr_image = tf.placeholder('float32', [None, 96, 96, 3], name='lr_image') hr_image = tf.placeholder('float32', [None, 192, 192, 3], name='hr_image') net_g = SRGAN_g(lr_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(hr_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) # net_g.print_params(False) # net_g.print_layers() # net_d.print_params(False) # net_d.print_layers() ## resize original hr images for VGG19 hr_image_224 = tf.image.resize_images( hr_image, size=[224, 224], method=0, # BICUBIC align_corners=False) ## generated hr image for VGG19 generated_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, #BICUBIC align_corners=False) ## scale image to [0,1] and get conv characteristics net_vgg, vgg_target_emb = Vgg19_simple_api((hr_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((generated_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(lr_image, is_train=False, reuse=True) ###======DEFINE TRAIN PROCESS======### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 prediction1 = tf.greater(logits_real, tf.fill(tf.shape(logits_real), 0.5)) acc_metric1 = tf.reduce_mean(tf.cast(prediction1, tf.float32)) prediction2 = tf.less(logits_fake, tf.fill(tf.shape(logits_fake), 0.5)) acc_metric2 = tf.reduce_mean(tf.cast(prediction2, tf.float32)) acc_metric = acc_metric1 + acc_metric2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, hr_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + g_gan_loss + vgg_loss psnr_metric = tf.image.psnr(net_g.outputs, hr_image, max_val=2.0, name='psnr') g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.global_variables_initializer()) if tl.files.file_exists(os.path.join(checkpoint_dir, 'g_srgan.npz')): tl.files.load_and_assign_npz(sess=sess, name=os.path.join(checkpoint_dir, 'g_srgan.npz'), network=net_g) else: tl.files.load_and_assign_npz(sess=sess, name=os.path.join(checkpoint_dir, 'g_srgan_init.npz'), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=os.path.join(checkpoint_dir, 'd_srgan.npz'), network=net_d) ###======LOAD VGG======### vgg19_npy_path = '../lib/SRGAN/vgg19.npy' npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###======TRAINING======### ## use train set to have a quick test during training ni = 4 num_sample = ni * ni idx = np.random.choice(len(train_lr_imgs), num_sample, replace=False) sample_imgs_lr = tl.prepro.threading_data( [img for i, img in enumerate(train_lr_imgs) if i in idx], fn=crop_sub_imgs_fn, size=(96, 96), is_random=False) sample_imgs_hr = tl.prepro.threading_data( [img for i, img in enumerate(train_hr_imgs) if i in idx], fn=crop_sub_imgs_fn, size=(192, 192), is_random=False) print('sample LR sub-image:', sample_imgs_lr.shape, sample_imgs_lr.min(), sample_imgs_lr.max()) print('sample HR sub-image:', sample_imgs_hr.shape, sample_imgs_hr.min(), sample_imgs_hr.max()) ## save the images tl.vis.save_images(sample_imgs_lr, [ni, ni], os.path.join(save_dir_ginit, '_train_sample_96.jpg')) tl.vis.save_images(sample_imgs_hr, [ni, ni], os.path.join(save_dir_ginit, '_train_sample_192.jpg')) tl.vis.save_images(sample_imgs_lr, [ni, ni], os.path.join(save_dir_gan, '_train_sample_96.jpg')) tl.vis.save_images(sample_imgs_hr, [ni, ni], os.path.join(save_dir_gan, '_train_sample_192.jpg')) print('finish saving sample images') ###====== initialize G ======### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, total_psnr, n_iter = 0, 0, 0 # random shuffle the train set for each epoch random.shuffle(train_hr_imgs) for idx in range(0, len(train_lr_imgs), batch_size): step_time = time.time() b_imgs_192 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, size=(192, 192), is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_192, fn=downsample_fn, size=(96, 96)) ## update G errM, metricP, _ = sess.run([mse_loss, psnr_metric, g_optim_init], { lr_image: b_imgs_96, hr_image: b_imgs_192 }) print("Epoch [%2d/%2d] %4d time: %4.2fs, mse: %.4f, psnr: %.4f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM, metricP.mean())) total_mse_loss += errM total_psnr += metricP.mean() n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.2fs, mse: %.4f, psnr: %.4f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter, total_psnr / n_iter) print(log) if validation: b_imgs_192_V = tl.prepro.threading_data(valid_hr_imgs, fn=crop_sub_imgs_fn, size=(192, 192), is_random=True) b_imgs_96_V = tl.prepro.threading_data(b_imgs_192_V, fn=downsample_fn, size=(96, 96)) errM_V, metricP_V, _ = sess.run( [mse_loss, psnr_metric, g_optim_init], { lr_image: b_imgs_96_V, hr_image: b_imgs_192_V }) print("Validation | mse: %.4f, psnr: %.4f" % (errM_V, metricP_V.mean())) ## quick evaluation on train set if (epoch != 0) and (epoch % save_every_epoch == 0): out = sess.run(net_g_test.outputs, {lr_image: sample_imgs_lr}) print("[*] save sample images") tl.vis.save_images( out, [ni, ni], os.path.join(save_dir_ginit, 'train_{}.jpg'.format(epoch))) ## save model if (epoch != 0) and (epoch % save_every_epoch == 0): tl.files.save_npz(net_g.all_params, name=os.path.join(checkpoint_dir, 'g_srgan_init.npz'), sess=sess) ###========================= train GAN (SRGAN) =========================### ## Learning rate decay decay_every = int(n_epoch / 2) if int(n_epoch / 2) > 0 else 1 for epoch in range(0, n_epoch + 1): # random shuffle the train set for each epoch random.shuffle(train_hr_imgs) ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, total_mse_loss, total_psnr, total_acc, n_iter = 0, 0, 0, 0, 0, 0 for idx in range(0, len(train_lr_imgs), batch_size): step_time = time.time() b_imgs_192 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, size=(192, 192), is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_192, fn=downsample_fn, size=(96, 96)) ## update D errD, metricA, _ = sess.run([d_loss, acc_metric, d_optim], { lr_image: b_imgs_96, hr_image: b_imgs_192 }) ## update G errG, errM, metricP, _ = sess.run( [g_loss, mse_loss, psnr_metric, g_optim], { lr_image: b_imgs_96, hr_image: b_imgs_192 }) print( "Epoch [%2d/%2d] %4d time: %4.2fs, d_loss: %.4f g_loss: %.4f (mse: %.4f, psnr: %.4f, accuracy: %.4f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, metricP.mean(), metricA / 2)) total_d_loss += errD total_g_loss += errG total_mse_loss += errM total_psnr += metricP.mean() total_acc += metricA / 2 n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.2fs, d_loss: %.4f g_loss: %.4f (mse: %4f, psnr: %.4f, accuracy: %.4f)" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter, total_mse_loss / n_iter, total_psnr / n_iter, total_acc / n_iter) print(log) if validation: b_imgs_192_V = tl.prepro.threading_data(valid_hr_imgs, fn=crop_sub_imgs_fn, size=(192, 192), is_random=True) b_imgs_96_V = tl.prepro.threading_data(b_imgs_192_V, fn=downsample_fn, size=(96, 96)) errM_V, metricP_V, _ = sess.run([mse_loss, psnr_metric, g_optim], { lr_image: b_imgs_96_V, hr_image: b_imgs_192_V }) print("Validation | mse: %.4f, psnr: %.4f" % (errM_V, metricP_V.mean())) ## quick evaluation on train set if (epoch != 0) and (epoch % save_every_epoch == 0): out = sess.run(net_g_test.outputs, {lr_image: sample_imgs_lr}) print("[*] save images") tl.vis.save_images( out, [ni, ni], os.path.join(save_dir_gan, 'train_{}.jpg'.format(epoch))) ## save model if (epoch != 0) and (epoch % save_every_epoch == 0): tl.files.save_npz(net_g.all_params, name=os.path.join(checkpoint_dir, 'g_srgan.npz'), sess=sess) tl.files.save_npz(net_d.all_params, name=os.path.join(checkpoint_dir, 'd_srgan.npz'), sess=sess)