def train(): ## create folders to save result images and trained model save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) #srresnet tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine has enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 56, 56, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 224, 224, 1], name='t_target_image') print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) #SRGAN_g is the SRResNet portion of the GAN print("net_g.outputs:", tf.shape(net_g.outputs)) net_g.print_params(False) net_g.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## SRResNet g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_224 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_224.shape, sample_imgs_224.min(), sample_imgs_224.max()) sample_imgs_56 = tl.prepro.threading_data(sample_imgs_224, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_56.shape, sample_imgs_56.min(), sample_imgs_56.max()) tl.vis.save_images(sample_imgs_56, [ni, ni], save_dir_gan + '/_train_sample_56.png') tl.vis.save_images(sample_imgs_224, [ni, ni], save_dir_gan + '/_train_sample_224.png') #tl.vis.save_image(sample_imgs_96[0], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_image(sample_imgs_384[0],save_dir_gan + '/_train_sample_384.png') ###========================= train SRResNet =========================### merged_summary_generator = tf.summary.merge( [mse_loss_summary, vgg_loss_summary, g_loss_summary]) #g_gan_loss_summary summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_g_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_224 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_56 = tl.prepro.threading_data(b_imgs_224, fn=downsample_fn) summary_pb = tf.summary.Summary() ## update G errG, errM, errV, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_56, t_target_image: b_imgs_224 }) #g_ga_loss summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) print( "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM, errV)) total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time / n_iter, total_g_loss / n_iter) print(log) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) out = sess.run(net_g_test.outputs, {t_image: sample_imgs_56}) print("[*] save images") tl.vis.save_image(out[0], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 3 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is None: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errG, errM, errV, errA, _ = sess.run( [g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def train(): ## create folders to save result images and trained model save_dir_gan = samples_path + "gan" tl.files.exists_or_mkdir(save_dir_gan) tl.files.exists_or_mkdir(checkpoint_path) ###====================== PRE-LOAD DATA ===========================### valid_hr_img_list = sorted( tl.files.load_file_list(path=valid_hr_img_path, regx='.*\.(bmp|png|webp|jpg)', printable=False)) ###========================== DEFINE MODEL ============================### ## train inference sample_t_image = tf.compat.v1.placeholder( 'float32', [sample_batch_size, 96, 96, 3], name='sample_t_image_input_to_SRGAN_generator') t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## test inference net_g_test = SRGAN_g(sample_t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### # MAE Loss mae_loss = tf.reduce_mean(tf.map_fn(tf.abs, t_target_image - net_g.outputs)) # GAN Loss d_loss = 0.5 * ( tf.reduce_mean( tf.square(logits_real - tf.reduce_mean(logits_fake) - 1)) + tf.reduce_mean( tf.square(logits_fake - tf.reduce_mean(logits_real) + 1))) g_gan_loss = 0.5 * ( tf.reduce_mean( tf.square(logits_real - tf.reduce_mean(logits_fake) + 1)) + tf.reduce_mean( tf.square(logits_fake - tf.reduce_mean(logits_real) - 1))) g_loss = 1e-1 * g_gan_loss + mae_loss d_real = tf.reduce_mean(logits_real) d_fake = tf.reduce_mean(logits_fake) with tf.variable_scope('learning_rate'): learning_rate_var = tf.Variable(learning_rate, trainable=False) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) ## SRGAN g_optim = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate_var).minimize(g_loss, var_list=g_vars) d_optim = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate_var).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.variables_initializer(tf.global_variables())) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_path + 'g.npz', network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_path + 'd.npz', network=net_d) ###============================= TRAINING ===============================### sample_imgs = tl.prepro.threading_data( valid_hr_img_list[0:sample_batch_size], fn=get_imgs_fn, path=valid_hr_img_path) sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) save_images(sample_imgs_96, [ni, ni], save_file_format, save_dir_gan + '/_train_sample_96') save_images(sample_imgs_384, [ni, ni], save_file_format, save_dir_gan + '/_train_sample_384') ###========================= train GAN =========================### sess.run(tf.assign(learning_rate_var, learning_rate)) for epoch in range(0, n_epoch_gan + 1): epoch_time = time.time() total_d_loss, total_g_loss_mae, total_g_loss_gan, n_iter = 0, 0, 0, 0 train_hr_img_list = load_deep_file_list(path=train_hr_img_path, regx='.*\.(bmp|png|webp|jpg)', recursive=True, printable=False) random.shuffle(train_hr_img_list) list_length = len(train_hr_img_list) print("Number of images: %d" % (list_length)) if list_length % batch_size != 0: train_hr_img_list += train_hr_img_list[0:batch_size - list_length % batch_size:1] list_length = len(train_hr_img_list) print("Length of list: %d" % (list_length)) for idx in range(0, list_length, batch_size): step_time = time.time() b_imgs_list = train_hr_img_list[idx:idx + batch_size] b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=train_hr_img_path) b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_data_augment_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) b_imgs_384 = tl.prepro.threading_data(b_imgs_384, fn=rescale_m1p1) ## update D errD, d_r, d_f, _ = sess.run([d_loss, d_real, d_fake, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G errM, errA, _, _ = sess.run( [mae_loss, g_gan_loss, g_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) print( "Epoch[%2d/%2d] %4d time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f d_r: %.8f d_f: %.8f" % (epoch, n_epoch_gan, n_iter, time.time() - step_time, errD, errM, errA, d_r, d_f)) total_d_loss += errD total_g_loss_mae += errM total_g_loss_gan += errA n_iter += 1 log = ( "[*] Epoch[%2d/%2d] time: %4.2fs d_loss: %.8f g_loss_mae: %.8f g_loss_gan: %.8f" % (epoch, n_epoch_gan, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss_mae / n_iter, total_g_loss_gan / n_iter)) print(log) ## quick evaluation on train set out = sess.run(net_g_test.outputs, {sample_t_image: sample_imgs_96}) print("[*] save images") save_images(out, [ni, ni], save_file_format, save_dir_gan + '/train_%d' % epoch) ## save model tl.files.save_npz(net_g.all_params, name=checkpoint_path + 'g.npz', sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_path + 'd.npz', sess=sess)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) ## update G errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted( tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted( tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. print("reading images") train_hr_imgs = [] #[None] * len(train_hr_img_list) #sess = tf.Session() for img__ in train_hr_img_list: image_loaded = scipy.misc.imread(os.path.join(config.TRAIN.hr_img_path, img__), mode='L') image_loaded = image_loaded.reshape( (image_loaded.shape[0], image_loaded.shape[1], 1)) train_hr_imgs.append(image_loaded) print(type(train_hr_imgs), len(train_hr_img_list)) ###========================== DEFINE MODEL ============================### ## train inference #t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') #t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') t_image = tf.placeholder('float32', [batch_size, 28, 224, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder( 'float32', [batch_size, 224, 224, 1], name='t_target_image' ) # may have to convert 224x224x1 into 224x224x3, with channel 1 & 2 as 0. May have to have separate place-holder ? print("t_image:", tf.shape(t_image)) print("t_target_image:", tf.shape(t_target_image)) net_g = SRGAN_g(t_image, is_train=True, reuse=False) print("net_g.outputs:", tf.shape(net_g.outputs)) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False ) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer t_predict_image_224 = tf.image.resize_images( net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg ## Added as VGG works for RGB and expects 3 channels. t_target_image_224 = tf.image.grayscale_to_rgb(t_target_image_224) t_predict_image_224 = tf.image.grayscale_to_rgb(t_predict_image_224) print("net_g.outputs:", tf.shape(net_g.outputs)) print("t_predict_image_224:", tf.shape(t_predict_image_224)) net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) vgg_loss = 2e-6 * tl.cost.mean_squared_error( vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + vgg_loss + g_gan_loss d_loss1_summary = tf.summary.scalar('Disciminator logits_real loss', d_loss1) d_loss2_summary = tf.summary.scalar('Disciminator logits_fake loss', d_loss2) d_loss_summary = tf.summary.scalar('Disciminator total loss', d_loss) g_gan_loss_summary = tf.summary.scalar('Generator GAN loss', g_gan_loss) mse_loss_summary = tf.summary.scalar('Generator MSE loss', mse_loss) vgg_loss_summary = tf.summary.scalar('Generator VGG loss', vgg_loss) g_loss_summary = tf.summary.scalar('Generator total loss', g_loss) g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain # UNCOMMENT THE LINE BELOW!!! #g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) #if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: # tl.fites.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) #tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print( "Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg" ) exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set print("sample_imgs size:", len(sample_imgs), sample_imgs[0].shape) sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn_mod) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') #tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') #tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ''' ###========================= initialize G ====================### merged_summary_initial_G = tf.summary.merge([mse_loss_summary]) summary_intial_G_writer = tf.summary.FileWriter("./log/train/initial_G") ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) count = 0 for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. intial_MSE_G_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update G errM, _, mse_summary_initial_G = sess.run([mse_loss, g_optim_init, merged_summary_initial_G], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(mse_summary_initial_G) intial_G_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. intial_G_summaries[val.tag] = val.simple_value #print("intial_G_summaries:", intial_G_summaries) intial_MSE_G_summary_per_epoch.append(intial_G_summaries['Generator_MSE_loss']) #summary_intial_G_writer.add_summary(mse_summary_initial_G, (count + 1)) #(epoch + 1)*(n_iter+1)) #count += 1 total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) summary_intial_G_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag="Generator_Initial_MSE_loss per epoch", simple_value=np.mean(intial_MSE_G_summary_per_epoch)),]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") for im in range(len(out)): if(im%4==0 or im==1197): tl.vis.save_image(out[im], save_dir_ginit + '/train_%d_%d.png' % (epoch,im)) ## save model saver=tf.train.Saver() if (epoch%10==0 and epoch!=0): saver.save(sess, 'checkpoint/init_'+str(epoch)+'.ckpt') #if (epoch != 0) and (epoch % 10 == 0): #tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch), sess=sess) ''' ###========================= train GAN (SRGAN) =========================### saver = tf.train.Saver() saver.restore(sess, 'checkpoint/main_10.ckpt') print('Restored main_10, begin 11/50') merged_summary_discriminator = tf.summary.merge( [d_loss1_summary, d_loss2_summary, d_loss_summary]) summary_discriminator_writer = tf.summary.FileWriter( "./log/train/discriminator") merged_summary_generator = tf.summary.merge([ g_gan_loss_summary, mse_loss_summary, vgg_loss_summary, g_loss_summary ]) summary_generator_writer = tf.summary.FileWriter("./log/train/generator") learning_rate_writer = tf.summary.FileWriter("./log/train/learning_rate") count = 0 for epoch in range(11, n_epoch + 11): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=(lr_init * new_lr_decay)), ]), (epoch)) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) learning_rate_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Learning_rate per epoch", simple_value=lr_init), ]), (epoch)) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. loss_per_batch = [] d_loss1_summary_per_epoch = [] d_loss2_summary_per_epoch = [] d_loss_summary_per_epoch = [] g_gan_loss_summary_per_epoch = [] mse_loss_summary_per_epoch = [] vgg_loss_summary_per_epoch = [] g_loss_summary_per_epoch = [] for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn_mod) ## update D errD, _, discriminator_summary = sess.run( [d_loss, d_optim, merged_summary_discriminator], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(discriminator_summary) #print("discriminator_summary", summary_pb, type(summary_pb)) discriminator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. discriminator_summaries[val.tag] = val.simple_value d_loss1_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_real_loss']) d_loss2_summary_per_epoch.append( discriminator_summaries['Disciminator_logits_fake_loss']) d_loss_summary_per_epoch.append( discriminator_summaries['Disciminator_total_loss']) ## update G errG, errM, errV, errA, _, generator_summary = sess.run( [ g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim, merged_summary_generator ], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) summary_pb = tf.summary.Summary() summary_pb.ParseFromString(generator_summary) #print("generator_summary", summary_pb, type(summary_pb)) generator_summaries = {} for val in summary_pb.value: # Assuming all summaries are scalars. generator_summaries[val.tag] = val.simple_value #print("generator_summaries:", generator_summaries) g_gan_loss_summary_per_epoch.append( generator_summaries['Generator_GAN_loss']) mse_loss_summary_per_epoch.append( generator_summaries['Generator_MSE_loss']) vgg_loss_summary_per_epoch.append( generator_summaries['Generator_VGG_loss']) g_loss_summary_per_epoch.append( generator_summaries['Generator_total_loss']) #summary_generator_writer.add_summary(generator_summary, (count + 1)) #summary_total = sess.run(summary_total_merged, {t_image: b_imgs_96, t_target_image: b_imgs_384}) #summary_total_merged_writer.add_summary(summary_total, (count + 1)) #count += 1 tot_epoch = n_epoch + 10 print( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, tot_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 #remove this for normal running: log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, tot_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ##### # # logging discriminator summary # ###### # logging per epcoch summary of logit_real_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_real_loss per epoch", simple_value=np.mean( d_loss1_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of logit_fake_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_logits_fake_loss per epoch", simple_value=np.mean( d_loss2_summary_per_epoch)), ]), (epoch)) # logging per epcoch summary of total_loss per epoch. Value logged is averaged across batches used per epoch. summary_discriminator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Disciminator_total_loss per epoch", simple_value=np.mean( d_loss_summary_per_epoch)), ]), (epoch)) ##### # # logging generator summary # ###### summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_GAN_loss per epoch", simple_value=np.mean( g_gan_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_MSE_loss per epoch", simple_value=np.mean( mse_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_VGG_loss per epoch", simple_value=np.mean( vgg_loss_summary_per_epoch)), ]), (epoch)) summary_generator_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="Generator_total_loss per epoch", simple_value=np.mean( g_loss_summary_per_epoch)), ]), (epoch)) ## quick evaluation on train set #if (epoch != 0) and (epoch % 10 == 0): out = sess.run( net_g_test.outputs, {t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) ## save model if (epoch % 10 == 0 and epoch != 0): saver.save(sess, 'checkpoint/main_' + str(epoch) + '.ckpt') print("[*] save images") for im in range(len(out)): tl.vis.save_image( out[im], save_dir_gan + '/train_%d_%d.png' % (epoch, im))
def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) save_dir_valid = "samples/{}_valid".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) tl.files.exists_or_mkdir(save_dir_valid) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) train_hr_imgs = read_csv_data(config.TRAIN.hr_img_path, width=48, height=48, channel=1) valid_hr_imgs = read_csv_data(config.VALID.hr_img_path, width=48, height=48, channel=1) ###========================== DEFINE MODEL ============================### ## train inference ## t = train t_image = tf.placeholder('float32', [None, 16, 16, 1], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [None, 48, 48, 1], name='t_target_image') net_g = SRGAN_g(t_image, is_train=True, reuse=False, nb_block=16) net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### # d_loss: for discriminator d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 # g_loss: for generator g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy( logits_fake, tf.ones_like(logits_fake), name='g') mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) # vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) g_loss = mse_loss + g_gan_loss # g_loss = mse_loss + vgg_loss + g_gan_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:9] valid_imgs = valid_hr_imgs[44:53] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) valid_imgs_48 = tl.prepro.threading_data(valid_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn, down_rate=3) valid_imgs_16 = tl.prepro.threading_data(valid_imgs_48, fn=downsample_fn, down_rate=3) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_16.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_48.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_16.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_48.png') tl.vis.save_images(valid_imgs_48, [ni, ni], save_dir_valid + '/_valid_sample_48.png') tl.vis.save_images(valid_imgs_16, [ni, ni], save_dir_valid + '/_valid_sample_16.png') sample_hr_imgs_bicubic = tl.prepro.threading_data(sample_imgs_96, fn=upsample_fn, up_rate=3) valid_hr_imgs_bicubic = tl.prepro.threading_data(valid_imgs_16, fn=upsample_fn, up_rate=3) tl.vis.save_images(sample_hr_imgs_bicubic, [ni, ni], save_dir_ginit + '/_sample_bicubic_48.png') tl.vis.save_images(valid_hr_imgs_bicubic, [ni, ni], save_dir_valid + '/_valid_sample_bicubic_48.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) train_writer_path = "./log/train" tl.files.exists_or_mkdir(train_writer_path) train_writer = tf.summary.FileWriter(train_writer_path, graph=tf.get_default_graph()) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): if idx + batch_size > len(train_hr_imgs): break step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) sys.stdout.write( "Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f \r" % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) sys.stdout.flush() total_mse_loss += errM n_iter += 1 log = "\n[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## quick evaluation on validation set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: valid_imgs_16 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) tl.vis.save_images(out, [ni, ni], save_dir_valid + '/valid_ganit_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_init_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) ## update G # errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) errG, errM, errA, _ = sess.run( [g_loss, mse_loss, g_gan_loss, g_optim], { t_image: b_imgs_96, t_target_image: b_imgs_384 }) # sys.stdout.write("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)\n" % # (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) sys.stdout.write( "Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f adv: %.6f) \r" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errA)) sys.stdout.flush() total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % ( epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: sample_imgs_96 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## quick evaluation on validation set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, { t_image: valid_imgs_16 }) #; print('gen sub-image:', out.shape, out.min(), out.max()) tl.vis.save_images(out, [ni, ni], save_dir_valid + '/valid_gan_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess) tl.files.save_npz( net_d.all_params, name=checkpoint_dir + '/d_{}_{}.npz'.format(tl.global_flag['mode'], epoch), sess=sess)
def train(): # 创建一个文件夹保存训练好的模型 save_dir_ginit = "samples/facenet_pgd_loss_ginit".format( tl.global_flag['mode']) save_dir_gan = "samples/facenet_pgd_loss_gan".format( tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint/facenet_pgd_loss_12.19" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) # 加载训练集数据 train_hr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) train_lr_img_list = sorted( tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # 加载facenet参考样本数据集 # train_reference_img_list = sorted( # tl.files.load_file_list(path=config.TRAIN.reference_img_path, regx='.*.png', printable=False)) train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) train_lr_imgs = tl.vis.read_images(train_lr_img_list, path=config.TRAIN.lr_img_path, n_threads=32) # train_reference_imgs = tl.vis.read_images(train_reference_img_list, path=config.TRAIN.reference_img_path, # n_threads=32) t_image = tf.placeholder('float32', [batch_size, 160, 160, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.placeholder('float32', [batch_size, 160, 160, 3], name='t_target_image') phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train') # softmax_output1 = tf.placeholder('float32') # softmax_output2 = tf.placeholder('float32') # 定义模型 net_g = SRGAN_g(t_image, is_train=True, reuse=False) net_g.print_params(False) net_g.print_layers() # 加载vggface模型 data1 = loadmat('vgg-face.mat') # # # resize成vggface可以接受的图像尺寸 # t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0, align_corners=False) # # t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # out_160 = tf.image.resize_images(net_g.outputs, size=[160, 160], method=0, align_corners=False) # t_target_image_160 = tf.image.resize_images(t_target_image, size=[160, 160], method=0, align_corners=False) out_160 = prewhitenfacenet(net_g.outputs) t_target_image_160 = prewhitenfacenet(t_target_image) image_batch1 = tf.identity(out_160, 'input') image_batch2 = tf.identity(t_target_image_160, 'input') # facenet_target_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0") # #facenet_reference_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0") # facenet_predict_emb2 = tf.get_default_graph().get_tensor_by_name("embeddings:0") # net_vgg, vgg_target_emb, vgg_relu_emb = vgg_face_api(data1, (t_target_image_224 + 1) / 2) # _, vgg_predict_emb, vgg_predict_relu_emb = vgg_face_api(data1, (t_predict_image_224 + 1) / 2) # predicted_out = tf.nn.l2_normalize(vgg_predict_relu_emb, 1, 1e-10, name='embeddings') # print(predicted_out) # predicted_target_out = tf.nn.l2_normalize(vgg_relu_emb, 1, 1e-10, name='embeddings') # net_vgg, vgg_target_emb, vgg_target_emb2 = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) # _, vgg_predict_emb, vgg_predict_emb2 = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) model_def = 'inception_resnet_v1_new' network = importlib.import_module(model_def) # print('Building training graph') # Build the inference graph prelogits1, _, texture_emb1 = network.inference( image_batch1, keep_probability, phase_train=phase_train_placeholder, bottleneck_layer_size=embedding_size, weight_decay=weight_decay, reuse=tf.AUTO_REUSE) prelogits2, _, texture_emb2 = network.inference( image_batch2, keep_probability, phase_train=phase_train_placeholder, bottleneck_layer_size=embedding_size, weight_decay=weight_decay, reuse=tf.AUTO_REUSE) # logits = slim.fully_connected(prelogits, len(train_set), activation_fn=None, # weights_initializer=slim.initializers.xavier_initializer(), # weights_regularizer=slim.l2_regularizer(args.weight_decay), # scope='Logits', reuse=False) embeddings1 = tf.nn.l2_normalize(prelogits1, 1, 1e-10, name='embeddings') embeddings2 = tf.nn.l2_normalize(prelogits2, 1, 1e-10, name='embeddings') # test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # distance1 = tf.reduce_sum(tf.square(facenet_target_emb2 - facenet_predict_emb2)) # softmax_output1 = softmax(distance1) # distance2 = tf.reduce_sum(tf.square(facenet_target_emb2 - facenet_target_emb2)) # softmax_output2 = softmax(distance2) #softmax2 = convert_to_softmax(facenet_reference_emb2, facenet_reference_emb2) #print(distance) # ###========================== DEFINE TRAIN OPS ==========================### mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) # mse_loss = tl.cost.mean_squared_error(predicted_out, predicted_target_out, is_mean=True) vgg_loss = 1e4 * tl.cost.mean_squared_error( texture_emb1, texture_emb2, is_mean=True) # distance1 = 1/batch_size * tf.reduce_sum(tf.square(embeddings1 - embeddings2)) distance1 = tf.reduce_sum(tf.square(embeddings1 - embeddings2), axis=1) softmax_output_value1 = tf.transpose(softmax(distance1)) # distance2 = 1/batch_size * tf.reduce_sum(tf.square(embeddings2 - embeddings2)) distance2 = tf.reduce_sum(tf.square(embeddings2 - embeddings2), axis=1) softmax_output_value2 = tf.transpose(softmax(distance2)) # softmax_loss = 1e-3 * tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=softmax_output_value2, # logits=softmax_output_value1)) index = tf.arg_max(softmax_output_value2, 1) label_mask = tf.one_hot(index, 2, on_value=1.0, off_value=0.0, dtype=tf.float32) softmax_loss = 1e2 * tf.reduce_mean( -tf.reduce_sum(label_mask * tf.log(softmax_output_value1), 1)) # softmax_loss = 1e2 * tl.cost.mean_squared_error(embeddings1, embeddings2, is_mean=True) # 生成器损失 g_loss = mse_loss + vgg_loss + softmax_loss g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) inception_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='InceptionResnetV1') saver = tf.train.Saver(inception_vars, max_to_keep=3) # 前100轮的初始化只优化Mse损失 g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize( mse_loss, var_list=g_vars) # SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) # 模型恢复 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) pretrained_model = '/home/fan/facenet_adversarial_faces/models/facenet/20170512-110547/' if pretrained_model: print('Restoring pretrained model: %s' % pretrained_model) # facenet.load_model(pretrained_model) model_exp = os.path.expanduser(pretrained_model) print('Model directory: %s' % model_exp) _, ckpt_file = facenet.get_model_filenames(model_exp) # print('Metagraph file: %s' % meta_file) print('Checkpoint file: %s' % ckpt_file) saver.restore(sess, os.path.join(model_exp, ckpt_file)) if tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz( sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) for var in tf.trainable_variables(): print(var.name) # 开始训练 sample_imgs_h = train_hr_imgs[0:batch_size] sample_imgs_l = train_lr_imgs[0:batch_size] sample_imgs_h = tl.prepro.threading_data(sample_imgs_h, fn=retain, is_random=False) print('sample HR sub-image:', sample_imgs_h.shape, sample_imgs_h.min(), sample_imgs_h.max()) sample_imgs_l = tl.prepro.threading_data(sample_imgs_l, fn=retain, is_random=False) print('sample LR sub-image:', sample_imgs_l.shape, sample_imgs_l.min(), sample_imgs_l.max()) tl.vis.save_images(sample_imgs_l, [ni, ni], save_dir_ginit + '/_train_sample_l.png') tl.vis.save_images(sample_imgs_h, [ni, ni], save_dir_ginit + '/_train_sample_h.png') tl.vis.save_images(sample_imgs_l, [ni, ni], save_dir_gan + '/_train_sample_l.png') tl.vis.save_images(sample_imgs_h, [ni, ni], save_dir_gan + '/_train_sample_h.png') # 初始化生成器 # fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() x_imgs = [1] * batch_size for i in range(0, batch_size): x_imgs[i] = np.concatenate( [train_hr_imgs[idx + i], train_lr_imgs[idx + i]], axis=2) b_imgs = tl.prepro.threading_data(x_imgs, fn=retain, is_random=True) b_imgs_h = b_imgs[:, :, :, 0:3] b_imgs_l = b_imgs[:, :, :, 3:6] # update G errM, _ = sess.run([mse_loss, g_optim_init], { t_image: b_imgs_l, t_target_image: b_imgs_h }) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f\n" % ( epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) f = open('log_init.txt', 'a') f.write(log) f.close() # quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_l}) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) # save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz( net_g.all_params, name=checkpoint_dir + '/g_{}_init_softmax.npz'.format(tl.global_flag['mode']), sess=sess) # 开始训练GAN网络 for epoch in range(0, n_epoch + 1): # update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % ( lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, total_mse_loss, total_vgg_loss, total_adv_loss, total_vgg_loss2, n_iter = 0, 0, 0, 0, 0, 0, 0 for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() x_imgs = [1] * batch_size for i in range(0, batch_size): x_imgs[i] = np.concatenate( [train_hr_imgs[idx + i], train_lr_imgs[idx + i]], axis=2) b_imgs = tl.prepro.threading_data(x_imgs, fn=retain, is_random=True) b_imgs_h = b_imgs[:, :, :, 0:3] b_imgs_l = b_imgs[:, :, :, 3:6] # update G errG, errM, errV, errV2, _ = sess.run( [g_loss, mse_loss, vgg_loss, softmax_loss, g_optim], { t_image: b_imgs_l, t_target_image: b_imgs_h, phase_train_placeholder: False }) print( "Epoch [%2d/%2d] %4d time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f facenet: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errG, errM, errV, errV2)) total_g_loss += errG total_mse_loss += errM total_vgg_loss += errV total_vgg_loss2 += errV2 n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, g_loss: %.8f (mse: %.6f vgg: %.6f facenet: %.6f)\n" \ % (epoch, n_epoch, time.time() - epoch_time, total_g_loss / n_iter, total_mse_loss / n_iter, total_vgg_loss / n_iter, total_vgg_loss2 / n_iter) print(log) f = open('log.txt', 'a') f.write(log) f.close() # quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_l}) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) # save model if (epoch != 0) and (epoch % 50 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_srgan_softmax%d.npz' % epoch, sess=sess) sess.close()