def Train(X_images, Y_images, test_images, learning_rate, epochs, batch_size): current_batch_size = 1 input_h = np.shape(X_images)[1] input_w = np.shape(X_images)[2] output_h = np.shape(Y_images)[1] output_w = np.shape(Y_images)[2] X_train = tf.placeholder(tf.float16, shape=(None, input_h, input_w, 1), name='X') Y_train = tf.placeholder(tf.float16, shape=(None, output_h, output_w, 1), name='Y') conv1 = FeatureExtraction(X_train) conv2 = Shrinking(conv1) conv3 = NonLinearMapping(conv2) conv4 = Expanding(conv3) Y_predict = Deconvolution(conv4, current_batch_size, output_h, output_w) Y_predict = tf.identity(Y_predict, "Y_predict") # Define loss function and optimizer loss = tf.image.ssim(Y_predict, Y_train, max_val=2.0) optimizer = tf.train.AdamOptimizer(learning_rate) train_step = optimizer.minimize(-loss) init = tf.global_variables_initializer() saver = tf.train.Saver() config = tf.ConfigProto() # use GPU0 config.gpu_options.visible_device_list = '1' # allocate 50% of GPU memory config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.9 with tf.Session(config=config) as sess: print('hello world') sess.run(init) start_index = 0 end_index = batch_size total = np.shape(X_images)[0] feed_all = False iter_count = 0 for epoch in range(epochs): X_images, Y_images = Image.Shuffle(X_images, Y_images) while True: X_batch, Y_batch = X_images[start_index:end_index], Y_images[ start_index:end_index] current_batch_size = end_index - start_index sess.run(train_step, feed_dict={ X_train: X_batch, Y_train: Y_batch }) start_index += batch_size end_index += batch_size if epoch % 20 == 0: loss_value = np, mean( loss.eval(feed_dict={ X_train: X_batch, Y_train: Y_batch })) print('In epoch: ' + str(epoch) + ' iteration = ' + str(iter_count) + ', loss = ' + str(loss_value)) if end_index >= total: end_index = total if start_index >= total: start_index = 0 end_index = batch_size break iter_count += 1 if epoch % 200 == 0 and epoch > 1: saver.save(sess, './model-' + str(epoch) + '.ckpt') saver.save(sess, './model-final' + '.ckpt') sess.close()
import tensorflow as tf import numpy as np from skimage.io import imread, imsave from Image import Image import SuperResolution # np.set_printoptions(threshold=np.nan) batch_size = 30 dataset_size = 1 model_path = './Models/128_32_300_30_100_ssim/' X_grey = Image.LoadTrainingGreyImage(dataset_size, './Training/X2_grey/') Y_grey = Image.LoadTrainingGreyImage(dataset_size, './Training/HR_grey/') X_norm = Image.Normalize(X_grey) Y_norm = Image.Normalize(Y_grey) X_cropped = Image.Segment(X_norm, 256) Y_cropped = Image.Segment(Y_norm, 512) # print(X_cropped[0]) X_shuffle, Y_shuffle = Image.Shuffle(X_cropped, Y_cropped) X_final = Image.ExpandDims(X_shuffle) Y_final = Image.ExpandDims(Y_shuffle) Image.SaveOutput([Y_final[0]], './y.png') # print(Y_final[0]) SuperResolution.Test(X_final[0], model_path, './')