def main(): dataset = 'datasets' model_name = 'FUNIT' os.makedirs(os.path.join('experiments', model_name, 'checkpoints'), exist_ok=True) log_dir = os.path.join('logs', model_name) os.makedirs(log_dir, exist_ok=True) validation_output_dir = 'sample' os.makedirs(validation_output_dir, exist_ok=True) classes = os.listdir(dataset) num_classes = len(classes) print(num_classes) img_size = 128 btGen = BatchGenerator(img_size=img_size, imgdir=dataset, num_classes=num_classes) num_iterations = 100000 mini_batch_size = 8 generator_learning_rate = 0.00010 discriminator_learning_rate = 0.00010 lambda_fm = 1 lambda_rec = 0.1 model = FUNIT(img_size=img_size, num_classes=num_classes, batch_size=mini_batch_size, rec_weight=lambda_rec, feature_weight=lambda_fm, log_dir=log_dir) ckpt = tf.train.get_checkpoint_state( os.path.join('experiments', model_name, 'checkpoints')) if ckpt: #last_model = ckpt.all_model_checkpoint_paths[1] last_model = ckpt.model_checkpoint_path print("loading {}".format(last_model)) model.load(filepath=last_model) else: print("checkpoints are not found") iteration = 1 while iteration <= num_iterations: generator_learning_rate *= 0.99999 discriminator_learning_rate *= 0.99999 cont_img, cont_label, cls_img, cls_label = btGen.getBatch( mini_batch_size) # to One-hot cont_labels = np.zeros([mini_batch_size, num_classes]) cls_labels = np.zeros([mini_batch_size, num_classes]) for b in range(mini_batch_size): cont_labels[b] = np.identity(num_classes)[cont_label[b]] cls_labels[b] = np.identity(num_classes)[cls_label[b]] gen_loss, dis_loss = model.train( content_image=cont_img, class_image=cls_img, content_label=cont_labels, class_label=cls_labels, discriminator_learning_rate=discriminator_learning_rate, generator_learning_rate=generator_learning_rate) print( 'Iteration: {:07d}, Generator Loss : {:.3f}, Discriminator Loss : {:.3f}' .format(iteration, gen_loss, dis_loss)) if iteration % 5000 == 0: print('Checkpointing...') model.save(directory=os.path.join('experiments', model_name, 'checkpoints'), filename='{}_{}.ckpt'.format(model_name, iteration)) if iteration % 100 == 0 or iteration == 1: cont_img, cont_label, cls_img, cls_label = btGen.getBatch( mini_batch_size) for b in range(mini_batch_size): cont_labels[b] = np.identity(num_classes)[cont_label[b]] cls_labels[b] = np.identity(num_classes)[cls_label[b]] gen_img = model.test(cont_img, cls_img) gen_img = np.array(gen_img) gen_img = np.squeeze(gen_img) print(gen_img.shape) contTiled = tileImage(cont_img) clsTiled = tileImage(cls_img) genTiled = tileImage(gen_img) out = np.concatenate([contTiled, clsTiled, genTiled], axis=1) out = (out + 1) * 127.5 print(out.shape) cv2.imwrite( "{}/{:07}.png".format(validation_output_dir, iteration), out) iteration += 1
def main(): if not os.path.exists(SAVE_DIR): os.mkdir(SAVE_DIR) if not os.path.exists(SVIM_DIR): os.mkdir(SVIM_DIR) img_size = 128 bs = 32 z_dim = 64 critic = 3 lmd = 10 datalen = foloderLength(DATASET_DIR) # loading images on training batch = BatchGenerator(img_size=img_size, imgdir=DATASET_DIR) id = np.random.choice(range(datalen), bs) IN_ = batch.getBatch(bs, id)[:4] IN_ = (IN_ + 1) * 127.5 IN_ = tileImage(IN_) cv2.imwrite("{}/input.png".format(SVIM_DIR), IN_) z = tf.placeholder(tf.float32, [bs, z_dim]) X_real = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) X_fake = buildGenerator(z, z_dim=z_dim, img_size=img_size, nBatch=bs) fake_y = buildDiscriminator(y=X_fake, nBatch=bs, isTraining=True) real_y = buildDiscriminator(y=X_real, nBatch=bs, reuse=True, isTraining=True) d_loss_real = -tf.reduce_mean(real_y) d_loss_fake = tf.reduce_mean(fake_y) g_loss = -tf.reduce_mean(fake_y) epsilon = tf.random_uniform(shape=[bs, 1, 1, 1], minval=0., maxval=1.) X_hat = X_real + epsilon * (X_fake - X_real) D_X_hat = buildDiscriminator(X_hat, nBatch=bs, reuse=True, isTraining=False) grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), axis=[1, 2, 3])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2) wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Generator") wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Discriminator") wd_g = tf.reduce_sum(wd_g) wd_d = tf.reduce_sum(wd_d) d_loss = d_loss_real + d_loss_fake + lmd * gradient_penalty + wd_d d_loss += 0.001 * tf.reduce_mean(tf.square(d_loss_real - 0.0)) g_loss = g_loss + wd_g g_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize( g_loss, var_list=[ x for x in tf.trainable_variables() if "Generator" in x.name ]) d_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize( d_loss, var_list=[ x for x in tf.trainable_variables() if "Discriminator" in x.name ]) printParam(scope="Generator") printParam(scope="Discriminator") start = time.time() config = tf.ConfigProto(gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=0.66)) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() summary = tf.summary.merge_all() ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if ckpt: # checkpointがある場合 last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス print("load " + last_model) saver.restore(sess, last_model) # 変数データの読み込み print("succeed restore model") else: print("models were not found") init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) g_hist = [] d_hist = [] start = time.time() stable = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32) for i in range(100001): # loading images on training for c in range(critic): id = np.random.choice(range(datalen), bs) batch_images = batch.getBatch(bs, id) batch_z = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32) _, dis_loss = sess.run([d_opt, d_loss], feed_dict={ z: batch_z, X_real: batch_images }) id = np.random.choice(range(datalen), bs) batch_images_x = batch.getBatch(bs, id) batch_z = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32) _, gen_loss = sess.run([g_opt, g_loss], feed_dict={ z: batch_z, X_real: batch_images }) print("in step %s, dis_loss = %.4e, gen_loss = %.4e" % (i, dis_loss, gen_loss)) g_hist.append(gen_loss) d_hist.append(dis_loss) if i % 100 == 0: batch_z = np.random.uniform(-1., +1., [bs, z_dim]).astype(np.float32) g_image = sess.run(X_fake, feed_dict={z: batch_z}) cv2.imwrite(os.path.join(SVIM_DIR, "img_%d_fake.png" % i), tileImage(g_image) * 127. + 127.5) g_image = sess.run(X_fake, feed_dict={z: stable}) cv2.imwrite(os.path.join(SVIM_DIR, "imgst_%d_fake.png" % i), tileImage(g_image) * 127. + 127.5) fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") #plt.yscale("log") ax.plot(g_hist, label="gen_loss", linewidth=0.25) ax.plot(d_hist, label="dis_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist.png") plt.close() print("%.4e sec took 100steps" % (time.time() - start)) start = time.time() if i % 1000 == 0: saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main(): img_size = 96 bs = 4 val_size = 4 trans_lr = 1e-4 start = time.time() batchgen = BatchGenerator(img_size=img_size, LRDir=TRAIN_LR_DIR, HRDir=TRAIN_HR_DIR, aug=True) valgen = BatchGenerator(img_size=img_size, LRDir=VAL_LR_DIR, HRDir=VAL_HR_DIR, aug=False) IN_, OUT_ = batchgen.getBatch(4) IN_ = tileImage(IN_) IN_ = cv2.resize(IN_, (img_size * 2 * 4, img_size * 2 * 4), interpolation=cv2.INTER_CUBIC) IN_ = (IN_ + 1) * 127.5 OUT_ = tileImage(OUT_) OUT_ = cv2.resize(OUT_, (img_size * 4 * 2, img_size * 4 * 2)) OUT_ = (OUT_ + 1) * 127.5 Z_ = np.concatenate((IN_, OUT_), axis=1) cv2.imwrite("input.png", Z_) print("%s sec took sampling" % (time.time() - start)) start = time.time() x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3]) lr = tf.placeholder(tf.float32) y = buildSRGAN_g(x) test_y = buildSRGAN_g(x, reuse=True, isTraining=False) fake_y = buildSRGAN_d(y) real_y = buildSRGAN_d(t, reuse=True) vgg_y1, vgg_y2, vgg_y3, vgg_y4, vgg_y5 = vgg19(y) vgg_t1, vgg_t2, vgg_t3, vgg_t4, vgg_t5 = vgg19(t, reuse=True) d_loss_real = tf.log((real_y) + 1e-10) d_loss_fake = tf.log(1 - (fake_y) + 1e-10) g_loss_fake = tf.reduce_mean(-tf.log((fake_y) + 1e-10)) * 2e-3 wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Generator") wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Discriminator") wd_g = tf.reduce_sum(wd_g) wd_d = tf.reduce_sum(wd_d) L1_loss = tf.reduce_mean(tf.square(y - t)) e_1 = tf.reduce_mean(tf.square(vgg_y1 - vgg_t1)) * 2.8 e_2 = tf.reduce_mean(tf.square(vgg_y2 - vgg_t2)) * 0.2 e_3 = tf.reduce_mean(tf.square(vgg_y3 - vgg_t3)) * 0.08 e_4 = tf.reduce_mean(tf.square(vgg_y4 - vgg_t4)) * 0.2 e_5 = tf.reduce_mean(tf.square(vgg_y5 - vgg_t5)) * 75.0 vgg_loss = (e_1 + e_2 + e_3 + e_4 + e_5) * 2e-7 pre_loss = L1_loss + vgg_loss + wd_g g_loss = L1_loss + vgg_loss + g_loss_fake + wd_g d_loss = tf.reduce_mean(-(d_loss_fake + d_loss_real)) + wd_d g_pre = tf.train.AdamOptimizer(1e-4, beta1=0.5).minimize( pre_loss, var_list=[x for x in tf.trainable_variables() if "SRGAN_g" in x.name]) g_opt = tf.train.AdamOptimizer(lr, beta1=0.5).minimize( g_loss, var_list=[x for x in tf.trainable_variables() if "SRGAN_g" in x.name]) d_opt = tf.train.AdamOptimizer(lr / 2, beta1=0.5).minimize( d_loss, var_list=[x for x in tf.trainable_variables() if "SRGAN_d" in x.name]) print("%.4f sec took building" % (time.time() - start)) printParam(scope="SRGAN_g") printParam(scope="SRGAN_d") printParam(scope="vgg19") g_vars = [x for x in tf.trainable_variables() if "SRGAN_g" in x.name] d_vars = [x for x in tf.trainable_variables() if "SRGAN_d" in x.name] vgg_vars = [x for x in tf.trainable_variables() if "vgg19" in x.name] saver = tf.train.Saver() saver_vgg = tf.train.Saver(vgg_vars) sess = tf.Session() sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if ckpt: # is checkpoint exist last_model = ckpt.model_checkpoint_path #last_model = ckpt.all_model_checkpoint_paths[0] print("load " + last_model) saver.restore(sess, last_model) # read variable data print("succeed restore model") else: init = tf.global_variables_initializer() sess.run(init) ckpt_vgg = tf.train.get_checkpoint_state('modelvgg') last_model = ckpt_vgg.model_checkpoint_path saver_vgg.restore(sess, last_model) print("%.4e sec took initializing" % (time.time() - start)) hist = [] hist_g = [] hist_d = [] start = time.time() print("start pretrain") for p in range(50001): batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, gen_loss, L1, vgg = sess.run([g_pre, pre_loss, L1_loss, vgg_loss], feed_dict={ x: batch_images_x, t: batch_images_t }) hist.append(gen_loss) print("in step %s, pre_loss =%.4e, L1_loss=%.4e, vgg_loss=%.4e" % (p, gen_loss, L1, vgg)) if p % 100 == 0: batch_images_x, batch_images_t = batchgen.getBatch(bs) out = sess.run(test_y, feed_dict={x: batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4), interpolation=cv2.INTER_CUBIC) X_ = (X_ + 1) * 127.5 Y_ = (Y_ + 1) * 127.5 Z_ = (Z_ + 1) * 127.5 Z_ = np.concatenate((X_, Y_, Z_), axis=1) cv2.imwrite("{}/pre_{}.png".format(SAVEIM_DIR, p), Z_) fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist, label="gen_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist_pre.png") plt.close() print("%.4e sec took 100steps" % (time.time() - start)) start = time.time() if p % 5000 == 0 and p != 0: saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p) print("start Discriminator") for d in range(0): batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, dis_loss = sess.run([ d_opt, d_loss, ], feed_dict={ x: batch_images_x, t: batch_images_t, lr: 1e-4, }) print("in step %s, dis_loss = %.4e" % (d, dis_loss)) print("start GAN") for i in range(100001): batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, gen_loss, L1, adv, vgg, = sess.run( [g_opt, g_loss, L1_loss, g_loss_fake, vgg_loss], feed_dict={ x: batch_images_x, t: batch_images_t, lr: trans_lr, }) batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, dis_loss = sess.run([ d_opt, d_loss, ], feed_dict={ x: batch_images_x, t: batch_images_t, lr: trans_lr, }) batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, gen_loss, L1, adv, vgg, = sess.run( [g_opt, g_loss, L1_loss, g_loss_fake, vgg_loss], feed_dict={ x: batch_images_x, t: batch_images_t, lr: trans_lr, }) if trans_lr > 1e-5: trans_lr = trans_lr * 0.99998 print("in step %s, dis_loss = %.4e, gen_loss = %.4e" % (i, dis_loss, gen_loss)) print("L1_loss=%.4e, adv_loss=%.4e, vgg_loss=%.4e" % (L1, adv, vgg)) hist_g.append(gen_loss) hist_d.append(dis_loss) if i % 100 == 0: batch_images_x, batch_images_t = batchgen.getBatch(bs) out = sess.run(test_y, feed_dict={x: batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = (X_ + 1) * 127.5 X_ = cv2.resize(X_, (img_size * 4 * 2, img_size * 4 * 2), interpolation=cv2.INTER_CUBIC) Y_ = (Y_ + 1) * 127.5 Z_ = (Z_ + 1) * 127.5 Z_ = np.concatenate((X_, Y_, Z_), axis=1) cv2.imwrite("{}/{}.png".format(SAVEIM_DIR, i), Z_) fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist_g, label="gen_loss", linewidth=0.25) ax.plot(hist_d, label="dis_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist.png") plt.close() print("%.4f sec took per 100steps, lr = %.4e" % (time.time() - start, trans_lr)) start = time.time() if i % 5000 == 0 and i != 0: saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main(): if not os.path.exists(SAVE_DIR): os.mkdir(SAVE_DIR) if not os.path.exists(SVIM_DIR): os.mkdir(SVIM_DIR) img_size = [2**(i + 2) for i in range(9)] bs = [64, 64, 32, 32, 32, 16, 8, 4, 4] steps = [8000, 10000, 20000, 40000, 50000, 60000, 80000, 90000, 100000] z_dim = 512 lmd = 10 batch = BatchGenerator(img_size=256, datadir=DATASET_DIR) IN_ = batch.getBatch(4) IN_ = (IN_ + 1) * 127.5 IN_ = tileImage(IN_) cv2.imwrite("{}/input.png".format(SVIM_DIR), IN_) z = tf.placeholder(tf.float32, [None, 1, 1, z_dim]) X_real = [tf.placeholder(tf.float32, [None, r, r, 3]) for r in img_size] alpha = tf.placeholder(tf.float32, []) X_fake = [buildGenerator(z, alpha, stage=i + 1) for i in range(9)] fake_y = [ buildDiscriminator(x, alpha, stage=i + 1, reuse=False) for i, x in enumerate(X_fake) ] real_y = [ buildDiscriminator(x, alpha, stage=i + 1, reuse=True) for i, x in enumerate(X_real) ] #WGAN-GP xhats = [] d_xhats = [] for i, (real, fake) in enumerate(zip(X_real, X_fake)): epsilon = tf.random_uniform(shape=[tf.shape(real)[0], 1, 1, 1], minval=0.0, maxval=1.0) inter = real * epsilon + fake * (1 - epsilon) d_xhat = buildDiscriminator(inter, alpha, stage=i + 1, reuse=True) xhats.append(inter) d_xhats.append(d_xhat) g_losses, d_losses = calc_losses(real_y, fake_y, xhats, d_xhats) g_var = [x for x in tf.trainable_variables() if "Generator" in x.name] d_var = [x for x in tf.trainable_variables() if "Discriminator" in x.name] opt = tf.train.AdamOptimizer(learning_rate=1e-3, beta1=0.0, beta2=0.99, epsilon=1e-8) g_opt = [opt.minimize(g_loss, var_list=g_var) for g_loss in g_losses] d_opt = [opt.minimize(d_loss, var_list=d_var) for d_loss in d_losses] printParam(scope="Generator") printParam(scope="Discriminator") start = time.time() config = tf.ConfigProto(gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=0.75)) sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if ckpt: # checkpointがある場合 last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス print("load " + last_model) saver.restore(sess, last_model) # 変数データの読み込み print("succeed restore model") else: print("models were not found") init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) start = time.time() for stage in range(0, 9): batch = BatchGenerator(img_size=img_size[stage], datadir=DATASET_DIR) x_batch = batch.getBatch(bs[stage], alpha=1.0) out = tileImage(x_batch) out = np.array((out + 1) * 127.5, dtype=np.uint8) outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage + 1)) os.makedirs(outdir, exist_ok=True) dst = os.path.join(outdir, 'sample.png') cv2.imwrite(dst, out) g_hist = [] d_hist = [] print("starting stage{}".format(stage + 1)) for i in range(steps[stage] + 1): delta = 4 * i / (steps[stage]) if stage == 0: alp = 1.0 else: alp = min(delta, 1.0) x_batch = batch.getBatch(bs[stage], alpha=alp) z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512]) _, dis_loss = sess.run([d_opt[stage], d_losses[stage]], feed_dict={ X_real[stage]: x_batch, z: z_batch, alpha: alp }) z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512]) _, gen_loss = sess.run([g_opt[stage], g_losses[stage]], feed_dict={ z: z_batch, alpha: alp }) g_hist.append(gen_loss) d_hist.append(dis_loss) print("in step %s, dis_loss = %.4e, gen_loss = %.4e" % (i, dis_loss, gen_loss)) if i % 100 == 0: # save sample image z_batch = np.random.normal(0, 0.5, [bs[stage], 1, 1, 512]) out = X_fake[stage].eval(feed_dict={ z: z_batch, alpha: alp }, session=sess) out = tileImage(out) out = np.array((out + 1) * 127.5, dtype=np.uint8) outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage + 1)) os.makedirs(outdir, exist_ok=True) dst = os.path.join(outdir, '{}.png'.format('{0:09d}'.format(i))) cv2.imwrite(dst, out) # save loss graph fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") ax.plot(g_hist, label="gen_loss", linewidth=0.25) ax.plot(d_hist, label="dis_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig(os.path.join(outdir, "hist.png")) plt.close() if i % 5000 == 0 and i != 0: saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu img_size = 64 bs = 4 trans_lr = 1e-4 start = time.time() batchgen = BatchGenerator(img_size=img_size, LRDir=TRAIN_LR_DIR, HRDir=TRAIN_HR_DIR, aug=True) valgen = BatchGenerator(img_size=img_size, LRDir=VAL_LR_DIR, HRDir=VAL_HR_DIR, aug=False) #save samples IN_, OUT_ = batchgen.getBatch(4)[:4] print(IN_.shape) IN_ = tileImage(IN_) IN_ = cv2.resize(IN_, (img_size * 2 * 4, img_size * 2 * 4), interpolation=cv2.INTER_CUBIC) IN_ = (IN_ + 1) * 127.5 OUT_ = tileImage(OUT_) OUT_ = cv2.resize(OUT_, (img_size * 4 * 2, img_size * 4 * 2)) OUT_ = (OUT_ + 1) * 127.5 Z_ = np.concatenate((IN_, OUT_), axis=1) cv2.imwrite("input.png", Z_) print("%s sec took sampling" % (time.time() - start)) start = time.time() x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3]) lr = tf.placeholder(tf.float32) generator = Generator() y = generator.ThermalSR(x) test_y = generator.ThermalSR(x, reuse=True, isTraining=False) # L1 loss function L1_loss = tf.losses.absolute_difference(y, t) # Contextual loss function #vgg_real = build_vgg19(y) #vgg_fake = build_vgg19(t) # CX_loss_content_list = [w * CX_loss_helper(vgg_real[layer], vgg_fake[layer], config.CX) #for layer, w in config.CX.feat_content_layers.items()] #CX_content_loss = tf.reduce_sum(CX_loss_content_list) #CX_content_loss *= config.W.CX_content # ssim loss function ssim_ = tf.reduce_mean(tf.image.ssim(y, t, 2.0)) ssim_loss = 1 - ssim_ # Total loss function Total_loss = L1_loss + ssim_loss g_loss = tf.train.AdamOptimizer(1e-4, beta1=0.5).minimize( Total_loss, var_list=[ x for x in tf.trainable_variables() if "ThermalSR" in x.name ]) print("%.4f sec took building" % (time.time() - start)) printParam(scope="ThermalSR") g_vars = [x for x in tf.trainable_variables() if "ThermalSR" in x.name] saver = tf.train.Saver(max_to_keep=15) sess = tf.Session() sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(SAVEPRE_DIR) if ckpt: # is checkpoint exist last_model = ckpt.model_checkpoint_path #last_model = ckpt.all_model_checkpoint_paths[0] print("load " + last_model) saver.restore(sess, last_model) # read variable data print("succeed restore model") else: init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) hist = [] hist_g = [] start = time.time() print("start pretrain") for p in range(50001): batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, gen_loss, l1, ssim = sess.run( [g_loss, Total_loss, L1_loss, ssim_loss], feed_dict={ x: batch_images_x, t: batch_images_t }) hist.append(gen_loss) print("in step %s, pre_loss =%.4e, l1_loss=%.4e, ssim_loss=%.4e" % (p, gen_loss, l1, ssim)) if p % 100 == 0: batch_images_x, batch_images_t = valgen.getBatch(bs) out = sess.run(test_y, feed_dict={x: batch_images_x}) # out1 = (out + 1)*127.5 # target1 = (batch_images_t + 1)*127.5 # p, s = evaluate.test_images(out1, target1) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4), interpolation=cv2.INTER_CUBIC) X_ = (X_ + 1) * 127.5 Y_ = (Y_ + 1) * 127.5 Z_ = (Z_ + 1) * 127.5 ZZ_ = np.concatenate((X_, Y_, Z_), axis=1) cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR, int(p)), ZZ_) print("%.4e sec took 100steps" % (time.time() - start)) start = time.time() if p % 1000 == 0: fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist, label="gen_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist_pre.png") plt.close() if p % 5000 == 0 and p != 0: # batch_images_x1, batch_images_t1 = valgen.getBatch(50) # out1 = sess.run(test_y, feed_dict={x:batch_images_x1}) # batch_images_t1 = (batch_images_t1 + 1)*127.5 # out1 = (out1 + 1)*127.5 # p1, s1 = evaluate.test_images(batch_images_t1, out1) # print('PSNR: %.2f, SSIM: %.4f' %(p1, s1)) saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p)
def main(): if not os.path.exists(SAVE_DIR): os.mkdir(SAVE_DIR) if not os.path.exists(SVIM_DIR): os.mkdir(SVIM_DIR) img_size = [2**(i+2) for i in range(9)] #bs = [64, 48, 32, 24, 16, 12, 8, 4, 4] # PC has enough VRAM #bs = [48, 32, 24, 16, 12, 8, 4, 4, 4] bs = [16, 16, 16, 16, 12, 8, 4, 3, 2] #steps = [16000,24000,40000,64000,96000,128000,160000,200000,240000] steps = [1,16000,24000,40000,64000,96000,128000,192000,320000] #steps = [12000,28000,60000,120000,240000,360000,600000,960000,2160000] z_dim = 512 # save sample images batch = BatchGenerator(img_size=512,datadir=DATASET_DIR) IN_ = batch.getBatch(4) IN_ = (IN_ + 1)*127.5 IN_ =tileImage(IN_) cv2.imwrite("{}/input.png".format(SVIM_DIR),IN_) z = tf.placeholder(tf.float32, [None, z_dim]) X_real = [tf.placeholder(tf.float32, [None, r, r, 3]) for r in img_size] alpha = tf.placeholder(tf.float32, []) X_fake = [buildGenerator(z, alpha, stage=i+1) for i in range(9)] fake_y = [buildDiscriminator(x, alpha, stage=i+1, reuse=False) for i, x in enumerate(X_fake)] real_y = [buildDiscriminator(x, alpha, stage=i+1, reuse=True) for i, x in enumerate(X_real)] lr = tf.placeholder(tf.float32, []) """ #WGAN-gp xhats = [] d_xhats = [] for i, (real, fake) in enumerate(zip(X_real, X_fake)): epsilon = tf.random_uniform(shape=[tf.shape(real)[0], 1, 1, 1], minval=0.0, maxval=1.0) inter = real * epsilon + fake * (1 - epsilon) d_xhat = buildDiscriminator(inter, alpha, stage=i+1, reuse=True) xhats.append(inter) d_xhats.append(d_xhat) g_losses, d_losses = calc_losses(real_y, fake_y, xhats, d_xhats) """ # softplus g_losses = [] d_losses = [] for i, (real_images, real_logit, fake_logit) in enumerate(zip(X_real, real_y, fake_y)): r1_gamma = 10.0 # discriminator loss: gradient penalty d_loss_gan = tf.nn.softplus(fake_logit) + tf.nn.softplus(-real_logit) real_loss = tf.reduce_sum(real_logit) real_grads = tf.gradients(real_loss, [real_images])[0] r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3]) d_loss = d_loss_gan + r1_penalty * (r1_gamma * 0.5) d_loss = tf.reduce_mean(d_loss) # generator loss: logistic nonsaturating g_loss = tf.nn.softplus(-fake_logit) g_loss = tf.reduce_mean(g_loss) g_losses.append(g_loss) d_losses.append(d_loss) g_var = [x for x in tf.trainable_variables() if "Generator" in x.name] d_var = [x for x in tf.trainable_variables() if "Discriminator" in x.name] opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.0, beta2=0.99, epsilon=1e-8) g_opt = [opt.minimize(g_loss, var_list=g_var) for g_loss in g_losses] d_opt = [opt.minimize(d_loss, var_list=d_var) for d_loss in d_losses] printParam(scope="Generator") printParam(scope="Discriminator") start = time.time() config = tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.75)) sess =tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if ckpt: # checkpointがある場合 last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス print ("load " + last_model) saver.restore(sess, last_model) # 変数データの読み込み print("succeed restore model") else: print("models were not found") init = tf.global_variables_initializer() sess.run(init) print("%.4f sec took initializing"%(time.time()-start)) start = time.time() for stage in range(0,9): #batch = BatchGenerator(img_size=img_size[stage],datadir=DATASET_DIR) if stage<6: batch = BatchGenerator(img_size=img_size[stage],datadir="ffhq_dataset128") else: batch = BatchGenerator(img_size=img_size[stage],datadir="ffhq_dataset") #save samples x_batch = batch.getBatch(bs[stage],alpha=1.0) out = tileImage(x_batch) out = np.array((out + 1) * 127.5, dtype=np.uint8) outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage+1)) os.makedirs(outdir, exist_ok=True) dst = os.path.join(outdir, 'sample.png') cv2.imwrite(dst, out) trans_lr = 1e-3 g_hist = [] d_hist = [] print("starting stage{}".format(stage+1)) for i in range(steps[stage]+1): delta = 4*i/(steps[stage]+1) # First stage does not require interpolation if stage == 1 or stage == 2: alp = 1.0 else: alp = min(delta, 1.0) x_batch = batch.getBatch(bs[stage],alpha=alp) z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim]) _, dis_loss = sess.run([d_opt[stage], d_losses[stage]], feed_dict={X_real[stage]: x_batch, z: z_batch, alpha: alp, lr:trans_lr}) z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim]) _, gen_loss = sess.run([g_opt[stage], g_losses[stage]], feed_dict={z: z_batch, alpha: alp, lr:trans_lr}) g_hist.append(gen_loss) d_hist.append(dis_loss) print("stage:[%d], in step %s, dis_loss = %.3e, gen_loss = %.3e, alpha = %.3f, lr = %.3e" %(stage+1, i,dis_loss, gen_loss, alp, trans_lr)) if alp==1.0: #decaying learning rate trans_lr *= (1 - 2 / steps[stage]) if i%100 == 0: z_batch = np.random.normal(0, 0.5, [bs[stage], z_dim]) out = X_fake[stage].eval(feed_dict={z: z_batch, alpha: alp}, session=sess) out = tileImage(out) out = np.array((out + 1) * 127.5, dtype=np.uint8) outdir = os.path.join(SVIM_DIR, 'stage{}'.format(stage+1)) os.makedirs(outdir, exist_ok=True) dst = os.path.join(outdir, '{}_alp.png'.format('{0:09d}'.format(i))) cv2.imwrite(dst, out) fig = plt.figure(figsize=(8,6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(g_hist,label="gen_loss", linewidth = 0.25) ax.plot(d_hist,label="dis_loss", linewidth = 0.25) plt.xlabel('step', fontsize = 16) plt.ylabel('loss', fontsize = 16) plt.legend(loc = 'upper right') plt.savefig(os.path.join(outdir,"hist.png")) plt.close() if i % 8000 == 0 and i!=0: saver.save(sess,os.path.join(SAVE_DIR,"model.ckpt"),i)
def main(): if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) if not os.path.exists(SVIM_DIR): os.makedirs(SVIM_DIR) img_size = 256 bs = 16 dir = DATASET_DIR val = VAL_DIR datalen = foloderLength(DATASET_DIR) vallen = foloderLength(VAL_DIR) # loading images on training batch = BatchGenerator(img_size=img_size, datadir=dir) val = BatchGenerator(img_size=img_size, datadir=val) id = np.random.choice(range(datalen), bs) IN_ = tileImage(batch.getBatch(bs, id)[:4]) IN_ = (IN_ + 1) * 127.5 cv2.imwrite("input.png", IN_) start = time.time() x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) t = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) y = buildGenerator(x, nBatch=bs) loss = loss_g(y, t) printParam(scope="generator") train_step = training(loss) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) saver = tf.train.Saver() summary = tf.summary.merge_all() ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if ckpt: # checkpointがある場合 last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス print("load " + last_model) saver.restore(sess, last_model) # 変数データの読み込み print("succeed restore model") else: print("models were not found") init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) hist = [] start = time.time() for i in range(100000): # loading images on training id = np.random.choice(range(datalen), bs) batch_images_x = batch.getBatch(bs, id, ocp=0.5) batch_images_t = batch.getBatch(bs, id, ocp=0.5) tmp, yloss = sess.run([train_step, loss], feed_dict={ x: batch_images_x, t: batch_images_t }) print("in step %s loss = %.4e" % (i, yloss)) hist.append(yloss) if i % 100 == 0: id = np.random.choice(range(vallen), bs) batch_images_x = val.getBatch(bs, id, ocp=0.5) out = sess.run(y, feed_dict={x: batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) X_ = (X_ + 1) * 127.5 Y_ = (Y_ + 1) * 127.5 Z_ = np.concatenate((X_, Y_), axis=1) #print(np.max(X_)) cv2.imwrite("{}/{}.png".format(SVIM_DIR, i), Z_) fig = plt.figure() ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist, label="test", linewidth=0.5) plt.savefig("hist.png") plt.close() print("%.4e sec took per 100steps" % (time.time() - start)) start = time.time() if i % 1000 == 0: if i > 1900: loss_1k_old = np.mean(hist[-2000:-1000]) loss_1k_new = np.mean(hist[-1000:]) print("old loss=%.4e , new loss=%.4e" % (loss_1k_old, loss_1k_new)) if loss_1k_old * 2 < loss_1k_new: break saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main(): if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) if not os.path.exists(SVIM_DIR): os.makedirs(SVIM_DIR) img_size = 256 bs = 4 lr = tf.placeholder(tf.float32) lmd = tf.placeholder(tf.float32) trans_lr = 2e-4 trans_lmd = 10 max_step = 100000 datalen = foloderLength(DATASET_DIR) vallen = foloderLength(VAL_DIR) # loading images on training batch = BatchGenerator(img_size=img_size, datadir=DATASET_DIR) val = BatchGenerator(img_size=img_size, datadir=VAL_DIR) id = np.random.choice(range(datalen), bs) IN_, OUT_ = batch.getBatch(bs, id)[:4] IN_ = (IN_ + 1) * 127.5 IN_ = tileImage(IN_) OUT_ = (OUT_ + 1) * 127.5 OUT_ = tileImage(OUT_) Z_ = np.concatenate([IN_, OUT_], axis=1) cv2.imwrite("input.png", Z_) x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) t = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) y = buildGenerator(x) fake_y = buildDiscriminator(x, y, isTraining=True, nBatch=bs) real_y = buildDiscriminator(x, t, reuse=True, isTraining=True, nBatch=bs) # sce gan d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=real_y, labels=tf.ones_like(real_y))) d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_y, labels=tf.zeros_like(fake_y))) g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_y, labels=tf.ones_like(fake_y))) # ls gan #d_loss_real = tf.reduce_mean((real_y-tf.ones_like (real_y))**2) #d_loss_fake = tf.reduce_mean((fake_y-tf.zeros_like (fake_y))**2) #g_loss = tf.reduce_mean((fake_y-tf.ones_like (fake_y))**2) #variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g') wd_g = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Generator") wd_d = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope="Discriminator") wd_g = tf.reduce_sum(wd_g) wd_d = tf.reduce_sum(wd_d) L1_loss = tf.reduce_mean(tf.abs(y - t)) d_loss = d_loss_real + d_loss_fake + wd_d g_loss = g_loss + lmd * L1_loss + wd_g #L2_loss = tf.nn.l2_loss(y-t) pre_loss = lmd * L1_loss + wd_g #g_pre = tf.train.AdamOptimizer(1e-3,beta1=0.5).minimize(pre_loss, var_list=[x for x in tf.trainable_variables() if "generator" in x.name]) g_opt = tf.train.AdamOptimizer(lr, beta1=0.5).minimize( g_loss, var_list=[ x for x in tf.trainable_variables() if "Generator" in x.name ]) d_opt = tf.train.AdamOptimizer(lr / 5, beta1=0.5).minimize( d_loss, var_list=[ x for x in tf.trainable_variables() if "Discriminator" in x.name ]) total_parameters = 0 printParam(scope="Generator") printParam(scope="Discriminator") start = time.time() config = tf.ConfigProto(gpu_options=tf.GPUOptions( per_process_gpu_memory_fraction=0.66)) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() summary = tf.summary.merge_all() ckpt = tf.train.get_checkpoint_state('model') if ckpt: # checkpointがある場合 last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス print("load " + last_model) saver.restore(sess, last_model) # 変数データの読み込み print("succeed restore model") else: print("models were not found") init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) hist = [] g_hist = [] d_hist = [] start = time.time() """ for p in range(10000): id = np.random.choice(range(datalen),bs) batch_images_x, batch_images_t = batch.getBatch(bs,id) tmp, gen_loss = sess.run([g_pre,pre_loss], feed_dict={ x: batch_images_x, t: batch_images_t }) hist.append(gen_loss) print("in step %s, pre_loss =%.4e" %(p, gen_loss)) if p % 100 == 0: out = sess.run(y,feed_dict={ x:batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = (X_ + 1)*127.5 Y_ = (Y_ + 1)*127.5 Z_ = (Z_ + 1)*127.5 Z_ = np.concatenate((X_,Y_,Z_), axis=1) #print(np.max(X_)) cv2.imwrite("pre{}.png".format(p),Z_) fig = plt.figure() ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist,label="gen_loss") plt.xlabel('x{} step'.format(100), fontsize = 16) plt.ylabel('loss', fontsize = 16) plt.legend(loc = 'upper right') plt.savefig("histL2.png") plt.close() print("%.4e sec took 1000steps" %(time.time()-start)) """ for i in range(100001): # loading images on training id = np.random.choice(range(datalen), bs) batch_images_x, batch_images_t = batch.getBatch(bs, id) tmp, dis_loss = sess.run([ d_opt, d_loss, ], feed_dict={ x: batch_images_x, t: batch_images_t, lr: trans_lr, lmd: trans_lmd }) tmp, gen_loss, l1 = sess.run([g_opt, g_loss, L1_loss], feed_dict={ x: batch_images_x, t: batch_images_t, lr: trans_lr, lmd: trans_lmd }) """ id = np.random.choice(range(datalen),bs) batch_images_x, batch_images_t = batch.getBatch(bs,id,ocp=0.1) tmp, gen_loss, l1 = sess.run([g_opt,g_loss, L1_loss], feed_dict={ x: batch_images_x, t: batch_images_t, lr:trans_lr, lmd:trans_lmd }) """ if trans_lr > 5e-5: trans_lr = trans_lr * 0.99998 if trans_lmd > 5: trans_lmd = trans_lmd * 0.9998 print("in step %s, dis_loss = %.4e, gen_loss = %.4e, l1_loss= %.4e" % (i, dis_loss, gen_loss, l1 * trans_lmd)) g_hist.append(gen_loss) d_hist.append(dis_loss) if i % 100 == 0: id = np.random.choice(range(vallen), bs) batch_images_x, batch_images_t = val.getBatch(bs, id) out = sess.run(y, feed_dict={x: batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = (X_ + 1) * 127.5 Y_ = (Y_ + 1) * 127.5 Z_ = (Z_ + 1) * 127.5 Z_ = np.concatenate((X_, Y_, Z_), axis=1) #print(np.max(X_)) cv2.imwrite("{}/{}.png".format(SVIM_DIR, i), Z_) fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(g_hist, label="gen_loss", linewidth=0.25) ax.plot(d_hist, label="dis_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist.png") plt.close() print("%.4f sec took per 100steps lmd = %.4e, lr = %.4e" % (time.time() - start, trans_lmd, trans_lr)) start = time.time() if i % 5000 == 0: if i > 10000: loss_1k_old = np.mean(g_hist[-2000:-1000]) loss_1k_new = np.mean(g_hist[-1000:]) print("old loss=%.4e , new loss=%.4e" % (loss_1k_old, loss_1k_new)) if loss_1k_old * 2 < loss_1k_new: break saver.save(sess, os.path.join(SAVE_DIR, "model.ckpt"), i)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu img_size = 64 bs = 4 trans_lr = 1e-4 start = time.time() batchgen = BatchGenerator(img_size=img_size, LRDir=TRAIN_LR_DIR, HRDir=TRAIN_HR_DIR, aug=True) valgen = BatchGenerator(img_size=img_size, LRDir=VAL_LR_DIR, HRDir=VAL_HR_DIR, aug=False) start = time.time() x = tf.placeholder(tf.float32, [bs, img_size, img_size, 3]) t = tf.placeholder(tf.float32, [bs, img_size * 4, img_size * 4, 3]) lr = tf.placeholder(tf.float32) generator = Generator() y = generator.ThermalSR(x) test_y = generator.ThermalSR(x, reuse=True, isTraining=False) # Contextual loss function vgg_real34, vgg_real54 = build_vgg19(t) vgg_fake34, vgg_fake54 = build_vgg19(y) #vgg_loss = 0.006*(tf.reduce_mean(tf.reduce_mean(tf.square(vgg_real54 - vgg_fake54)))) CX_loss_content_list = CX_loss_helper(vgg_real34, vgg_fake34, config.CX) CX_content_loss = tf.reduce_sum(CX_loss_content_list) CX_content_loss *= config.W.CX_content L1_loss = tf.losses.absolute_difference(y, t) ssim_loss = tf.reduce_mean(tf.image.ssim(y, t, 2.0)) ssim_loss1 = 1 - ssim_loss Total_loss = 10 * L1_loss + 10 * ssim_loss1 + 0.1 * CX_content_loss g_loss = tf.train.AdamOptimizer(1e-4, beta1=0.9).minimize( Total_loss, var_list=[ x for x in tf.trainable_variables() if "ThermalSR" in x.name ]) print("%.4f sec took building" % (time.time() - start)) printParam(scope="ThermalSR") g_vars = [x for x in tf.trainable_variables() if "ThermalSR" in x.name] saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(SAVEPRE_DIR) if ckpt: # is checkpoint exist last_model = ckpt.model_checkpoint_path #last_model = ckpt.all_model_checkpoint_paths[0] print("load " + last_model) saver.restore(sess, last_model) # read variable data print("succeed restore model") else: init = tf.global_variables_initializer() sess.run(init) print("%.4e sec took initializing" % (time.time() - start)) hist = [] hist_g = [] start = time.time() print("start pretrain") for p in range(50001): batch_images_x, batch_images_t = batchgen.getBatch(bs) tmp, gen_loss, l1, ssim, cx = sess.run( [g_loss, Total_loss, L1_loss, ssim_loss, CX_content_loss], feed_dict={ x: batch_images_x, t: batch_images_t }) hist.append(gen_loss) print( "in step %s, pre_loss =%.4e, l1_loss=%.4e, ssim_loss=%.4e, cx_loss=%.4e" % (p, gen_loss, l1, ssim, cx)) if p % 100 == 0: batch_images_x, batch_images_t = valgen.getBatch(bs) out = sess.run(test_y, feed_dict={x: batch_images_x}) X_ = tileImage(batch_images_x[:4]) Y_ = tileImage(out[:4]) Z_ = tileImage(batch_images_t[:4]) X_ = cv2.resize(X_, (img_size * 2 * 4, img_size * 2 * 4), interpolation=cv2.INTER_CUBIC) X_ = (X_ + 1) * 127.5 Y_ = (Y_ + 1) * 127.5 Z_ = (Z_ + 1) * 127.5 ZZ_ = np.concatenate((X_, Y_, Z_), axis=1) #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_lr,int(p)),X_) #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_sr,int(p)),Y_) #cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR_hr,int(p)),Z_) cv2.imwrite("{0}/pre_{1:06d}.png".format(SAVEIM_DIR, int(p)), ZZ_) print("%.4e sec took 100steps" % (time.time() - start)) start = time.time() if p % 1000 == 0: fig = plt.figure(figsize=(8, 6), dpi=128) ax = fig.add_subplot(111) plt.title("Loss") plt.grid(which="both") plt.yscale("log") ax.plot(hist, label="gen_loss", linewidth=0.25) plt.xlabel('step', fontsize=16) plt.ylabel('loss', fontsize=16) plt.legend(loc='upper right') plt.savefig("hist_pre_ThermalSR_Axis.png") plt.close() if p % 5000 == 0 and p != 0: saver.save(sess, os.path.join(SAVEPRE_DIR, "model.ckpt"), p)