def main(argv=None):
    keep_probability = tf.placeholder(
        tf.float32, name="keep_probabilty")  #Dropout probability
    Sparse_Sampled_Image = tf.placeholder(
        tf.float32, shape=[None, None, None, 3],
        name="input_Sparse_image")  #Input image sparsly sampled image
    Full_Image = tf.placeholder(
        tf.float32, shape=[None, None, None, 3],
        name="Full_image")  # Full image all pixels filled

    ReconstructImage = BuildNet.inference(
        Sparse_Sampled_Image, keep_probability, 3,
        Vgg_Model_Dir)  # Here the graph(net) is builded
    loss = tf.reduce_mean(
        tf.abs(ReconstructImage - Full_Image, name="L1_Loss")
    )  # Define loss function for training as the difference between reconstruct image and ground truth image

    # tf.summary.scalar("L1_Loss", loss)

    trainable_var = tf.trainable_variables()
    train_op = train(loss, trainable_var)

    #print("Setting up summary op...")
    #summary_op = tf.summary.merge_all()

    print("Reading images list")
    TrainImages = []  #Train Image List

    TrainImages += [
        each for each in os.listdir(Train_Image_Dir)
        if each.endswith('.PNG') or each.endswith('.JPG') or each.endswith(
            '.TIF') or each.endswith('.GIF') or each.endswith('.png') or
        each.endswith('.jpg') or each.endswith('.tif') or each.endswith('.gif')
    ]  # Get list of training images

    print('Number of  Train images=' + str(len(TrainImages)))

    #-------------------------Training Region-----------------------------------------------------------------------------------------------------------------------------

    sess = tf.Session()  #Start Tensorflow session

    print("Setting up Saver...")
    saver = tf.train.Saver()
    # summary_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    sess.run(tf.global_variables_initializer())
    #sess.run(tf.initialize_all_variables())
    ckpt = tf.train.get_checkpoint_state(logs_dir)
    if ckpt and ckpt.model_checkpoint_path:  # if trained model exist restore it
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")
#---------------------------Start Training: Create loss files----------------------------------------------------------------------------------------------------------
    Nimg = 0
    f = open(
        TrainLossTxtFile,
        "w")  # Create text file for writing the loss trough out the training
    f.write("Iteration\tTrain_Loss\t Learning Rate=" + str(learning_rate))
    f.close()
    #-----------------------------------------------------------------------------------------------------------------
    Epoch = 0
    #..............Start Training loop: Main Training....................................................................
    for itr in range(1, MAX_ITERATION + 1):
        if Nimg >= len(TrainImages) - 1:  # End of an epoch
            Nimg = 0
            random.shuffle(TrainImages)  #Suffle images every epoch
            Epoch += 1
            print("Epoch " + str(Epoch) + " Completed")
#.....................Load images for training
        batch_size = np.min([Batch_Size, len(TrainImages) - Nimg])
        FullImages = np.zeros([batch_size, Im_Hight, Im_Width, 3],
                              dtype=np.int)
        SparseSampledImages = np.zeros([batch_size, Im_Hight, Im_Width, 3],
                                       dtype=np.int)
        for fi in range(batch_size):
            FullImages[fi], SparseSampledImages[fi] = ImageReader.LoadImages(
                Train_Image_Dir + TrainImages[Nimg], Im_Hight, Im_Width,
                SamplingRate)
            Nimg += 1

#.......................Run one batch of training...............................................................................
        feed_dict = {
            Sparse_Sampled_Image: SparseSampledImages,
            Full_Image: FullImages,
            keep_probability: 0.4 + np.random.rand() * 0.6
        }  # Run one cycle of traning
        sess.run(train_op, feed_dict=feed_dict)
        #......................Write training set loss..........................................................................
        if itr % 10 == 0:
            feed_dict = {
                Sparse_Sampled_Image: SparseSampledImages,
                Full_Image: FullImages,
                keep_probability: 1
            }
            train_loss = sess.run(loss, feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g " % (itr, train_loss))
            #  summary_writer.add_summary(summary_str, itr)
            with open(TrainLossTxtFile,
                      "a") as f:  #Write training loss for file
                f.write("\n" + str(itr) + "\t" + str(train_loss))
                f.close()


#....................Save Trained net (ones every 1000 training cycles...............................................
        if itr % 200 == 0:
            print("Saving Model")
            saver.save(sess, logs_dir + "model.ckpt",
                       itr)  # save trained model
def main(argv=None):
    keep_probability = tf.placeholder(
        tf.float32, name="keep_probabilty")  #Dropout probability
    Sparse_Sampled_Image = tf.placeholder(
        tf.float32, shape=[None, None, None, 3],
        name="input_Sparse_image")  #Input image sparsly sampled image

    ReconstructImage = BuildNet.inference(
        Sparse_Sampled_Image, keep_probability, 3,
        Vgg_Model_Dir)  # Here the graph(net) is builded

    print("Reading images list")
    #---------------------Read list of image for recostruction------------------------------------------------------------
    Images = []  #Train Image List

    Images += [
        each for each in os.listdir(Image_Dir)
        if each.endswith('.PNG') or each.endswith('.JPG') or each.endswith(
            '.TIF') or each.endswith('.GIF') or each.endswith('.png') or
        each.endswith('.jpg') or each.endswith('.tif') or each.endswith('.gif')
    ]  # Get list of training images

    print('Number of images=' + str(len(Images)))

    #-------------------------Load trained mode----------------------------------------------------------------------------------------------------------------------------

    sess = tf.Session()  #Start Tensorflow session

    print("Setting up Saver...")
    saver = tf.train.Saver()
    # summary_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    sess.run(tf.global_variables_initializer())
    #sess.run(tf.initialize_all_variables())
    ckpt = tf.train.get_checkpoint_state(logs_dir)
    if ckpt and ckpt.model_checkpoint_path:  # if trained model exist restore it
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")
    else:
        print("Error no trained model found in log dir " + logs_dir +
              "For creating trained model see: Train.py")
        return


#..............Start image reconstruction....................................................................
    for itr in range(len(Images)):
        #.....................Load images for prediction-------------------------------------
        print(str(itr) + ") Reconstructing: " + Image_Dir + Images[itr])
        FullImage, SparseSampledImage = ImageReader.LoadImages(
            Image_Dir + Images[itr], 0, 0, SamplingRate)

        #.......................Run one  prediction...............................................................................
        feed_dict = {
            Sparse_Sampled_Image: SparseSampledImage,
            keep_probability: 1
        }  # Run one cycle of traning
        ReconImage = sess.run(
            ReconstructImage,
            feed_dict=feed_dict)  # run image reconstruction using network
        #......................Save image..........................................................................
        #ReconImage[ReconImage>255]=255
        #ReconImage[ReconImage<0]=0
        misc.imsave(
            OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Reconstructed" +
            Images[itr][-4:], ReconImage[0])
        misc.imsave(
            OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Original" +
            Images[itr][-4:], FullImage[0])
        misc.imsave(
            OUTPUT_Dir + "/" + Images[itr][0:-4] + "_Sampled" +
            Images[itr][-4:], SparseSampledImage[0])