def run(data, checkpoint_dir, eval_interval_secs, min_global_step, num_eval_examples): """Runs evaluation in a loop. Args: data: a pointer to teh MNIST data checkpoint_dir: Directory containing model checkpoints. eval_interval_secs: Interval between consecutive evaluations. min_global_step: Number of steps until the first evaluation. num_eval_examples: Number of examples to run the evaluation on. """ g = tf.Graph() with g.as_default(): # Build the model for evaluation. model_config = configuration.ModelConfig() the_model = model.DAE(model_config) the_model.build() # Create the Saver to restore model Variables. saver = tf.train.Saver() g.finalize() # Run a new evaluation run every eval_interval_secs. while True: start = time.time() # Run evaluation. run_once(data, the_model, saver, checkpoint_dir, min_global_step, num_eval_examples) time_to_next_eval = start + eval_interval_secs - time.time() # Wait until the time to next evaluation elapses if time_to_next_eval > 0: time.sleep(time_to_next_eval)
def main(unused_argv): # Parse arguments. parser = argparse.ArgumentParser() args = parse_arguments(parser) # Model configuration. model_config = configuration.ModelConfig() training_config = configuration.TrainingConfig() # Create training directory. train_dir = args.train_dir if not tf.gfile.IsDirectory(train_dir): tf.logging.info("Creating training directory: %s", train_dir) tf.gfile.MakeDirs(train_dir) # Load MNIST data. mnist = input_data.read_data_sets('MNIST') # Build the TensorFlow graph. g = tf.Graph() with g.as_default(): # Build the model. the_model = model.DAE(model_config) the_model.build() # Set up the learning rate. learning_rate = tf.constant(training_config.learning_rate) # Set up the training ops. train_op = tf.contrib.layers.optimize_loss( loss=the_model.total_loss, global_step=the_model.global_step, learning_rate=learning_rate, optimizer=training_config.optimizer) # Set up the Saver for saving and restoring model checkpoints. saver = tf.train.Saver() # Run training. print("Training") with tf.Session() as sess: print("Initializing parameters") sess.run(tf.global_variables_initializer()) for step in range(1, args.number_of_steps): # Read batch. batch = mnist.train.next_batch(model_config.batch_size)[0] # Create a noisy version of the batch. noisy_batch = utils.add_noise(batch) # Prepare the dictionnary to feed the data to the graph. feed_dict = { "images:0": batch, "noisy_images:0": noisy_batch, "phase_train:0": True } # Run training _, loss = sess.run([train_op, the_model.total_loss], feed_dict=feed_dict) if step % 50 == 0: # Save checkpoint. ave_path = saver.save(sess, train_dir + '/model.ckpt') # Print Loss. print("Step:", '%06d' % (step), "cost=", "{:.9f}".format(loss)) print('Finished training ...') print('Start testing ...') # load batch. testing_data = mnist.test.images # Plot the Original Image # Plot the Denoised Image # Create a noisy version of the data. corrupted_testing = utils.add_noise(testing_data) ori_plot = corrupted_testing[:10] count = 1 for img in ori_plot: name = 'ori_img' + str(count) path = 'img/' + name count += 1 plot_image(img.reshape((28, 28)), name, path) # Prepare the dictionnary to feed the data to the graph. feed_dict = { "images:0": testing_data, "noisy_images:0": corrupted_testing, "phase_train:0": False } # Compute the loss reconstruc, loss = sess.run( [the_model.reconstructed_images, the_model.total_loss], feed_dict=feed_dict) ori_plot = reconstruc[:10] count = 1 for img in ori_plot: name = 'de_img' + str(count) path = 'img/' + name count += 1 plot_image(img.reshape((28, 28)), name, path) print(loss) print("Testing loss= ", loss)