def main(argv=None): keep_probability = tf.placeholder( tf.float32, name="keep_probabilty") #Dropout probability Sparse_Sampled_Image = tf.placeholder( tf.float32, shape=[None, None, None, 3], name="input_Sparse_image") #Input image sparsly sampled image Full_Image = tf.placeholder( tf.float32, shape=[None, None, None, 3], name="Full_image") # Full image all pixels filled ReconstructImage = BuildNet.inference( Sparse_Sampled_Image, keep_probability, 3, Vgg_Model_Dir) # Here the graph(net) is builded loss = tf.reduce_mean( tf.abs(ReconstructImage - Full_Image, name="L1_Loss") ) # Define loss function for training as the difference between reconstruct image and ground truth image # tf.summary.scalar("L1_Loss", loss) trainable_var = tf.trainable_variables() train_op = train(loss, trainable_var) #print("Setting up summary op...") #summary_op = tf.summary.merge_all() print("Reading images list") TrainImages = [] #Train Image List TrainImages += [ each for each in os.listdir(Train_Image_Dir) if each.endswith('.PNG') or each.endswith('.JPG') or each.endswith( '.TIF') or each.endswith('.GIF') or each.endswith('.png') or each.endswith('.jpg') or each.endswith('.tif') or each.endswith('.gif') ] # Get list of training images print('Number of Train images=' + str(len(TrainImages))) #-------------------------Training Region----------------------------------------------------------------------------------------------------------------------------- sess = tf.Session() #Start Tensorflow session print("Setting up Saver...") saver = tf.train.Saver() # summary_writer = tf.summary.FileWriter(logs_dir, sess.graph) sess.run(tf.global_variables_initializer()) #sess.run(tf.initialize_all_variables()) ckpt = tf.train.get_checkpoint_state(logs_dir) if ckpt and ckpt.model_checkpoint_path: # if trained model exist restore it saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") #---------------------------Start Training: Create loss files---------------------------------------------------------------------------------------------------------- Nimg = 0 f = open( TrainLossTxtFile, "w") # Create text file for writing the loss trough out the training f.write("Iteration\tTrain_Loss\t Learning Rate=" + str(learning_rate)) f.close() #----------------------------------------------------------------------------------------------------------------- Epoch = 0 #..............Start Training loop: Main Training.................................................................... for itr in range(1, MAX_ITERATION + 1): if Nimg >= len(TrainImages) - 1: # End of an epoch Nimg = 0 random.shuffle(TrainImages) #Suffle images every epoch Epoch += 1 print("Epoch " + str(Epoch) + " Completed") #.....................Load images for training batch_size = np.min([Batch_Size, len(TrainImages) - Nimg]) FullImages = np.zeros([batch_size, Im_Hight, Im_Width, 3], dtype=np.int) SparseSampledImages = np.zeros([batch_size, Im_Hight, Im_Width, 3], dtype=np.int) for fi in range(batch_size): FullImages[fi], SparseSampledImages[fi] = ImageReader.LoadImages( Train_Image_Dir + TrainImages[Nimg], Im_Hight, Im_Width, SamplingRate) Nimg += 1 #.......................Run one batch of training............................................................................... feed_dict = { Sparse_Sampled_Image: SparseSampledImages, Full_Image: FullImages, keep_probability: 0.4 + np.random.rand() * 0.6 } # Run one cycle of traning sess.run(train_op, feed_dict=feed_dict) #......................Write training set loss.......................................................................... if itr % 10 == 0: feed_dict = { Sparse_Sampled_Image: SparseSampledImages, Full_Image: FullImages, keep_probability: 1 } train_loss = sess.run(loss, feed_dict=feed_dict) print("Step: %d, Train_loss:%g " % (itr, train_loss)) # summary_writer.add_summary(summary_str, itr) with open(TrainLossTxtFile, "a") as f: #Write training loss for file f.write("\n" + str(itr) + "\t" + str(train_loss)) f.close() #....................Save Trained net (ones every 1000 training cycles............................................... if itr % 200 == 0: print("Saving Model") saver.save(sess, logs_dir + "model.ckpt", itr) # save trained model
def main(argv=None): keep_probability = tf.placeholder( tf.float32, name="keep_probabilty") #Dropout probability Sparse_Sampled_Image = tf.placeholder( tf.float32, shape=[None, None, None, 3], name="input_Sparse_image") #Input image sparsly sampled image ReconstructImage = BuildNet.inference( Sparse_Sampled_Image, keep_probability, 3, Vgg_Model_Dir) # Here the graph(net) is builded print("Reading images list") #---------------------Read list of image for recostruction------------------------------------------------------------ Images = [] #Train Image List Images += [ each for each in os.listdir(Image_Dir) if each.endswith('.PNG') or each.endswith('.JPG') or each.endswith( '.TIF') or each.endswith('.GIF') or each.endswith('.png') or each.endswith('.jpg') or each.endswith('.tif') or each.endswith('.gif') ] # Get list of training images print('Number of images=' + str(len(Images))) #-------------------------Load trained mode---------------------------------------------------------------------------------------------------------------------------- sess = tf.Session() #Start Tensorflow session print("Setting up Saver...") saver = tf.train.Saver() # summary_writer = tf.summary.FileWriter(logs_dir, sess.graph) sess.run(tf.global_variables_initializer()) #sess.run(tf.initialize_all_variables()) ckpt = tf.train.get_checkpoint_state(logs_dir) if ckpt and ckpt.model_checkpoint_path: # if trained model exist restore it saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") else: print("Error no trained model found in log dir " + logs_dir + "For creating trained model see: Train.py") return #..............Start image reconstruction.................................................................... for itr in range(len(Images)): #.....................Load images for prediction------------------------------------- print(str(itr) + ") Reconstructing: " + Image_Dir + Images[itr]) FullImage, SparseSampledImage = ImageReader.LoadImages( Image_Dir + Images[itr], 0, 0, SamplingRate) #.......................Run one prediction............................................................................... feed_dict = { Sparse_Sampled_Image: SparseSampledImage, keep_probability: 1 } # Run one cycle of traning ReconImage = sess.run( ReconstructImage, feed_dict=feed_dict) # run image reconstruction using network #......................Save image.......................................................................... #ReconImage[ReconImage>255]=255 #ReconImage[ReconImage<0]=0 misc.imsave( OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Reconstructed" + Images[itr][-4:], ReconImage[0]) misc.imsave( OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Original" + Images[itr][-4:], FullImage[0]) misc.imsave( OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Sampled" + Images[itr][-4:], SparseSampledImage[0])